1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/tcp_listen_socket_unittest.h"
6
7#include <fcntl.h>
8#include <sys/types.h>
9
10#include "base/bind.h"
11#include "base/posix/eintr_wrapper.h"
12#include "base/sys_byteorder.h"
13#include "net/base/ip_endpoint.h"
14#include "net/base/net_errors.h"
15#include "net/base/net_util.h"
16#include "net/socket/socket_descriptor.h"
17#include "testing/platform_test.h"
18
19namespace net {
20
21const int kReadBufSize = 1024;
22const char kHelloWorld[] = "HELLO, WORLD";
23const char kLoopback[] = "127.0.0.1";
24
25TCPListenSocketTester::TCPListenSocketTester()
26    : loop_(NULL),
27      cv_(&lock_),
28      server_port_(0) {}
29
30void TCPListenSocketTester::SetUp() {
31  base::Thread::Options options;
32  options.message_loop_type = base::MessageLoop::TYPE_IO;
33  thread_.reset(new base::Thread("socketio_test"));
34  thread_->StartWithOptions(options);
35  loop_ = reinterpret_cast<base::MessageLoopForIO*>(thread_->message_loop());
36
37  loop_->PostTask(FROM_HERE, base::Bind(
38      &TCPListenSocketTester::Listen, this));
39
40  // verify Listen succeeded
41  NextAction();
42  ASSERT_FALSE(server_.get() == NULL);
43  ASSERT_EQ(ACTION_LISTEN, last_action_.type());
44
45  int server_port = GetServerPort();
46  ASSERT_GT(server_port, 0);
47
48  // verify the connect/accept and setup test_socket_
49  test_socket_ = CreatePlatformSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
50  ASSERT_NE(kInvalidSocket, test_socket_);
51  struct sockaddr_in client;
52  client.sin_family = AF_INET;
53  client.sin_addr.s_addr = inet_addr(kLoopback);
54  client.sin_port = base::HostToNet16(server_port);
55  int ret = HANDLE_EINTR(
56      connect(test_socket_, reinterpret_cast<sockaddr*>(&client),
57              sizeof(client)));
58#if defined(OS_POSIX)
59  // The connect() call may be interrupted by a signal. When connect()
60  // is retried on EINTR, it fails with EISCONN.
61  if (ret == StreamListenSocket::kSocketError)
62    ASSERT_EQ(EISCONN, errno);
63#else
64  // Don't have signals.
65  ASSERT_NE(StreamListenSocket::kSocketError, ret);
66#endif
67
68  NextAction();
69  ASSERT_EQ(ACTION_ACCEPT, last_action_.type());
70}
71
72void TCPListenSocketTester::TearDown() {
73#if defined(OS_WIN)
74  ASSERT_EQ(0, closesocket(test_socket_));
75#elif defined(OS_POSIX)
76  ASSERT_EQ(0, IGNORE_EINTR(close(test_socket_)));
77#endif
78  NextAction();
79  ASSERT_EQ(ACTION_CLOSE, last_action_.type());
80
81  loop_->PostTask(FROM_HERE, base::Bind(
82      &TCPListenSocketTester::Shutdown, this));
83  NextAction();
84  ASSERT_EQ(ACTION_SHUTDOWN, last_action_.type());
85
86  thread_.reset();
87  loop_ = NULL;
88}
89
90void TCPListenSocketTester::ReportAction(
91    const TCPListenSocketTestAction& action) {
92  base::AutoLock locked(lock_);
93  queue_.push_back(action);
94  cv_.Broadcast();
95}
96
97void TCPListenSocketTester::NextAction() {
98  base::AutoLock locked(lock_);
99  while (queue_.empty())
100    cv_.Wait();
101  last_action_ = queue_.front();
102  queue_.pop_front();
103}
104
105int TCPListenSocketTester::ClearTestSocket() {
106  char buf[kReadBufSize];
107  int len_ret = 0;
108  do {
109    int len = HANDLE_EINTR(recv(test_socket_, buf, kReadBufSize, 0));
110    if (len == StreamListenSocket::kSocketError || len == 0) {
111      break;
112    } else {
113      len_ret += len;
114    }
115  } while (true);
116  return len_ret;
117}
118
119void TCPListenSocketTester::Shutdown() {
120  connection_.reset();
121  server_.reset();
122  ReportAction(TCPListenSocketTestAction(ACTION_SHUTDOWN));
123}
124
125void TCPListenSocketTester::Listen() {
126  server_ = DoListen();
127  ASSERT_TRUE(server_.get());
128
129  // The server's port will be needed to open the client socket.
130  IPEndPoint local_address;
131  ASSERT_EQ(OK, server_->GetLocalAddress(&local_address));
132  SetServerPort(local_address.port());
133
134  ReportAction(TCPListenSocketTestAction(ACTION_LISTEN));
135}
136
137void TCPListenSocketTester::SendFromTester() {
138  connection_->Send(kHelloWorld);
139  ReportAction(TCPListenSocketTestAction(ACTION_SEND));
140}
141
142void TCPListenSocketTester::TestClientSend() {
143  ASSERT_TRUE(Send(test_socket_, kHelloWorld));
144  NextAction();
145  ASSERT_EQ(ACTION_READ, last_action_.type());
146  ASSERT_EQ(last_action_.data(), kHelloWorld);
147}
148
149void TCPListenSocketTester::TestClientSendLong() {
150  size_t hello_len = strlen(kHelloWorld);
151  std::string long_string;
152  size_t long_len = 0;
153  for (int i = 0; i < 200; i++) {
154    long_string += kHelloWorld;
155    long_len += hello_len;
156  }
157  ASSERT_TRUE(Send(test_socket_, long_string));
158  size_t read_len = 0;
159  while (read_len < long_len) {
160    NextAction();
161    ASSERT_EQ(ACTION_READ, last_action_.type());
162    std::string last_data = last_action_.data();
163    size_t len = last_data.length();
164    if (long_string.compare(read_len, len, last_data)) {
165      ASSERT_EQ(long_string.compare(read_len, len, last_data), 0);
166    }
167    read_len += last_data.length();
168  }
169  ASSERT_EQ(read_len, long_len);
170}
171
172void TCPListenSocketTester::TestServerSend() {
173  loop_->PostTask(FROM_HERE, base::Bind(
174      &TCPListenSocketTester::SendFromTester, this));
175  NextAction();
176  ASSERT_EQ(ACTION_SEND, last_action_.type());
177  const int buf_len = 200;
178  char buf[buf_len+1];
179  unsigned recv_len = 0;
180  while (recv_len < strlen(kHelloWorld)) {
181    int r = HANDLE_EINTR(recv(test_socket_,
182                              buf + recv_len, buf_len - recv_len, 0));
183    ASSERT_GE(r, 0);
184    recv_len += static_cast<unsigned>(r);
185    if (!r)
186      break;
187  }
188  buf[recv_len] = 0;
189  ASSERT_STREQ(kHelloWorld, buf);
190}
191
192void TCPListenSocketTester::TestServerSendMultiple() {
193  // Send enough data to exceed the socket receive window. 20kb is probably a
194  // safe bet.
195  int send_count = (1024*20) / (sizeof(kHelloWorld)-1);
196
197  // Send multiple writes. Since no reading is occurring the data should be
198  // buffered in TCPListenSocket.
199  for (int i = 0; i < send_count; ++i) {
200    loop_->PostTask(FROM_HERE, base::Bind(
201        &TCPListenSocketTester::SendFromTester, this));
202    NextAction();
203    ASSERT_EQ(ACTION_SEND, last_action_.type());
204  }
205
206  // Make multiple reads. All of the data should eventually be returned.
207  char buf[sizeof(kHelloWorld)];
208  const int buf_len = sizeof(kHelloWorld);
209  for (int i = 0; i < send_count; ++i) {
210    unsigned recv_len = 0;
211    while (recv_len < buf_len-1) {
212      int r = HANDLE_EINTR(recv(test_socket_,
213                                buf + recv_len, buf_len - 1 - recv_len, 0));
214      ASSERT_GE(r, 0);
215      recv_len += static_cast<unsigned>(r);
216      if (!r)
217        break;
218    }
219    buf[recv_len] = 0;
220    ASSERT_STREQ(kHelloWorld, buf);
221  }
222}
223
224bool TCPListenSocketTester::Send(SocketDescriptor sock,
225                                 const std::string& str) {
226  int len = static_cast<int>(str.length());
227  int send_len = HANDLE_EINTR(send(sock, str.data(), len, 0));
228  if (send_len == StreamListenSocket::kSocketError) {
229    LOG(ERROR) << "send failed: " << errno;
230    return false;
231  } else if (send_len != len) {
232    return false;
233  }
234  return true;
235}
236
237void TCPListenSocketTester::DidAccept(
238    StreamListenSocket* server,
239    scoped_ptr<StreamListenSocket> connection) {
240  connection_ = connection.Pass();
241  ReportAction(TCPListenSocketTestAction(ACTION_ACCEPT));
242}
243
244void TCPListenSocketTester::DidRead(StreamListenSocket* connection,
245                                    const char* data,
246                                    int len) {
247  std::string str(data, len);
248  ReportAction(TCPListenSocketTestAction(ACTION_READ, str));
249}
250
251void TCPListenSocketTester::DidClose(StreamListenSocket* sock) {
252  ReportAction(TCPListenSocketTestAction(ACTION_CLOSE));
253}
254
255TCPListenSocketTester::~TCPListenSocketTester() {}
256
257scoped_ptr<TCPListenSocket> TCPListenSocketTester::DoListen() {
258  // Let the OS pick a free port.
259  return TCPListenSocket::CreateAndListen(kLoopback, 0, this);
260}
261
262int TCPListenSocketTester::GetServerPort() {
263  base::AutoLock locked(lock_);
264  return server_port_;
265}
266
267void TCPListenSocketTester::SetServerPort(int server_port) {
268  base::AutoLock locked(lock_);
269  server_port_ = server_port;
270}
271
272class TCPListenSocketTest : public PlatformTest {
273 public:
274  TCPListenSocketTest() {
275    tester_ = NULL;
276  }
277
278  virtual void SetUp() {
279    PlatformTest::SetUp();
280    tester_ = new TCPListenSocketTester();
281    tester_->SetUp();
282  }
283
284  virtual void TearDown() {
285    PlatformTest::TearDown();
286    tester_->TearDown();
287    tester_ = NULL;
288  }
289
290  scoped_refptr<TCPListenSocketTester> tester_;
291};
292
293TEST_F(TCPListenSocketTest, ClientSend) {
294  tester_->TestClientSend();
295}
296
297TEST_F(TCPListenSocketTest, ClientSendLong) {
298  tester_->TestClientSendLong();
299}
300
301TEST_F(TCPListenSocketTest, ServerSend) {
302  tester_->TestServerSend();
303}
304
305TEST_F(TCPListenSocketTest, ServerSendMultiple) {
306  tester_->TestServerSendMultiple();
307}
308
309}  // namespace net
310