1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_ 17#define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_ 18 19#include <memory> 20#include <ostream> 21#include <string> 22 23#include "tensorflow/compiler/xla/service/device_memory_allocator.h" 24#include "tensorflow/compiler/xla/shape_tree.h" 25#include "tensorflow/compiler/xla/statusor.h" 26#include "tensorflow/compiler/xla/xla_data.pb.h" 27#include "tensorflow/core/lib/gtl/array_slice.h" 28#include "tensorflow/core/platform/stream_executor_no_cuda.h" 29#include "tensorflow/core/platform/types.h" 30 31namespace xla { 32 33// Class which encapsulates a buffer or set of buffers containing data of a 34// particular XLA shape. 35class ShapedBuffer { 36 public: 37 // Construct a ShapedBuffer with null DeviceMemoryBases at each index. The 38 // shape of the data on the host and the device may differ because the device 39 // may have a different representation for different data types. Therefore, 40 // both the on-host and on-device shape are required. The on-device shape 41 // determines the number of device allocations (DeviceMemoryBase) held by the 42 // ShapedBuffer. 43 ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, 44 const perftools::gputools::Platform* platform, 45 int device_ordinal); 46 47 // Returns the shape of the on-host representation of the data held by this 48 // ShapedBuffer. 49 const Shape& on_host_shape() const { return on_host_shape_; } 50 51 // Returns the shape of the on-device representation of the data held by this 52 // ShapedBuffer. 53 const Shape& on_device_shape() const { return on_device_shape_; } 54 55 const perftools::gputools::Platform* platform() const { return platform_; } 56 int device_ordinal() const { return device_ordinal_; } 57 58 // Return the root buffer of the shape (shape index {}). 59 const perftools::gputools::DeviceMemoryBase& root_buffer() const { 60 return buffer(/*index=*/{}); 61 } 62 63 // Returns the buffer at the given shape index where index is defined as in 64 // ShapeUtil::GetSubshape. 65 const perftools::gputools::DeviceMemoryBase& buffer( 66 const ShapeIndex& index) const { 67 return buffers_.element(index); 68 } 69 70 // Sets the device memory buffer at the given index. 71 void set_buffer(const perftools::gputools::DeviceMemoryBase& buffer, 72 const ShapeIndex& index) { 73 *buffers_.mutable_element(index) = buffer; 74 } 75 76 // Returns the underlying ShapeTree containing all the device addresses in the 77 // ShapedBuffer. 78 const ShapeTree<perftools::gputools::DeviceMemoryBase>& buffers() const { 79 return buffers_; 80 } 81 ShapeTree<perftools::gputools::DeviceMemoryBase>& buffers() { 82 return buffers_; 83 } 84 85 // Set all device memory pointers in the object to null. 86 void clear(); 87 88 string ToString() const; 89 90 ShapedBuffer(ShapedBuffer&& s); 91 ShapedBuffer& operator=(ShapedBuffer&&); 92 93 protected: 94 ShapedBuffer(const ShapedBuffer&) = delete; 95 ShapedBuffer& operator=(const ShapedBuffer&) = delete; 96 97 // The shape of the data when represented on the host. 98 Shape on_host_shape_; 99 100 // The shape of the data on the device. 101 Shape on_device_shape_; 102 103 // The platform the memory is allocated on. 104 const perftools::gputools::Platform* platform_; 105 106 // The device the memory is allocated on. 107 int device_ordinal_; 108 109 // The tree of device buffers. Its shape is on_device_shape(). 110 ShapeTree<perftools::gputools::DeviceMemoryBase> buffers_; 111}; 112 113std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); 114 115// ShapedBuffer derived class which allocates all internal buffers on 116// construction and deallocates the memory when the object is 117// destructed. 118class ScopedShapedBuffer : public ShapedBuffer { 119 public: 120 // Takes a ShapedBuffer and returns a ScopedShapedBuffer which manages the 121 // deallocation of the device memory held in the shaped buffer. All device 122 // memory pointers in the given ShapedBuffer are set to null. 123 static StatusOr<std::unique_ptr<ScopedShapedBuffer>> MakeScoped( 124 ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator); 125 126 // Create a ScopedShapedBuffer with null DeviceMemoryBases at each index. 127 ScopedShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, 128 DeviceMemoryAllocator* allocator, int device_ordinal); 129 130 // Create a ScopedShapedBuffer by taking over the memory from the incoming 131 // ShapedBuffer. 132 ScopedShapedBuffer(ShapedBuffer shaped_buffer, 133 DeviceMemoryAllocator* allocator); 134 135 // Return the allocator used to allocate the device memory held in this 136 // ScopedShapedBuffer. 137 DeviceMemoryAllocator* memory_allocator() const { return allocator_; } 138 139 // Release all device memory owned by this ScopedShapedBuffer and 140 // return the device memory pointers in the form of a 141 // ShapedBuffer. The returned ShapedBuffer takes over the memory 142 // from the ScopedShapedBuffer. The resulting ScopedShapedBuffer can 143 // only be destroyed. 144 std::unique_ptr<ShapedBuffer> release(); 145 146 // All buffers in the shape are deallocated on destruction. 147 virtual ~ScopedShapedBuffer(); 148 149 protected: 150 ScopedShapedBuffer(const ScopedShapedBuffer&) = delete; 151 void operator=(const ScopedShapedBuffer&) = delete; 152 153 DeviceMemoryAllocator* allocator_; 154}; 155 156} // namespace xla 157 158#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_ 159