1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/contrib/gdr/gdr_worker.h" 17 18#include "tensorflow/core/common_runtime/device.h" 19#include "tensorflow/core/common_runtime/device_mgr.h" 20#include "tensorflow/core/common_runtime/dma_helper.h" 21#if GOOGLE_CUDA 22#include "tensorflow/core/common_runtime/gpu/gpu_util.h" 23#endif // GOOGLE_CUDA 24#include "tensorflow/core/common_runtime/process_util.h" 25#include "tensorflow/core/common_runtime/step_stats_collector.h" 26#include "tensorflow/core/distributed_runtime/graph_mgr.h" 27#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" 28#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h" 29#include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h" 30#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 31#include "tensorflow/core/distributed_runtime/worker.h" 32#include "tensorflow/core/distributed_runtime/worker_cache.h" 33#include "tensorflow/core/distributed_runtime/worker_session.h" 34#include "tensorflow/core/framework/cancellation.h" 35#include "tensorflow/core/framework/tensor.h" 36#include "tensorflow/core/lib/core/errors.h" 37#include "tensorflow/core/platform/logging.h" 38#include "tensorflow/core/platform/tracing.h" 39 40namespace tensorflow { 41 42GdrWorker::GdrWorker(WorkerEnv* worker_env, 43 RemoteMemoryManager* remote_memory_manager) 44 : GrpcWorker(worker_env), 45 remote_memory_manager_(remote_memory_manager), 46 recv_tensor_recent_request_ids_(100000) {} 47 48void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, 49 const RecvTensorRequest* request, 50 ::grpc::ByteBuffer* response, 51 StatusCallback done) { 52 Status s = recv_tensor_recent_request_ids_.TrackUnique( 53 request->request_id(), "RecvTensor (GdrWorker)", *request); 54 if (!s.ok()) { 55 done(s); 56 return; 57 } 58 59 const int64 step_id = request->step_id(); 60 const string& key = request->rendezvous_key(); 61 TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str()); 62 Rendezvous::ParsedKey parsed; 63 s = Rendezvous::ParseKey(key, &parsed); 64 Device* src_dev = nullptr; 65 if (s.ok()) { 66 s = PrepareRecvTensor(parsed, &src_dev); 67 } 68 if (!s.ok()) { 69 done(s); 70 return; 71 } 72 73 // Request the tensor associated with the rendezvous key. Any time 74 // while waiting for the tensor to be produced, up until the start 75 // of execution of the callback lambda body below, an RPC 76 // cancellation should abort the rendezvous. 77 opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); }); 78 const bool dma_ok = request->dma_ok(); 79 env_->rendezvous_mgr->RecvLocalAsync( 80 step_id, parsed, 81 [this, opts, response, done, src_dev, dma_ok]( 82 const Status& status, const Rendezvous::Args& send_args, 83 const Rendezvous::Args&, const Tensor& val, const bool is_dead) { 84 opts->ClearCancelCallback(); 85 if (status.ok()) { 86 // DMA can only be used for Tensors that do not fall into 87 // the following three odd edge cases: 1) a zero-size 88 // buffer, 2) a dead tensor which has an uninit value, and 89 // 3) the tensor has the on_host allocation attribute, 90 // i.e. it's in CPU RAM *independent of its assigned 91 // device type*. 92 const bool on_host = 93 (src_dev->tensorflow_gpu_device_info() == nullptr) || 94 send_args.alloc_attrs.on_host(); 95 if (val.TotalBytes() > 0 && (!is_dead) && 96 DMAHelper::CanUseDMA(&val) && dma_ok) { 97 // DMA cases. 98 RecvTensorResponse* proto = new RecvTensorResponse; 99 proto->set_is_dead(is_dead); 100 proto->set_send_start_micros(Env::Default()->NowMicros()); 101 TensorProto* tensor_proto = proto->mutable_tensor(); 102 tensor_proto->set_dtype(val.dtype()); 103 val.shape().AsProto(tensor_proto->mutable_tensor_shape()); 104 auto transport_options = proto->mutable_transport_options(); 105 remote_memory_manager_->TransportOptionsFromTensor( 106 transport_options, val, src_dev, send_args.device_context, 107 on_host, [proto, done, response](const Status& s) { 108 if (s.ok()) { 109 grpc::EncodeRecvTensorResponseToByteBuffer(*proto, 110 response); 111 done(Status::OK()); 112 } else { 113 done(s); 114 } 115 delete proto; 116 }); 117 } else { 118 // Non-DMA cases. 119 if (src_dev->tensorflow_gpu_device_info() && (!on_host)) { 120#if GOOGLE_CUDA 121 const DeviceContext* send_dev_context = send_args.device_context; 122 AllocatorAttributes alloc_attrs; 123 alloc_attrs.set_gpu_compatible(true); 124 alloc_attrs.set_on_host(true); 125 Allocator* alloc = src_dev->GetAllocator(alloc_attrs); 126 Tensor* copy = new Tensor(alloc, val.dtype(), val.shape()); 127 CHECK(send_dev_context) 128 << "send dev name: " << src_dev->name() 129 << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); 130 // "val" is on a GPU. Uses GPUUtil to fill the response proto. 131 StatusCallback copy_ready = [response, done, copy, 132 is_dead](const Status& s) { 133 // The value is now ready to be returned on the wire. 134 grpc::EncodeTensorToByteBuffer(is_dead, *copy, response); 135 done(s); 136 delete copy; 137 }; 138 139 GPUUtil::CopyGPUTensorToCPU(src_dev, send_dev_context, &val, copy, 140 copy_ready); 141#else 142 done(errors::Internal("No GPU device in process")); 143#endif // GOOGLE_CUDA 144 } else { 145 grpc::EncodeTensorToByteBuffer(is_dead, val, response); 146 done(Status::OK()); 147 } 148 } 149 } else { 150 // !s.ok() 151 done(status); 152 } 153 }); 154} 155 156} // namespace tensorflow 157