1/*
2 * libjingle
3 * Copyright 2004, Google Inc.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 *  1. Redistributions of source code must retain the above copyright notice,
9 *     this list of conditions and the following disclaimer.
10 *  2. Redistributions in binary form must reproduce the above copyright notice,
11 *     this list of conditions and the following disclaimer in the documentation
12 *     and/or other materials provided with the distribution.
13 *  3. The name of the author may not be used to endorse or promote products
14 *     derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28#include <string>
29
30#include "talk/base/gunit.h"
31#include "talk/base/logging.h"
32#include "talk/base/natserver.h"
33#include "talk/base/natsocketfactory.h"
34#include "talk/base/nethelpers.h"
35#include "talk/base/network.h"
36#include "talk/base/physicalsocketserver.h"
37#include "talk/base/testclient.h"
38#include "talk/base/virtualsocketserver.h"
39
40using namespace talk_base;
41
42bool CheckReceive(
43    TestClient* client, bool should_receive, const char* buf, size_t size) {
44  return (should_receive) ?
45      client->CheckNextPacket(buf, size, 0) :
46      client->CheckNoPacket();
47}
48
49TestClient* CreateTestClient(
50      SocketFactory* factory, const SocketAddress& local_addr) {
51  AsyncUDPSocket* socket = AsyncUDPSocket::Create(factory, local_addr);
52  return new TestClient(socket);
53}
54
55// Tests that when sending from internal_addr to external_addrs through the
56// NAT type specified by nat_type, all external addrs receive the sent packet
57// and, if exp_same is true, all use the same mapped-address on the NAT.
58void TestSend(
59      SocketServer* internal, const SocketAddress& internal_addr,
60      SocketServer* external, const SocketAddress external_addrs[4],
61      NATType nat_type, bool exp_same) {
62  Thread th_int(internal);
63  Thread th_ext(external);
64
65  SocketAddress server_addr = internal_addr;
66  server_addr.SetPort(0);  // Auto-select a port
67  NATServer* nat = new NATServer(
68      nat_type, internal, server_addr, external, external_addrs[0]);
69  NATSocketFactory* natsf = new NATSocketFactory(internal,
70                                                 nat->internal_address());
71
72  TestClient* in = CreateTestClient(natsf, internal_addr);
73  TestClient* out[4];
74  for (int i = 0; i < 4; i++)
75    out[i] = CreateTestClient(external, external_addrs[i]);
76
77  th_int.Start();
78  th_ext.Start();
79
80  const char* buf = "filter_test";
81  size_t len = strlen(buf);
82
83  in->SendTo(buf, len, out[0]->address());
84  SocketAddress trans_addr;
85  EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
86
87  for (int i = 1; i < 4; i++) {
88    in->SendTo(buf, len, out[i]->address());
89    SocketAddress trans_addr2;
90    EXPECT_TRUE(out[i]->CheckNextPacket(buf, len, &trans_addr2));
91    bool are_same = (trans_addr == trans_addr2);
92    ASSERT_EQ(are_same, exp_same) << "same translated address";
93    ASSERT_NE(AF_UNSPEC, trans_addr.family());
94    ASSERT_NE(AF_UNSPEC, trans_addr2.family());
95  }
96
97  th_int.Stop();
98  th_ext.Stop();
99
100  delete nat;
101  delete natsf;
102  delete in;
103  for (int i = 0; i < 4; i++)
104    delete out[i];
105}
106
107// Tests that when sending from external_addrs to internal_addr, the packet
108// is delivered according to the specified filter_ip and filter_port rules.
109void TestRecv(
110      SocketServer* internal, const SocketAddress& internal_addr,
111      SocketServer* external, const SocketAddress external_addrs[4],
112      NATType nat_type, bool filter_ip, bool filter_port) {
113  Thread th_int(internal);
114  Thread th_ext(external);
115
116  SocketAddress server_addr = internal_addr;
117  server_addr.SetPort(0);  // Auto-select a port
118  NATServer* nat = new NATServer(
119      nat_type, internal, server_addr, external, external_addrs[0]);
120  NATSocketFactory* natsf = new NATSocketFactory(internal,
121                                                 nat->internal_address());
122
123  TestClient* in = CreateTestClient(natsf, internal_addr);
124  TestClient* out[4];
125  for (int i = 0; i < 4; i++)
126    out[i] = CreateTestClient(external, external_addrs[i]);
127
128  th_int.Start();
129  th_ext.Start();
130
131  const char* buf = "filter_test";
132  size_t len = strlen(buf);
133
134  in->SendTo(buf, len, out[0]->address());
135  SocketAddress trans_addr;
136  EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
137
138  out[1]->SendTo(buf, len, trans_addr);
139  EXPECT_TRUE(CheckReceive(in, !filter_ip, buf, len));
140
141  out[2]->SendTo(buf, len, trans_addr);
142  EXPECT_TRUE(CheckReceive(in, !filter_port, buf, len));
143
144  out[3]->SendTo(buf, len, trans_addr);
145  EXPECT_TRUE(CheckReceive(in, !filter_ip && !filter_port, buf, len));
146
147  th_int.Stop();
148  th_ext.Stop();
149
150  delete nat;
151  delete natsf;
152  delete in;
153  for (int i = 0; i < 4; i++)
154    delete out[i];
155}
156
157// Tests that NATServer allocates bindings properly.
158void TestBindings(
159    SocketServer* internal, const SocketAddress& internal_addr,
160    SocketServer* external, const SocketAddress external_addrs[4]) {
161  TestSend(internal, internal_addr, external, external_addrs,
162           NAT_OPEN_CONE, true);
163  TestSend(internal, internal_addr, external, external_addrs,
164           NAT_ADDR_RESTRICTED, true);
165  TestSend(internal, internal_addr, external, external_addrs,
166           NAT_PORT_RESTRICTED, true);
167  TestSend(internal, internal_addr, external, external_addrs,
168           NAT_SYMMETRIC, false);
169}
170
171// Tests that NATServer filters packets properly.
172void TestFilters(
173    SocketServer* internal, const SocketAddress& internal_addr,
174    SocketServer* external, const SocketAddress external_addrs[4]) {
175  TestRecv(internal, internal_addr, external, external_addrs,
176           NAT_OPEN_CONE, false, false);
177  TestRecv(internal, internal_addr, external, external_addrs,
178           NAT_ADDR_RESTRICTED, true, false);
179  TestRecv(internal, internal_addr, external, external_addrs,
180           NAT_PORT_RESTRICTED, true, true);
181  TestRecv(internal, internal_addr, external, external_addrs,
182           NAT_SYMMETRIC, true, true);
183}
184
185bool TestConnectivity(const SocketAddress& src, const IPAddress& dst) {
186  // The physical NAT tests require connectivity to the selected ip from the
187  // internal address used for the NAT. Things like firewalls can break that, so
188  // check to see if it's worth even trying with this ip.
189  scoped_ptr<PhysicalSocketServer> pss(new PhysicalSocketServer());
190  scoped_ptr<AsyncSocket> client(pss->CreateAsyncSocket(src.family(),
191                                                        SOCK_DGRAM));
192  scoped_ptr<AsyncSocket> server(pss->CreateAsyncSocket(src.family(),
193                                                        SOCK_DGRAM));
194  if (client->Bind(SocketAddress(src.ipaddr(), 0)) != 0 ||
195      server->Bind(SocketAddress(dst, 0)) != 0) {
196    return false;
197  }
198  const char* buf = "hello other socket";
199  size_t len = strlen(buf);
200  int sent = client->SendTo(buf, len, server->GetLocalAddress());
201  SocketAddress addr;
202  const size_t kRecvBufSize = 64;
203  char recvbuf[kRecvBufSize];
204  Thread::Current()->SleepMs(100);
205  int received = server->RecvFrom(recvbuf, kRecvBufSize, &addr);
206  return received == sent && ::memcmp(buf, recvbuf, len) == 0;
207}
208
209void TestPhysicalInternal(const SocketAddress& int_addr) {
210  BasicNetworkManager network_manager;
211  network_manager.set_ipv6_enabled(true);
212  network_manager.StartUpdating();
213  // Process pending messages so the network list is updated.
214  Thread::Current()->ProcessMessages(0);
215
216  std::vector<Network*> networks;
217  network_manager.GetNetworks(&networks);
218  if (networks.empty()) {
219    LOG(LS_WARNING) << "Not enough network adapters for test.";
220    return;
221  }
222
223  SocketAddress ext_addr1(int_addr);
224  SocketAddress ext_addr2;
225  // Find an available IP with matching family. The test breaks if int_addr
226  // can't talk to ip, so check for connectivity as well.
227  for (std::vector<Network*>::iterator it = networks.begin();
228      it != networks.end(); ++it) {
229    const IPAddress& ip = (*it)->ip();
230    if (ip.family() == int_addr.family() && TestConnectivity(int_addr, ip)) {
231      ext_addr2.SetIP(ip);
232      break;
233    }
234  }
235  if (ext_addr2.IsNil()) {
236    LOG(LS_WARNING) << "No available IP of same family as " << int_addr;
237    return;
238  }
239
240  LOG(LS_INFO) << "selected ip " << ext_addr2.ipaddr();
241
242  SocketAddress ext_addrs[4] = {
243      SocketAddress(ext_addr1),
244      SocketAddress(ext_addr2),
245      SocketAddress(ext_addr1),
246      SocketAddress(ext_addr2)
247  };
248
249  scoped_ptr<PhysicalSocketServer> int_pss(new PhysicalSocketServer());
250  scoped_ptr<PhysicalSocketServer> ext_pss(new PhysicalSocketServer());
251
252  TestBindings(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
253  TestFilters(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
254}
255
256TEST(NatTest, TestPhysicalIPv4) {
257  TestPhysicalInternal(SocketAddress("127.0.0.1", 0));
258}
259
260TEST(NatTest, TestPhysicalIPv6) {
261  if (HasIPv6Enabled()) {
262    TestPhysicalInternal(SocketAddress("::1", 0));
263  } else {
264    LOG(LS_WARNING) << "No IPv6, skipping";
265  }
266}
267
268class TestVirtualSocketServer : public VirtualSocketServer {
269 public:
270  explicit TestVirtualSocketServer(SocketServer* ss)
271      : VirtualSocketServer(ss),
272        ss_(ss) {}
273  // Expose this publicly
274  IPAddress GetNextIP(int af) { return VirtualSocketServer::GetNextIP(af); }
275
276 private:
277  scoped_ptr<SocketServer> ss_;
278};
279
280void TestVirtualInternal(int family) {
281  scoped_ptr<TestVirtualSocketServer> int_vss(new TestVirtualSocketServer(
282      new PhysicalSocketServer()));
283  scoped_ptr<TestVirtualSocketServer> ext_vss(new TestVirtualSocketServer(
284      new PhysicalSocketServer()));
285
286  SocketAddress int_addr;
287  SocketAddress ext_addrs[4];
288  int_addr.SetIP(int_vss->GetNextIP(family));
289  ext_addrs[0].SetIP(ext_vss->GetNextIP(int_addr.family()));
290  ext_addrs[1].SetIP(ext_vss->GetNextIP(int_addr.family()));
291  ext_addrs[2].SetIP(ext_addrs[0].ipaddr());
292  ext_addrs[3].SetIP(ext_addrs[1].ipaddr());
293
294  TestBindings(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
295  TestFilters(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
296}
297
298TEST(NatTest, TestVirtualIPv4) {
299  TestVirtualInternal(AF_INET);
300}
301
302TEST(NatTest, TestVirtualIPv6) {
303  if (HasIPv6Enabled()) {
304    TestVirtualInternal(AF_INET6);
305  } else {
306    LOG(LS_WARNING) << "No IPv6, skipping";
307  }
308}
309
310// TODO: Finish this test
311class NatTcpTest : public testing::Test, public sigslot::has_slots<> {
312 public:
313  NatTcpTest() : connected_(false) {}
314  virtual void SetUp() {
315    int_vss_ = new TestVirtualSocketServer(new PhysicalSocketServer());
316    ext_vss_ = new TestVirtualSocketServer(new PhysicalSocketServer());
317    nat_ = new NATServer(NAT_OPEN_CONE, int_vss_, SocketAddress(),
318                         ext_vss_, SocketAddress());
319    natsf_ = new NATSocketFactory(int_vss_, nat_->internal_address());
320  }
321  void OnConnectEvent(AsyncSocket* socket) {
322    connected_ = true;
323  }
324  void OnAcceptEvent(AsyncSocket* socket) {
325    accepted_ = server_->Accept(NULL);
326  }
327  void OnCloseEvent(AsyncSocket* socket, int error) {
328  }
329  void ConnectEvents() {
330    server_->SignalReadEvent.connect(this, &NatTcpTest::OnAcceptEvent);
331    client_->SignalConnectEvent.connect(this, &NatTcpTest::OnConnectEvent);
332  }
333  TestVirtualSocketServer* int_vss_;
334  TestVirtualSocketServer* ext_vss_;
335  NATServer* nat_;
336  NATSocketFactory* natsf_;
337  AsyncSocket* client_;
338  AsyncSocket* server_;
339  AsyncSocket* accepted_;
340  bool connected_;
341};
342
343TEST_F(NatTcpTest, DISABLED_TestConnectOut) {
344  server_ = ext_vss_->CreateAsyncSocket(SOCK_STREAM);
345  server_->Bind(SocketAddress());
346  server_->Listen(5);
347
348  client_ = int_vss_->CreateAsyncSocket(SOCK_STREAM);
349  EXPECT_GE(0, client_->Bind(SocketAddress()));
350  EXPECT_GE(0, client_->Connect(server_->GetLocalAddress()));
351
352
353  ConnectEvents();
354
355  EXPECT_TRUE_WAIT(connected_, 1000);
356  EXPECT_EQ(client_->GetRemoteAddress(), server_->GetLocalAddress());
357  EXPECT_EQ(client_->GetRemoteAddress(), accepted_->GetLocalAddress());
358  EXPECT_EQ(client_->GetLocalAddress(), accepted_->GetRemoteAddress());
359
360  client_->Close();
361}
362//#endif
363