1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
17
18#include <string>
19#include <utility>
20#include <vector>
21
22#include "tensorflow/compiler/xla/layout_util.h"
23#include "tensorflow/compiler/xla/literal_util.h"
24#include "tensorflow/compiler/xla/service/interpreter/platform_id.h"
25#include "tensorflow/compiler/xla/shape_util.h"
26#include "tensorflow/compiler/xla/status_macros.h"
27#include "tensorflow/compiler/xla/statusor.h"
28#include "tensorflow/compiler/xla/types.h"
29#include "tensorflow/compiler/xla/util.h"
30#include "tensorflow/compiler/xla/xla_data.pb.h"
31#include "tensorflow/core/lib/core/errors.h"
32#include "tensorflow/core/platform/logging.h"
33#include "tensorflow/core/platform/stream_executor_no_cuda.h"
34
35namespace se = ::perftools::gputools;
36
37namespace xla {
38
39GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id,
40                                               size_t pointer_size)
41    : platform_id_(platform_id), pointer_size_(pointer_size) {
42  // We currently only support kHostPlatformId for CPU, kCudaPlatformId for
43  // GPU and kInterpreterPlatformId for Interpreter. Before supporting other
44  // platforms, we need to test this transfer manager on them.
45  CHECK(platform_id_ == se::host::kHostPlatformId ||
46        platform_id_ == se::interpreter::kInterpreterPlatformId ||
47        platform_id_ == se::cuda::kCudaPlatformId);
48}
49
50se::Platform::Id GenericTransferManager::PlatformId() const {
51  return platform_id_;
52}
53
54Status GenericTransferManager::WriteSingleTupleIndexTable(
55    perftools::gputools::StreamExecutor* executor,
56    tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> elements,
57    const Shape& shape, perftools::gputools::DeviceMemoryBase* region) {
58  TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape));
59
60  std::vector<const void*> element_pointers;
61  for (const se::DeviceMemoryBase& element : elements) {
62    element_pointers.push_back(element.opaque());
63  }
64  return TransferBufferToDevice(executor, GetByteSizeRequirement(shape),
65                                element_pointers.data(), region);
66}
67
68StatusOr<std::unique_ptr<Literal>>
69GenericTransferManager::TransferLiteralFromDevice(
70    se::StreamExecutor* executor, const ShapedBuffer& device_buffer) {
71  VLOG(2) << "transferring literal from device ordinal "
72          << executor->device_ordinal() << "; device buffer: " << device_buffer;
73  TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
74
75  // The on-host and on-device shape should always be the same for the generic
76  // transfer manager.
77  TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
78                                device_buffer.on_host_shape()));
79
80  std::unique_ptr<Literal> literal =
81      Literal::CreateFromShape(device_buffer.on_host_shape());
82
83  TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
84      device_buffer.on_host_shape(),
85      [&](const Shape& subshape, const ShapeIndex& index) -> Status {
86        if (!ShapeUtil::IsTuple(subshape)) {
87          TF_RETURN_IF_ERROR(TransferBufferFromDevice(
88              executor,
89              /*source=*/device_buffer.buffer(index),
90              /*size=*/GetByteSizeRequirement(subshape),
91              /*destination=*/
92              literal->untyped_data(index)));
93        }
94
95        return Status::OK();
96      }));
97  return std::move(literal);
98}
99
100Status GenericTransferManager::TransferLiteralToDevice(
101    se::StreamExecutor* executor, const Literal& literal,
102    const ShapedBuffer& device_buffer) {
103  const Shape& shape = literal.shape();
104  VLOG(2) << "transferring literal shape to device: "
105          << ShapeUtil::HumanString(shape)
106          << "; device buffer: " << device_buffer;
107
108  // The on-host and on-device shape should always be the same for the generic
109  // transfer manager.
110  TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(),
111                                device_buffer.on_host_shape()));
112
113  TF_RET_CHECK(
114      ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape()));
115  TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
116
117  TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer));
118
119  return ShapeUtil::ForEachSubshapeWithStatus(
120      device_buffer.on_host_shape(),
121      [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
122        se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
123        if (ShapeUtil::IsArray(device_subshape)) {
124          TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
125                       device_memory.size());
126          // Element is array-shaped: transfer array data to device buffer.
127          const auto subliteral = LiteralView::Create(literal, index);
128          std::unique_ptr<Literal> relayed_out_literal;
129          const void* source;
130          if (LayoutUtil::Equal(device_subshape.layout(),
131                                subliteral.shape().layout())) {
132            source = subliteral.untyped_data();
133          } else {
134            // Relayout data before transferring.
135            relayed_out_literal = subliteral.Relayout(device_subshape.layout(),
136                                                      /*shape_index=*/{});
137            source = relayed_out_literal->untyped_data();
138          }
139          return TransferBufferToDevice(
140              executor,
141              /*size=*/GetByteSizeRequirement(device_subshape), source,
142              &device_memory);
143        }
144        return Status::OK();
145      });
146}
147
148Status GenericTransferManager::TransferLiteralToInfeed(
149    se::StreamExecutor* executor, const Literal& literal) {
150  return Unimplemented("Generic transfer to Infeed");
151}
152
153Status GenericTransferManager::TransferBufferToInfeed(
154    perftools::gputools::StreamExecutor* executor, int64 size,
155    const void* source) {
156  return Unimplemented("Generic transfer to Infeed");
157}
158
159Status GenericTransferManager::TransferLiteralFromOutfeed(
160    perftools::gputools::StreamExecutor* executor, const Shape& literal_shape,
161    Literal* literal) {
162  return Unimplemented(
163      "Outfeed is not supported on this platform (b/30467474)");
164}
165
166Status GenericTransferManager::ResetDevices(
167    tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
168    /*executors*/) {
169  return Unimplemented(
170      "Device reset is not yet supported on this platform (b/30481585)");
171}
172
173int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) const {
174  return ShapeUtil::ByteSizeOf(shape, pointer_size_);
175}
176
177}  // namespace xla
178