1/*
2 *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "webrtc/base/natsocketfactory.h"
12
13#include "webrtc/base/arraysize.h"
14#include "webrtc/base/logging.h"
15#include "webrtc/base/natserver.h"
16#include "webrtc/base/virtualsocketserver.h"
17
18namespace rtc {
19
20// Packs the given socketaddress into the buffer in buf, in the quasi-STUN
21// format that the natserver uses.
22// Returns 0 if an invalid address is passed.
23size_t PackAddressForNAT(char* buf, size_t buf_size,
24                         const SocketAddress& remote_addr) {
25  const IPAddress& ip = remote_addr.ipaddr();
26  int family = ip.family();
27  buf[0] = 0;
28  buf[1] = family;
29  // Writes the port.
30  *(reinterpret_cast<uint16_t*>(&buf[2])) = HostToNetwork16(remote_addr.port());
31  if (family == AF_INET) {
32    ASSERT(buf_size >= kNATEncodedIPv4AddressSize);
33    in_addr v4addr = ip.ipv4_address();
34    memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
35    return kNATEncodedIPv4AddressSize;
36  } else if (family == AF_INET6) {
37    ASSERT(buf_size >= kNATEncodedIPv6AddressSize);
38    in6_addr v6addr = ip.ipv6_address();
39    memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
40    return kNATEncodedIPv6AddressSize;
41  }
42  return 0U;
43}
44
45// Decodes the remote address from a packet that has been encoded with the nat's
46// quasi-STUN format. Returns the length of the address (i.e., the offset into
47// data where the original packet starts).
48size_t UnpackAddressFromNAT(const char* buf, size_t buf_size,
49                            SocketAddress* remote_addr) {
50  ASSERT(buf_size >= 8);
51  ASSERT(buf[0] == 0);
52  int family = buf[1];
53  uint16_t port =
54      NetworkToHost16(*(reinterpret_cast<const uint16_t*>(&buf[2])));
55  if (family == AF_INET) {
56    const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
57    *remote_addr = SocketAddress(IPAddress(*v4addr), port);
58    return kNATEncodedIPv4AddressSize;
59  } else if (family == AF_INET6) {
60    ASSERT(buf_size >= 20);
61    const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
62    *remote_addr = SocketAddress(IPAddress(*v6addr), port);
63    return kNATEncodedIPv6AddressSize;
64  }
65  return 0U;
66}
67
68
69// NATSocket
70class NATSocket : public AsyncSocket, public sigslot::has_slots<> {
71 public:
72  explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
73      : sf_(sf), family_(family), type_(type), connected_(false),
74        socket_(NULL), buf_(NULL), size_(0) {
75  }
76
77  ~NATSocket() override {
78    delete socket_;
79    delete[] buf_;
80  }
81
82  SocketAddress GetLocalAddress() const override {
83    return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
84  }
85
86  SocketAddress GetRemoteAddress() const override {
87    return remote_addr_;  // will be NIL if not connected
88  }
89
90  int Bind(const SocketAddress& addr) override {
91    if (socket_) {  // already bound, bubble up error
92      return -1;
93    }
94
95    int result;
96    socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_);
97    result = (socket_) ? socket_->Bind(addr) : -1;
98    if (result >= 0) {
99      socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
100      socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
101      socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
102      socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
103    } else {
104      server_addr_.Clear();
105      delete socket_;
106      socket_ = NULL;
107    }
108
109    return result;
110  }
111
112  int Connect(const SocketAddress& addr) override {
113    if (!socket_) {  // socket must be bound, for now
114      return -1;
115    }
116
117    int result = 0;
118    if (type_ == SOCK_STREAM) {
119      result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_);
120    } else {
121      connected_ = true;
122    }
123
124    if (result >= 0) {
125      remote_addr_ = addr;
126    }
127
128    return result;
129  }
130
131  int Send(const void* data, size_t size) override {
132    ASSERT(connected_);
133    return SendTo(data, size, remote_addr_);
134  }
135
136  int SendTo(const void* data,
137             size_t size,
138             const SocketAddress& addr) override {
139    ASSERT(!connected_ || addr == remote_addr_);
140    if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
141      return socket_->SendTo(data, size, addr);
142    }
143    // This array will be too large for IPv4 packets, but only by 12 bytes.
144    scoped_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]);
145    size_t addrlength = PackAddressForNAT(buf.get(),
146                                          size + kNATEncodedIPv6AddressSize,
147                                          addr);
148    size_t encoded_size = size + addrlength;
149    memcpy(buf.get() + addrlength, data, size);
150    int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
151    if (result >= 0) {
152      ASSERT(result == static_cast<int>(encoded_size));
153      result = result - static_cast<int>(addrlength);
154    }
155    return result;
156  }
157
158  int Recv(void* data, size_t size) override {
159    SocketAddress addr;
160    return RecvFrom(data, size, &addr);
161  }
162
163  int RecvFrom(void* data, size_t size, SocketAddress* out_addr) override {
164    if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
165      return socket_->RecvFrom(data, size, out_addr);
166    }
167    // Make sure we have enough room to read the requested amount plus the
168    // largest possible header address.
169    SocketAddress remote_addr;
170    Grow(size + kNATEncodedIPv6AddressSize);
171
172    // Read the packet from the socket.
173    int result = socket_->RecvFrom(buf_, size_, &remote_addr);
174    if (result >= 0) {
175      ASSERT(remote_addr == server_addr_);
176
177      // TODO: we need better framing so we know how many bytes we can
178      // return before we need to read the next address. For UDP, this will be
179      // fine as long as the reader always reads everything in the packet.
180      ASSERT((size_t)result < size_);
181
182      // Decode the wire packet into the actual results.
183      SocketAddress real_remote_addr;
184      size_t addrlength = UnpackAddressFromNAT(buf_, result, &real_remote_addr);
185      memcpy(data, buf_ + addrlength, result - addrlength);
186
187      // Make sure this packet should be delivered before returning it.
188      if (!connected_ || (real_remote_addr == remote_addr_)) {
189        if (out_addr)
190          *out_addr = real_remote_addr;
191        result = result - static_cast<int>(addrlength);
192      } else {
193        LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
194                      << real_remote_addr.ToString();
195        result = 0;  // Tell the caller we didn't read anything
196      }
197    }
198
199    return result;
200  }
201
202  int Close() override {
203    int result = 0;
204    if (socket_) {
205      result = socket_->Close();
206      if (result >= 0) {
207        connected_ = false;
208        remote_addr_ = SocketAddress();
209        delete socket_;
210        socket_ = NULL;
211      }
212    }
213    return result;
214  }
215
216  int Listen(int backlog) override { return socket_->Listen(backlog); }
217  AsyncSocket* Accept(SocketAddress* paddr) override {
218    return socket_->Accept(paddr);
219  }
220  int GetError() const override { return socket_->GetError(); }
221  void SetError(int error) override { socket_->SetError(error); }
222  ConnState GetState() const override {
223    return connected_ ? CS_CONNECTED : CS_CLOSED;
224  }
225  int EstimateMTU(uint16_t* mtu) override { return socket_->EstimateMTU(mtu); }
226  int GetOption(Option opt, int* value) override {
227    return socket_->GetOption(opt, value);
228  }
229  int SetOption(Option opt, int value) override {
230    return socket_->SetOption(opt, value);
231  }
232
233  void OnConnectEvent(AsyncSocket* socket) {
234    // If we're NATed, we need to send a message with the real addr to use.
235    ASSERT(socket == socket_);
236    if (server_addr_.IsNil()) {
237      connected_ = true;
238      SignalConnectEvent(this);
239    } else {
240      SendConnectRequest();
241    }
242  }
243  void OnReadEvent(AsyncSocket* socket) {
244    // If we're NATed, we need to process the connect reply.
245    ASSERT(socket == socket_);
246    if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) {
247      HandleConnectReply();
248    } else {
249      SignalReadEvent(this);
250    }
251  }
252  void OnWriteEvent(AsyncSocket* socket) {
253    ASSERT(socket == socket_);
254    SignalWriteEvent(this);
255  }
256  void OnCloseEvent(AsyncSocket* socket, int error) {
257    ASSERT(socket == socket_);
258    SignalCloseEvent(this, error);
259  }
260
261 private:
262  // Makes sure the buffer is at least the given size.
263  void Grow(size_t new_size) {
264    if (size_ < new_size) {
265      delete[] buf_;
266      size_ = new_size;
267      buf_ = new char[size_];
268    }
269  }
270
271  // Sends the destination address to the server to tell it to connect.
272  void SendConnectRequest() {
273    char buf[kNATEncodedIPv6AddressSize];
274    size_t length = PackAddressForNAT(buf, arraysize(buf), remote_addr_);
275    socket_->Send(buf, length);
276  }
277
278  // Handles the byte sent back from the server and fires the appropriate event.
279  void HandleConnectReply() {
280    char code;
281    socket_->Recv(&code, sizeof(code));
282    if (code == 0) {
283      connected_ = true;
284      SignalConnectEvent(this);
285    } else {
286      Close();
287      SignalCloseEvent(this, code);
288    }
289  }
290
291  NATInternalSocketFactory* sf_;
292  int family_;
293  int type_;
294  bool connected_;
295  SocketAddress remote_addr_;
296  SocketAddress server_addr_;  // address of the NAT server
297  AsyncSocket* socket_;
298  char* buf_;
299  size_t size_;
300};
301
302// NATSocketFactory
303NATSocketFactory::NATSocketFactory(SocketFactory* factory,
304                                   const SocketAddress& nat_udp_addr,
305                                   const SocketAddress& nat_tcp_addr)
306    : factory_(factory), nat_udp_addr_(nat_udp_addr),
307      nat_tcp_addr_(nat_tcp_addr) {
308}
309
310Socket* NATSocketFactory::CreateSocket(int type) {
311  return CreateSocket(AF_INET, type);
312}
313
314Socket* NATSocketFactory::CreateSocket(int family, int type) {
315  return new NATSocket(this, family, type);
316}
317
318AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) {
319  return CreateAsyncSocket(AF_INET, type);
320}
321
322AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) {
323  return new NATSocket(this, family, type);
324}
325
326AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type,
327    const SocketAddress& local_addr, SocketAddress* nat_addr) {
328  if (type == SOCK_STREAM) {
329    *nat_addr = nat_tcp_addr_;
330  } else {
331    *nat_addr = nat_udp_addr_;
332  }
333  return factory_->CreateAsyncSocket(family, type);
334}
335
336// NATSocketServer
337NATSocketServer::NATSocketServer(SocketServer* server)
338    : server_(server), msg_queue_(NULL) {
339}
340
341NATSocketServer::Translator* NATSocketServer::GetTranslator(
342    const SocketAddress& ext_ip) {
343  return nats_.Get(ext_ip);
344}
345
346NATSocketServer::Translator* NATSocketServer::AddTranslator(
347    const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
348  // Fail if a translator already exists with this extternal address.
349  if (nats_.Get(ext_ip))
350    return NULL;
351
352  return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
353}
354
355void NATSocketServer::RemoveTranslator(
356    const SocketAddress& ext_ip) {
357  nats_.Remove(ext_ip);
358}
359
360Socket* NATSocketServer::CreateSocket(int type) {
361  return CreateSocket(AF_INET, type);
362}
363
364Socket* NATSocketServer::CreateSocket(int family, int type) {
365  return new NATSocket(this, family, type);
366}
367
368AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) {
369  return CreateAsyncSocket(AF_INET, type);
370}
371
372AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) {
373  return new NATSocket(this, family, type);
374}
375
376void NATSocketServer::SetMessageQueue(MessageQueue* queue) {
377  msg_queue_ = queue;
378  server_->SetMessageQueue(queue);
379}
380
381bool NATSocketServer::Wait(int cms, bool process_io) {
382  return server_->Wait(cms, process_io);
383}
384
385void NATSocketServer::WakeUp() {
386  server_->WakeUp();
387}
388
389AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type,
390    const SocketAddress& local_addr, SocketAddress* nat_addr) {
391  AsyncSocket* socket = NULL;
392  Translator* nat = nats_.FindClient(local_addr);
393  if (nat) {
394    socket = nat->internal_factory()->CreateAsyncSocket(family, type);
395    *nat_addr = (type == SOCK_STREAM) ?
396        nat->internal_tcp_address() : nat->internal_udp_address();
397  } else {
398    socket = server_->CreateAsyncSocket(family, type);
399  }
400  return socket;
401}
402
403// NATSocketServer::Translator
404NATSocketServer::Translator::Translator(
405    NATSocketServer* server, NATType type, const SocketAddress& int_ip,
406    SocketFactory* ext_factory, const SocketAddress& ext_ip)
407    : server_(server) {
408  // Create a new private network, and a NATServer running on the private
409  // network that bridges to the external network. Also tell the private
410  // network to use the same message queue as us.
411  VirtualSocketServer* internal_server = new VirtualSocketServer(server_);
412  internal_server->SetMessageQueue(server_->queue());
413  internal_factory_.reset(internal_server);
414  nat_server_.reset(new NATServer(type, internal_server, int_ip, int_ip,
415                                  ext_factory, ext_ip));
416}
417
418NATSocketServer::Translator::~Translator() = default;
419
420NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
421    const SocketAddress& ext_ip) {
422  return nats_.Get(ext_ip);
423}
424
425NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
426    const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
427  // Fail if a translator already exists with this extternal address.
428  if (nats_.Get(ext_ip))
429    return NULL;
430
431  AddClient(ext_ip);
432  return nats_.Add(ext_ip,
433                   new Translator(server_, type, int_ip, server_, ext_ip));
434}
435void NATSocketServer::Translator::RemoveTranslator(
436    const SocketAddress& ext_ip) {
437  nats_.Remove(ext_ip);
438  RemoveClient(ext_ip);
439}
440
441bool NATSocketServer::Translator::AddClient(
442    const SocketAddress& int_ip) {
443  // Fail if a client already exists with this internal address.
444  if (clients_.find(int_ip) != clients_.end())
445    return false;
446
447  clients_.insert(int_ip);
448  return true;
449}
450
451void NATSocketServer::Translator::RemoveClient(
452    const SocketAddress& int_ip) {
453  std::set<SocketAddress>::iterator it = clients_.find(int_ip);
454  if (it != clients_.end()) {
455    clients_.erase(it);
456  }
457}
458
459NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
460    const SocketAddress& int_ip) {
461  // See if we have the requested IP, or any of our children do.
462  return (clients_.find(int_ip) != clients_.end()) ?
463      this : nats_.FindClient(int_ip);
464}
465
466// NATSocketServer::TranslatorMap
467NATSocketServer::TranslatorMap::~TranslatorMap() {
468  for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
469    delete it->second;
470  }
471}
472
473NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
474    const SocketAddress& ext_ip) {
475  TranslatorMap::iterator it = find(ext_ip);
476  return (it != end()) ? it->second : NULL;
477}
478
479NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
480    const SocketAddress& ext_ip, Translator* nat) {
481  (*this)[ext_ip] = nat;
482  return nat;
483}
484
485void NATSocketServer::TranslatorMap::Remove(
486    const SocketAddress& ext_ip) {
487  TranslatorMap::iterator it = find(ext_ip);
488  if (it != end()) {
489    delete it->second;
490    erase(it);
491  }
492}
493
494NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
495    const SocketAddress& int_ip) {
496  Translator* nat = NULL;
497  for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
498    nat = it->second->FindClient(int_ip);
499  }
500  return nat;
501}
502
503}  // namespace rtc
504