1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifdef TENSORFLOW_USE_VERBS
17
18#include "grpc++/alarm.h"
19#include "grpc++/grpc++.h"
20#include "grpc++/server_builder.h"
21
22#include "tensorflow/contrib/verbs/grpc_verbs_service.h"
23#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
24#include "tensorflow/core/distributed_runtime/session_mgr.h"
25
26namespace tensorflow {
27
28GrpcVerbsService::GrpcVerbsService(const WorkerEnv* worker_env,
29                                   ::grpc::ServerBuilder* builder)
30    : is_shutdown_(false), worker_env_(worker_env) {
31  builder->RegisterService(&verbs_service_);
32  cq_ = builder->AddCompletionQueue().release();
33}
34
35GrpcVerbsService::~GrpcVerbsService() {
36  delete shutdown_alarm_;
37  delete cq_;
38}
39
40void GrpcVerbsService::Shutdown() {
41  bool did_shutdown = false;
42  {
43    mutex_lock l(shutdown_mu_);
44    if (!is_shutdown_) {
45      LOG(INFO) << "Shutting down GrpcWorkerService.";
46      is_shutdown_ = true;
47      did_shutdown = true;
48    }
49  }
50  if (did_shutdown) {
51    shutdown_alarm_ =
52        new ::grpc::Alarm(cq_, gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
53  }
54}
55
56// This macro creates a new request for the given RPC method name
57// (e.g., `ENQUEUE_REQUEST(GetRemoteAddress, false);`), and enqueues it on
58// `this->cq_`.
59//
60// This macro is invoked one or more times for each RPC method to
61// ensure that there are sufficient completion queue entries to
62// handle incoming requests without blocking.
63//
64// The implementation of the request handler for each RPC method
65// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
66// to keep accepting new requests.
67#define ENQUEUE_REQUEST(method, supports_cancel)                             \
68  do {                                                                       \
69    mutex_lock l(shutdown_mu_);                                              \
70    if (!is_shutdown_) {                                                     \
71      Call<GrpcVerbsService, grpc::VerbsService::AsyncService,               \
72           method##Request, method##Response>::                              \
73          EnqueueRequest(&verbs_service_, cq_,                               \
74                         &grpc::VerbsService::AsyncService::Request##method, \
75                         &GrpcVerbsService::method##Handler,                 \
76                         (supports_cancel));                                 \
77    }                                                                        \
78  } while (0)
79
80// This method blocks forever handling requests from the completion queue.
81void GrpcVerbsService::HandleRPCsLoop() {
82  for (int i = 0; i < 10; ++i) {
83    ENQUEUE_REQUEST(GetRemoteAddress, false);
84  }
85
86  void* tag;
87  bool ok;
88
89  while (cq_->Next(&tag, &ok)) {
90    UntypedCall<GrpcVerbsService>::Tag* callback_tag =
91        static_cast<UntypedCall<GrpcVerbsService>::Tag*>(tag);
92    if (callback_tag) {
93      callback_tag->OnCompleted(this, ok);
94    } else {
95      cq_->Shutdown();
96    }
97  }
98}
99
100void GrpcVerbsService::GetRemoteAddressHandler(
101    WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call) {
102  Status s = GetRemoteAddressSync(&call->request, &call->response);
103  call->SendResponse(ToGrpcStatus(s));
104  ENQUEUE_REQUEST(GetRemoteAddress, false);
105}
106
107// synchronous method
108Status GrpcVerbsService::GetRemoteAddressSync(
109    const GetRemoteAddressRequest* request,
110    GetRemoteAddressResponse* response) {
111  // analyzing request
112  // the channel setting part is redundant.
113  const string remote_host_name = request->host_name();
114  RdmaChannel* rc = rdma_mgr_->FindChannel(remote_host_name);
115  CHECK(rc);
116  RdmaAddress ra;
117  ra.lid = request->channel().lid();
118  ra.qpn = request->channel().qpn();
119  ra.psn = request->channel().psn();
120  ra.snp = request->channel().snp();
121  ra.iid = request->channel().iid();
122  rc->SetRemoteAddress(ra, false);
123  rc->Connect();
124  int i = 0;
125  int idx[] = {1, 0};
126  std::vector<RdmaMessageBuffer*> mb(rc->message_buffers());
127  CHECK_EQ(request->mr_size(), RdmaChannel::kNumMessageBuffers);
128  for (const auto& mr : request->mr()) {
129    // the connections are crossed, i.e.
130    // local tx_message_buffer <---> remote rx_message_buffer_
131    // local rx_message_buffer <---> remote tx_message_buffer_
132    // hence idx[] = {1, 0}.
133    RdmaMessageBuffer* rb = mb[idx[i]];
134    RemoteMR rmr;
135    rmr.remote_addr = mr.remote_addr();
136    rmr.rkey = mr.rkey();
137    rb->SetRemoteMR(rmr, false);
138    i++;
139  }
140  CHECK(i == RdmaChannel::kNumMessageBuffers);
141
142  // setting up response
143  response->set_host_name(
144      worker_env_->session_mgr->LegacySession()->worker_name);
145  Channel* channel_info = response->mutable_channel();
146  channel_info->set_lid(rc->self().lid);
147  channel_info->set_qpn(rc->self().qpn);
148  channel_info->set_psn(rc->self().psn);
149  channel_info->set_snp(rc->self().snp);
150  channel_info->set_iid(rc->self().iid);
151  for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
152    MemoryRegion* mr = response->add_mr();
153    mr->set_remote_addr(reinterpret_cast<uint64>(mb[i]->buffer()));
154    mr->set_rkey(mb[i]->self()->rkey);
155  }
156  return Status::OK();
157}
158
159// Create a GrpcVerbsService, then assign it to a given handle.
160void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
161                        ::grpc::ServerBuilder* builder) {
162  *handle = new GrpcVerbsService(worker_env, builder);
163}
164
165}  // namespace tensorflow
166
167#endif  // TENSORFLOW_USE_VERBS
168