1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#define EIGEN_USE_THREADS
16
17#include "tensorflow/compiler/xla/service/hlo_runner.h"
18
19#include <set>
20#include <string>
21#include <utility>
22
23#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24#include "tensorflow/compiler/xla/layout_util.h"
25#include "tensorflow/compiler/xla/ptr_util.h"
26#include "tensorflow/compiler/xla/service/backend.h"
27#include "tensorflow/compiler/xla/service/executable.h"
28#include "tensorflow/compiler/xla/service/hlo_computation.h"
29#include "tensorflow/compiler/xla/service/transfer_manager.h"
30#include "tensorflow/compiler/xla/shape_util.h"
31#include "tensorflow/compiler/xla/statusor.h"
32#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
33#include "tensorflow/compiler/xla/types.h"
34#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
35#include "tensorflow/core/platform/logging.h"
36#include "tensorflow/core/platform/types.h"
37
38namespace se = ::perftools::gputools;
39
40namespace xla {
41
42/*static*/ StatusOr<std::unique_ptr<HloModule>>
43HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string,
44                                  const DebugOptions& debug_options) {
45  HloModuleConfig config;
46  config.set_debug_options(debug_options);
47  return tools::Parse(hlo_string, config);
48}
49
50namespace {
51
52// Creates an HloModule from the given proto.
53StatusOr<std::unique_ptr<HloModule>> HloProtoToModule(
54    const HloProto& proto, const DebugOptions& debug_options) {
55  TF_ASSIGN_OR_RETURN(
56      HloModuleConfig config,
57      HloModule::CreateModuleConfigFromProto(proto.hlo_module()));
58  config.set_debug_options(debug_options);
59  TF_ASSIGN_OR_RETURN(auto module,
60                      HloModule::CreateFromProto(proto.hlo_module(), config));
61  return std::move(module);
62}
63
64}  // namespace
65
66/*static*/ StatusOr<std::unique_ptr<HloModule>>
67HloRunner::ReadModuleFromBinaryProtoFile(const std::string& filename,
68                                         const DebugOptions& debug_options) {
69  HloProto proto;
70  TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
71                                                 filename, &proto));
72  return HloProtoToModule(proto, debug_options);
73}
74
75/*static*/ StatusOr<std::unique_ptr<HloModule>>
76HloRunner::ReadModuleFromTextProtoFile(const std::string& filename,
77                                       const DebugOptions& debug_options) {
78  HloProto proto;
79  TF_RETURN_IF_ERROR(
80      tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto));
81  return HloProtoToModule(proto, debug_options);
82}
83
84/*static*/ StatusOr<std::unique_ptr<HloModule>>
85HloRunner::ReadModuleFromHloTextFile(const std::string& filename,
86                                     const DebugOptions& debug_options) {
87  string hlo_string;
88  TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(),
89                                                  filename, &hlo_string));
90  HloModuleConfig config;
91  config.set_debug_options(debug_options);
92  return tools::Parse(hlo_string, config);
93}
94
95// Define this in .cc file to avoid having to include eigen or forward declare
96// these types in the header.
97struct HloRunner::EigenThreadPoolWrapper {
98  std::unique_ptr<EigenThreadPoolWrapper> pool;
99  std::unique_ptr<Eigen::ThreadPoolDevice> device;
100};
101
102HloRunner::HloRunner() {}
103
104HloRunner::HloRunner(se::Platform* platform) {
105  BackendOptions backend_options;
106  backend_options.set_platform(platform);
107  backend_ = Backend::CreateBackend(backend_options).ConsumeValueOrDie();
108  VLOG(1) << "Created HloRunner for platform: " << platform->Name();
109}
110
111HloRunner::~HloRunner() {}
112
113StatusOr<std::unique_ptr<Literal>> HloRunner::ExecuteInternal(
114    std::unique_ptr<HloModule> module,
115    const tensorflow::gtl::ArraySlice<Literal*> arguments,
116    bool run_hlo_passes) {
117  if (run_hlo_passes) {
118    TF_ASSIGN_OR_RETURN(
119        module, backend().compiler()->RunHloPasses(
120                    std::move(module), backend().default_stream_executor(),
121                    /*device_allocator=*/nullptr));
122  }
123  TF_ASSIGN_OR_RETURN(
124      std::unique_ptr<Executable> executable,
125      backend().compiler()->RunBackend(std::move(module),
126                                       backend().default_stream_executor(),
127                                       /*device_allocator=*/nullptr));
128
129  se::Stream stream(backend().default_stream_executor());
130  stream.Init();
131
132  ExecutableRunOptions run_options;
133  run_options.set_device_ordinal(backend().default_device_ordinal());
134  run_options.set_stream(&stream);
135  run_options.set_allocator(backend().memory_allocator());
136  run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool());
137  run_options.set_intra_op_thread_pool(
138      backend().eigen_intra_op_thread_pool_device());
139
140  ServiceExecutableRunOptions service_run_options(
141      run_options, backend().StreamBorrower(),
142      backend().inter_op_thread_pool());
143
144  // Copy arguments to device.
145  std::vector<std::unique_ptr<ScopedShapedBuffer>> argument_buffers;
146  std::vector<ShapedBuffer*> argument_buffer_ptrs;
147  for (Literal* argument : arguments) {
148    TF_ASSIGN_OR_RETURN(
149        std::unique_ptr<ScopedShapedBuffer> argument_buffer,
150        backend().transfer_manager()->AllocateScopedShapedBuffer(
151            argument->shape(), run_options.allocator(),
152            run_options.device_ordinal()));
153    TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice(
154        stream.parent(), *argument, *argument_buffer));
155    argument_buffers.push_back(std::move(argument_buffer));
156    argument_buffer_ptrs.push_back(argument_buffers.back().get());
157  }
158
159  TF_ASSIGN_OR_RETURN(
160      std::unique_ptr<ShapedBuffer> result,
161      executable->ExecuteOnStream(&service_run_options, argument_buffer_ptrs,
162                                  /*hlo_execution_profile=*/nullptr));
163
164  // Create a ScopedShapedBuffer of the result to manage deallocation. This will
165  // deallocate all the device memory when it goes out of scope.
166  TF_ASSIGN_OR_RETURN(
167      std::unique_ptr<ScopedShapedBuffer> scoped_result,
168      ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator()));
169
170  auto result_literal = backend().transfer_manager()->TransferLiteralFromDevice(
171      stream.parent(), *scoped_result);
172  if (result_literal.ok()) {
173    VLOG(4) << "Executed binary and got result: "
174            << result_literal.ValueOrDie()->ToString();
175  } else {
176    VLOG(4) << "Executed binary and got status: "
177            << result_literal.status().ToString();
178  }
179  return result_literal;
180}
181
182Backend& HloRunner::backend() {
183  if (!backend_) {
184    backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();
185    VLOG(1) << "executing on platform " << backend().platform()->Name();
186  }
187  return *backend_;
188}
189
190}  // namespace xla
191