1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/contrib/gdr/gdr_worker.h"
17
18#include "tensorflow/core/common_runtime/device.h"
19#include "tensorflow/core/common_runtime/device_mgr.h"
20#include "tensorflow/core/common_runtime/dma_helper.h"
21#if GOOGLE_CUDA
22#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
23#endif  // GOOGLE_CUDA
24#include "tensorflow/core/common_runtime/process_util.h"
25#include "tensorflow/core/common_runtime/step_stats_collector.h"
26#include "tensorflow/core/distributed_runtime/graph_mgr.h"
27#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
28#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
29#include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
30#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
31#include "tensorflow/core/distributed_runtime/worker.h"
32#include "tensorflow/core/distributed_runtime/worker_cache.h"
33#include "tensorflow/core/distributed_runtime/worker_session.h"
34#include "tensorflow/core/framework/cancellation.h"
35#include "tensorflow/core/framework/tensor.h"
36#include "tensorflow/core/lib/core/errors.h"
37#include "tensorflow/core/platform/logging.h"
38#include "tensorflow/core/platform/tracing.h"
39
40namespace tensorflow {
41
42GdrWorker::GdrWorker(WorkerEnv* worker_env,
43                     RemoteMemoryManager* remote_memory_manager)
44    : GrpcWorker(worker_env),
45      remote_memory_manager_(remote_memory_manager),
46      recv_tensor_recent_request_ids_(100000) {}
47
48void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts,
49                                    const RecvTensorRequest* request,
50                                    ::grpc::ByteBuffer* response,
51                                    StatusCallback done) {
52  Status s = recv_tensor_recent_request_ids_.TrackUnique(
53      request->request_id(), "RecvTensor (GdrWorker)", *request);
54  if (!s.ok()) {
55    done(s);
56    return;
57  }
58
59  const int64 step_id = request->step_id();
60  const string& key = request->rendezvous_key();
61  TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
62  Rendezvous::ParsedKey parsed;
63  s = Rendezvous::ParseKey(key, &parsed);
64  Device* src_dev = nullptr;
65  if (s.ok()) {
66    s = PrepareRecvTensor(parsed, &src_dev);
67  }
68  if (!s.ok()) {
69    done(s);
70    return;
71  }
72
73  // Request the tensor associated with the rendezvous key. Any time
74  // while waiting for the tensor to be produced, up until the start
75  // of execution of the callback lambda body below, an RPC
76  // cancellation should abort the rendezvous.
77  opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
78  const bool dma_ok = request->dma_ok();
79  env_->rendezvous_mgr->RecvLocalAsync(
80      step_id, parsed,
81      [this, opts, response, done, src_dev, dma_ok](
82          const Status& status, const Rendezvous::Args& send_args,
83          const Rendezvous::Args&, const Tensor& val, const bool is_dead) {
84        opts->ClearCancelCallback();
85        if (status.ok()) {
86          // DMA can only be used for Tensors that do not fall into
87          // the following three odd edge cases: 1) a zero-size
88          // buffer, 2) a dead tensor which has an uninit value, and
89          // 3) the tensor has the on_host allocation attribute,
90          // i.e. it's in CPU RAM *independent of its assigned
91          // device type*.
92          const bool on_host =
93              (src_dev->tensorflow_gpu_device_info() == nullptr) ||
94              send_args.alloc_attrs.on_host();
95          if (val.TotalBytes() > 0 && (!is_dead) &&
96              DMAHelper::CanUseDMA(&val) && dma_ok) {
97            // DMA cases.
98            RecvTensorResponse* proto = new RecvTensorResponse;
99            proto->set_is_dead(is_dead);
100            proto->set_send_start_micros(Env::Default()->NowMicros());
101            TensorProto* tensor_proto = proto->mutable_tensor();
102            tensor_proto->set_dtype(val.dtype());
103            val.shape().AsProto(tensor_proto->mutable_tensor_shape());
104            auto transport_options = proto->mutable_transport_options();
105            remote_memory_manager_->TransportOptionsFromTensor(
106                transport_options, val, src_dev, send_args.device_context,
107                on_host, [proto, done, response](const Status& s) {
108                  if (s.ok()) {
109                    grpc::EncodeRecvTensorResponseToByteBuffer(*proto,
110                                                               response);
111                    done(Status::OK());
112                  } else {
113                    done(s);
114                  }
115                  delete proto;
116                });
117          } else {
118            // Non-DMA cases.
119            if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
120#if GOOGLE_CUDA
121              const DeviceContext* send_dev_context = send_args.device_context;
122              AllocatorAttributes alloc_attrs;
123              alloc_attrs.set_gpu_compatible(true);
124              alloc_attrs.set_on_host(true);
125              Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
126              Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
127              CHECK(send_dev_context)
128                  << "send dev name: " << src_dev->name()
129                  << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
130              // "val" is on a GPU. Uses GPUUtil to fill the response proto.
131              StatusCallback copy_ready = [response, done, copy,
132                                           is_dead](const Status& s) {
133                // The value is now ready to be returned on the wire.
134                grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
135                done(s);
136                delete copy;
137              };
138
139              GPUUtil::CopyGPUTensorToCPU(src_dev, send_dev_context, &val, copy,
140                                          copy_ready);
141#else
142              done(errors::Internal("No GPU device in process"));
143#endif  // GOOGLE_CUDA
144            } else {
145              grpc::EncodeTensorToByteBuffer(is_dead, val, response);
146              done(Status::OK());
147            }
148          }
149        } else {
150          //  !s.ok()
151          done(status);
152        }
153      });
154}
155
156}  // namespace tensorflow
157