1/*
2 *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "webrtc/base/natsocketfactory.h"
12#include "webrtc/base/natserver.h"
13#include "webrtc/base/logging.h"
14#include "webrtc/base/socketadapters.h"
15
16namespace rtc {
17
18RouteCmp::RouteCmp(NAT* nat) : symmetric(nat->IsSymmetric()) {
19}
20
21size_t RouteCmp::operator()(const SocketAddressPair& r) const {
22  size_t h = r.source().Hash();
23  if (symmetric)
24    h ^= r.destination().Hash();
25  return h;
26}
27
28bool RouteCmp::operator()(
29      const SocketAddressPair& r1, const SocketAddressPair& r2) const {
30  if (r1.source() < r2.source())
31    return true;
32  if (r2.source() < r1.source())
33    return false;
34  if (symmetric && (r1.destination() < r2.destination()))
35    return true;
36  if (symmetric && (r2.destination() < r1.destination()))
37    return false;
38  return false;
39}
40
41AddrCmp::AddrCmp(NAT* nat)
42    : use_ip(nat->FiltersIP()), use_port(nat->FiltersPort()) {
43}
44
45size_t AddrCmp::operator()(const SocketAddress& a) const {
46  size_t h = 0;
47  if (use_ip)
48    h ^= HashIP(a.ipaddr());
49  if (use_port)
50    h ^= a.port() | (a.port() << 16);
51  return h;
52}
53
54bool AddrCmp::operator()(
55      const SocketAddress& a1, const SocketAddress& a2) const {
56  if (use_ip && (a1.ipaddr() < a2.ipaddr()))
57    return true;
58  if (use_ip && (a2.ipaddr() < a1.ipaddr()))
59    return false;
60  if (use_port && (a1.port() < a2.port()))
61    return true;
62  if (use_port && (a2.port() < a1.port()))
63    return false;
64  return false;
65}
66
67// Proxy socket that will capture the external destination address intended for
68// a TCP connection to the NAT server.
69class NATProxyServerSocket : public AsyncProxyServerSocket {
70 public:
71  NATProxyServerSocket(AsyncSocket* socket)
72      : AsyncProxyServerSocket(socket, kNATEncodedIPv6AddressSize) {
73    BufferInput(true);
74  }
75
76  void SendConnectResult(int err, const SocketAddress& addr) override {
77    char code = err ? 1 : 0;
78    BufferedReadAdapter::DirectSend(&code, sizeof(char));
79  }
80
81 protected:
82  void ProcessInput(char* data, size_t* len) override {
83    if (*len < 2) {
84      return;
85    }
86
87    int family = data[1];
88    ASSERT(family == AF_INET || family == AF_INET6);
89    if ((family == AF_INET && *len < kNATEncodedIPv4AddressSize) ||
90        (family == AF_INET6 && *len < kNATEncodedIPv6AddressSize)) {
91      return;
92    }
93
94    SocketAddress dest_addr;
95    size_t address_length = UnpackAddressFromNAT(data, *len, &dest_addr);
96
97    *len -= address_length;
98    if (*len > 0) {
99      memmove(data, data + address_length, *len);
100    }
101
102    bool remainder = (*len > 0);
103    BufferInput(false);
104    SignalConnectRequest(this, dest_addr);
105    if (remainder) {
106      SignalReadEvent(this);
107    }
108  }
109
110};
111
112class NATProxyServer : public ProxyServer {
113 public:
114  NATProxyServer(SocketFactory* int_factory, const SocketAddress& int_addr,
115                 SocketFactory* ext_factory, const SocketAddress& ext_ip)
116      : ProxyServer(int_factory, int_addr, ext_factory, ext_ip) {
117  }
118
119 protected:
120  AsyncProxyServerSocket* WrapSocket(AsyncSocket* socket) override {
121    return new NATProxyServerSocket(socket);
122  }
123};
124
125NATServer::NATServer(
126    NATType type, SocketFactory* internal,
127    const SocketAddress& internal_udp_addr,
128    const SocketAddress& internal_tcp_addr,
129    SocketFactory* external, const SocketAddress& external_ip)
130    : external_(external), external_ip_(external_ip.ipaddr(), 0) {
131  nat_ = NAT::Create(type);
132
133  udp_server_socket_ = AsyncUDPSocket::Create(internal, internal_udp_addr);
134  udp_server_socket_->SignalReadPacket.connect(this,
135                                               &NATServer::OnInternalUDPPacket);
136  tcp_proxy_server_ = new NATProxyServer(internal, internal_tcp_addr, external,
137                                         external_ip);
138
139  int_map_ = new InternalMap(RouteCmp(nat_));
140  ext_map_ = new ExternalMap();
141}
142
143NATServer::~NATServer() {
144  for (InternalMap::iterator iter = int_map_->begin();
145       iter != int_map_->end();
146       iter++)
147    delete iter->second;
148
149  delete nat_;
150  delete udp_server_socket_;
151  delete tcp_proxy_server_;
152  delete int_map_;
153  delete ext_map_;
154}
155
156void NATServer::OnInternalUDPPacket(
157    AsyncPacketSocket* socket, const char* buf, size_t size,
158    const SocketAddress& addr, const PacketTime& packet_time) {
159  // Read the intended destination from the wire.
160  SocketAddress dest_addr;
161  size_t length = UnpackAddressFromNAT(buf, size, &dest_addr);
162
163  // Find the translation for these addresses (allocating one if necessary).
164  SocketAddressPair route(addr, dest_addr);
165  InternalMap::iterator iter = int_map_->find(route);
166  if (iter == int_map_->end()) {
167    Translate(route);
168    iter = int_map_->find(route);
169  }
170  ASSERT(iter != int_map_->end());
171
172  // Allow the destination to send packets back to the source.
173  iter->second->WhitelistInsert(dest_addr);
174
175  // Send the packet to its intended destination.
176  rtc::PacketOptions options;
177  iter->second->socket->SendTo(buf + length, size - length, dest_addr, options);
178}
179
180void NATServer::OnExternalUDPPacket(
181    AsyncPacketSocket* socket, const char* buf, size_t size,
182    const SocketAddress& remote_addr, const PacketTime& packet_time) {
183  SocketAddress local_addr = socket->GetLocalAddress();
184
185  // Find the translation for this addresses.
186  ExternalMap::iterator iter = ext_map_->find(local_addr);
187  ASSERT(iter != ext_map_->end());
188
189  // Allow the NAT to reject this packet.
190  if (ShouldFilterOut(iter->second, remote_addr)) {
191    LOG(LS_INFO) << "Packet from " << remote_addr.ToSensitiveString()
192                 << " was filtered out by the NAT.";
193    return;
194  }
195
196  // Forward this packet to the internal address.
197  // First prepend the address in a quasi-STUN format.
198  scoped_ptr<char[]> real_buf(new char[size + kNATEncodedIPv6AddressSize]);
199  size_t addrlength = PackAddressForNAT(real_buf.get(),
200                                        size + kNATEncodedIPv6AddressSize,
201                                        remote_addr);
202  // Copy the data part after the address.
203  rtc::PacketOptions options;
204  memcpy(real_buf.get() + addrlength, buf, size);
205  udp_server_socket_->SendTo(real_buf.get(), size + addrlength,
206                             iter->second->route.source(), options);
207}
208
209void NATServer::Translate(const SocketAddressPair& route) {
210  AsyncUDPSocket* socket = AsyncUDPSocket::Create(external_, external_ip_);
211
212  if (!socket) {
213    LOG(LS_ERROR) << "Couldn't find a free port!";
214    return;
215  }
216
217  TransEntry* entry = new TransEntry(route, socket, nat_);
218  (*int_map_)[route] = entry;
219  (*ext_map_)[socket->GetLocalAddress()] = entry;
220  socket->SignalReadPacket.connect(this, &NATServer::OnExternalUDPPacket);
221}
222
223bool NATServer::ShouldFilterOut(TransEntry* entry,
224                                const SocketAddress& ext_addr) {
225  return entry->WhitelistContains(ext_addr);
226}
227
228NATServer::TransEntry::TransEntry(
229    const SocketAddressPair& r, AsyncUDPSocket* s, NAT* nat)
230    : route(r), socket(s) {
231  whitelist = new AddressSet(AddrCmp(nat));
232}
233
234NATServer::TransEntry::~TransEntry() {
235  delete whitelist;
236  delete socket;
237}
238
239void NATServer::TransEntry::WhitelistInsert(const SocketAddress& addr) {
240  CritScope cs(&crit_);
241  whitelist->insert(addr);
242}
243
244bool NATServer::TransEntry::WhitelistContains(const SocketAddress& ext_addr) {
245  CritScope cs(&crit_);
246  return whitelist->find(ext_addr) == whitelist->end();
247}
248
249}  // namespace rtc
250