1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_
17#define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_
18
19#include <memory>
20#include <ostream>
21#include <string>
22
23#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
24#include "tensorflow/compiler/xla/shape_tree.h"
25#include "tensorflow/compiler/xla/statusor.h"
26#include "tensorflow/compiler/xla/xla_data.pb.h"
27#include "tensorflow/core/lib/gtl/array_slice.h"
28#include "tensorflow/core/platform/stream_executor_no_cuda.h"
29#include "tensorflow/core/platform/types.h"
30
31namespace xla {
32
33// Class which encapsulates a buffer or set of buffers containing data of a
34// particular XLA shape.
35class ShapedBuffer {
36 public:
37  // Construct a ShapedBuffer with null DeviceMemoryBases at each index. The
38  // shape of the data on the host and the device may differ because the device
39  // may have a different representation for different data types. Therefore,
40  // both the on-host and on-device shape are required. The on-device shape
41  // determines the number of device allocations (DeviceMemoryBase) held by the
42  // ShapedBuffer.
43  ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape,
44               const perftools::gputools::Platform* platform,
45               int device_ordinal);
46
47  // Returns the shape of the on-host representation of the data held by this
48  // ShapedBuffer.
49  const Shape& on_host_shape() const { return on_host_shape_; }
50
51  // Returns the shape of the on-device representation of the data held by this
52  // ShapedBuffer.
53  const Shape& on_device_shape() const { return on_device_shape_; }
54
55  const perftools::gputools::Platform* platform() const { return platform_; }
56  int device_ordinal() const { return device_ordinal_; }
57
58  // Return the root buffer of the shape (shape index {}).
59  const perftools::gputools::DeviceMemoryBase& root_buffer() const {
60    return buffer(/*index=*/{});
61  }
62
63  // Returns the buffer at the given shape index where index is defined as in
64  // ShapeUtil::GetSubshape.
65  const perftools::gputools::DeviceMemoryBase& buffer(
66      const ShapeIndex& index) const {
67    return buffers_.element(index);
68  }
69
70  // Sets the device memory buffer at the given index.
71  void set_buffer(const perftools::gputools::DeviceMemoryBase& buffer,
72                  const ShapeIndex& index) {
73    *buffers_.mutable_element(index) = buffer;
74  }
75
76  // Returns the underlying ShapeTree containing all the device addresses in the
77  // ShapedBuffer.
78  const ShapeTree<perftools::gputools::DeviceMemoryBase>& buffers() const {
79    return buffers_;
80  }
81  ShapeTree<perftools::gputools::DeviceMemoryBase>& buffers() {
82    return buffers_;
83  }
84
85  // Set all device memory pointers in the object to null.
86  void clear();
87
88  string ToString() const;
89
90  ShapedBuffer(ShapedBuffer&& s);
91  ShapedBuffer& operator=(ShapedBuffer&&);
92
93 protected:
94  ShapedBuffer(const ShapedBuffer&) = delete;
95  ShapedBuffer& operator=(const ShapedBuffer&) = delete;
96
97  // The shape of the data when represented on the host.
98  Shape on_host_shape_;
99
100  // The shape of the data on the device.
101  Shape on_device_shape_;
102
103  // The platform the memory is allocated on.
104  const perftools::gputools::Platform* platform_;
105
106  // The device the memory is allocated on.
107  int device_ordinal_;
108
109  // The tree of device buffers. Its shape is on_device_shape().
110  ShapeTree<perftools::gputools::DeviceMemoryBase> buffers_;
111};
112
113std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer);
114
115// ShapedBuffer derived class which allocates all internal buffers on
116// construction and deallocates the memory when the object is
117// destructed.
118class ScopedShapedBuffer : public ShapedBuffer {
119 public:
120  // Takes a ShapedBuffer and returns a ScopedShapedBuffer which manages the
121  // deallocation of the device memory held in the shaped buffer. All device
122  // memory pointers in the given ShapedBuffer are set to null.
123  static StatusOr<std::unique_ptr<ScopedShapedBuffer>> MakeScoped(
124      ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator);
125
126  // Create a ScopedShapedBuffer with null DeviceMemoryBases at each index.
127  ScopedShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape,
128                     DeviceMemoryAllocator* allocator, int device_ordinal);
129
130  // Create a ScopedShapedBuffer by taking over the memory from the incoming
131  // ShapedBuffer.
132  ScopedShapedBuffer(ShapedBuffer shaped_buffer,
133                     DeviceMemoryAllocator* allocator);
134
135  // Return the allocator used to allocate the device memory held in this
136  // ScopedShapedBuffer.
137  DeviceMemoryAllocator* memory_allocator() const { return allocator_; }
138
139  // Release all device memory owned by this ScopedShapedBuffer and
140  // return the device memory pointers in the form of a
141  // ShapedBuffer. The returned ShapedBuffer takes over the memory
142  // from the ScopedShapedBuffer. The resulting ScopedShapedBuffer can
143  // only be destroyed.
144  std::unique_ptr<ShapedBuffer> release();
145
146  // All buffers in the shape are deallocated on destruction.
147  virtual ~ScopedShapedBuffer();
148
149 protected:
150  ScopedShapedBuffer(const ScopedShapedBuffer&) = delete;
151  void operator=(const ScopedShapedBuffer&) = delete;
152
153  DeviceMemoryAllocator* allocator_;
154};
155
156}  // namespace xla
157
158#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_
159