1/*
2 *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "webrtc/base/natsocketfactory.h"
12#include "webrtc/base/natserver.h"
13#include "webrtc/base/logging.h"
14
15namespace rtc {
16
17RouteCmp::RouteCmp(NAT* nat) : symmetric(nat->IsSymmetric()) {
18}
19
20size_t RouteCmp::operator()(const SocketAddressPair& r) const {
21  size_t h = r.source().Hash();
22  if (symmetric)
23    h ^= r.destination().Hash();
24  return h;
25}
26
27bool RouteCmp::operator()(
28      const SocketAddressPair& r1, const SocketAddressPair& r2) const {
29  if (r1.source() < r2.source())
30    return true;
31  if (r2.source() < r1.source())
32    return false;
33  if (symmetric && (r1.destination() < r2.destination()))
34    return true;
35  if (symmetric && (r2.destination() < r1.destination()))
36    return false;
37  return false;
38}
39
40AddrCmp::AddrCmp(NAT* nat)
41    : use_ip(nat->FiltersIP()), use_port(nat->FiltersPort()) {
42}
43
44size_t AddrCmp::operator()(const SocketAddress& a) const {
45  size_t h = 0;
46  if (use_ip)
47    h ^= HashIP(a.ipaddr());
48  if (use_port)
49    h ^= a.port() | (a.port() << 16);
50  return h;
51}
52
53bool AddrCmp::operator()(
54      const SocketAddress& a1, const SocketAddress& a2) const {
55  if (use_ip && (a1.ipaddr() < a2.ipaddr()))
56    return true;
57  if (use_ip && (a2.ipaddr() < a1.ipaddr()))
58    return false;
59  if (use_port && (a1.port() < a2.port()))
60    return true;
61  if (use_port && (a2.port() < a1.port()))
62    return false;
63  return false;
64}
65
66NATServer::NATServer(
67    NATType type, SocketFactory* internal, const SocketAddress& internal_addr,
68    SocketFactory* external, const SocketAddress& external_ip)
69    : external_(external), external_ip_(external_ip.ipaddr(), 0) {
70  nat_ = NAT::Create(type);
71
72  server_socket_ = AsyncUDPSocket::Create(internal, internal_addr);
73  server_socket_->SignalReadPacket.connect(this, &NATServer::OnInternalPacket);
74
75  int_map_ = new InternalMap(RouteCmp(nat_));
76  ext_map_ = new ExternalMap();
77}
78
79NATServer::~NATServer() {
80  for (InternalMap::iterator iter = int_map_->begin();
81       iter != int_map_->end();
82       iter++)
83    delete iter->second;
84
85  delete nat_;
86  delete server_socket_;
87  delete int_map_;
88  delete ext_map_;
89}
90
91void NATServer::OnInternalPacket(
92    AsyncPacketSocket* socket, const char* buf, size_t size,
93    const SocketAddress& addr, const PacketTime& packet_time) {
94
95  // Read the intended destination from the wire.
96  SocketAddress dest_addr;
97  size_t length = UnpackAddressFromNAT(buf, size, &dest_addr);
98
99  // Find the translation for these addresses (allocating one if necessary).
100  SocketAddressPair route(addr, dest_addr);
101  InternalMap::iterator iter = int_map_->find(route);
102  if (iter == int_map_->end()) {
103    Translate(route);
104    iter = int_map_->find(route);
105  }
106  ASSERT(iter != int_map_->end());
107
108  // Allow the destination to send packets back to the source.
109  iter->second->WhitelistInsert(dest_addr);
110
111  // Send the packet to its intended destination.
112  rtc::PacketOptions options;
113  iter->second->socket->SendTo(buf + length, size - length, dest_addr, options);
114}
115
116void NATServer::OnExternalPacket(
117    AsyncPacketSocket* socket, const char* buf, size_t size,
118    const SocketAddress& remote_addr, const PacketTime& packet_time) {
119
120  SocketAddress local_addr = socket->GetLocalAddress();
121
122  // Find the translation for this addresses.
123  ExternalMap::iterator iter = ext_map_->find(local_addr);
124  ASSERT(iter != ext_map_->end());
125
126  // Allow the NAT to reject this packet.
127  if (ShouldFilterOut(iter->second, remote_addr)) {
128    LOG(LS_INFO) << "Packet from " << remote_addr.ToSensitiveString()
129                 << " was filtered out by the NAT.";
130    return;
131  }
132
133  // Forward this packet to the internal address.
134  // First prepend the address in a quasi-STUN format.
135  scoped_ptr<char[]> real_buf(new char[size + kNATEncodedIPv6AddressSize]);
136  size_t addrlength = PackAddressForNAT(real_buf.get(),
137                                        size + kNATEncodedIPv6AddressSize,
138                                        remote_addr);
139  // Copy the data part after the address.
140  rtc::PacketOptions options;
141  memcpy(real_buf.get() + addrlength, buf, size);
142  server_socket_->SendTo(real_buf.get(), size + addrlength,
143                         iter->second->route.source(), options);
144}
145
146void NATServer::Translate(const SocketAddressPair& route) {
147  AsyncUDPSocket* socket = AsyncUDPSocket::Create(external_, external_ip_);
148
149  if (!socket) {
150    LOG(LS_ERROR) << "Couldn't find a free port!";
151    return;
152  }
153
154  TransEntry* entry = new TransEntry(route, socket, nat_);
155  (*int_map_)[route] = entry;
156  (*ext_map_)[socket->GetLocalAddress()] = entry;
157  socket->SignalReadPacket.connect(this, &NATServer::OnExternalPacket);
158}
159
160bool NATServer::ShouldFilterOut(TransEntry* entry,
161                                const SocketAddress& ext_addr) {
162  return entry->WhitelistContains(ext_addr);
163}
164
165NATServer::TransEntry::TransEntry(
166    const SocketAddressPair& r, AsyncUDPSocket* s, NAT* nat)
167    : route(r), socket(s) {
168  whitelist = new AddressSet(AddrCmp(nat));
169}
170
171NATServer::TransEntry::~TransEntry() {
172  delete whitelist;
173  delete socket;
174}
175
176void NATServer::TransEntry::WhitelistInsert(const SocketAddress& addr) {
177  CritScope cs(&crit_);
178  whitelist->insert(addr);
179}
180
181bool NATServer::TransEntry::WhitelistContains(const SocketAddress& ext_addr) {
182  CritScope cs(&crit_);
183  return whitelist->find(ext_addr) == whitelist->end();
184}
185
186}  // namespace rtc
187