11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License"); 41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License. 51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at 61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins http://www.apache.org/licenses/LICENSE-2.0 81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software 101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS, 111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and 131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License. 141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/ 151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/transfer_manager.h" 171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <string> 191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <utility> 201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/shape_util.h" 221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/status_macros.h" 231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/types.h" 241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/util.h" 251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h" 261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/macros.h" 271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace se = ::perftools::gputools; 291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla { 31b0bcf675a4b5d6217f3b58fd27b344f20e7bf25dSanjoy Das/* static */ tensorflow::mutex 32b0bcf675a4b5d6217f3b58fd27b344f20e7bf25dSanjoy Das TransferManager::platform_transfer_manager_mutex_( 33b0bcf675a4b5d6217f3b58fd27b344f20e7bf25dSanjoy Das tensorflow::LINKER_INITIALIZED); 341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ std::map<perftools::gputools::Platform::Id, 361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TransferManager::State>* 371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsTransferManager::GetPlatformTransferManagers() { 381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins static auto* r = 391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins new std::map<perftools::gputools::Platform::Id, TransferManager::State>; 401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return r; 411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 421e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 43fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlowerStatus TransferManager::TransferArrayToDevice( 44fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower perftools::gputools::StreamExecutor* executor, const Literal& literal, 45fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower const perftools::gputools::DeviceMemoryBase& dest) { 46fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); 47fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) 48fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower << "On-device representation of " 49fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower << ShapeUtil::HumanString(literal.shape()) 50fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower << " is not an array: " << ShapeUtil::HumanString(on_device_shape); 51fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower if (dest.size() < GetByteSizeRequirement(on_device_shape)) { 52fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower return FailedPrecondition( 53fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower "Allocation on device not large enough for array: " 54fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower "%lld < %lld", 55fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower dest.size(), GetByteSizeRequirement(on_device_shape)); 56fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower } 57fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, 58fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower executor->platform(), executor->device_ordinal()); 59fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower shaped_buffer.set_buffer(dest, /*index=*/{}); 60fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower return TransferLiteralToDevice(executor, literal, shaped_buffer); 61fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower} 62fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower 63fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlowerStatusOr<std::unique_ptr<Literal>> TransferManager::TransferArrayFromDevice( 64fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower perftools::gputools::StreamExecutor* executor, const Shape& shape, 65fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower const perftools::gputools::DeviceMemoryBase& source) { 66fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower TF_RET_CHECK(ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) 67fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower << "Shape " << ShapeUtil::HumanString(shape) 68fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower << " has a differently shaped representation on-device: " 69fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower << ShapeUtil::HumanString(HostShapeToDeviceShape(shape)); 70fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower if (source.size() < GetByteSizeRequirement(shape)) { 71fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower return FailedPrecondition( 72fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower "Allocation on device not large enough for array: " 73fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower "%lld < %lld", 74fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower source.size(), GetByteSizeRequirement(shape)); 75fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower } 76fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, 77fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower executor->platform(), executor->device_ordinal()); 78fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower shaped_buffer.set_buffer(source, /*index=*/{}); 79fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower return TransferLiteralFromDevice(executor, shaped_buffer); 80fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower} 81fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower 821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ void TransferManager::RegisterTransferManager( 831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins se::Platform::Id platform_id, 841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins TransferManagerCreationFunction creation_function) { 851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins tensorflow::mutex_lock lock( 86b0bcf675a4b5d6217f3b58fd27b344f20e7bf25dSanjoy Das TransferManager::platform_transfer_manager_mutex_); 871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto* managers = GetPlatformTransferManagers(); 881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins CHECK(managers->find(platform_id) == managers->end()); 891e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins (*managers)[platform_id].creation_function = creation_function; 901e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* static */ StatusOr<TransferManager*> TransferManager::GetForPlatform( 931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins const se::Platform* platform) { 941e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins tensorflow::mutex_lock lock( 95b0bcf675a4b5d6217f3b58fd27b344f20e7bf25dSanjoy Das TransferManager::platform_transfer_manager_mutex_); 961e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto* managers = GetPlatformTransferManagers(); 971e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 981e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto it = managers->find(platform->id()); 991e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (it == managers->end()) { 1001e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return NotFound( 1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "could not find registered transfer manager for platform %s -- check " 1021e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "target linkage", 1031e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins platform->Name().c_str()); 1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (it->second.manager == nullptr) { 1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins // Lazily create the transfer manager the first time it is needed 1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins it->second.manager = (*it->second.creation_function)(); 1091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1118cb5e9867482a8e05f756fad35634e1674fe7f16A. Unique TensorFlower return it->second.manager.get(); 1121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 11422d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark HeffernanStatus TransferManager::WriteTupleIndexTables( 11522d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan perftools::gputools::StreamExecutor* executor, 11622d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan const ShapedBuffer& device_buffer) { 117fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower VLOG(2) << "Writing tuple index tables for " << device_buffer; 11822d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan 11922d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); 12022d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan 12122d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan return ShapeUtil::ForEachSubshapeWithStatus( 122fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower device_buffer.on_device_shape(), 12322d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { 12422d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan if (ShapeUtil::IsTuple(device_subshape)) { 12522d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan se::DeviceMemoryBase device_memory = device_buffer.buffer(index); 12622d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == 12722d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan device_memory.size()); 12822d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan 12922d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan std::vector<se::DeviceMemoryBase> elements; 13022d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan ShapeIndex element_index = index; 13122d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan for (int64 i = 0; i < ShapeUtil::TupleElementCount(device_subshape); 13222d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan ++i) { 13322d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan element_index.push_back(i); 13422d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan elements.push_back(device_buffer.buffer(element_index)); 13522d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan element_index.pop_back(); 13622d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan } 137fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower return WriteSingleTupleIndexTable(executor, elements, device_subshape, 13822d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan &device_memory); 13922d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan } 14022d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan 14122d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan return Status::OK(); 14222d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan }); 14322d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan} 14422d948d2739ecaadfb4091302f2050ba9cf0d0c1Mark Heffernan 1451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsStatus TransferManager::TransferBufferFromDevice( 1461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins se::StreamExecutor* executor, const se::DeviceMemoryBase& source, 1471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins int64 size, void* destination) { 1481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (source.size() < size) { 1491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return FailedPrecondition( 1501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "Source allocation on device not large enough for data tranfer: " 1511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "%lld < %lld", 1521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins source.size(), size); 1531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto copy_status = executor->SynchronousMemcpyD2H(source, size, destination); 1551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (!copy_status.ok()) { 1561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return AddStatus( 1571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Status(static_cast<tensorflow::error::Code>(copy_status.code()), 1581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins copy_status.error_message()), 1591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "failed transfer from device to buffer"); 1601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return Status::OK(); 1621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 1641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsStatus TransferManager::TransferBufferToDevice( 1651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins se::StreamExecutor* executor, int64 size, const void* source, 1661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins se::DeviceMemoryBase* destination) { 1671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (destination->size() < size) { 1681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return FailedPrecondition( 1691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "Destination allocation on device not large enough for data tranfer: " 1701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "%lld < %lld", 1711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins destination->size(), size); 1721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins auto copy_status = executor->SynchronousMemcpyH2D(source, size, destination); 1741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins if (!copy_status.ok()) { 1751e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return AddStatus( 1761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins Status(static_cast<tensorflow::error::Code>(copy_status.code()), 1771e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins copy_status.error_message()), 1781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins "failed transfer of buffer to device"); 1791e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 1801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins return Status::OK(); 1811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 1821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 183fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlowerStatusOr<std::unique_ptr<ShapedBuffer>> TransferManager::AllocateShapedBuffer( 184fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower const Shape& on_host_shape, DeviceMemoryAllocator* allocator, 185fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower int device_ordinal) { 186fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower if (!LayoutUtil::HasLayout(on_host_shape)) { 187fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower return InvalidArgument( 188fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower "Shape must have a layout: %s", 189fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower ShapeUtil::HumanStringWithLayout(on_host_shape).c_str()); 190fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower } 191fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape)); 192fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); 193fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape)); 194fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower 195fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower auto shaped_buffer = WrapUnique(new ShapedBuffer( 196fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower on_host_shape, on_device_shape, allocator->platform(), device_ordinal)); 197fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower 198fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower // Allocate an appropriate sized buffer for each element in the shape 199fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower // including the tuple pointer arrays. 200fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower for (auto& pair : shaped_buffer->buffers()) { 201fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower const ShapeIndex& index = pair.first; 202fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower se::DeviceMemoryBase& memory_base = pair.second; 203fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index); 204fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower TF_ASSIGN_OR_RETURN(memory_base, 205fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower allocator->Allocate(shaped_buffer->device_ordinal(), 206fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower GetByteSizeRequirement(subshape))); 2071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins } 208fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower 209fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower return std::move(shaped_buffer); 210fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower} 211fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower 212fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlowerStatusOr<std::unique_ptr<ScopedShapedBuffer>> 213fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlowerTransferManager::AllocateScopedShapedBuffer(const Shape& on_host_shape, 214fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower DeviceMemoryAllocator* allocator, 215fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower int device_ordinal) { 216fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower TF_ASSIGN_OR_RETURN( 217fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower std::unique_ptr<ShapedBuffer> unscoped_buffer, 218fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower AllocateShapedBuffer(on_host_shape, allocator, device_ordinal)); 219fc2526a8c1cf0bc2a93c8cc819ff7209eb4628c9A. Unique TensorFlower return ScopedShapedBuffer::MakeScoped(unscoped_buffer.get(), allocator); 2201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} 2211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins 2221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins} // namespace xla 223