1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifdef TENSORFLOW_USE_VERBS
17
18#include "tensorflow/contrib/verbs/rdma_mgr.h"
19#include <fstream>
20#include <vector>
21#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
22#include "tensorflow/contrib/verbs/verbs_service.pb.h"
23#include "tensorflow/core/common_runtime/bfc_allocator.h"
24#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
25#include "tensorflow/core/common_runtime/gpu/process_state.h"
26#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
27#include "tensorflow/core/distributed_runtime/session_mgr.h"
28#include "tensorflow/core/framework/allocator_registry.h"
29#include "tensorflow/core/lib/core/status.h"
30
31namespace tensorflow {
32
33RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
34                 GrpcChannelCache* const channel_cache)
35    : worker_env_(worker_env), channel_cache_(channel_cache) {
36  rdma_adapter_ = new RdmaAdapter(worker_env_);
37  // hardcoded to default session (legacy_session_)
38  // TODO: use WorkerSessionForSession
39  // need to pass in session handle
40  local_worker_ = worker_env_->session_mgr->LegacySession()->worker_name;
41  std::vector<string> workers;
42  worker_env_->session_mgr->LegacySession()->worker_cache->ListWorkers(
43      &workers);
44  num_remote_workers_ = workers.size() - 1;
45  VLOG(2) << "rmda_mgr on local worker: " << local_worker_;
46  for (size_t i = 0; i < workers.size(); i++) {
47    if (local_worker_.compare(workers[i]) != 0) {
48      channel_table_.insert(
49          {workers[i],
50           new RdmaChannel(rdma_adapter_, local_worker_, workers[i])});
51    }
52  }
53}
54
55// Setup Rdma channels between peers.
56// This is done at the beginning of the server setup.
57
58void RdmaMgr::SetupChannels() {
59  for (const auto& p : channel_table_) {
60    string worker_name = p.first;
61    RDMA_LOG(2) << "Connecting to remote node " << worker_name;
62    RdmaChannel* rc = p.second;
63    GetRemoteAddressRequest req;
64    GetRemoteAddressResponse resp;
65    // get the channel cache
66    SharedGrpcChannelPtr client_channel =
67        channel_cache_->FindWorkerChannel(worker_name);
68    GrpcVerbsClient* client = new GrpcVerbsClient(client_channel);
69    CHECK(client != nullptr) << "No worker known as " << worker_name;
70
71    // setting up request
72    req.set_host_name(local_worker_);
73    Channel* channel_info = req.mutable_channel();
74    channel_info->set_lid(rc->self_.lid);
75    channel_info->set_qpn(rc->self_.qpn);
76    channel_info->set_psn(rc->self_.psn);
77    channel_info->set_snp(rc->self_.snp);
78    channel_info->set_iid(rc->self_.iid);
79    for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
80      MemoryRegion* mr = req.add_mr();
81      mr->set_remote_addr(
82          reinterpret_cast<uint64_t>(rc->message_buffers_[i]->buffer_));
83      mr->set_rkey(rc->message_buffers_[i]->self_->rkey);
84    }
85    // synchronous call
86    Status s;
87    int attempts = 0;
88    static const int max_num_attempts = 5;
89    do {
90      s = client->GetRemoteAddress(&req, &resp);
91      // save obtained remote addresses
92      // connect to the remote channel
93      if (s.ok()) {
94        CHECK(worker_name.compare(resp.host_name()) == 0);
95        RdmaAddress ra;
96        ra.lid = resp.channel().lid();
97        ra.qpn = resp.channel().qpn();
98        ra.psn = resp.channel().psn();
99        ra.snp = resp.channel().snp();
100        ra.iid = resp.channel().iid();
101        rc->SetRemoteAddress(ra, false);
102        rc->Connect();
103        int i = 0;
104        int idx[] = {1, 0};
105        for (const auto& mr : resp.mr()) {
106          // the connections are crossed, i.e.
107          // local tx_message_buffer <---> remote rx_message_buffer_
108          // local rx_message_buffer <---> remote tx_message_buffer_
109          // hence idx[] = {1, 0}.
110          RdmaMessageBuffer* rb = rc->message_buffers_[idx[i]];
111          RemoteMR rmr;
112          rmr.remote_addr = mr.remote_addr();
113          rmr.rkey = mr.rkey();
114          rb->SetRemoteMR(rmr, false);
115          i++;
116        }
117        CHECK(i == RdmaChannel::kNumMessageBuffers);
118      } else {
119        LOG(ERROR) << "Connecting to " << worker_name << ": Got "
120                   << s.error_message() << ". Retrying (" << (attempts + 1)
121                   << "/" << max_num_attempts << ")...";
122        if (++attempts == max_num_attempts) {
123          break;
124        }
125        worker_env_->env->SleepForMicroseconds(2000000);
126      }
127    } while (!s.ok());
128    RDMA_LOG(0) << "Connected to remote node " << worker_name;
129    delete client;
130  }
131}
132
133// Check connectivity by pinging every channel
134bool RdmaMgr::ConnectivityCheck() {
135  int i, rcnt = 0, scnt = 0;
136
137  for (const auto& p : channel_table_) {
138    string worker_name = p.first;
139    RdmaChannel* rc = p.second;
140
141    VLOG(2) << "Ping to " << worker_name;
142    CHECK(rc->PingPostSend() == 0) << "Couldn't post send  to " << worker_name
143                                   << " with error: " << std::strerror(errno);
144    for (i = 0; i < rc->adapter_->params_.queue_depth - 1; i++) {
145      rc->Recv();
146    }
147  }
148
149  while (rcnt < num_remote_workers_ || scnt < num_remote_workers_) {
150    int ne;
151    do {
152      ne = ibv_poll_cq(rdma_adapter_->cq_, 2 * num_remote_workers_,
153                       rdma_adapter_->wc_);
154      CHECK(ne >= 0) << "poll CQ failed " << ne << "with error"
155                     << std::strerror(errno);
156    } while (ne < 1);
157
158    for (i = 0; i < ne; ++i) {
159      ibv_wc_status s = rdma_adapter_->wc_[i].status;
160      // recv complete
161      if ((int)rdma_adapter_->wc_[i].wr_id == RdmaChannel::kPingRecvWrid) {
162        CHECK(s == IBV_WC_SUCCESS)
163            << ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "("
164            << rdma_adapter_->wc_[i].status << ") for PING_RECV_WRID";
165        ++rcnt;
166        // send complete
167      } else {
168        RdmaChannel* rc =
169            reinterpret_cast<RdmaChannel*>(rdma_adapter_->wc_[i].wr_id);
170        CHECK(s == IBV_WC_SUCCESS)
171            << ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "("
172            << rdma_adapter_->wc_[i].status << ") to " << rc->remote_name_;
173        ++scnt;
174      }
175    }  // for
176  }    // while
177  CHECK(rcnt == scnt) << "Connectivity check failed!";
178  rdma_adapter_->StartPolling();
179  return (num_remote_workers_ == rcnt) && (num_remote_workers_ == scnt);
180}
181
182RdmaMgr::~RdmaMgr() {
183  for (const auto& p : channel_table_) delete p.second;
184  channel_table_.clear();
185  delete rdma_adapter_;
186}
187
188// Find a channel via the given name.
189// Args:
190//   name: peer name, e.g. worker1
191// Returns
192//   channel object that is connected to the named peer.
193RdmaChannel* RdmaMgr::FindChannel(const string& name) {
194  ChannelTable::iterator iter = channel_table_.find(name);
195  CHECK(iter != channel_table_.end());
196  return iter->second;
197}
198
199bool IsGDRAvailable() {
200#if defined(__APPLE__)
201  return false;
202#elif defined(PLATFORM_WINDOWS)
203  return false;
204#else
205  std::ifstream ifs("/proc/modules");
206  string line;
207  while (std::getline(ifs, line)) {
208    auto sep = line.find(' ');
209    CHECK_NE(sep, std::string::npos);
210    if (line.substr(0, sep) == "nv_peer_mem") {
211      return true;
212    }
213  }
214  return false;
215#endif
216}
217
218int TryToReadNumaNode(ibv_device* device) {
219#if defined(__APPLE__)
220  LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0";
221  return 0;
222#elif defined(PLATFORM_WINDOWS)
223  // Windows support for NUMA is not currently implemented. Return node 0.
224  return 0;
225#else
226  VLOG(2) << "Trying to read NUMA node for device: " << device->name;
227  static const int kUnknownNumaNode = -1;
228
229  auto filename = string(device->ibdev_path) + "/device/numa_node";
230
231  std::ifstream ifs(filename.c_str());
232  string content;
233  CHECK(std::getline(ifs, content));
234
235  int32 value;
236  if (strings::safe_strto32(content, &value)) {
237    if (value < 0) {
238      LOG(INFO) << "Successful NUMA node read from SysFS had negative value ("
239                << value
240                << "), but there must be at least one NUMA node"
241                   ", so returning NUMA node zero";
242      return 0;
243    }
244    LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
245    return value;
246  }
247  return kUnknownNumaNode;
248#endif
249}
250
251void MRDeleter(ibv_mr* mr) {
252  if (mr) {
253    ibv_dereg_mr(mr);
254  }
255}
256
257// TODO(byronyi): remove this class duplicated from the one in
258// common/runtime/gpu/pool_allocator.h when it is available in common_runtime
259class BasicCPUAllocator : public SubAllocator {
260 public:
261  ~BasicCPUAllocator() override {}
262
263  void* Alloc(size_t alignment, size_t num_bytes) override {
264    return port::AlignedMalloc(num_bytes, alignment);
265  }
266  void Free(void* ptr, size_t) override { port::AlignedFree(ptr); }
267};
268
269// TODO(byronyi): remove this class and its registration when the default
270// cpu_allocator() returns visitable allocator
271class BFCRdmaAllocator : public BFCAllocator {
272 public:
273  BFCRdmaAllocator()
274      : BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") {
275  }
276};
277
278REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator);
279
280void RdmaMgr::InitAllocators() {
281  RdmaMemoryMgr::Singleton().pd_ = rdma_adapter_->pd_;
282
283  Allocator* allocators[] = {
284#if GOOGLE_CUDA
285    ProcessState::singleton()->GetCUDAHostAllocator(0),
286    ProcessState::singleton()->GetCPUAllocator(0),
287#endif  // GOOGLE_CUDA
288    cpu_allocator(),
289  };
290
291  using namespace std::placeholders;
292
293  std::set<Allocator*> instrumented_;
294
295  // Host memory allocators
296  for (Allocator* allocator : allocators) {
297    VisitableAllocator::Visitor alloc_visitor =
298        std::bind(&RdmaMemoryMgr::InsertMemoryRegion,
299                  &RdmaMemoryMgr::Singleton(), _1, _2, allocator->Name());
300    VisitableAllocator::Visitor free_visitor = std::bind(
301        &RdmaMemoryMgr::EvictMemoryRegion, &RdmaMemoryMgr::Singleton(), _1, _2);
302
303    auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
304    CHECK(visitable_allocator)
305        << "is not visitable for instrumentation" << allocator->Name();
306    // Make sure we don't instrument the same allocator twice
307    if (instrumented_.find(allocator) == std::end(instrumented_)) {
308      visitable_allocator->AddAllocVisitor(alloc_visitor);
309      visitable_allocator->AddFreeVisitor(free_visitor);
310      instrumented_.insert(allocator);
311      LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
312    }
313  }
314
315#if GOOGLE_CUDA
316  if (IsGDRAvailable()) {
317    // Note we don't free allocated GPU memory so there is no free visitor
318    int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1;
319
320    char buf[8];
321    sprintf(buf, "gpu");
322    VisitableAllocator::Visitor cuda_alloc_visitor =
323        std::bind(&RdmaMemoryMgr::InsertMemoryRegion,
324                  &RdmaMemoryMgr::Singleton(), _1, _2, std::string(buf));
325
326    ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor);
327    LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
328  }
329#endif  // GOOGLE_CUDA
330}
331
332}  // end namespace tensorflow
333
334#endif
335