11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/transfer_manager.h"
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <string>
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <utility>
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/shape_util.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/status_macros.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/types.h"
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/util.h"
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/macros.h"
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace se = ::perftools::gputools;
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
31b0bcf675a4b5d6217f3b58fd27b344f20e7bf25dSanjoy Das/* static */ tensorflow::mutex
32b0bcf675a4b5d6217f3b58fd27b344f20e7bf25dSanjoy Das    TransferManager::platform_transfer_manager_mutex_(
33b0bcf675a4b5d6217f3b58fd27b344f20e7bf25dSanjoy Das        tensorflow::LINKER_INITIALIZED);
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::map<perftools::gputools::Platform::Id,
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                      TransferManager::State>*
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTransferManager::GetPlatformTransferManagers() {
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  static auto* r =
391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      new std::map<perftools::gputools::Platform::Id, TransferManager::State>;
401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return r;
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
43fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlowerStatus TransferManager::TransferArrayToDevice(
44fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower    perftools::gputools::StreamExecutor* executor, const Literal& literal,
45fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower    const perftools::gputools::DeviceMemoryBase& dest) {
46fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  const Shape on_device_shape = HostShapeToDeviceShape(literal.shape());
47fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape))
48fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower      << "On-device representation of "
49fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower      << ShapeUtil::HumanString(literal.shape())
50fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower      << " is not an array: " << ShapeUtil::HumanString(on_device_shape);
51fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  if (dest.size() < GetByteSizeRequirement(on_device_shape)) {
52fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower    return FailedPrecondition(
53fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower        "Allocation on device not large enough for array: "
54fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower        "%lld < %lld",
55fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower        dest.size(), GetByteSizeRequirement(on_device_shape));
56fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  }
57fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape,
58fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower                             executor->platform(), executor->device_ordinal());
59fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  shaped_buffer.set_buffer(dest, /*index=*/{});
60fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  return TransferLiteralToDevice(executor, literal, shaped_buffer);
61fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower}
62fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower
63fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlowerStatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice(
64fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower    perftools::gputools::StreamExecutor* executor, const Shape& shape,
65fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower    const perftools::gputools::DeviceMemoryBase& source) {
66fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  TF_RET_CHECK(ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape))
67fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower      << "Shape " << ShapeUtil::HumanString(shape)
68fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower      << " has a differently shaped representation on-device: "
69fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower      << ShapeUtil::HumanString(HostShapeToDeviceShape(shape));
70fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  if (source.size() < GetByteSizeRequirement(shape)) {
71fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower    return FailedPrecondition(
72fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower        "Allocation on device not large enough for array: "
73fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower        "%lld < %lld",
74fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower        source.size(), GetByteSizeRequirement(shape));
75fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  }
76fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape,
77fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower                             executor->platform(), executor->device_ordinal());
78fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  shaped_buffer.set_buffer(source, /*index=*/{});
79fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  return TransferLiteralFromDevice(executor, shaped_buffer);
80fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower}
81fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower
821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ void TransferManager::RegisterTransferManager(
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    se::Platform::Id platform_id,
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    TransferManagerCreationFunction creation_function) {
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  tensorflow::mutex_lock lock(
86b0bcf675a4b5d6217f3b58fd27b344f20e7bf25dSanjoy Das      TransferManager::platform_transfer_manager_mutex_);
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto* managers = GetPlatformTransferManagers();
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK(managers->find(platform_id) == managers->end());
891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  (*managers)[platform_id].creation_function = creation_function;
901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ StatusOr<TransferManager*> TransferManager::GetForPlatform(
931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    const se::Platform* platform) {
941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  tensorflow::mutex_lock lock(
95b0bcf675a4b5d6217f3b58fd27b344f20e7bf25dSanjoy Das      TransferManager::platform_transfer_manager_mutex_);
961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto* managers = GetPlatformTransferManagers();
971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto it = managers->find(platform->id());
991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (it == managers->end()) {
1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return NotFound(
1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        "could not find registered transfer manager for platform %s -- check "
1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        "target linkage",
1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        platform->Name().c_str());
1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (it->second.manager == nullptr) {
1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    // Lazily create the transfer manager the first time it is needed
1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    it->second.manager = (*it->second.creation_function)();
1091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1118cb5e9867482a8e05f756fad35634e1674fe7f16A. Unique TensorFlower  return it->second.manager.get();
1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
11422d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark HeffernanStatus TransferManager::WriteTupleIndexTables(
11522d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan    perftools::gputools::StreamExecutor* executor,
11622d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan    const ShapedBuffer& device_buffer) {
117fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  VLOG(2) << "Writing tuple index tables for " << device_buffer;
11822d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan
11922d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan  TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal());
12022d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan
12122d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan  return ShapeUtil::ForEachSubshapeWithStatus(
122fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower      device_buffer.on_device_shape(),
12322d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan      [&](const Shape& device_subshape, const ShapeIndex& index) -> Status {
12422d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan        if (ShapeUtil::IsTuple(device_subshape)) {
12522d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan          se::DeviceMemoryBase device_memory = device_buffer.buffer(index);
12622d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan          TF_RET_CHECK(GetByteSizeRequirement(device_subshape) ==
12722d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan                       device_memory.size());
12822d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan
12922d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan          std::vector<se::DeviceMemoryBase> elements;
13022d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan          ShapeIndex element_index = index;
13122d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan          for (int64 i = 0; i < ShapeUtil::TupleElementCount(device_subshape);
13222d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan               ++i) {
13322d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan            element_index.push_back(i);
13422d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan            elements.push_back(device_buffer.buffer(element_index));
13522d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan            element_index.pop_back();
13622d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan          }
137fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower          return WriteSingleTupleIndexTable(executor, elements, device_subshape,
13822d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan                                            &device_memory);
13922d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan        }
14022d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan
14122d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan        return Status::OK();
14222d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan      });
14322d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan}
14422d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan
1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsStatus TransferManager::TransferBufferFromDevice(
1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    int64 size, void* destination) {
1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (source.size() < size) {
1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return FailedPrecondition(
1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        "Source allocation on device not large enough for data tranfer: "
1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        "%lld < %lld",
1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        source.size(), size);
1531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto copy_status = executor->SynchronousMemcpyD2H(source, size, destination);
1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (!copy_status.ok()) {
1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return AddStatus(
1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        Status(static_cast<tensorflow::error::Code>(copy_status.code()),
1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins               copy_status.error_message()),
1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        "failed transfer from device to buffer");
1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return Status::OK();
1621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsStatus TransferManager::TransferBufferToDevice(
1651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    se::StreamExecutor* executor, int64 size, const void* source,
1661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    se::DeviceMemoryBase* destination) {
1671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (destination->size() < size) {
1681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return FailedPrecondition(
1691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        "Destination allocation on device not large enough for data tranfer: "
1701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        "%lld < %lld",
1711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        destination->size(), size);
1721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  auto copy_status = executor->SynchronousMemcpyH2D(source, size, destination);
1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  if (!copy_status.ok()) {
1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    return AddStatus(
1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        Status(static_cast<tensorflow::error::Code>(copy_status.code()),
1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins               copy_status.error_message()),
1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        "failed transfer of buffer to device");
1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return Status::OK();
1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
183fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlowerStatusOr<std::unique_ptr<ShapedBuffer>> TransferManager::AllocateShapedBuffer(
184fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower    const Shape& on_host_shape, DeviceMemoryAllocator* allocator,
185fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower    int device_ordinal) {
186fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  if (!LayoutUtil::HasLayout(on_host_shape)) {
187fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower    return InvalidArgument(
188fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower        "Shape must have a layout: %s",
189fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower        ShapeUtil::HumanStringWithLayout(on_host_shape).c_str());
190fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  }
191fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape));
192fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
193fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
194fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower
195fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  auto shaped_buffer = WrapUnique(new ShapedBuffer(
196fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower      on_host_shape, on_device_shape, allocator->platform(), device_ordinal));
197fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower
198fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  // Allocate an appropriate sized buffer for each element in the shape
199fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  // including the tuple pointer arrays.
200fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  for (auto& pair : shaped_buffer->buffers()) {
201fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower    const ShapeIndex& index = pair.first;
202fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower    se::DeviceMemoryBase& memory_base = pair.second;
203fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower    const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index);
204fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower    TF_ASSIGN_OR_RETURN(memory_base,
205fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower                        allocator->Allocate(shaped_buffer->device_ordinal(),
206fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower                                            GetByteSizeRequirement(subshape)));
2071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
208fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower
209fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  return std::move(shaped_buffer);
210fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower}
211fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower
212fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlowerStatusOr<std::unique_ptr<ScopedShapedBuffer>>
213fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlowerTransferManager::AllocateScopedShapedBuffer(const Shape& on_host_shape,
214fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower                                            DeviceMemoryAllocator* allocator,
215fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower                                            int device_ordinal) {
216fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  TF_ASSIGN_OR_RETURN(
217fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower      std::unique_ptr<ShapedBuffer> unscoped_buffer,
218fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower      AllocateShapedBuffer(on_host_shape, allocator, device_ordinal));
219fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower  return ScopedShapedBuffer::MakeScoped(unscoped_buffer.get(), allocator);
2201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
2211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
223