1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifdef TENSORFLOW_USE_GDR
17
18#include "tensorflow/contrib/gdr/gdr_memory_manager.h"
19
20#include <atomic>
21#include <cerrno>
22#include <fstream>
23#include <list>
24#include <map>
25#include <set>
26
27#include <fcntl.h>
28#include <rdma/rdma_cma.h>
29#include <rdma/rdma_verbs.h>
30#include <sys/epoll.h>
31
32#include "tensorflow/contrib/gdr/gdr.pb.h"
33#include "tensorflow/core/common_runtime/bfc_allocator.h"
34#include "tensorflow/core/common_runtime/device.h"
35#include "tensorflow/core/common_runtime/dma_helper.h"
36#if GOOGLE_CUDA
37#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
38#include "tensorflow/core/common_runtime/gpu/process_state.h"
39#endif  // GOOGLE_CUDA
40#include "tensorflow/core/framework/allocator_registry.h"
41#include "tensorflow/core/lib/core/status.h"
42#include "tensorflow/core/platform/macros.h"
43#include "tensorflow/core/platform/mutex.h"
44
45namespace tensorflow {
46
47namespace {
48
49bool IsGDRAvailable() {
50#if defined(__APPLE__)
51  return false;
52#elif defined(PLATFORM_WINDOWS)
53  return false;
54#else
55  std::ifstream ifs("/proc/modules");
56  string line;
57  while (std::getline(ifs, line)) {
58    auto sep = line.find(' ');
59    CHECK_NE(sep, std::string::npos);
60    if (line.substr(0, sep) == "nv_peer_mem") {
61      return true;
62    }
63  }
64  return false;
65#endif
66}
67
68int TryToReadNumaNode(ibv_device* device) {
69#if defined(__APPLE__)
70  LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0";
71  return 0;
72#elif defined(PLATFORM_WINDOWS)
73  // Windows support for NUMA is not currently implemented. Return node 0.
74  return 0;
75#else
76  VLOG(2) << "Trying to read NUMA node for device: " << device->name;
77  static const int kUnknownNumaNode = -1;
78
79  auto filename = string(device->ibdev_path) + "/device/numa_node";
80
81  std::ifstream ifs(filename.c_str());
82  string content;
83  CHECK(std::getline(ifs, content));
84
85  int32 value;
86  if (strings::safe_strto32(content, &value)) {
87    if (value < 0) {
88      LOG(INFO) << "Successful NUMA node read from SysFS had negative value ("
89                << value
90                << "), but there must be at least one NUMA node"
91                   ", so returning NUMA node zero";
92      return 0;
93    }
94    LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
95    return value;
96  }
97  return kUnknownNumaNode;
98#endif
99}
100
101void EndpointDeleter(rdma_cm_id* id) {
102  if (id) {
103    rdma_destroy_ep(id);
104  }
105}
106
107void MRDeleter(ibv_mr* mr) {
108  if (mr) {
109    rdma_dereg_mr(mr);
110  }
111}
112
113using RdmaEndpointPtr = std::unique_ptr<rdma_cm_id, decltype(&EndpointDeleter)>;
114
115using MemoryRegionPtr = std::unique_ptr<ibv_mr, decltype(&MRDeleter)>;
116
117class GdrMemoryManager : public RemoteMemoryManager {
118 public:
119  GdrMemoryManager(const string& host, const string& port);
120
121  virtual ~GdrMemoryManager();
122
123  virtual Status Init() override;
124
125  virtual void Run() override;
126
127  virtual void Stop() override;
128
129  virtual void TransportOptionsFromTensor(
130      ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
131      Device* device, DeviceContext* device_context, bool on_host,
132      StatusCallback done) override;
133
134  virtual void TensorFromTransportOptions(
135      Tensor* tensor, const ::google::protobuf::Any& transport_options,
136      Device* device, DeviceContext* device_context, bool on_host,
137      StatusCallback done) override;
138
139 protected:
140  Status CreateEndpoint(const string& host, const string& port,
141                        RdmaEndpointPtr& endpoint);
142
143  static bool Comparator(const void* ptr, const MemoryRegionPtr& other) {
144    return ptr < reinterpret_cast<char*>(other->addr) + other->length;
145  }
146
147  ibv_mr* FindMemoryRegion(void* addr, size_t length);
148
149  void InsertMemoryRegion(void* addr, size_t length);
150
151  void EvictMemoryRegion(void* addr, size_t length);
152
153 private:
154  const string host_;
155  const string port_;
156  RdmaEndpointPtr listening_;
157  std::atomic<bool> stopped_;
158  int epfd_;
159
160  // Server side endpoints
161  // Accessed sequentially in Run() so not protected by lock
162  std::list<RdmaEndpointPtr> server_clients_;
163
164  using TensorKey = uint32_t;
165  std::atomic<TensorKey> next_key_;
166
167  // Server side on-the-fly tensor buffers
168  mutex server_mu_;
169  std::map<TensorKey, const TensorBuffer*> tensor_buffers_
170      GUARDED_BY(server_mu_);
171
172  // Client side endpoints
173  mutex client_mu_;
174  std::map<std::pair<string, string>, RdmaEndpointPtr> clients_
175      GUARDED_BY(cient_mu_);
176
177  // Managed memory regions
178  mutex alloc_mu_;
179  std::vector<MemoryRegionPtr> mrs_ GUARDED_BY(alloc_mu_);
180
181  TF_DISALLOW_COPY_AND_ASSIGN(GdrMemoryManager);
182};
183
184// TODO(byronyi): remove this class duplicated from the one in
185// common/runtime/gpu/pool_allocator.h when it is available in common_runtime
186class BasicCPUAllocator : public SubAllocator {
187 public:
188  ~BasicCPUAllocator() override {}
189
190  void* Alloc(size_t alignment, size_t num_bytes) override {
191    return port::AlignedMalloc(num_bytes, alignment);
192  }
193  void Free(void* ptr, size_t) override { port::AlignedFree(ptr); }
194};
195
196// TODO(byronyi): remove this class and its registration when the default
197// cpu_allocator() returns visitable allocator
198class BFCRdmaAllocator : public BFCAllocator {
199 public:
200  BFCRdmaAllocator()
201      : BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") {
202  }
203};
204
205REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator);
206
207GdrMemoryManager::GdrMemoryManager(const string& host, const string& port)
208    : host_(host),
209      port_(port),
210      listening_(nullptr, EndpointDeleter),
211      stopped_(true),
212      next_key_(0) {}
213
214GdrMemoryManager::~GdrMemoryManager() { close(epfd_); }
215
216Status GdrMemoryManager::Init() {
217  epfd_ = epoll_create1(0);
218  if (epfd_ == -1) {
219    return errors::Unavailable(strerror(errno), ": ", "epoll_create");
220  }
221
222  rdma_addrinfo* addrinfo;
223  rdma_addrinfo hints = {};
224  hints.ai_port_space = RDMA_PS_TCP;
225  hints.ai_flags = RAI_PASSIVE;
226  if (rdma_getaddrinfo(const_cast<char*>(host_.c_str()),
227                       const_cast<char*>(port_.c_str()), &hints, &addrinfo)) {
228    return errors::Unavailable(strerror(errno), ": ", "cannot resolve rdma://",
229                               host_, ":", port_);
230  }
231
232  ibv_qp_init_attr init_attr = {};
233  init_attr.qp_type = IBV_QPT_RC;
234  init_attr.cap.max_recv_wr = 32;
235  init_attr.cap.max_send_wr = 1;
236  init_attr.cap.max_recv_sge = 1;
237  init_attr.cap.max_send_sge = 1;
238
239  // Create listening endpoint
240  rdma_cm_id* id;
241  if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) {
242    return errors::Unavailable(strerror(errno), ": ", "cannot bind to rdma://",
243                               host_, ":", port_);
244  }
245  listening_.reset(id);
246  rdma_freeaddrinfo(addrinfo);
247
248  // Listen without backlog
249  if (rdma_listen(listening_.get(), 0)) {
250    return errors::Unavailable(strerror(errno), ": ",
251                               "cannot listen on rdma://", host_, ":", port_);
252  }
253  LOG(INFO) << "RDMA server is listening on " << host_ << ":" << port_;
254
255  if (listening_->verbs == nullptr) {
256    return errors::Unimplemented(
257        "Unsupported address ", host_, ":", port_,
258        " as it does not bind to a particular RDMA device");
259  }
260
261  int flags = fcntl(listening_->channel->fd, F_GETFL, 0);
262  if (fcntl(listening_->channel->fd, F_SETFL, flags | O_NONBLOCK)) {
263    return errors::Unavailable(strerror(errno), ": ",
264                               "cannot set server to non-blocking mode");
265  }
266
267  epoll_event event = {};
268  event.events = EPOLLIN | EPOLLPRI;
269  event.data.ptr = listening_.get();
270  if (epoll_ctl(epfd_, EPOLL_CTL_ADD, listening_->channel->fd, &event)) {
271    return errors::Unavailable(strerror(errno), ": ",
272                               "cannot add server to epoll");
273  }
274
275  Allocator* allocators[] = {
276#if GOOGLE_CUDA
277    ProcessState::singleton()->GetCUDAHostAllocator(0),
278    ProcessState::singleton()->GetCPUAllocator(0),
279#endif  // GOOGLE_CUDA
280    cpu_allocator(),
281  };
282
283  using namespace std::placeholders;
284  VisitableAllocator::Visitor alloc_visitor =
285      std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
286  VisitableAllocator::Visitor free_visitor =
287      std::bind(&GdrMemoryManager::EvictMemoryRegion, this, _1, _2);
288
289  std::set<Allocator*> instrumented_;
290
291  // Host memory allocators
292  for (Allocator* allocator : allocators) {
293    auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator);
294    CHECK(visitable_allocator)
295        << "is not visitable for instrumentation" << allocator->Name();
296    // Make sure we don't instrument the same allocator twice
297    if (instrumented_.find(allocator) == std::end(instrumented_)) {
298      visitable_allocator->AddAllocVisitor(alloc_visitor);
299      visitable_allocator->AddFreeVisitor(free_visitor);
300      instrumented_.insert(allocator);
301      LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name();
302    }
303  }
304
305#if GOOGLE_CUDA
306  VisitableAllocator::Visitor cuda_alloc_visitor =
307      std::bind(&GdrMemoryManager::InsertMemoryRegion, this, _1, _2);
308  if (IsGDRAvailable()) {
309    // Note we don't free allocated GPU memory so there is no free visitor
310    int32_t bus_id = TryToReadNumaNode(listening_->verbs->device) + 1;
311    ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor);
312    LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
313  }
314#endif  // GOOGLE_CUDA
315
316  return Status::OK();
317}
318
319void GdrMemoryManager::Run() {
320  stopped_ = false;
321  while (!stopped_) {
322    epoll_event events[32];
323    int ret = epoll_wait(epfd_, events, 32, 1);
324    if (ret == -1) {
325      LOG(ERROR) << "epoll_wait: " << strerror(errno);
326      return;
327    }
328    for (int i = 0; i < ret; i++) {
329      rdma_cm_id* id = static_cast<rdma_cm_id*>(events[i].data.ptr);
330      if (id == listening_.get()) {
331        // Accept incoming connections
332        if (!rdma_get_request(listening_.get(), &id)) {
333          if (!rdma_accept(id, nullptr)) {
334            LOG(INFO) << "Accepted new RDMA connection";
335            if (ibv_req_notify_cq(id->recv_cq, 0)) {
336              LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed";
337              EndpointDeleter(id);
338              continue;
339            }
340            for (int i = 0; i < 32; i++) {
341              if (rdma_post_recvv(id, nullptr, nullptr, 0)) {
342                LOG(ERROR) << strerror(errno) << ": rdma_post_recvv failed";
343                EndpointDeleter(id);
344                continue;
345              }
346            }
347            int flags = fcntl(id->recv_cq_channel->fd, F_GETFL, 0);
348            if (fcntl(id->recv_cq_channel->fd, F_SETFL, flags | O_NONBLOCK)) {
349              LOG(ERROR) << strerror(errno)
350                         << ": cannot set server_client to non-blocking mode";
351              EndpointDeleter(id);
352              continue;
353            }
354            epoll_event event = {};
355            event.events = EPOLLIN | EPOLLPRI;
356            event.data.ptr = id;
357            if (epoll_ctl(epfd_, EPOLL_CTL_ADD, id->recv_cq_channel->fd,
358                          &event)) {
359              LOG(ERROR) << strerror(errno)
360                         << ": cannot add server client to epoll";
361              EndpointDeleter(id);
362              continue;
363            }
364            server_clients_.push_back({id, EndpointDeleter});
365          }
366        }
367      } else {
368        // Polling work completions
369        ibv_cq* cq;
370        void* context;
371        if (!ibv_get_cq_event(id->recv_cq_channel, &cq, &context)) {
372          ibv_ack_cq_events(id->recv_cq, 1);
373          if (ibv_req_notify_cq(id->recv_cq, 0)) {
374            LOG(ERROR) << strerror(errno) << ": ibv_req_notify_cq failed";
375            continue;
376          }
377          ibv_wc wc[32];
378          int ret = ibv_poll_cq(id->recv_cq, 32, wc);
379          if (ret < 0) {
380            LOG(ERROR) << "ibv_poll_cq failed";
381            continue;
382          }
383          for (int i = 0; i < ret; i++) {
384            if (wc[i].opcode != IBV_WC_RECV_RDMA_WITH_IMM) {
385              LOG(ERROR) << "Received unknown operation " << wc[i].opcode;
386            }
387            if (wc[i].status != 0) {
388              LOG(ERROR) << ibv_wc_status_str(wc[i].status);
389            }
390            TensorKey tensor_key = ntohl(wc[i].imm_data);
391            {
392              mutex_lock l(server_mu_);
393              auto iter = tensor_buffers_.find(tensor_key);
394              if (iter == std::end(tensor_buffers_)) {
395                LOG(ERROR) << "Cannot find tensor buffer for tensor key "
396                           << tensor_key;
397              } else {
398                const TensorBuffer* buffer = iter->second;
399                buffer->Unref();
400                tensor_buffers_.erase(iter);
401              }
402            }
403            if (rdma_post_recvv(id, nullptr, nullptr, 0)) {
404              perror("rdma_post_recvv");
405              LOG(ERROR) << "rdma_post_recvv failed";
406              continue;
407            }
408          }
409        }
410      }
411    }
412  }
413}
414
415void GdrMemoryManager::Stop() { stopped_ = true; }
416
417void GdrMemoryManager::TransportOptionsFromTensor(
418    ::google::protobuf::Any* mutable_transport_options, const Tensor& tensor,
419    Device* device, DeviceContext* device_context, bool on_host,
420    StatusCallback done) {
421  auto buffer = DMAHelper::buffer(&tensor);
422  void* addr = buffer->data();
423  size_t length = buffer->size();
424  if (length == 0) {
425    done(errors::Unavailable("Cannot register tensor buffer of size 0"));
426    return;
427  }
428
429  ibv_mr* mr = FindMemoryRegion(addr, length);
430
431#if GOOGLE_CUDA
432  if (!on_host) {
433    Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
434    Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape());
435    GPUUtil::CopyGPUTensorToCPU(
436        device, device_context, &tensor, host_copy,
437        [done, host_copy, mutable_transport_options, this](const Status& s) {
438          if (!s.ok()) {
439            done(s);
440            delete host_copy;
441            return;
442          }
443          auto buffer = DMAHelper::buffer(host_copy);
444          void* addr = buffer->data();
445          size_t length = buffer->size();
446          ibv_mr* mr = FindMemoryRegion(addr, length);
447
448          if (mr == nullptr) {
449            done(errors::Unavailable("Cannot find pinned memory region"));
450            delete host_copy;
451            return;
452          }
453
454          buffer->Ref();
455          TensorKey tensor_key = next_key_++;
456          {
457            mutex_lock l(server_mu_);
458            tensor_buffers_.insert(std::make_pair(tensor_key, buffer));
459          }
460
461          uint64_t checksum = 0;
462          if (VLOG_IS_ON(2)) {
463            checksum = GPUUtil::Checksum(*host_copy);
464          }
465
466          RemoteMemoryRegion remote_mr;
467          remote_mr.set_host(host_);
468          remote_mr.set_port(port_);
469          remote_mr.set_addr(reinterpret_cast<uint64_t>(addr));
470          remote_mr.set_rkey(mr->rkey);
471          remote_mr.set_tensor_key(tensor_key);
472          remote_mr.set_checksum(checksum);
473          mutable_transport_options->PackFrom(remote_mr);
474
475          done(Status::OK());
476          delete host_copy;
477        });
478    return;
479  }
480#endif
481
482  if (mr == nullptr) {
483    done(errors::Unavailable("Cannot find pinned memory region"));
484    return;
485  }
486
487  buffer->Ref();
488  TensorKey tensor_key = next_key_++;
489  {
490    mutex_lock l(server_mu_);
491    tensor_buffers_.insert(std::make_pair(tensor_key, buffer));
492  }
493
494  uint64_t checksum = 0;
495  if (VLOG_IS_ON(2)) {
496#ifdef GOOGLE_CUDA
497    if (!on_host) {
498      checksum = GPUUtil::Checksum(device, device_context, tensor);
499    } else {
500      checksum = GPUUtil::Checksum(tensor);
501    }
502#endif
503  }
504
505  RemoteMemoryRegion remote_mr;
506  remote_mr.set_host(host_);
507  remote_mr.set_port(port_);
508  remote_mr.set_addr(reinterpret_cast<uint64_t>(addr));
509  remote_mr.set_rkey(mr->rkey);
510  remote_mr.set_tensor_key(tensor_key);
511  remote_mr.set_checksum(checksum);
512  mutable_transport_options->PackFrom(remote_mr);
513
514  done(Status::OK());
515}
516
517void GdrMemoryManager::TensorFromTransportOptions(
518    Tensor* tensor, const ::google::protobuf::Any& transport_options,
519    Device* device, DeviceContext* device_context, bool on_host,
520    StatusCallback done) {
521  RemoteMemoryRegion remote_mr;
522  if (!transport_options.UnpackTo(&remote_mr)) {
523    done(errors::NotFound("No RDMA transport options found"));
524    return;
525  }
526
527  auto buffer = DMAHelper::buffer(tensor);
528  void* addr = buffer->data();
529  size_t length = buffer->size();
530  ibv_mr* mr = FindMemoryRegion(addr, length);
531
532  Tensor host_copy;
533#if GOOGLE_CUDA
534  if (mr == nullptr && !on_host) {
535    Allocator* alloc = ProcessState::singleton()->GetCUDAHostAllocator(0);
536    host_copy = Tensor(alloc, tensor->dtype(), tensor->shape());
537    buffer = DMAHelper::buffer(&host_copy);
538    addr = buffer->data();
539    length = buffer->size();
540    mr = FindMemoryRegion(addr, length);
541  }
542#endif  // GOOGLE_CUDA
543
544  if (mr == nullptr) {
545    done(errors::Unavailable("Cannot find pinned memory region"));
546    return;
547  }
548
549  decltype(clients_)::iterator iter;
550  bool success;
551  {
552    mutex_lock l(client_mu_);
553    std::tie(iter, success) = clients_.insert(
554        std::make_pair(std::make_pair(remote_mr.host(), remote_mr.port()),
555                       RdmaEndpointPtr(nullptr, EndpointDeleter)));
556    if (success || iter->second.get() == nullptr) {
557      Status s =
558          CreateEndpoint(remote_mr.host(), remote_mr.port(), iter->second);
559      if (!s.ok()) {
560        done(s);
561        return;
562      }
563    }
564  }
565  rdma_cm_id* id = iter->second.get();
566
567  uint64_t start = Env::Default()->NowMicros();
568
569  if (rdma_post_read(id, nullptr, buffer->data(), buffer->size(), mr, 0,
570                     remote_mr.addr(), remote_mr.rkey())) {
571    done(errors::Unavailable(strerror(errno), ": ", "rdma_post_read failed"));
572    return;
573  }
574
575  ibv_send_wr wr = {};
576  wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
577  wr.imm_data = htonl(remote_mr.tensor_key());
578  wr.send_flags = IBV_SEND_SIGNALED;
579  ibv_send_wr* bad_wr;
580  if (ibv_post_send(id->qp, &wr, &bad_wr)) {
581    done(errors::Unavailable(strerror(errno), ": ", "ibv_post_send failed"));
582    return;
583  }
584
585  ibv_wc wc = {};
586  int ret;
587  while ((ret = ibv_poll_cq(id->send_cq, 1, &wc)) == 0)
588    ;
589  if (ret < 0 || wc.status) {
590    done(errors::Unavailable(ibv_wc_status_str(wc.status)));
591    return;
592  }
593
594#if GOOGLE_CUDA
595  if (host_copy.NumElements() > 0) {
596    uint64_t checksum = 0;
597    if (VLOG_IS_ON(2)) {
598      checksum = GPUUtil::Checksum(host_copy);
599      CHECK(checksum == remote_mr.checksum())
600          << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum();
601    }
602    Tensor* ref = new Tensor;
603    std::swap(host_copy, *ref);
604    GPUUtil::CopyCPUTensorToGPU(
605        ref, device_context, device, tensor,
606        [ref, done, buffer, remote_mr, start](const Status& s) {
607          if (!s.ok()) {
608            done(s);
609            delete ref;
610            return;
611          }
612          uint64_t end = Env::Default()->NowMicros();
613
614          VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey()
615                  << " of size " << buffer->size() << " with tensor key "
616                  << remote_mr.tensor_key() << " took " << (end - start)
617                  << " micros";
618          done(Status::OK());
619          delete ref;
620        });
621    return;
622  }
623#endif  // GOOGLE_CUDA
624
625  uint64_t end = Env::Default()->NowMicros();
626
627  VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey()
628          << " of size " << buffer->size() << " with tensor key "
629          << remote_mr.tensor_key() << " took " << (end - start) << " micros";
630
631  uint64_t checksum = 0;
632  if (VLOG_IS_ON(2)) {
633#ifdef GOOGLE_CUDA
634    if (device->tensorflow_gpu_device_info() && (!on_host)) {
635      checksum = GPUUtil::Checksum(device, device_context, *tensor);
636    } else {
637      checksum = GPUUtil::Checksum(*tensor);
638    }
639    CHECK(checksum == remote_mr.checksum())
640        << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum();
641#endif
642  }
643  done(Status::OK());
644}
645
646Status GdrMemoryManager::CreateEndpoint(const string& host, const string& port,
647                                        RdmaEndpointPtr& endpoint) {
648  rdma_addrinfo* addrinfo;
649  rdma_addrinfo hints = {};
650  hints.ai_port_space = RDMA_PS_TCP;
651  if (rdma_getaddrinfo(const_cast<char*>(host.c_str()),
652                       const_cast<char*>(port.c_str()), &hints, &addrinfo)) {
653    return errors::InvalidArgument(
654        strerror(errno), ": ", "cannot connect to rdma://", host, ":", port);
655  }
656
657  ibv_qp_init_attr init_attr = {};
658  init_attr.qp_type = IBV_QPT_RC;
659  init_attr.cap.max_recv_wr = 1;
660  init_attr.cap.max_send_wr = 32;
661  init_attr.cap.max_recv_sge = 1;
662  init_attr.cap.max_send_sge = 1;
663
664  rdma_cm_id* id;
665  if (rdma_create_ep(&id, addrinfo, nullptr, &init_attr)) {
666    rdma_freeaddrinfo(addrinfo);
667    return errors::Unavailable(strerror(errno), ": ",
668                               "cannot create endpoint to rdma://", host, ":",
669                               port);
670  }
671  rdma_freeaddrinfo(addrinfo);
672
673  if (rdma_connect(id, nullptr)) {
674    rdma_destroy_ep(id);
675    return errors::Unavailable(strerror(errno), ": ",
676                               "cannot connect to rdma://", host, ":", port);
677  }
678
679  LOG(INFO) << "RDMA endpoint connected to rdma://" << host << ":" << port;
680  endpoint = RdmaEndpointPtr(id, EndpointDeleter);
681  return Status::OK();
682}
683
684ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) {
685  if (length == 0) return nullptr;
686  mutex_lock l(alloc_mu_);
687  auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
688  if (iter == std::end(mrs_) || iter->get()->addr > addr) {
689    return nullptr;
690  } else {
691    return iter->get();
692  }
693}
694
695void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length) {
696  if (length == 0) return;
697  ibv_mr* mr = rdma_reg_read(listening_.get(), addr, length);
698  if (mr != nullptr) {
699    mutex_lock l(alloc_mu_);
700    auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
701    mrs_.insert(iter, {mr, &MRDeleter});
702  } else {
703    LOG(WARNING) << "Cannot register memory region";
704  }
705}
706
707void GdrMemoryManager::EvictMemoryRegion(void* addr, size_t length) {
708  if (length == 0) return;
709  mutex_lock l(alloc_mu_);
710  auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
711  if (iter != std::end(mrs_) && iter->get()->addr == addr) {
712    mrs_.erase(iter);
713  } else {
714    LOG(WARNING) << "Failed to de-register memory region";
715  }
716}
717
718}  // namespace
719
720RemoteMemoryManager* CreateRemoteMemoryManager(const string& host,
721                                               const string& port) {
722  return new GdrMemoryManager(host, port);
723}
724
725}  // namespace tensorflow
726
727#endif  // TENSORFLOW_USE_GDR
728