1d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower
3d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
4d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFloweryou may not use this file except in compliance with the License.
5d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerYou may obtain a copy of the License at
6d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower
7d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
8d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower
9d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
10d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
11d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerSee the License for the specific language governing permissions and
13d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerlimitations under the License.
14d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower==============================================================================*/
15d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower
16d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#ifdef GOOGLE_CUDA
17d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#define EIGEN_USE_GPU
18d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#include "tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h"
19d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#endif
20d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower
21d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#include "tensorflow/core/kernels/ops_testutil.h"
22d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower
23d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowernamespace tensorflow {
24d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower
25d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowervoid OpsTestBase::SetDevice(const DeviceType& device_type,
26d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower                            std::unique_ptr<Device> device) {
27d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  CHECK(device_.get()) << "No device provided";
28d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  device_type_ = device_type;
29d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  device_ = std::move(device);
30d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#ifdef GOOGLE_CUDA
31d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  if (device_type == DEVICE_GPU) {
32d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    managed_allocator_.reset(new GpuManagedAllocator());
33d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    allocator_ = managed_allocator_.get();
34d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  } else {
35d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    managed_allocator_.reset();
36d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    allocator_ = device_->GetAllocator(AllocatorAttributes());
37d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  }
38d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#else
39d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  CHECK_NE(device_type, DEVICE_GPU)
40d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower      << "Requesting GPU on binary compiled without GOOGLE_CUDA.";
41d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#endif
42d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower}
43d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower
44d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlowerTensor* OpsTestBase::GetOutput(int output_index) {
45d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  CHECK_LT(output_index, context_->num_outputs());
46d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  Tensor* output = context_->mutable_output(output_index);
47d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#ifdef GOOGLE_CUDA
48d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  if (device_type_ == DEVICE_GPU) {
49d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    managed_outputs_.resize(context_->num_outputs());
50d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    // Copy the output tensor to managed memory if we haven't done so.
51d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    if (!managed_outputs_[output_index]) {
52d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower      Tensor* managed_output =
53d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower          new Tensor(allocator(), output->dtype(), output->shape());
54d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower      auto src = output->tensor_data();
55d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower      auto dst = managed_output->tensor_data();
56d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower      context_->eigen_gpu_device().memcpy(const_cast<char*>(dst.data()),
57d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower                                          src.data(), src.size());
58d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower      context_->eigen_gpu_device().synchronize();
59d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower      managed_outputs_[output_index] = managed_output;
60d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    }
61d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower    output = managed_outputs_[output_index];
62d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  }
63d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower#endif
64d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower  return output;
65d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower}
66d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower
67d244ffb69ceb971d34e330cae45ed944d488e9ecA. Unique TensorFlower}  // namespace tensorflow
68