1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/tcp_server_socket.h"
6
7#include <string>
8#include <vector>
9
10#include "base/compiler_specific.h"
11#include "base/memory/ref_counted.h"
12#include "base/memory/scoped_ptr.h"
13#include "net/base/address_list.h"
14#include "net/base/io_buffer.h"
15#include "net/base/ip_endpoint.h"
16#include "net/base/net_errors.h"
17#include "net/base/test_completion_callback.h"
18#include "net/socket/tcp_client_socket.h"
19#include "testing/gtest/include/gtest/gtest.h"
20#include "testing/platform_test.h"
21
22namespace net {
23
24namespace {
25const int kListenBacklog = 5;
26
27class TCPServerSocketTest : public PlatformTest {
28 protected:
29  TCPServerSocketTest()
30      : socket_(NULL, NetLog::Source()) {
31  }
32
33  void SetUpIPv4() {
34    IPEndPoint address;
35    ParseAddress("127.0.0.1", 0, &address);
36    ASSERT_EQ(OK, socket_.Listen(address, kListenBacklog));
37    ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
38  }
39
40  void SetUpIPv6(bool* success) {
41    *success = false;
42    IPEndPoint address;
43    ParseAddress("::1", 0, &address);
44    if (socket_.Listen(address, kListenBacklog) != 0) {
45      LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is "
46          "disabled. Skipping the test";
47      return;
48    }
49    ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
50    *success = true;
51  }
52
53  void ParseAddress(std::string ip_str, int port, IPEndPoint* address) {
54    IPAddressNumber ip_number;
55    bool rv = ParseIPLiteralToNumber(ip_str, &ip_number);
56    if (!rv)
57      return;
58    *address = IPEndPoint(ip_number, port);
59  }
60
61  static IPEndPoint GetPeerAddress(StreamSocket* socket) {
62    IPEndPoint address;
63    EXPECT_EQ(OK, socket->GetPeerAddress(&address));
64    return address;
65  }
66
67  AddressList local_address_list() const {
68    return AddressList(local_address_);
69  }
70
71  TCPServerSocket socket_;
72  IPEndPoint local_address_;
73};
74
75TEST_F(TCPServerSocketTest, Accept) {
76  ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
77
78  TestCompletionCallback connect_callback;
79  TCPClientSocket connecting_socket(local_address_list(),
80                                    NULL, NetLog::Source());
81  connecting_socket.Connect(connect_callback.callback());
82
83  TestCompletionCallback accept_callback;
84  scoped_ptr<StreamSocket> accepted_socket;
85  int result = socket_.Accept(&accepted_socket, accept_callback.callback());
86  if (result == ERR_IO_PENDING)
87    result = accept_callback.WaitForResult();
88  ASSERT_EQ(OK, result);
89
90  ASSERT_TRUE(accepted_socket.get() != NULL);
91
92  // Both sockets should be on the loopback network interface.
93  EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
94            local_address_.address());
95
96  EXPECT_EQ(OK, connect_callback.WaitForResult());
97}
98
99// Test Accept() callback.
100TEST_F(TCPServerSocketTest, AcceptAsync) {
101  ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
102
103  TestCompletionCallback accept_callback;
104  scoped_ptr<StreamSocket> accepted_socket;
105
106  ASSERT_EQ(ERR_IO_PENDING,
107            socket_.Accept(&accepted_socket, accept_callback.callback()));
108
109  TestCompletionCallback connect_callback;
110  TCPClientSocket connecting_socket(local_address_list(),
111                                    NULL, NetLog::Source());
112  connecting_socket.Connect(connect_callback.callback());
113
114  EXPECT_EQ(OK, connect_callback.WaitForResult());
115  EXPECT_EQ(OK, accept_callback.WaitForResult());
116
117  EXPECT_TRUE(accepted_socket != NULL);
118
119  // Both sockets should be on the loopback network interface.
120  EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
121            local_address_.address());
122}
123
124// Accept two connections simultaneously.
125TEST_F(TCPServerSocketTest, Accept2Connections) {
126  ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
127
128  TestCompletionCallback accept_callback;
129  scoped_ptr<StreamSocket> accepted_socket;
130
131  ASSERT_EQ(ERR_IO_PENDING,
132            socket_.Accept(&accepted_socket, accept_callback.callback()));
133
134  TestCompletionCallback connect_callback;
135  TCPClientSocket connecting_socket(local_address_list(),
136                                    NULL, NetLog::Source());
137  connecting_socket.Connect(connect_callback.callback());
138
139  TestCompletionCallback connect_callback2;
140  TCPClientSocket connecting_socket2(local_address_list(),
141                                     NULL, NetLog::Source());
142  connecting_socket2.Connect(connect_callback2.callback());
143
144  EXPECT_EQ(OK, accept_callback.WaitForResult());
145
146  TestCompletionCallback accept_callback2;
147  scoped_ptr<StreamSocket> accepted_socket2;
148  int result = socket_.Accept(&accepted_socket2, accept_callback2.callback());
149  if (result == ERR_IO_PENDING)
150    result = accept_callback2.WaitForResult();
151  ASSERT_EQ(OK, result);
152
153  EXPECT_EQ(OK, connect_callback.WaitForResult());
154
155  EXPECT_TRUE(accepted_socket != NULL);
156  EXPECT_TRUE(accepted_socket2 != NULL);
157  EXPECT_NE(accepted_socket.get(), accepted_socket2.get());
158
159  EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
160            local_address_.address());
161  EXPECT_EQ(GetPeerAddress(accepted_socket2.get()).address(),
162            local_address_.address());
163}
164
165TEST_F(TCPServerSocketTest, AcceptIPv6) {
166  bool initialized = false;
167  ASSERT_NO_FATAL_FAILURE(SetUpIPv6(&initialized));
168  if (!initialized)
169    return;
170
171  TestCompletionCallback connect_callback;
172  TCPClientSocket connecting_socket(local_address_list(),
173                                    NULL, NetLog::Source());
174  connecting_socket.Connect(connect_callback.callback());
175
176  TestCompletionCallback accept_callback;
177  scoped_ptr<StreamSocket> accepted_socket;
178  int result = socket_.Accept(&accepted_socket, accept_callback.callback());
179  if (result == ERR_IO_PENDING)
180    result = accept_callback.WaitForResult();
181  ASSERT_EQ(OK, result);
182
183  ASSERT_TRUE(accepted_socket.get() != NULL);
184
185  // Both sockets should be on the loopback network interface.
186  EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
187            local_address_.address());
188
189  EXPECT_EQ(OK, connect_callback.WaitForResult());
190}
191
192TEST_F(TCPServerSocketTest, AcceptIO) {
193  ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
194
195  TestCompletionCallback connect_callback;
196  TCPClientSocket connecting_socket(local_address_list(),
197                                    NULL, NetLog::Source());
198  connecting_socket.Connect(connect_callback.callback());
199
200  TestCompletionCallback accept_callback;
201  scoped_ptr<StreamSocket> accepted_socket;
202  int result = socket_.Accept(&accepted_socket, accept_callback.callback());
203  ASSERT_EQ(OK, accept_callback.GetResult(result));
204
205  ASSERT_TRUE(accepted_socket.get() != NULL);
206
207  // Both sockets should be on the loopback network interface.
208  EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
209            local_address_.address());
210
211  EXPECT_EQ(OK, connect_callback.WaitForResult());
212
213  const std::string message("test message");
214  std::vector<char> buffer(message.size());
215
216  size_t bytes_written = 0;
217  while (bytes_written < message.size()) {
218    scoped_refptr<IOBufferWithSize> write_buffer(
219        new IOBufferWithSize(message.size() - bytes_written));
220    memmove(write_buffer->data(), message.data(), message.size());
221
222    TestCompletionCallback write_callback;
223    int write_result = accepted_socket->Write(
224        write_buffer.get(), write_buffer->size(), write_callback.callback());
225    write_result = write_callback.GetResult(write_result);
226    ASSERT_TRUE(write_result >= 0);
227    ASSERT_TRUE(bytes_written + write_result <= message.size());
228    bytes_written += write_result;
229  }
230
231  size_t bytes_read = 0;
232  while (bytes_read < message.size()) {
233    scoped_refptr<IOBufferWithSize> read_buffer(
234        new IOBufferWithSize(message.size() - bytes_read));
235    TestCompletionCallback read_callback;
236    int read_result = connecting_socket.Read(
237        read_buffer.get(), read_buffer->size(), read_callback.callback());
238    read_result = read_callback.GetResult(read_result);
239    ASSERT_TRUE(read_result >= 0);
240    ASSERT_TRUE(bytes_read + read_result <= message.size());
241    memmove(&buffer[bytes_read], read_buffer->data(), read_result);
242    bytes_read += read_result;
243  }
244
245  std::string received_message(buffer.begin(), buffer.end());
246  ASSERT_EQ(message, received_message);
247}
248
249}  // namespace
250
251}  // namespace net
252