1/*
2 *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include <string>
12
13#include "webrtc/base/gunit.h"
14#include "webrtc/base/logging.h"
15#include "webrtc/base/natserver.h"
16#include "webrtc/base/natsocketfactory.h"
17#include "webrtc/base/nethelpers.h"
18#include "webrtc/base/network.h"
19#include "webrtc/base/physicalsocketserver.h"
20#include "webrtc/base/testclient.h"
21#include "webrtc/base/virtualsocketserver.h"
22#include "webrtc/test/testsupport/gtest_disable.h"
23
24using namespace rtc;
25
26bool CheckReceive(
27    TestClient* client, bool should_receive, const char* buf, size_t size) {
28  return (should_receive) ?
29      client->CheckNextPacket(buf, size, 0) :
30      client->CheckNoPacket();
31}
32
33TestClient* CreateTestClient(
34      SocketFactory* factory, const SocketAddress& local_addr) {
35  AsyncUDPSocket* socket = AsyncUDPSocket::Create(factory, local_addr);
36  return new TestClient(socket);
37}
38
39// Tests that when sending from internal_addr to external_addrs through the
40// NAT type specified by nat_type, all external addrs receive the sent packet
41// and, if exp_same is true, all use the same mapped-address on the NAT.
42void TestSend(
43      SocketServer* internal, const SocketAddress& internal_addr,
44      SocketServer* external, const SocketAddress external_addrs[4],
45      NATType nat_type, bool exp_same) {
46  Thread th_int(internal);
47  Thread th_ext(external);
48
49  SocketAddress server_addr = internal_addr;
50  server_addr.SetPort(0);  // Auto-select a port
51  NATServer* nat = new NATServer(
52      nat_type, internal, server_addr, external, external_addrs[0]);
53  NATSocketFactory* natsf = new NATSocketFactory(internal,
54                                                 nat->internal_address());
55
56  TestClient* in = CreateTestClient(natsf, internal_addr);
57  TestClient* out[4];
58  for (int i = 0; i < 4; i++)
59    out[i] = CreateTestClient(external, external_addrs[i]);
60
61  th_int.Start();
62  th_ext.Start();
63
64  const char* buf = "filter_test";
65  size_t len = strlen(buf);
66
67  in->SendTo(buf, len, out[0]->address());
68  SocketAddress trans_addr;
69  EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
70
71  for (int i = 1; i < 4; i++) {
72    in->SendTo(buf, len, out[i]->address());
73    SocketAddress trans_addr2;
74    EXPECT_TRUE(out[i]->CheckNextPacket(buf, len, &trans_addr2));
75    bool are_same = (trans_addr == trans_addr2);
76    ASSERT_EQ(are_same, exp_same) << "same translated address";
77    ASSERT_NE(AF_UNSPEC, trans_addr.family());
78    ASSERT_NE(AF_UNSPEC, trans_addr2.family());
79  }
80
81  th_int.Stop();
82  th_ext.Stop();
83
84  delete nat;
85  delete natsf;
86  delete in;
87  for (int i = 0; i < 4; i++)
88    delete out[i];
89}
90
91// Tests that when sending from external_addrs to internal_addr, the packet
92// is delivered according to the specified filter_ip and filter_port rules.
93void TestRecv(
94      SocketServer* internal, const SocketAddress& internal_addr,
95      SocketServer* external, const SocketAddress external_addrs[4],
96      NATType nat_type, bool filter_ip, bool filter_port) {
97  Thread th_int(internal);
98  Thread th_ext(external);
99
100  SocketAddress server_addr = internal_addr;
101  server_addr.SetPort(0);  // Auto-select a port
102  NATServer* nat = new NATServer(
103      nat_type, internal, server_addr, external, external_addrs[0]);
104  NATSocketFactory* natsf = new NATSocketFactory(internal,
105                                                 nat->internal_address());
106
107  TestClient* in = CreateTestClient(natsf, internal_addr);
108  TestClient* out[4];
109  for (int i = 0; i < 4; i++)
110    out[i] = CreateTestClient(external, external_addrs[i]);
111
112  th_int.Start();
113  th_ext.Start();
114
115  const char* buf = "filter_test";
116  size_t len = strlen(buf);
117
118  in->SendTo(buf, len, out[0]->address());
119  SocketAddress trans_addr;
120  EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
121
122  out[1]->SendTo(buf, len, trans_addr);
123  EXPECT_TRUE(CheckReceive(in, !filter_ip, buf, len));
124
125  out[2]->SendTo(buf, len, trans_addr);
126  EXPECT_TRUE(CheckReceive(in, !filter_port, buf, len));
127
128  out[3]->SendTo(buf, len, trans_addr);
129  EXPECT_TRUE(CheckReceive(in, !filter_ip && !filter_port, buf, len));
130
131  th_int.Stop();
132  th_ext.Stop();
133
134  delete nat;
135  delete natsf;
136  delete in;
137  for (int i = 0; i < 4; i++)
138    delete out[i];
139}
140
141// Tests that NATServer allocates bindings properly.
142void TestBindings(
143    SocketServer* internal, const SocketAddress& internal_addr,
144    SocketServer* external, const SocketAddress external_addrs[4]) {
145  TestSend(internal, internal_addr, external, external_addrs,
146           NAT_OPEN_CONE, true);
147  TestSend(internal, internal_addr, external, external_addrs,
148           NAT_ADDR_RESTRICTED, true);
149  TestSend(internal, internal_addr, external, external_addrs,
150           NAT_PORT_RESTRICTED, true);
151  TestSend(internal, internal_addr, external, external_addrs,
152           NAT_SYMMETRIC, false);
153}
154
155// Tests that NATServer filters packets properly.
156void TestFilters(
157    SocketServer* internal, const SocketAddress& internal_addr,
158    SocketServer* external, const SocketAddress external_addrs[4]) {
159  TestRecv(internal, internal_addr, external, external_addrs,
160           NAT_OPEN_CONE, false, false);
161  TestRecv(internal, internal_addr, external, external_addrs,
162           NAT_ADDR_RESTRICTED, true, false);
163  TestRecv(internal, internal_addr, external, external_addrs,
164           NAT_PORT_RESTRICTED, true, true);
165  TestRecv(internal, internal_addr, external, external_addrs,
166           NAT_SYMMETRIC, true, true);
167}
168
169bool TestConnectivity(const SocketAddress& src, const IPAddress& dst) {
170  // The physical NAT tests require connectivity to the selected ip from the
171  // internal address used for the NAT. Things like firewalls can break that, so
172  // check to see if it's worth even trying with this ip.
173  scoped_ptr<PhysicalSocketServer> pss(new PhysicalSocketServer());
174  scoped_ptr<AsyncSocket> client(pss->CreateAsyncSocket(src.family(),
175                                                        SOCK_DGRAM));
176  scoped_ptr<AsyncSocket> server(pss->CreateAsyncSocket(src.family(),
177                                                        SOCK_DGRAM));
178  if (client->Bind(SocketAddress(src.ipaddr(), 0)) != 0 ||
179      server->Bind(SocketAddress(dst, 0)) != 0) {
180    return false;
181  }
182  const char* buf = "hello other socket";
183  size_t len = strlen(buf);
184  int sent = client->SendTo(buf, len, server->GetLocalAddress());
185  SocketAddress addr;
186  const size_t kRecvBufSize = 64;
187  char recvbuf[kRecvBufSize];
188  Thread::Current()->SleepMs(100);
189  int received = server->RecvFrom(recvbuf, kRecvBufSize, &addr);
190  return received == sent && ::memcmp(buf, recvbuf, len) == 0;
191}
192
193void TestPhysicalInternal(const SocketAddress& int_addr) {
194  BasicNetworkManager network_manager;
195  network_manager.set_ipv6_enabled(true);
196  network_manager.StartUpdating();
197  // Process pending messages so the network list is updated.
198  Thread::Current()->ProcessMessages(0);
199
200  std::vector<Network*> networks;
201  network_manager.GetNetworks(&networks);
202  if (networks.empty()) {
203    LOG(LS_WARNING) << "Not enough network adapters for test.";
204    return;
205  }
206
207  SocketAddress ext_addr1(int_addr);
208  SocketAddress ext_addr2;
209  // Find an available IP with matching family. The test breaks if int_addr
210  // can't talk to ip, so check for connectivity as well.
211  for (std::vector<Network*>::iterator it = networks.begin();
212      it != networks.end(); ++it) {
213    const IPAddress& ip = (*it)->GetBestIP();
214    if (ip.family() == int_addr.family() && TestConnectivity(int_addr, ip)) {
215      ext_addr2.SetIP(ip);
216      break;
217    }
218  }
219  if (ext_addr2.IsNil()) {
220    LOG(LS_WARNING) << "No available IP of same family as " << int_addr;
221    return;
222  }
223
224  LOG(LS_INFO) << "selected ip " << ext_addr2.ipaddr();
225
226  SocketAddress ext_addrs[4] = {
227      SocketAddress(ext_addr1),
228      SocketAddress(ext_addr2),
229      SocketAddress(ext_addr1),
230      SocketAddress(ext_addr2)
231  };
232
233  scoped_ptr<PhysicalSocketServer> int_pss(new PhysicalSocketServer());
234  scoped_ptr<PhysicalSocketServer> ext_pss(new PhysicalSocketServer());
235
236  TestBindings(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
237  TestFilters(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
238}
239
240TEST(NatTest, DISABLED_ON_MAC(TestPhysicalIPv4)) {
241  TestPhysicalInternal(SocketAddress("127.0.0.1", 0));
242}
243
244TEST(NatTest, DISABLED_ON_MAC(TestPhysicalIPv6)) {
245  if (HasIPv6Enabled()) {
246    TestPhysicalInternal(SocketAddress("::1", 0));
247  } else {
248    LOG(LS_WARNING) << "No IPv6, skipping";
249  }
250}
251
252class TestVirtualSocketServer : public VirtualSocketServer {
253 public:
254  explicit TestVirtualSocketServer(SocketServer* ss)
255      : VirtualSocketServer(ss),
256        ss_(ss) {}
257  // Expose this publicly
258  IPAddress GetNextIP(int af) { return VirtualSocketServer::GetNextIP(af); }
259
260 private:
261  scoped_ptr<SocketServer> ss_;
262};
263
264void TestVirtualInternal(int family) {
265  scoped_ptr<TestVirtualSocketServer> int_vss(new TestVirtualSocketServer(
266      new PhysicalSocketServer()));
267  scoped_ptr<TestVirtualSocketServer> ext_vss(new TestVirtualSocketServer(
268      new PhysicalSocketServer()));
269
270  SocketAddress int_addr;
271  SocketAddress ext_addrs[4];
272  int_addr.SetIP(int_vss->GetNextIP(family));
273  ext_addrs[0].SetIP(ext_vss->GetNextIP(int_addr.family()));
274  ext_addrs[1].SetIP(ext_vss->GetNextIP(int_addr.family()));
275  ext_addrs[2].SetIP(ext_addrs[0].ipaddr());
276  ext_addrs[3].SetIP(ext_addrs[1].ipaddr());
277
278  TestBindings(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
279  TestFilters(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
280}
281
282TEST(NatTest, DISABLED_ON_MAC(TestVirtualIPv4)) {
283  TestVirtualInternal(AF_INET);
284}
285
286TEST(NatTest, DISABLED_ON_MAC(TestVirtualIPv6)) {
287  if (HasIPv6Enabled()) {
288    TestVirtualInternal(AF_INET6);
289  } else {
290    LOG(LS_WARNING) << "No IPv6, skipping";
291  }
292}
293
294// TODO: Finish this test
295class NatTcpTest : public testing::Test, public sigslot::has_slots<> {
296 public:
297  NatTcpTest() : connected_(false) {}
298  virtual void SetUp() {
299    int_vss_ = new TestVirtualSocketServer(new PhysicalSocketServer());
300    ext_vss_ = new TestVirtualSocketServer(new PhysicalSocketServer());
301    nat_ = new NATServer(NAT_OPEN_CONE, int_vss_, SocketAddress(),
302                         ext_vss_, SocketAddress());
303    natsf_ = new NATSocketFactory(int_vss_, nat_->internal_address());
304  }
305  void OnConnectEvent(AsyncSocket* socket) {
306    connected_ = true;
307  }
308  void OnAcceptEvent(AsyncSocket* socket) {
309    accepted_ = server_->Accept(NULL);
310  }
311  void OnCloseEvent(AsyncSocket* socket, int error) {
312  }
313  void ConnectEvents() {
314    server_->SignalReadEvent.connect(this, &NatTcpTest::OnAcceptEvent);
315    client_->SignalConnectEvent.connect(this, &NatTcpTest::OnConnectEvent);
316  }
317  TestVirtualSocketServer* int_vss_;
318  TestVirtualSocketServer* ext_vss_;
319  NATServer* nat_;
320  NATSocketFactory* natsf_;
321  AsyncSocket* client_;
322  AsyncSocket* server_;
323  AsyncSocket* accepted_;
324  bool connected_;
325};
326
327TEST_F(NatTcpTest, DISABLED_TestConnectOut) {
328  server_ = ext_vss_->CreateAsyncSocket(SOCK_STREAM);
329  server_->Bind(SocketAddress());
330  server_->Listen(5);
331
332  client_ = int_vss_->CreateAsyncSocket(SOCK_STREAM);
333  EXPECT_GE(0, client_->Bind(SocketAddress()));
334  EXPECT_GE(0, client_->Connect(server_->GetLocalAddress()));
335
336
337  ConnectEvents();
338
339  EXPECT_TRUE_WAIT(connected_, 1000);
340  EXPECT_EQ(client_->GetRemoteAddress(), server_->GetLocalAddress());
341  EXPECT_EQ(client_->GetRemoteAddress(), accepted_->GetLocalAddress());
342  EXPECT_EQ(client_->GetLocalAddress(), accepted_->GetRemoteAddress());
343
344  client_->Close();
345}
346//#endif
347