1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/udp/udp_client_socket.h"
6#include "net/udp/udp_server_socket.h"
7
8#include "base/basictypes.h"
9#include "base/metrics/histogram.h"
10#include "net/base/io_buffer.h"
11#include "net/base/ip_endpoint.h"
12#include "net/base/net_errors.h"
13#include "net/base/net_test_suite.h"
14#include "net/base/net_util.h"
15#include "net/base/sys_addrinfo.h"
16#include "net/base/test_completion_callback.h"
17#include "testing/gtest/include/gtest/gtest.h"
18#include "testing/platform_test.h"
19
20namespace net {
21
22namespace {
23
24class UDPSocketTest : public PlatformTest {
25 public:
26  UDPSocketTest()
27      : buffer_(new IOBufferWithSize(kMaxRead)) {
28  }
29
30  // Blocks until data is read from the socket.
31  std::string RecvFromSocket(UDPServerSocket* socket) {
32    TestCompletionCallback callback;
33
34    int rv = socket->RecvFrom(buffer_, kMaxRead, &recv_from_address_,
35                              &callback);
36    if (rv == ERR_IO_PENDING)
37      rv = callback.WaitForResult();
38    if (rv < 0)
39      return "";  // error!
40    return std::string(buffer_->data(), rv);
41  }
42
43  // Loop until |msg| has been written to the socket or until an
44  // error occurs.
45  // If |address| is specified, then it is used for the destination
46  // to send to. Otherwise, will send to the last socket this server
47  // received from.
48  int SendToSocket(UDPServerSocket* socket, std::string msg) {
49    return SendToSocket(socket, msg, recv_from_address_);
50  }
51
52  int SendToSocket(UDPServerSocket* socket,
53                   std::string msg,
54                   const IPEndPoint& address) {
55    TestCompletionCallback callback;
56
57    int length = msg.length();
58    scoped_refptr<StringIOBuffer> io_buffer(new StringIOBuffer(msg));
59    scoped_refptr<DrainableIOBuffer> buffer(
60        new DrainableIOBuffer(io_buffer, length));
61
62    int bytes_sent = 0;
63    while (buffer->BytesRemaining()) {
64      int rv = socket->SendTo(buffer, buffer->BytesRemaining(),
65                              address, &callback);
66      if (rv == ERR_IO_PENDING)
67        rv = callback.WaitForResult();
68      if (rv <= 0)
69        return bytes_sent > 0 ? bytes_sent : rv;
70      bytes_sent += rv;
71      buffer->DidConsume(rv);
72    }
73    return bytes_sent;
74  }
75
76  std::string ReadSocket(UDPClientSocket* socket) {
77    TestCompletionCallback callback;
78
79    int rv = socket->Read(buffer_, kMaxRead, &callback);
80    if (rv == ERR_IO_PENDING)
81      rv = callback.WaitForResult();
82    if (rv < 0)
83      return "";  // error!
84    return std::string(buffer_->data(), rv);
85  }
86
87  // Loop until |msg| has been written to the socket or until an
88  // error occurs.
89  int WriteSocket(UDPClientSocket* socket, std::string msg) {
90    TestCompletionCallback callback;
91
92    int length = msg.length();
93    scoped_refptr<StringIOBuffer> io_buffer(new StringIOBuffer(msg));
94    scoped_refptr<DrainableIOBuffer> buffer(
95        new DrainableIOBuffer(io_buffer, length));
96
97    int bytes_sent = 0;
98    while (buffer->BytesRemaining()) {
99      int rv = socket->Write(buffer, buffer->BytesRemaining(), &callback);
100      if (rv == ERR_IO_PENDING)
101        rv = callback.WaitForResult();
102      if (rv <= 0)
103        return bytes_sent > 0 ? bytes_sent : rv;
104      bytes_sent += rv;
105      buffer->DidConsume(rv);
106    }
107    return bytes_sent;
108  }
109
110 protected:
111  static const int kMaxRead = 1024;
112  scoped_refptr<IOBufferWithSize> buffer_;
113  IPEndPoint recv_from_address_;
114};
115
116// Creates and address from an ip/port and returns it in |address|.
117void CreateUDPAddress(std::string ip_str, int port, IPEndPoint* address) {
118  IPAddressNumber ip_number;
119  bool rv = ParseIPLiteralToNumber(ip_str, &ip_number);
120  if (!rv)
121    return;
122  *address = IPEndPoint(ip_number, port);
123}
124
125TEST_F(UDPSocketTest, Connect) {
126  const int kPort = 9999;
127  std::string simple_message("hello world!");
128
129  // Setup the server to listen.
130  IPEndPoint bind_address;
131  CreateUDPAddress("0.0.0.0", kPort, &bind_address);
132  UDPServerSocket server(NULL, NetLog::Source());
133  int rv = server.Listen(bind_address);
134  EXPECT_EQ(OK, rv);
135
136  // Setup the client.
137  IPEndPoint server_address;
138  CreateUDPAddress("127.0.0.1", kPort, &server_address);
139  UDPClientSocket client(NULL, NetLog::Source());
140  rv = client.Connect(server_address);
141  EXPECT_EQ(OK, rv);
142
143  // Client sends to the server.
144  rv = WriteSocket(&client, simple_message);
145  EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
146
147  // Server waits for message.
148  std::string str = RecvFromSocket(&server);
149  DCHECK(simple_message == str);
150
151  // Server echoes reply.
152  rv = SendToSocket(&server, simple_message);
153  EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
154
155  // Client waits for response.
156  str = ReadSocket(&client);
157  DCHECK(simple_message == str);
158}
159
160// In this test, we verify that connect() on a socket will have the effect
161// of filtering reads on this socket only to data read from the destination
162// we connected to.
163//
164// The purpose of this test is that some documentation indicates that connect
165// binds the client's sends to send to a particular server endpoint, but does
166// not bind the client's reads to only be from that endpoint, and that we need
167// to always use recvfrom() to disambiguate.
168TEST_F(UDPSocketTest, VerifyConnectBindsAddr) {
169  const int kPort1 = 9999;
170  const int kPort2 = 10000;
171  std::string simple_message("hello world!");
172  std::string foreign_message("BAD MESSAGE TO GET!!");
173
174  // Setup the first server to listen.
175  IPEndPoint bind_address;
176  CreateUDPAddress("0.0.0.0", kPort1, &bind_address);
177  UDPServerSocket server1(NULL, NetLog::Source());
178  int rv = server1.Listen(bind_address);
179  EXPECT_EQ(OK, rv);
180
181  // Setup the second server to listen.
182  CreateUDPAddress("0.0.0.0", kPort2, &bind_address);
183  UDPServerSocket server2(NULL, NetLog::Source());
184  rv = server2.Listen(bind_address);
185  EXPECT_EQ(OK, rv);
186
187  // Setup the client, connected to server 1.
188  IPEndPoint server_address;
189  CreateUDPAddress("127.0.0.1", kPort1, &server_address);
190  UDPClientSocket client(NULL, NetLog::Source());
191  rv = client.Connect(server_address);
192  EXPECT_EQ(OK, rv);
193
194  // Client sends to server1.
195  rv = WriteSocket(&client, simple_message);
196  EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
197
198  // Server1 waits for message.
199  std::string str = RecvFromSocket(&server1);
200  DCHECK(simple_message == str);
201
202  // Get the client's address.
203  IPEndPoint client_address;
204  rv = client.GetLocalAddress(&client_address);
205  EXPECT_EQ(OK, rv);
206
207  // Server2 sends reply.
208  rv = SendToSocket(&server2, foreign_message,
209                    client_address);
210  EXPECT_EQ(foreign_message.length(), static_cast<size_t>(rv));
211
212  // Server1 sends reply.
213  rv = SendToSocket(&server1, simple_message,
214                    client_address);
215  EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
216
217  // Client waits for response.
218  str = ReadSocket(&client);
219  DCHECK(simple_message == str);
220}
221
222TEST_F(UDPSocketTest, ClientGetLocalPeerAddresses) {
223  struct TestData {
224    std::string remote_address;
225    std::string local_address;
226    bool may_fail;
227  } tests[] = {
228    { "127.0.00.1", "127.0.0.1", false },
229    { "192.168.1.1", "127.0.0.1", false },
230    { "::1", "::1", true },
231    { "2001:db8:0::42", "::1", true },
232  };
233  for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); i++) {
234    SCOPED_TRACE(std::string("Connecting from ") +  tests[i].local_address +
235                 std::string(" to ") + tests[i].remote_address);
236
237    net::IPAddressNumber ip_number;
238    net::ParseIPLiteralToNumber(tests[i].remote_address, &ip_number);
239    net::IPEndPoint remote_address(ip_number, 80);
240    net::ParseIPLiteralToNumber(tests[i].local_address, &ip_number);
241    net::IPEndPoint local_address(ip_number, 80);
242
243    UDPClientSocket client(NULL, NetLog::Source());
244    int rv = client.Connect(remote_address);
245    if (tests[i].may_fail && rv == ERR_ADDRESS_UNREACHABLE) {
246      // Connect() may return ERR_ADDRESS_UNREACHABLE for IPv6
247      // addresses if IPv6 is not configured.
248      continue;
249    }
250
251    EXPECT_LE(ERR_IO_PENDING, rv);
252
253    IPEndPoint fetched_local_address;
254    rv = client.GetLocalAddress(&fetched_local_address);
255    EXPECT_EQ(OK, rv);
256
257    // TODO(mbelshe): figure out how to verify the IP and port.
258    //                The port is dynamically generated by the udp stack.
259    //                The IP is the real IP of the client, not necessarily
260    //                loopback.
261    //EXPECT_EQ(local_address.address(), fetched_local_address.address());
262
263    IPEndPoint fetched_remote_address;
264    rv = client.GetPeerAddress(&fetched_remote_address);
265    EXPECT_EQ(OK, rv);
266
267    EXPECT_EQ(remote_address, fetched_remote_address);
268  }
269}
270
271TEST_F(UDPSocketTest, ServerGetLocalAddress) {
272  IPEndPoint bind_address;
273  CreateUDPAddress("127.0.0.1", 0, &bind_address);
274  UDPServerSocket server(NULL, NetLog::Source());
275  int rv = server.Listen(bind_address);
276  EXPECT_EQ(OK, rv);
277
278  IPEndPoint local_address;
279  rv = server.GetLocalAddress(&local_address);
280  EXPECT_EQ(rv, 0);
281
282  // Verify that port was allocated.
283  EXPECT_GT(local_address.port(), 0);
284  EXPECT_EQ(local_address.address(), bind_address.address());
285}
286
287TEST_F(UDPSocketTest, ServerGetPeerAddress) {
288  IPEndPoint bind_address;
289  CreateUDPAddress("127.0.0.1", 0, &bind_address);
290  UDPServerSocket server(NULL, NetLog::Source());
291  int rv = server.Listen(bind_address);
292  EXPECT_EQ(OK, rv);
293
294  IPEndPoint peer_address;
295  rv = server.GetPeerAddress(&peer_address);
296  EXPECT_EQ(rv, ERR_SOCKET_NOT_CONNECTED);
297}
298
299// Close the socket while read is pending.
300TEST_F(UDPSocketTest, CloseWithPendingRead) {
301  IPEndPoint bind_address;
302  CreateUDPAddress("127.0.0.1", 0, &bind_address);
303  UDPServerSocket server(NULL, NetLog::Source());
304  int rv = server.Listen(bind_address);
305  EXPECT_EQ(OK, rv);
306
307  TestCompletionCallback callback;
308  IPEndPoint from;
309  rv = server.RecvFrom(buffer_, kMaxRead, &from, &callback);
310  EXPECT_EQ(rv, ERR_IO_PENDING);
311
312  server.Close();
313
314  EXPECT_FALSE(callback.have_result());
315}
316
317}  // namespace
318
319}  // namespace net
320