1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
17#define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
18
19#include <map>
20#include <set>
21#include <vector>
22
23#include "tensorflow/compiler/xla/literal_util.h"
24#include "tensorflow/compiler/xla/service/shaped_buffer.h"
25#include "tensorflow/compiler/xla/statusor.h"
26#include "tensorflow/compiler/xla/types.h"
27#include "tensorflow/compiler/xla/xla_data.pb.h"
28#include "tensorflow/core/lib/gtl/array_slice.h"
29#include "tensorflow/core/platform/mutex.h"
30#include "tensorflow/core/platform/stream_executor_no_cuda.h"
31#include "tensorflow/core/platform/thread_annotations.h"
32#include "tensorflow/core/platform/types.h"
33
34namespace xla {
35
36// The TransferManager interface lets backends provide platform-specific
37// mechanisms for constructing literals from given device memory handles.
38// This lets each platform customize how literals are transferred to/from the
39// device in terms of padding, leading dimension, etc.
40class TransferManager {
41 public:
42  virtual ~TransferManager() {}
43
44  // Returns the ID of the platform that this transfer manager acts on.
45  virtual perftools::gputools::Platform::Id PlatformId() const = 0;
46
47  // Returns the shape of the on-device representation for the given shape on
48  // the host. This is intended for use with ShapedBuffer where buffers are
49  // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user
50  // needing to consider device-specific behaviors.
51  virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const {
52    return host_shape;
53  }
54
55  // Returns a literal containing the data held in the given ShapedBuffer.
56  // using the provided executor. The optional literal_shape will be the shape
57  // for the literal. The shape of the ShapedBuffer and
58  // DeviceShape(literal_shape) must be compatible, but need not have the same
59  // layout.
60  virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice(
61      perftools::gputools::StreamExecutor* executor,
62      const ShapedBuffer& device_buffer) = 0;
63
64  // Transfers the given literal into the previously allocated device memory
65  // represented by the given ShapedBuffer using the given executor. The shape
66  // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
67  // but need not have the same layout
68  virtual Status TransferLiteralToDevice(
69      perftools::gputools::StreamExecutor* executor, const Literal& literal,
70      const ShapedBuffer& device_buffer) = 0;
71
72  // Convenience methods for transferring an array to or from the device at a
73  // known address. This avoids having to construct a ShapedBuffer just to
74  // transfer an array at a known address.
75  Status TransferArrayToDevice(
76      perftools::gputools::StreamExecutor* executor, const Literal& literal,
77      const perftools::gputools::DeviceMemoryBase& dest);
78  StatusOr<std::unique_ptr<Literal>> TransferArrayFromDevice(
79      perftools::gputools::StreamExecutor* executor, const Shape& shape,
80      const perftools::gputools::DeviceMemoryBase& source);
81
82  // Transfers the given literal into the Infeed interface of the device,
83  // using the given executor.
84  virtual Status TransferLiteralToInfeed(
85      perftools::gputools::StreamExecutor* executor,
86      const Literal& literal) = 0;
87
88  // Transfers the given literal from the Outfeed interface of the device,
89  // using the given executor.
90  virtual Status TransferLiteralFromOutfeed(
91      perftools::gputools::StreamExecutor* executor, const Shape& literal_shape,
92      Literal* literal) = 0;
93
94  // Resets the devices associated with this transfer manager.
95  virtual Status ResetDevices(
96      tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
97          executor) = 0;
98
99  // Given an allocated ShapedBuffer, constructs the tuple index table(s) in
100  // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the
101  // ShapedBuffer is array-shaped this method does nothing.
102  Status WriteTupleIndexTables(perftools::gputools::StreamExecutor* executor,
103                               const ShapedBuffer& device_buffer);
104
105  // Determines the byte size requirement for the given shape on the underlying
106  // architecture. This will be used to allocate an appropriately sized memory
107  // region for a host-to-device transfer.
108  virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0;
109
110  // Allocate a ShapedBuffer which can hold data with the given on-host
111  // shape. The on-device shape may be different as indicated by
112  // HostShapeToDeviceShape.
113  StatusOr<std::unique_ptr<ShapedBuffer>> AllocateShapedBuffer(
114      const Shape& on_host_shape, DeviceMemoryAllocator* allocator,
115      int device_ordinal);
116  StatusOr<std::unique_ptr<ScopedShapedBuffer>> AllocateScopedShapedBuffer(
117      const Shape& on_host_shape, DeviceMemoryAllocator* allocator,
118      int device_ordinal);
119
120  /////
121  // The TransferManager class also serves as a point to register objects for
122  // the various platforms.
123
124  // Registers the TransferManager singleton for the platform kind. This is
125  // assumed to be a singleton, so no ownership is transferred.
126  //
127  // Precondition: a platform kind must not be registered more than once.
128  typedef std::unique_ptr<TransferManager> (*TransferManagerCreationFunction)();
129  static void RegisterTransferManager(
130      perftools::gputools::Platform::Id platform_id,
131      TransferManagerCreationFunction transfer_manager);
132
133  // Returns the transfer manager singleton pointer if it is available for the
134  // given platform, or an error status if it is not.
135  static StatusOr<TransferManager*> GetForPlatform(
136      const perftools::gputools::Platform* platform);
137
138 protected:
139  // Transfer a memory block of the given size from 'source' buffer to the
140  // Infeed interface of the device using the given executor.
141  //
142  // size is the size to transfer from source in bytes.
143  //
144  // source is the source data that must be in the target-dependent layout that
145  // the Infeed HLO used in the computation expects.
146  virtual Status TransferBufferToInfeed(
147      perftools::gputools::StreamExecutor* executor, int64 size,
148      const void* source) = 0;
149
150  // Transfer a memory block of the given size from the device source into the
151  // 'destination' buffer.
152  //
153  // size is the size to transfer to destination in bytes.
154  virtual Status TransferBufferFromDevice(
155      perftools::gputools::StreamExecutor* executor,
156      const perftools::gputools::DeviceMemoryBase& source, int64 size,
157      void* destination);
158
159  // Transfer a memory block of the given size from 'source' buffer to the given
160  // destination of the device.
161  //
162  // size is the size to transfer from source in bytes.
163  virtual Status TransferBufferToDevice(
164      perftools::gputools::StreamExecutor* executor, int64 size,
165      const void* source, perftools::gputools::DeviceMemoryBase* destination);
166
167  // Writes the given device-memory pointers in 'elements' to the given region
168  // to construct a tuple index table in the platform-specific tuple
169  // representation.
170  virtual Status WriteSingleTupleIndexTable(
171      perftools::gputools::StreamExecutor* executor,
172      tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
173          elements,
174      const Shape& shape, perftools::gputools::DeviceMemoryBase* region) = 0;
175
176 private:
177  // The mutex that guards the platform-to-transfer manager map.
178  static tensorflow::mutex platform_transfer_manager_mutex_;
179
180  // State kept for each kind of TransferManager.  Registration functions
181  // set up creation_function, and then we use that to lazily create
182  // "manager" the first time GetForPlatform is invoked for a particular id.
183  struct State {
184    std::unique_ptr<TransferManager> manager;
185    TransferManagerCreationFunction creation_function = nullptr;
186  };
187
188  // Map from platform kind to transfer manager singleton.
189  static std::map<perftools::gputools::Platform::Id, State>*
190  GetPlatformTransferManagers();
191};
192
193}  // namespace xla
194
195#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
196