1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifdef TENSORFLOW_USE_VERBS 17 18#include "grpc++/alarm.h" 19#include "grpc++/grpc++.h" 20#include "grpc++/server_builder.h" 21 22#include "tensorflow/contrib/verbs/grpc_verbs_service.h" 23#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 24#include "tensorflow/core/distributed_runtime/session_mgr.h" 25 26namespace tensorflow { 27 28GrpcVerbsService::GrpcVerbsService(const WorkerEnv* worker_env, 29 ::grpc::ServerBuilder* builder) 30 : is_shutdown_(false), worker_env_(worker_env) { 31 builder->RegisterService(&verbs_service_); 32 cq_ = builder->AddCompletionQueue().release(); 33} 34 35GrpcVerbsService::~GrpcVerbsService() { 36 delete shutdown_alarm_; 37 delete cq_; 38} 39 40void GrpcVerbsService::Shutdown() { 41 bool did_shutdown = false; 42 { 43 mutex_lock l(shutdown_mu_); 44 if (!is_shutdown_) { 45 LOG(INFO) << "Shutting down GrpcWorkerService."; 46 is_shutdown_ = true; 47 did_shutdown = true; 48 } 49 } 50 if (did_shutdown) { 51 shutdown_alarm_ = 52 new ::grpc::Alarm(cq_, gpr_now(GPR_CLOCK_MONOTONIC), nullptr); 53 } 54} 55 56// This macro creates a new request for the given RPC method name 57// (e.g., `ENQUEUE_REQUEST(GetRemoteAddress, false);`), and enqueues it on 58// `this->cq_`. 59// 60// This macro is invoked one or more times for each RPC method to 61// ensure that there are sufficient completion queue entries to 62// handle incoming requests without blocking. 63// 64// The implementation of the request handler for each RPC method 65// must ensure that it calls ENQUEUE_REQUEST() for that RPC method, 66// to keep accepting new requests. 67#define ENQUEUE_REQUEST(method, supports_cancel) \ 68 do { \ 69 mutex_lock l(shutdown_mu_); \ 70 if (!is_shutdown_) { \ 71 Call<GrpcVerbsService, grpc::VerbsService::AsyncService, \ 72 method##Request, method##Response>:: \ 73 EnqueueRequest(&verbs_service_, cq_, \ 74 &grpc::VerbsService::AsyncService::Request##method, \ 75 &GrpcVerbsService::method##Handler, \ 76 (supports_cancel)); \ 77 } \ 78 } while (0) 79 80// This method blocks forever handling requests from the completion queue. 81void GrpcVerbsService::HandleRPCsLoop() { 82 for (int i = 0; i < 10; ++i) { 83 ENQUEUE_REQUEST(GetRemoteAddress, false); 84 } 85 86 void* tag; 87 bool ok; 88 89 while (cq_->Next(&tag, &ok)) { 90 UntypedCall<GrpcVerbsService>::Tag* callback_tag = 91 static_cast<UntypedCall<GrpcVerbsService>::Tag*>(tag); 92 if (callback_tag) { 93 callback_tag->OnCompleted(this, ok); 94 } else { 95 cq_->Shutdown(); 96 } 97 } 98} 99 100void GrpcVerbsService::GetRemoteAddressHandler( 101 WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call) { 102 Status s = GetRemoteAddressSync(&call->request, &call->response); 103 call->SendResponse(ToGrpcStatus(s)); 104 ENQUEUE_REQUEST(GetRemoteAddress, false); 105} 106 107// synchronous method 108Status GrpcVerbsService::GetRemoteAddressSync( 109 const GetRemoteAddressRequest* request, 110 GetRemoteAddressResponse* response) { 111 // analyzing request 112 // the channel setting part is redundant. 113 const string remote_host_name = request->host_name(); 114 RdmaChannel* rc = rdma_mgr_->FindChannel(remote_host_name); 115 CHECK(rc); 116 RdmaAddress ra; 117 ra.lid = request->channel().lid(); 118 ra.qpn = request->channel().qpn(); 119 ra.psn = request->channel().psn(); 120 ra.snp = request->channel().snp(); 121 ra.iid = request->channel().iid(); 122 rc->SetRemoteAddress(ra, false); 123 rc->Connect(); 124 int i = 0; 125 int idx[] = {1, 0}; 126 std::vector<RdmaMessageBuffer*> mb(rc->message_buffers()); 127 CHECK_EQ(request->mr_size(), RdmaChannel::kNumMessageBuffers); 128 for (const auto& mr : request->mr()) { 129 // the connections are crossed, i.e. 130 // local tx_message_buffer <---> remote rx_message_buffer_ 131 // local rx_message_buffer <---> remote tx_message_buffer_ 132 // hence idx[] = {1, 0}. 133 RdmaMessageBuffer* rb = mb[idx[i]]; 134 RemoteMR rmr; 135 rmr.remote_addr = mr.remote_addr(); 136 rmr.rkey = mr.rkey(); 137 rb->SetRemoteMR(rmr, false); 138 i++; 139 } 140 CHECK(i == RdmaChannel::kNumMessageBuffers); 141 142 // setting up response 143 response->set_host_name( 144 worker_env_->session_mgr->LegacySession()->worker_name); 145 Channel* channel_info = response->mutable_channel(); 146 channel_info->set_lid(rc->self().lid); 147 channel_info->set_qpn(rc->self().qpn); 148 channel_info->set_psn(rc->self().psn); 149 channel_info->set_snp(rc->self().snp); 150 channel_info->set_iid(rc->self().iid); 151 for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) { 152 MemoryRegion* mr = response->add_mr(); 153 mr->set_remote_addr(reinterpret_cast<uint64>(mb[i]->buffer())); 154 mr->set_rkey(mb[i]->self()->rkey); 155 } 156 return Status::OK(); 157} 158 159// Create a GrpcVerbsService, then assign it to a given handle. 160void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env, 161 ::grpc::ServerBuilder* builder) { 162 *handle = new GrpcVerbsService(worker_env, builder); 163} 164 165} // namespace tensorflow 166 167#endif // TENSORFLOW_USE_VERBS 168