1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ 17#define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ 18 19#include <map> 20#include <set> 21#include <vector> 22 23#include "tensorflow/compiler/xla/literal_util.h" 24#include "tensorflow/compiler/xla/service/shaped_buffer.h" 25#include "tensorflow/compiler/xla/statusor.h" 26#include "tensorflow/compiler/xla/types.h" 27#include "tensorflow/compiler/xla/xla_data.pb.h" 28#include "tensorflow/core/lib/gtl/array_slice.h" 29#include "tensorflow/core/platform/mutex.h" 30#include "tensorflow/core/platform/stream_executor_no_cuda.h" 31#include "tensorflow/core/platform/thread_annotations.h" 32#include "tensorflow/core/platform/types.h" 33 34namespace xla { 35 36// The TransferManager interface lets backends provide platform-specific 37// mechanisms for constructing literals from given device memory handles. 38// This lets each platform customize how literals are transferred to/from the 39// device in terms of padding, leading dimension, etc. 40class TransferManager { 41 public: 42 virtual ~TransferManager() {} 43 44 // Returns the ID of the platform that this transfer manager acts on. 45 virtual perftools::gputools::Platform::Id PlatformId() const = 0; 46 47 // Returns the shape of the on-device representation for the given shape on 48 // the host. This is intended for use with ShapedBuffer where buffers are 49 // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user 50 // needing to consider device-specific behaviors. 51 virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const { 52 return host_shape; 53 } 54 55 // Returns a literal containing the data held in the given ShapedBuffer. 56 // using the provided executor. The optional literal_shape will be the shape 57 // for the literal. The shape of the ShapedBuffer and 58 // DeviceShape(literal_shape) must be compatible, but need not have the same 59 // layout. 60 virtual StatusOr<std::unique_ptr<Literal>> TransferLiteralFromDevice( 61 perftools::gputools::StreamExecutor* executor, 62 const ShapedBuffer& device_buffer) = 0; 63 64 // Transfers the given literal into the previously allocated device memory 65 // represented by the given ShapedBuffer using the given executor. The shape 66 // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, 67 // but need not have the same layout 68 virtual Status TransferLiteralToDevice( 69 perftools::gputools::StreamExecutor* executor, const Literal& literal, 70 const ShapedBuffer& device_buffer) = 0; 71 72 // Convenience methods for transferring an array to or from the device at a 73 // known address. This avoids having to construct a ShapedBuffer just to 74 // transfer an array at a known address. 75 Status TransferArrayToDevice( 76 perftools::gputools::StreamExecutor* executor, const Literal& literal, 77 const perftools::gputools::DeviceMemoryBase& dest); 78 StatusOr<std::unique_ptr<Literal>> TransferArrayFromDevice( 79 perftools::gputools::StreamExecutor* executor, const Shape& shape, 80 const perftools::gputools::DeviceMemoryBase& source); 81 82 // Transfers the given literal into the Infeed interface of the device, 83 // using the given executor. 84 virtual Status TransferLiteralToInfeed( 85 perftools::gputools::StreamExecutor* executor, 86 const Literal& literal) = 0; 87 88 // Transfers the given literal from the Outfeed interface of the device, 89 // using the given executor. 90 virtual Status TransferLiteralFromOutfeed( 91 perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, 92 Literal* literal) = 0; 93 94 // Resets the devices associated with this transfer manager. 95 virtual Status ResetDevices( 96 tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> 97 executor) = 0; 98 99 // Given an allocated ShapedBuffer, constructs the tuple index table(s) in 100 // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the 101 // ShapedBuffer is array-shaped this method does nothing. 102 Status WriteTupleIndexTables(perftools::gputools::StreamExecutor* executor, 103 const ShapedBuffer& device_buffer); 104 105 // Determines the byte size requirement for the given shape on the underlying 106 // architecture. This will be used to allocate an appropriately sized memory 107 // region for a host-to-device transfer. 108 virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0; 109 110 // Allocate a ShapedBuffer which can hold data with the given on-host 111 // shape. The on-device shape may be different as indicated by 112 // HostShapeToDeviceShape. 113 StatusOr<std::unique_ptr<ShapedBuffer>> AllocateShapedBuffer( 114 const Shape& on_host_shape, DeviceMemoryAllocator* allocator, 115 int device_ordinal); 116 StatusOr<std::unique_ptr<ScopedShapedBuffer>> AllocateScopedShapedBuffer( 117 const Shape& on_host_shape, DeviceMemoryAllocator* allocator, 118 int device_ordinal); 119 120 ///// 121 // The TransferManager class also serves as a point to register objects for 122 // the various platforms. 123 124 // Registers the TransferManager singleton for the platform kind. This is 125 // assumed to be a singleton, so no ownership is transferred. 126 // 127 // Precondition: a platform kind must not be registered more than once. 128 typedef std::unique_ptr<TransferManager> (*TransferManagerCreationFunction)(); 129 static void RegisterTransferManager( 130 perftools::gputools::Platform::Id platform_id, 131 TransferManagerCreationFunction transfer_manager); 132 133 // Returns the transfer manager singleton pointer if it is available for the 134 // given platform, or an error status if it is not. 135 static StatusOr<TransferManager*> GetForPlatform( 136 const perftools::gputools::Platform* platform); 137 138 protected: 139 // Transfer a memory block of the given size from 'source' buffer to the 140 // Infeed interface of the device using the given executor. 141 // 142 // size is the size to transfer from source in bytes. 143 // 144 // source is the source data that must be in the target-dependent layout that 145 // the Infeed HLO used in the computation expects. 146 virtual Status TransferBufferToInfeed( 147 perftools::gputools::StreamExecutor* executor, int64 size, 148 const void* source) = 0; 149 150 // Transfer a memory block of the given size from the device source into the 151 // 'destination' buffer. 152 // 153 // size is the size to transfer to destination in bytes. 154 virtual Status TransferBufferFromDevice( 155 perftools::gputools::StreamExecutor* executor, 156 const perftools::gputools::DeviceMemoryBase& source, int64 size, 157 void* destination); 158 159 // Transfer a memory block of the given size from 'source' buffer to the given 160 // destination of the device. 161 // 162 // size is the size to transfer from source in bytes. 163 virtual Status TransferBufferToDevice( 164 perftools::gputools::StreamExecutor* executor, int64 size, 165 const void* source, perftools::gputools::DeviceMemoryBase* destination); 166 167 // Writes the given device-memory pointers in 'elements' to the given region 168 // to construct a tuple index table in the platform-specific tuple 169 // representation. 170 virtual Status WriteSingleTupleIndexTable( 171 perftools::gputools::StreamExecutor* executor, 172 tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> 173 elements, 174 const Shape& shape, perftools::gputools::DeviceMemoryBase* region) = 0; 175 176 private: 177 // The mutex that guards the platform-to-transfer manager map. 178 static tensorflow::mutex platform_transfer_manager_mutex_; 179 180 // State kept for each kind of TransferManager. Registration functions 181 // set up creation_function, and then we use that to lazily create 182 // "manager" the first time GetForPlatform is invoked for a particular id. 183 struct State { 184 std::unique_ptr<TransferManager> manager; 185 TransferManagerCreationFunction creation_function = nullptr; 186 }; 187 188 // Map from platform kind to transfer manager singleton. 189 static std::map<perftools::gputools::Platform::Id, State>* 190 GetPlatformTransferManagers(); 191}; 192 193} // namespace xla 194 195#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ 196