1d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower 3d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License"); 4d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFloweryou may not use this file except in compliance with the License. 5d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerYou may obtain a copy of the License at 6d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower 7d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower http://www.apache.org/licenses/LICENSE-2.0 8d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower 9d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software 10d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS, 11d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerSee the License for the specific language governing permissions and 13d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerlimitations under the License. 14d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower==============================================================================*/ 15d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower 16d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#ifdef GOOGLE_CUDA 17d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#define EIGEN_USE_GPU 18d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#include "tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h" 19d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#endif 20d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower 21d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#include "tensorflow/core/kernels/ops_testutil.h" 22d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower 23d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowernamespace tensorflow { 24d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower 25d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowervoid OpsTestBase::SetDevice(const DeviceType& device_type, 26d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower std::unique_ptr<Device> device) { 27d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower CHECK(device_.get()) << "No device provided"; 28d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower device_type_ = device_type; 29d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower device_ = std::move(device); 30d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#ifdef GOOGLE_CUDA 31d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower if (device_type == DEVICE_GPU) { 32d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower managed_allocator_.reset(new GpuManagedAllocator()); 33d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower allocator_ = managed_allocator_.get(); 34d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower } else { 35d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower managed_allocator_.reset(); 36d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower allocator_ = device_->GetAllocator(AllocatorAttributes()); 37d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower } 38d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#else 39d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower CHECK_NE(device_type, DEVICE_GPU) 40d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower << "Requesting GPU on binary compiled without GOOGLE_CUDA."; 41d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#endif 42d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower} 43d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower 44d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerTensor* OpsTestBase::GetOutput(int output_index) { 45d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower CHECK_LT(output_index, context_->num_outputs()); 46d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower Tensor* output = context_->mutable_output(output_index); 47d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#ifdef GOOGLE_CUDA 48d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower if (device_type_ == DEVICE_GPU) { 49d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower managed_outputs_.resize(context_->num_outputs()); 50d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower // Copy the output tensor to managed memory if we haven't done so. 51d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower if (!managed_outputs_[output_index]) { 52d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower Tensor* managed_output = 53d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower new Tensor(allocator(), output->dtype(), output->shape()); 54d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower auto src = output->tensor_data(); 55d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower auto dst = managed_output->tensor_data(); 56d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower context_->eigen_gpu_device().memcpy(const_cast<char*>(dst.data()), 57d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower src.data(), src.size()); 58d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower context_->eigen_gpu_device().synchronize(); 59d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower managed_outputs_[output_index] = managed_output; 60d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower } 61d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower output = managed_outputs_[output_index]; 62d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower } 63d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#endif 64d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower return output; 65d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower} 66d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower 67d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower} // namespace tensorflow 68