1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/tcp_socket.h"
6
7#include <string.h>
8
9#include <string>
10#include <vector>
11
12#include "base/memory/ref_counted.h"
13#include "base/memory/scoped_ptr.h"
14#include "net/base/address_list.h"
15#include "net/base/io_buffer.h"
16#include "net/base/ip_endpoint.h"
17#include "net/base/net_errors.h"
18#include "net/base/test_completion_callback.h"
19#include "net/socket/tcp_client_socket.h"
20#include "testing/gtest/include/gtest/gtest.h"
21#include "testing/platform_test.h"
22
23namespace net {
24
25namespace {
26const int kListenBacklog = 5;
27
28class TCPSocketTest : public PlatformTest {
29 protected:
30  TCPSocketTest() : socket_(NULL, NetLog::Source()) {
31  }
32
33  void SetUpListenIPv4() {
34    IPEndPoint address;
35    ParseAddress("127.0.0.1", 0, &address);
36
37    ASSERT_EQ(OK, socket_.Open(ADDRESS_FAMILY_IPV4));
38    ASSERT_EQ(OK, socket_.Bind(address));
39    ASSERT_EQ(OK, socket_.Listen(kListenBacklog));
40    ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
41  }
42
43  void SetUpListenIPv6(bool* success) {
44    *success = false;
45    IPEndPoint address;
46    ParseAddress("::1", 0, &address);
47
48    if (socket_.Open(ADDRESS_FAMILY_IPV6) != OK ||
49        socket_.Bind(address) != OK ||
50        socket_.Listen(kListenBacklog) != OK) {
51      LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is "
52          "disabled. Skipping the test";
53      return;
54    }
55    ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
56    *success = true;
57  }
58
59  void ParseAddress(const std::string& ip_str, int port, IPEndPoint* address) {
60    IPAddressNumber ip_number;
61    bool rv = ParseIPLiteralToNumber(ip_str, &ip_number);
62    if (!rv)
63      return;
64    *address = IPEndPoint(ip_number, port);
65  }
66
67  void TestAcceptAsync() {
68    TestCompletionCallback accept_callback;
69    scoped_ptr<TCPSocket> accepted_socket;
70    IPEndPoint accepted_address;
71    ASSERT_EQ(ERR_IO_PENDING,
72              socket_.Accept(&accepted_socket, &accepted_address,
73                             accept_callback.callback()));
74
75    TestCompletionCallback connect_callback;
76    TCPClientSocket connecting_socket(local_address_list(),
77                                      NULL, NetLog::Source());
78    connecting_socket.Connect(connect_callback.callback());
79
80    EXPECT_EQ(OK, connect_callback.WaitForResult());
81    EXPECT_EQ(OK, accept_callback.WaitForResult());
82
83    EXPECT_TRUE(accepted_socket.get());
84
85    // Both sockets should be on the loopback network interface.
86    EXPECT_EQ(accepted_address.address(), local_address_.address());
87  }
88
89  AddressList local_address_list() const {
90    return AddressList(local_address_);
91  }
92
93  TCPSocket socket_;
94  IPEndPoint local_address_;
95};
96
97// Test listening and accepting with a socket bound to an IPv4 address.
98TEST_F(TCPSocketTest, Accept) {
99  ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
100
101  TestCompletionCallback connect_callback;
102  // TODO(yzshen): Switch to use TCPSocket when it supports client socket
103  // operations.
104  TCPClientSocket connecting_socket(local_address_list(),
105                                    NULL, NetLog::Source());
106  connecting_socket.Connect(connect_callback.callback());
107
108  TestCompletionCallback accept_callback;
109  scoped_ptr<TCPSocket> accepted_socket;
110  IPEndPoint accepted_address;
111  int result = socket_.Accept(&accepted_socket, &accepted_address,
112                              accept_callback.callback());
113  if (result == ERR_IO_PENDING)
114    result = accept_callback.WaitForResult();
115  ASSERT_EQ(OK, result);
116
117  EXPECT_TRUE(accepted_socket.get());
118
119  // Both sockets should be on the loopback network interface.
120  EXPECT_EQ(accepted_address.address(), local_address_.address());
121
122  EXPECT_EQ(OK, connect_callback.WaitForResult());
123}
124
125// Test Accept() callback.
126TEST_F(TCPSocketTest, AcceptAsync) {
127  ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
128  TestAcceptAsync();
129}
130
131#if defined(OS_WIN)
132// Test Accept() for AdoptListenSocket.
133TEST_F(TCPSocketTest, AcceptForAdoptedListenSocket) {
134  // Create a socket to be used with AdoptListenSocket.
135  SOCKET existing_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
136  ASSERT_EQ(OK, socket_.AdoptListenSocket(existing_socket));
137
138  IPEndPoint address;
139  ParseAddress("127.0.0.1", 0, &address);
140  SockaddrStorage storage;
141  ASSERT_TRUE(address.ToSockAddr(storage.addr, &storage.addr_len));
142  ASSERT_EQ(0, bind(existing_socket, storage.addr, storage.addr_len));
143
144  ASSERT_EQ(OK, socket_.Listen(kListenBacklog));
145  ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
146
147  TestAcceptAsync();
148}
149#endif
150
151// Accept two connections simultaneously.
152TEST_F(TCPSocketTest, Accept2Connections) {
153  ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
154
155  TestCompletionCallback accept_callback;
156  scoped_ptr<TCPSocket> accepted_socket;
157  IPEndPoint accepted_address;
158
159  ASSERT_EQ(ERR_IO_PENDING,
160            socket_.Accept(&accepted_socket, &accepted_address,
161                           accept_callback.callback()));
162
163  TestCompletionCallback connect_callback;
164  TCPClientSocket connecting_socket(local_address_list(),
165                                    NULL, NetLog::Source());
166  connecting_socket.Connect(connect_callback.callback());
167
168  TestCompletionCallback connect_callback2;
169  TCPClientSocket connecting_socket2(local_address_list(),
170                                     NULL, NetLog::Source());
171  connecting_socket2.Connect(connect_callback2.callback());
172
173  EXPECT_EQ(OK, accept_callback.WaitForResult());
174
175  TestCompletionCallback accept_callback2;
176  scoped_ptr<TCPSocket> accepted_socket2;
177  IPEndPoint accepted_address2;
178
179  int result = socket_.Accept(&accepted_socket2, &accepted_address2,
180                              accept_callback2.callback());
181  if (result == ERR_IO_PENDING)
182    result = accept_callback2.WaitForResult();
183  ASSERT_EQ(OK, result);
184
185  EXPECT_EQ(OK, connect_callback.WaitForResult());
186  EXPECT_EQ(OK, connect_callback2.WaitForResult());
187
188  EXPECT_TRUE(accepted_socket.get());
189  EXPECT_TRUE(accepted_socket2.get());
190  EXPECT_NE(accepted_socket.get(), accepted_socket2.get());
191
192  EXPECT_EQ(accepted_address.address(), local_address_.address());
193  EXPECT_EQ(accepted_address2.address(), local_address_.address());
194}
195
196// Test listening and accepting with a socket bound to an IPv6 address.
197TEST_F(TCPSocketTest, AcceptIPv6) {
198  bool initialized = false;
199  ASSERT_NO_FATAL_FAILURE(SetUpListenIPv6(&initialized));
200  if (!initialized)
201    return;
202
203  TestCompletionCallback connect_callback;
204  TCPClientSocket connecting_socket(local_address_list(),
205                                    NULL, NetLog::Source());
206  connecting_socket.Connect(connect_callback.callback());
207
208  TestCompletionCallback accept_callback;
209  scoped_ptr<TCPSocket> accepted_socket;
210  IPEndPoint accepted_address;
211  int result = socket_.Accept(&accepted_socket, &accepted_address,
212                              accept_callback.callback());
213  if (result == ERR_IO_PENDING)
214    result = accept_callback.WaitForResult();
215  ASSERT_EQ(OK, result);
216
217  EXPECT_TRUE(accepted_socket.get());
218
219  // Both sockets should be on the loopback network interface.
220  EXPECT_EQ(accepted_address.address(), local_address_.address());
221
222  EXPECT_EQ(OK, connect_callback.WaitForResult());
223}
224
225TEST_F(TCPSocketTest, ReadWrite) {
226  ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
227
228  TestCompletionCallback connect_callback;
229  TCPSocket connecting_socket(NULL, NetLog::Source());
230  int result = connecting_socket.Open(ADDRESS_FAMILY_IPV4);
231  ASSERT_EQ(OK, result);
232  connecting_socket.Connect(local_address_, connect_callback.callback());
233
234  TestCompletionCallback accept_callback;
235  scoped_ptr<TCPSocket> accepted_socket;
236  IPEndPoint accepted_address;
237  result = socket_.Accept(&accepted_socket, &accepted_address,
238                          accept_callback.callback());
239  ASSERT_EQ(OK, accept_callback.GetResult(result));
240
241  ASSERT_TRUE(accepted_socket.get());
242
243  // Both sockets should be on the loopback network interface.
244  EXPECT_EQ(accepted_address.address(), local_address_.address());
245
246  EXPECT_EQ(OK, connect_callback.WaitForResult());
247
248  const std::string message("test message");
249  std::vector<char> buffer(message.size());
250
251  size_t bytes_written = 0;
252  while (bytes_written < message.size()) {
253    scoped_refptr<IOBufferWithSize> write_buffer(
254        new IOBufferWithSize(message.size() - bytes_written));
255    memmove(write_buffer->data(), message.data() + bytes_written,
256            message.size() - bytes_written);
257
258    TestCompletionCallback write_callback;
259    int write_result = accepted_socket->Write(
260        write_buffer.get(), write_buffer->size(), write_callback.callback());
261    write_result = write_callback.GetResult(write_result);
262    ASSERT_TRUE(write_result >= 0);
263    bytes_written += write_result;
264    ASSERT_TRUE(bytes_written <= message.size());
265  }
266
267  size_t bytes_read = 0;
268  while (bytes_read < message.size()) {
269    scoped_refptr<IOBufferWithSize> read_buffer(
270        new IOBufferWithSize(message.size() - bytes_read));
271    TestCompletionCallback read_callback;
272    int read_result = connecting_socket.Read(
273        read_buffer.get(), read_buffer->size(), read_callback.callback());
274    read_result = read_callback.GetResult(read_result);
275    ASSERT_TRUE(read_result >= 0);
276    ASSERT_TRUE(bytes_read + read_result <= message.size());
277    memmove(&buffer[bytes_read], read_buffer->data(), read_result);
278    bytes_read += read_result;
279  }
280
281  std::string received_message(buffer.begin(), buffer.end());
282  ASSERT_EQ(message, received_message);
283}
284
285}  // namespace
286}  // namespace net
287