1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifdef TENSORFLOW_USE_VERBS 17 18#include "tensorflow/contrib/verbs/rdma_mgr.h" 19#include <fstream> 20#include <vector> 21#include "tensorflow/contrib/verbs/grpc_verbs_client.h" 22#include "tensorflow/contrib/verbs/verbs_service.pb.h" 23#include "tensorflow/core/common_runtime/bfc_allocator.h" 24#include "tensorflow/core/common_runtime/gpu/gpu_util.h" 25#include "tensorflow/core/common_runtime/gpu/process_state.h" 26#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h" 27#include "tensorflow/core/distributed_runtime/session_mgr.h" 28#include "tensorflow/core/framework/allocator_registry.h" 29#include "tensorflow/core/lib/core/status.h" 30 31namespace tensorflow { 32 33RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env, 34 GrpcChannelCache* const channel_cache) 35 : worker_env_(worker_env), channel_cache_(channel_cache) { 36 rdma_adapter_ = new RdmaAdapter(worker_env_); 37 // hardcoded to default session (legacy_session_) 38 // TODO: use WorkerSessionForSession 39 // need to pass in session handle 40 local_worker_ = worker_env_->session_mgr->LegacySession()->worker_name; 41 std::vector<string> workers; 42 worker_env_->session_mgr->LegacySession()->worker_cache->ListWorkers( 43 &workers); 44 num_remote_workers_ = workers.size() - 1; 45 VLOG(2) << "rmda_mgr on local worker: " << local_worker_; 46 for (size_t i = 0; i < workers.size(); i++) { 47 if (local_worker_.compare(workers[i]) != 0) { 48 channel_table_.insert( 49 {workers[i], 50 new RdmaChannel(rdma_adapter_, local_worker_, workers[i])}); 51 } 52 } 53} 54 55// Setup Rdma channels between peers. 56// This is done at the beginning of the server setup. 57 58void RdmaMgr::SetupChannels() { 59 for (const auto& p : channel_table_) { 60 string worker_name = p.first; 61 RDMA_LOG(2) << "Connecting to remote node " << worker_name; 62 RdmaChannel* rc = p.second; 63 GetRemoteAddressRequest req; 64 GetRemoteAddressResponse resp; 65 // get the channel cache 66 SharedGrpcChannelPtr client_channel = 67 channel_cache_->FindWorkerChannel(worker_name); 68 GrpcVerbsClient* client = new GrpcVerbsClient(client_channel); 69 CHECK(client != nullptr) << "No worker known as " << worker_name; 70 71 // setting up request 72 req.set_host_name(local_worker_); 73 Channel* channel_info = req.mutable_channel(); 74 channel_info->set_lid(rc->self_.lid); 75 channel_info->set_qpn(rc->self_.qpn); 76 channel_info->set_psn(rc->self_.psn); 77 channel_info->set_snp(rc->self_.snp); 78 channel_info->set_iid(rc->self_.iid); 79 for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) { 80 MemoryRegion* mr = req.add_mr(); 81 mr->set_remote_addr( 82 reinterpret_cast<uint64_t>(rc->message_buffers_[i]->buffer_)); 83 mr->set_rkey(rc->message_buffers_[i]->self_->rkey); 84 } 85 // synchronous call 86 Status s; 87 int attempts = 0; 88 static const int max_num_attempts = 5; 89 do { 90 s = client->GetRemoteAddress(&req, &resp); 91 // save obtained remote addresses 92 // connect to the remote channel 93 if (s.ok()) { 94 CHECK(worker_name.compare(resp.host_name()) == 0); 95 RdmaAddress ra; 96 ra.lid = resp.channel().lid(); 97 ra.qpn = resp.channel().qpn(); 98 ra.psn = resp.channel().psn(); 99 ra.snp = resp.channel().snp(); 100 ra.iid = resp.channel().iid(); 101 rc->SetRemoteAddress(ra, false); 102 rc->Connect(); 103 int i = 0; 104 int idx[] = {1, 0}; 105 for (const auto& mr : resp.mr()) { 106 // the connections are crossed, i.e. 107 // local tx_message_buffer <---> remote rx_message_buffer_ 108 // local rx_message_buffer <---> remote tx_message_buffer_ 109 // hence idx[] = {1, 0}. 110 RdmaMessageBuffer* rb = rc->message_buffers_[idx[i]]; 111 RemoteMR rmr; 112 rmr.remote_addr = mr.remote_addr(); 113 rmr.rkey = mr.rkey(); 114 rb->SetRemoteMR(rmr, false); 115 i++; 116 } 117 CHECK(i == RdmaChannel::kNumMessageBuffers); 118 } else { 119 LOG(ERROR) << "Connecting to " << worker_name << ": Got " 120 << s.error_message() << ". Retrying (" << (attempts + 1) 121 << "/" << max_num_attempts << ")..."; 122 if (++attempts == max_num_attempts) { 123 break; 124 } 125 worker_env_->env->SleepForMicroseconds(2000000); 126 } 127 } while (!s.ok()); 128 RDMA_LOG(0) << "Connected to remote node " << worker_name; 129 delete client; 130 } 131} 132 133// Check connectivity by pinging every channel 134bool RdmaMgr::ConnectivityCheck() { 135 int i, rcnt = 0, scnt = 0; 136 137 for (const auto& p : channel_table_) { 138 string worker_name = p.first; 139 RdmaChannel* rc = p.second; 140 141 VLOG(2) << "Ping to " << worker_name; 142 CHECK(rc->PingPostSend() == 0) << "Couldn't post send to " << worker_name 143 << " with error: " << std::strerror(errno); 144 for (i = 0; i < rc->adapter_->params_.queue_depth - 1; i++) { 145 rc->Recv(); 146 } 147 } 148 149 while (rcnt < num_remote_workers_ || scnt < num_remote_workers_) { 150 int ne; 151 do { 152 ne = ibv_poll_cq(rdma_adapter_->cq_, 2 * num_remote_workers_, 153 rdma_adapter_->wc_); 154 CHECK(ne >= 0) << "poll CQ failed " << ne << "with error" 155 << std::strerror(errno); 156 } while (ne < 1); 157 158 for (i = 0; i < ne; ++i) { 159 ibv_wc_status s = rdma_adapter_->wc_[i].status; 160 // recv complete 161 if ((int)rdma_adapter_->wc_[i].wr_id == RdmaChannel::kPingRecvWrid) { 162 CHECK(s == IBV_WC_SUCCESS) 163 << ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "(" 164 << rdma_adapter_->wc_[i].status << ") for PING_RECV_WRID"; 165 ++rcnt; 166 // send complete 167 } else { 168 RdmaChannel* rc = 169 reinterpret_cast<RdmaChannel*>(rdma_adapter_->wc_[i].wr_id); 170 CHECK(s == IBV_WC_SUCCESS) 171 << ": " << ibv_wc_status_str(rdma_adapter_->wc_[i].status) << "(" 172 << rdma_adapter_->wc_[i].status << ") to " << rc->remote_name_; 173 ++scnt; 174 } 175 } // for 176 } // while 177 CHECK(rcnt == scnt) << "Connectivity check failed!"; 178 rdma_adapter_->StartPolling(); 179 return (num_remote_workers_ == rcnt) && (num_remote_workers_ == scnt); 180} 181 182RdmaMgr::~RdmaMgr() { 183 for (const auto& p : channel_table_) delete p.second; 184 channel_table_.clear(); 185 delete rdma_adapter_; 186} 187 188// Find a channel via the given name. 189// Args: 190// name: peer name, e.g. worker1 191// Returns 192// channel object that is connected to the named peer. 193RdmaChannel* RdmaMgr::FindChannel(const string& name) { 194 ChannelTable::iterator iter = channel_table_.find(name); 195 CHECK(iter != channel_table_.end()); 196 return iter->second; 197} 198 199bool IsGDRAvailable() { 200#if defined(__APPLE__) 201 return false; 202#elif defined(PLATFORM_WINDOWS) 203 return false; 204#else 205 std::ifstream ifs("/proc/modules"); 206 string line; 207 while (std::getline(ifs, line)) { 208 auto sep = line.find(' '); 209 CHECK_NE(sep, std::string::npos); 210 if (line.substr(0, sep) == "nv_peer_mem") { 211 return true; 212 } 213 } 214 return false; 215#endif 216} 217 218int TryToReadNumaNode(ibv_device* device) { 219#if defined(__APPLE__) 220 LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0"; 221 return 0; 222#elif defined(PLATFORM_WINDOWS) 223 // Windows support for NUMA is not currently implemented. Return node 0. 224 return 0; 225#else 226 VLOG(2) << "Trying to read NUMA node for device: " << device->name; 227 static const int kUnknownNumaNode = -1; 228 229 auto filename = string(device->ibdev_path) + "/device/numa_node"; 230 231 std::ifstream ifs(filename.c_str()); 232 string content; 233 CHECK(std::getline(ifs, content)); 234 235 int32 value; 236 if (strings::safe_strto32(content, &value)) { 237 if (value < 0) { 238 LOG(INFO) << "Successful NUMA node read from SysFS had negative value (" 239 << value 240 << "), but there must be at least one NUMA node" 241 ", so returning NUMA node zero"; 242 return 0; 243 } 244 LOG(INFO) << "NUMA node for device: " << device->name << " is " << value; 245 return value; 246 } 247 return kUnknownNumaNode; 248#endif 249} 250 251void MRDeleter(ibv_mr* mr) { 252 if (mr) { 253 ibv_dereg_mr(mr); 254 } 255} 256 257// TODO(byronyi): remove this class duplicated from the one in 258// common/runtime/gpu/pool_allocator.h when it is available in common_runtime 259class BasicCPUAllocator : public SubAllocator { 260 public: 261 ~BasicCPUAllocator() override {} 262 263 void* Alloc(size_t alignment, size_t num_bytes) override { 264 return port::AlignedMalloc(num_bytes, alignment); 265 } 266 void Free(void* ptr, size_t) override { port::AlignedFree(ptr); } 267}; 268 269// TODO(byronyi): remove this class and its registration when the default 270// cpu_allocator() returns visitable allocator 271class BFCRdmaAllocator : public BFCAllocator { 272 public: 273 BFCRdmaAllocator() 274 : BFCAllocator(new BasicCPUAllocator(), 1LL << 36, true, "cpu_rdma_bfc") { 275 } 276}; 277 278REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocator); 279 280void RdmaMgr::InitAllocators() { 281 RdmaMemoryMgr::Singleton().pd_ = rdma_adapter_->pd_; 282 283 Allocator* allocators[] = { 284#if GOOGLE_CUDA 285 ProcessState::singleton()->GetCUDAHostAllocator(0), 286 ProcessState::singleton()->GetCPUAllocator(0), 287#endif // GOOGLE_CUDA 288 cpu_allocator(), 289 }; 290 291 using namespace std::placeholders; 292 293 std::set<Allocator*> instrumented_; 294 295 // Host memory allocators 296 for (Allocator* allocator : allocators) { 297 VisitableAllocator::Visitor alloc_visitor = 298 std::bind(&RdmaMemoryMgr::InsertMemoryRegion, 299 &RdmaMemoryMgr::Singleton(), _1, _2, allocator->Name()); 300 VisitableAllocator::Visitor free_visitor = std::bind( 301 &RdmaMemoryMgr::EvictMemoryRegion, &RdmaMemoryMgr::Singleton(), _1, _2); 302 303 auto* visitable_allocator = dynamic_cast<VisitableAllocator*>(allocator); 304 CHECK(visitable_allocator) 305 << "is not visitable for instrumentation" << allocator->Name(); 306 // Make sure we don't instrument the same allocator twice 307 if (instrumented_.find(allocator) == std::end(instrumented_)) { 308 visitable_allocator->AddAllocVisitor(alloc_visitor); 309 visitable_allocator->AddFreeVisitor(free_visitor); 310 instrumented_.insert(allocator); 311 LOG(INFO) << "Instrumenting CPU allocator " << allocator->Name(); 312 } 313 } 314 315#if GOOGLE_CUDA 316 if (IsGDRAvailable()) { 317 // Note we don't free allocated GPU memory so there is no free visitor 318 int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1; 319 320 char buf[8]; 321 sprintf(buf, "gpu"); 322 VisitableAllocator::Visitor cuda_alloc_visitor = 323 std::bind(&RdmaMemoryMgr::InsertMemoryRegion, 324 &RdmaMemoryMgr::Singleton(), _1, _2, std::string(buf)); 325 326 ProcessState::singleton()->AddGPUAllocVisitor(bus_id, cuda_alloc_visitor); 327 LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id; 328 } 329#endif // GOOGLE_CUDA 330} 331 332} // end namespace tensorflow 333 334#endif 335