1/* 2 * Copyright 2004 The WebRTC Project Authors. All rights reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11#include "webrtc/base/natsocketfactory.h" 12 13#include "webrtc/base/arraysize.h" 14#include "webrtc/base/logging.h" 15#include "webrtc/base/natserver.h" 16#include "webrtc/base/virtualsocketserver.h" 17 18namespace rtc { 19 20// Packs the given socketaddress into the buffer in buf, in the quasi-STUN 21// format that the natserver uses. 22// Returns 0 if an invalid address is passed. 23size_t PackAddressForNAT(char* buf, size_t buf_size, 24 const SocketAddress& remote_addr) { 25 const IPAddress& ip = remote_addr.ipaddr(); 26 int family = ip.family(); 27 buf[0] = 0; 28 buf[1] = family; 29 // Writes the port. 30 *(reinterpret_cast<uint16_t*>(&buf[2])) = HostToNetwork16(remote_addr.port()); 31 if (family == AF_INET) { 32 ASSERT(buf_size >= kNATEncodedIPv4AddressSize); 33 in_addr v4addr = ip.ipv4_address(); 34 memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4); 35 return kNATEncodedIPv4AddressSize; 36 } else if (family == AF_INET6) { 37 ASSERT(buf_size >= kNATEncodedIPv6AddressSize); 38 in6_addr v6addr = ip.ipv6_address(); 39 memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4); 40 return kNATEncodedIPv6AddressSize; 41 } 42 return 0U; 43} 44 45// Decodes the remote address from a packet that has been encoded with the nat's 46// quasi-STUN format. Returns the length of the address (i.e., the offset into 47// data where the original packet starts). 48size_t UnpackAddressFromNAT(const char* buf, size_t buf_size, 49 SocketAddress* remote_addr) { 50 ASSERT(buf_size >= 8); 51 ASSERT(buf[0] == 0); 52 int family = buf[1]; 53 uint16_t port = 54 NetworkToHost16(*(reinterpret_cast<const uint16_t*>(&buf[2]))); 55 if (family == AF_INET) { 56 const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]); 57 *remote_addr = SocketAddress(IPAddress(*v4addr), port); 58 return kNATEncodedIPv4AddressSize; 59 } else if (family == AF_INET6) { 60 ASSERT(buf_size >= 20); 61 const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]); 62 *remote_addr = SocketAddress(IPAddress(*v6addr), port); 63 return kNATEncodedIPv6AddressSize; 64 } 65 return 0U; 66} 67 68 69// NATSocket 70class NATSocket : public AsyncSocket, public sigslot::has_slots<> { 71 public: 72 explicit NATSocket(NATInternalSocketFactory* sf, int family, int type) 73 : sf_(sf), family_(family), type_(type), connected_(false), 74 socket_(NULL), buf_(NULL), size_(0) { 75 } 76 77 ~NATSocket() override { 78 delete socket_; 79 delete[] buf_; 80 } 81 82 SocketAddress GetLocalAddress() const override { 83 return (socket_) ? socket_->GetLocalAddress() : SocketAddress(); 84 } 85 86 SocketAddress GetRemoteAddress() const override { 87 return remote_addr_; // will be NIL if not connected 88 } 89 90 int Bind(const SocketAddress& addr) override { 91 if (socket_) { // already bound, bubble up error 92 return -1; 93 } 94 95 int result; 96 socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_); 97 result = (socket_) ? socket_->Bind(addr) : -1; 98 if (result >= 0) { 99 socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent); 100 socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent); 101 socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent); 102 socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent); 103 } else { 104 server_addr_.Clear(); 105 delete socket_; 106 socket_ = NULL; 107 } 108 109 return result; 110 } 111 112 int Connect(const SocketAddress& addr) override { 113 if (!socket_) { // socket must be bound, for now 114 return -1; 115 } 116 117 int result = 0; 118 if (type_ == SOCK_STREAM) { 119 result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_); 120 } else { 121 connected_ = true; 122 } 123 124 if (result >= 0) { 125 remote_addr_ = addr; 126 } 127 128 return result; 129 } 130 131 int Send(const void* data, size_t size) override { 132 ASSERT(connected_); 133 return SendTo(data, size, remote_addr_); 134 } 135 136 int SendTo(const void* data, 137 size_t size, 138 const SocketAddress& addr) override { 139 ASSERT(!connected_ || addr == remote_addr_); 140 if (server_addr_.IsNil() || type_ == SOCK_STREAM) { 141 return socket_->SendTo(data, size, addr); 142 } 143 // This array will be too large for IPv4 packets, but only by 12 bytes. 144 scoped_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]); 145 size_t addrlength = PackAddressForNAT(buf.get(), 146 size + kNATEncodedIPv6AddressSize, 147 addr); 148 size_t encoded_size = size + addrlength; 149 memcpy(buf.get() + addrlength, data, size); 150 int result = socket_->SendTo(buf.get(), encoded_size, server_addr_); 151 if (result >= 0) { 152 ASSERT(result == static_cast<int>(encoded_size)); 153 result = result - static_cast<int>(addrlength); 154 } 155 return result; 156 } 157 158 int Recv(void* data, size_t size) override { 159 SocketAddress addr; 160 return RecvFrom(data, size, &addr); 161 } 162 163 int RecvFrom(void* data, size_t size, SocketAddress* out_addr) override { 164 if (server_addr_.IsNil() || type_ == SOCK_STREAM) { 165 return socket_->RecvFrom(data, size, out_addr); 166 } 167 // Make sure we have enough room to read the requested amount plus the 168 // largest possible header address. 169 SocketAddress remote_addr; 170 Grow(size + kNATEncodedIPv6AddressSize); 171 172 // Read the packet from the socket. 173 int result = socket_->RecvFrom(buf_, size_, &remote_addr); 174 if (result >= 0) { 175 ASSERT(remote_addr == server_addr_); 176 177 // TODO: we need better framing so we know how many bytes we can 178 // return before we need to read the next address. For UDP, this will be 179 // fine as long as the reader always reads everything in the packet. 180 ASSERT((size_t)result < size_); 181 182 // Decode the wire packet into the actual results. 183 SocketAddress real_remote_addr; 184 size_t addrlength = UnpackAddressFromNAT(buf_, result, &real_remote_addr); 185 memcpy(data, buf_ + addrlength, result - addrlength); 186 187 // Make sure this packet should be delivered before returning it. 188 if (!connected_ || (real_remote_addr == remote_addr_)) { 189 if (out_addr) 190 *out_addr = real_remote_addr; 191 result = result - static_cast<int>(addrlength); 192 } else { 193 LOG(LS_ERROR) << "Dropping packet from unknown remote address: " 194 << real_remote_addr.ToString(); 195 result = 0; // Tell the caller we didn't read anything 196 } 197 } 198 199 return result; 200 } 201 202 int Close() override { 203 int result = 0; 204 if (socket_) { 205 result = socket_->Close(); 206 if (result >= 0) { 207 connected_ = false; 208 remote_addr_ = SocketAddress(); 209 delete socket_; 210 socket_ = NULL; 211 } 212 } 213 return result; 214 } 215 216 int Listen(int backlog) override { return socket_->Listen(backlog); } 217 AsyncSocket* Accept(SocketAddress* paddr) override { 218 return socket_->Accept(paddr); 219 } 220 int GetError() const override { return socket_->GetError(); } 221 void SetError(int error) override { socket_->SetError(error); } 222 ConnState GetState() const override { 223 return connected_ ? CS_CONNECTED : CS_CLOSED; 224 } 225 int EstimateMTU(uint16_t* mtu) override { return socket_->EstimateMTU(mtu); } 226 int GetOption(Option opt, int* value) override { 227 return socket_->GetOption(opt, value); 228 } 229 int SetOption(Option opt, int value) override { 230 return socket_->SetOption(opt, value); 231 } 232 233 void OnConnectEvent(AsyncSocket* socket) { 234 // If we're NATed, we need to send a message with the real addr to use. 235 ASSERT(socket == socket_); 236 if (server_addr_.IsNil()) { 237 connected_ = true; 238 SignalConnectEvent(this); 239 } else { 240 SendConnectRequest(); 241 } 242 } 243 void OnReadEvent(AsyncSocket* socket) { 244 // If we're NATed, we need to process the connect reply. 245 ASSERT(socket == socket_); 246 if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) { 247 HandleConnectReply(); 248 } else { 249 SignalReadEvent(this); 250 } 251 } 252 void OnWriteEvent(AsyncSocket* socket) { 253 ASSERT(socket == socket_); 254 SignalWriteEvent(this); 255 } 256 void OnCloseEvent(AsyncSocket* socket, int error) { 257 ASSERT(socket == socket_); 258 SignalCloseEvent(this, error); 259 } 260 261 private: 262 // Makes sure the buffer is at least the given size. 263 void Grow(size_t new_size) { 264 if (size_ < new_size) { 265 delete[] buf_; 266 size_ = new_size; 267 buf_ = new char[size_]; 268 } 269 } 270 271 // Sends the destination address to the server to tell it to connect. 272 void SendConnectRequest() { 273 char buf[kNATEncodedIPv6AddressSize]; 274 size_t length = PackAddressForNAT(buf, arraysize(buf), remote_addr_); 275 socket_->Send(buf, length); 276 } 277 278 // Handles the byte sent back from the server and fires the appropriate event. 279 void HandleConnectReply() { 280 char code; 281 socket_->Recv(&code, sizeof(code)); 282 if (code == 0) { 283 connected_ = true; 284 SignalConnectEvent(this); 285 } else { 286 Close(); 287 SignalCloseEvent(this, code); 288 } 289 } 290 291 NATInternalSocketFactory* sf_; 292 int family_; 293 int type_; 294 bool connected_; 295 SocketAddress remote_addr_; 296 SocketAddress server_addr_; // address of the NAT server 297 AsyncSocket* socket_; 298 char* buf_; 299 size_t size_; 300}; 301 302// NATSocketFactory 303NATSocketFactory::NATSocketFactory(SocketFactory* factory, 304 const SocketAddress& nat_udp_addr, 305 const SocketAddress& nat_tcp_addr) 306 : factory_(factory), nat_udp_addr_(nat_udp_addr), 307 nat_tcp_addr_(nat_tcp_addr) { 308} 309 310Socket* NATSocketFactory::CreateSocket(int type) { 311 return CreateSocket(AF_INET, type); 312} 313 314Socket* NATSocketFactory::CreateSocket(int family, int type) { 315 return new NATSocket(this, family, type); 316} 317 318AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) { 319 return CreateAsyncSocket(AF_INET, type); 320} 321 322AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) { 323 return new NATSocket(this, family, type); 324} 325 326AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type, 327 const SocketAddress& local_addr, SocketAddress* nat_addr) { 328 if (type == SOCK_STREAM) { 329 *nat_addr = nat_tcp_addr_; 330 } else { 331 *nat_addr = nat_udp_addr_; 332 } 333 return factory_->CreateAsyncSocket(family, type); 334} 335 336// NATSocketServer 337NATSocketServer::NATSocketServer(SocketServer* server) 338 : server_(server), msg_queue_(NULL) { 339} 340 341NATSocketServer::Translator* NATSocketServer::GetTranslator( 342 const SocketAddress& ext_ip) { 343 return nats_.Get(ext_ip); 344} 345 346NATSocketServer::Translator* NATSocketServer::AddTranslator( 347 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) { 348 // Fail if a translator already exists with this extternal address. 349 if (nats_.Get(ext_ip)) 350 return NULL; 351 352 return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip)); 353} 354 355void NATSocketServer::RemoveTranslator( 356 const SocketAddress& ext_ip) { 357 nats_.Remove(ext_ip); 358} 359 360Socket* NATSocketServer::CreateSocket(int type) { 361 return CreateSocket(AF_INET, type); 362} 363 364Socket* NATSocketServer::CreateSocket(int family, int type) { 365 return new NATSocket(this, family, type); 366} 367 368AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) { 369 return CreateAsyncSocket(AF_INET, type); 370} 371 372AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) { 373 return new NATSocket(this, family, type); 374} 375 376void NATSocketServer::SetMessageQueue(MessageQueue* queue) { 377 msg_queue_ = queue; 378 server_->SetMessageQueue(queue); 379} 380 381bool NATSocketServer::Wait(int cms, bool process_io) { 382 return server_->Wait(cms, process_io); 383} 384 385void NATSocketServer::WakeUp() { 386 server_->WakeUp(); 387} 388 389AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type, 390 const SocketAddress& local_addr, SocketAddress* nat_addr) { 391 AsyncSocket* socket = NULL; 392 Translator* nat = nats_.FindClient(local_addr); 393 if (nat) { 394 socket = nat->internal_factory()->CreateAsyncSocket(family, type); 395 *nat_addr = (type == SOCK_STREAM) ? 396 nat->internal_tcp_address() : nat->internal_udp_address(); 397 } else { 398 socket = server_->CreateAsyncSocket(family, type); 399 } 400 return socket; 401} 402 403// NATSocketServer::Translator 404NATSocketServer::Translator::Translator( 405 NATSocketServer* server, NATType type, const SocketAddress& int_ip, 406 SocketFactory* ext_factory, const SocketAddress& ext_ip) 407 : server_(server) { 408 // Create a new private network, and a NATServer running on the private 409 // network that bridges to the external network. Also tell the private 410 // network to use the same message queue as us. 411 VirtualSocketServer* internal_server = new VirtualSocketServer(server_); 412 internal_server->SetMessageQueue(server_->queue()); 413 internal_factory_.reset(internal_server); 414 nat_server_.reset(new NATServer(type, internal_server, int_ip, int_ip, 415 ext_factory, ext_ip)); 416} 417 418NATSocketServer::Translator::~Translator() = default; 419 420NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator( 421 const SocketAddress& ext_ip) { 422 return nats_.Get(ext_ip); 423} 424 425NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator( 426 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) { 427 // Fail if a translator already exists with this extternal address. 428 if (nats_.Get(ext_ip)) 429 return NULL; 430 431 AddClient(ext_ip); 432 return nats_.Add(ext_ip, 433 new Translator(server_, type, int_ip, server_, ext_ip)); 434} 435void NATSocketServer::Translator::RemoveTranslator( 436 const SocketAddress& ext_ip) { 437 nats_.Remove(ext_ip); 438 RemoveClient(ext_ip); 439} 440 441bool NATSocketServer::Translator::AddClient( 442 const SocketAddress& int_ip) { 443 // Fail if a client already exists with this internal address. 444 if (clients_.find(int_ip) != clients_.end()) 445 return false; 446 447 clients_.insert(int_ip); 448 return true; 449} 450 451void NATSocketServer::Translator::RemoveClient( 452 const SocketAddress& int_ip) { 453 std::set<SocketAddress>::iterator it = clients_.find(int_ip); 454 if (it != clients_.end()) { 455 clients_.erase(it); 456 } 457} 458 459NATSocketServer::Translator* NATSocketServer::Translator::FindClient( 460 const SocketAddress& int_ip) { 461 // See if we have the requested IP, or any of our children do. 462 return (clients_.find(int_ip) != clients_.end()) ? 463 this : nats_.FindClient(int_ip); 464} 465 466// NATSocketServer::TranslatorMap 467NATSocketServer::TranslatorMap::~TranslatorMap() { 468 for (TranslatorMap::iterator it = begin(); it != end(); ++it) { 469 delete it->second; 470 } 471} 472 473NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get( 474 const SocketAddress& ext_ip) { 475 TranslatorMap::iterator it = find(ext_ip); 476 return (it != end()) ? it->second : NULL; 477} 478 479NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add( 480 const SocketAddress& ext_ip, Translator* nat) { 481 (*this)[ext_ip] = nat; 482 return nat; 483} 484 485void NATSocketServer::TranslatorMap::Remove( 486 const SocketAddress& ext_ip) { 487 TranslatorMap::iterator it = find(ext_ip); 488 if (it != end()) { 489 delete it->second; 490 erase(it); 491 } 492} 493 494NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient( 495 const SocketAddress& int_ip) { 496 Translator* nat = NULL; 497 for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) { 498 nat = it->second->FindClient(int_ip); 499 } 500 return nat; 501} 502 503} // namespace rtc 504