1// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/base/listen_socket_unittest.h"
6
7#include <fcntl.h>
8
9#include "base/eintr_wrapper.h"
10#include "net/base/net_util.h"
11#include "testing/platform_test.h"
12
13const int ListenSocketTester::kTestPort = 9999;
14
15static const int kReadBufSize = 1024;
16static const char* kHelloWorld = "HELLO, WORLD";
17static const int kMaxQueueSize = 20;
18static const char* kLoopback = "127.0.0.1";
19static const int kDefaultTimeoutMs = 5000;
20#if defined(OS_POSIX)
21static const char* kSemaphoreName = "chromium.listen_socket";
22#endif
23
24
25ListenSocket* ListenSocketTester::DoListen() {
26  return ListenSocket::Listen(kLoopback, kTestPort, this);
27}
28
29void ListenSocketTester::SetUp() {
30#if defined(OS_WIN)
31  InitializeCriticalSection(&lock_);
32  semaphore_ = CreateSemaphore(NULL, 0, kMaxQueueSize, NULL);
33  server_ = NULL;
34  net::EnsureWinsockInit();
35#elif defined(OS_POSIX)
36  ASSERT_EQ(0, pthread_mutex_init(&lock_, NULL));
37  sem_unlink(kSemaphoreName);
38  semaphore_ = sem_open(kSemaphoreName, O_CREAT, 0, 0);
39  ASSERT_NE(SEM_FAILED, semaphore_);
40#endif
41  base::Thread::Options options;
42  options.message_loop_type = MessageLoop::TYPE_IO;
43  thread_.reset(new base::Thread("socketio_test"));
44  thread_->StartWithOptions(options);
45  loop_ = reinterpret_cast<MessageLoopForIO*>(thread_->message_loop());
46
47  loop_->PostTask(FROM_HERE, NewRunnableMethod(
48      this, &ListenSocketTester::Listen));
49
50  // verify Listen succeeded
51  ASSERT_TRUE(NextAction(kDefaultTimeoutMs));
52  ASSERT_FALSE(server_ == NULL);
53  ASSERT_EQ(ACTION_LISTEN, last_action_.type());
54
55  // verify the connect/accept and setup test_socket_
56  test_socket_ = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
57  ASSERT_NE(-1, test_socket_);
58  struct sockaddr_in client;
59  client.sin_family = AF_INET;
60  client.sin_addr.s_addr = inet_addr(kLoopback);
61  client.sin_port = htons(kTestPort);
62  int ret =
63      HANDLE_EINTR(connect(test_socket_, reinterpret_cast<sockaddr*>(&client),
64                           sizeof(client)));
65  ASSERT_NE(ret, SOCKET_ERROR);
66
67  net::SetNonBlocking(test_socket_);
68  ASSERT_TRUE(NextAction(kDefaultTimeoutMs));
69  ASSERT_EQ(ACTION_ACCEPT, last_action_.type());
70}
71
72void ListenSocketTester::TearDown() {
73  // verify close
74#if defined(OS_WIN)
75  closesocket(test_socket_);
76#elif defined(OS_POSIX)
77  close(test_socket_);
78#endif
79  ASSERT_TRUE(NextAction(kDefaultTimeoutMs));
80  ASSERT_EQ(ACTION_CLOSE, last_action_.type());
81
82  loop_->PostTask(FROM_HERE, NewRunnableMethod(
83      this, &ListenSocketTester::Shutdown));
84  ASSERT_TRUE(NextAction(kDefaultTimeoutMs));
85  ASSERT_EQ(ACTION_SHUTDOWN, last_action_.type());
86
87#if defined(OS_WIN)
88  CloseHandle(semaphore_);
89  semaphore_ = 0;
90  DeleteCriticalSection(&lock_);
91#elif defined(OS_POSIX)
92  ASSERT_EQ(0, pthread_mutex_lock(&lock_));
93  semaphore_ = NULL;
94  ASSERT_EQ(0, pthread_mutex_unlock(&lock_));
95  ASSERT_EQ(0, sem_unlink(kSemaphoreName));
96  ASSERT_EQ(0, pthread_mutex_destroy(&lock_));
97#endif
98
99  thread_.reset();
100  loop_ = NULL;
101}
102
103void ListenSocketTester::ReportAction(const ListenSocketTestAction& action) {
104#if defined(OS_WIN)
105  EnterCriticalSection(&lock_);
106  queue_.push_back(action);
107  LeaveCriticalSection(&lock_);
108  ReleaseSemaphore(semaphore_, 1, NULL);
109#elif defined(OS_POSIX)
110  ASSERT_EQ(0, pthread_mutex_lock(&lock_));
111  queue_.push_back(action);
112  ASSERT_EQ(0, pthread_mutex_unlock(&lock_));
113  ASSERT_EQ(0, sem_post(semaphore_));
114#endif
115}
116
117bool ListenSocketTester::NextAction(int timeout) {
118#if defined(OS_WIN)
119  DWORD ret = ::WaitForSingleObject(semaphore_, timeout);
120  if (ret != WAIT_OBJECT_0)
121    return false;
122  EnterCriticalSection(&lock_);
123  if (queue_.size() == 0) {
124    LeaveCriticalSection(&lock_);
125    return false;
126  }
127  last_action_ = queue_.front();
128  queue_.pop_front();
129  LeaveCriticalSection(&lock_);
130  return true;
131#elif defined(OS_POSIX)
132  if (semaphore_ == SEM_FAILED)
133    return false;
134  while (true) {
135    int result = sem_trywait(semaphore_);
136    PlatformThread::Sleep(1);  // 1MS sleep
137    timeout--;
138    if (timeout <= 0)
139      return false;
140    if (result == 0)
141      break;
142  }
143  pthread_mutex_lock(&lock_);
144  if (queue_.size() == 0) {
145    pthread_mutex_unlock(&lock_);
146    return false;
147  }
148  last_action_ = queue_.front();
149  queue_.pop_front();
150  pthread_mutex_unlock(&lock_);
151  return true;
152#endif
153}
154
155int ListenSocketTester::ClearTestSocket() {
156  char buf[kReadBufSize];
157  int len_ret = 0;
158  int time_out = 0;
159  do {
160    int len = HANDLE_EINTR(recv(test_socket_, buf, kReadBufSize, 0));
161#if defined(OS_WIN)
162    if (len == SOCKET_ERROR) {
163      int err = WSAGetLastError();
164      if (err == WSAEWOULDBLOCK) {
165#elif defined(OS_POSIX)
166    if (len == SOCKET_ERROR) {
167      if (errno == EWOULDBLOCK || errno == EAGAIN) {
168#endif
169        PlatformThread::Sleep(1);
170        time_out++;
171        if (time_out > 10)
172          break;
173        continue;  // still trying
174      }
175    } else if (len == 0) {
176      // socket closed
177      break;
178    } else {
179      time_out = 0;
180      len_ret += len;
181    }
182  } while (true);
183  return len_ret;
184}
185
186void ListenSocketTester::Shutdown() {
187  connection_->Release();
188  connection_ = NULL;
189  server_->Release();
190  server_ = NULL;
191  ReportAction(ListenSocketTestAction(ACTION_SHUTDOWN));
192}
193
194void ListenSocketTester::Listen() {
195  server_ = DoListen();
196  if (server_) {
197    server_->AddRef();
198    ReportAction(ListenSocketTestAction(ACTION_LISTEN));
199  }
200}
201
202void ListenSocketTester::SendFromTester() {
203  connection_->Send(kHelloWorld);
204  ReportAction(ListenSocketTestAction(ACTION_SEND));
205}
206
207void ListenSocketTester::DidAccept(ListenSocket *server,
208                                   ListenSocket *connection) {
209  connection_ = connection;
210  connection_->AddRef();
211  ReportAction(ListenSocketTestAction(ACTION_ACCEPT));
212}
213
214void ListenSocketTester::DidRead(ListenSocket *connection,
215                                 const std::string& data) {
216  ReportAction(ListenSocketTestAction(ACTION_READ, data));
217}
218
219void ListenSocketTester::DidClose(ListenSocket *sock) {
220  ReportAction(ListenSocketTestAction(ACTION_CLOSE));
221}
222
223bool ListenSocketTester::Send(SOCKET sock, const std::string& str) {
224  int len = static_cast<int>(str.length());
225  int send_len = HANDLE_EINTR(send(sock, str.data(), len, 0));
226  if (send_len == SOCKET_ERROR) {
227    LOG(ERROR) << "send failed: " << errno;
228    return false;
229  } else if (send_len != len) {
230    return false;
231  }
232  return true;
233}
234
235void ListenSocketTester::TestClientSend() {
236  ASSERT_TRUE(Send(test_socket_, kHelloWorld));
237  ASSERT_TRUE(NextAction(kDefaultTimeoutMs));
238  ASSERT_EQ(ACTION_READ, last_action_.type());
239  ASSERT_EQ(last_action_.data(), kHelloWorld);
240}
241
242void ListenSocketTester::TestClientSendLong() {
243  int hello_len = strlen(kHelloWorld);
244  std::string long_string;
245  int long_len = 0;
246  for (int i = 0; i < 200; i++) {
247    long_string += kHelloWorld;
248    long_len += hello_len;
249  }
250  ASSERT_TRUE(Send(test_socket_, long_string));
251  int read_len = 0;
252  while (read_len < long_len) {
253    ASSERT_TRUE(NextAction(kDefaultTimeoutMs));
254    ASSERT_EQ(ACTION_READ, last_action_.type());
255    std::string last_data = last_action_.data();
256    size_t len = last_data.length();
257    if (long_string.compare(read_len, len, last_data)) {
258      ASSERT_EQ(long_string.compare(read_len, len, last_data), 0);
259    }
260    read_len += static_cast<int>(last_data.length());
261  }
262  ASSERT_EQ(read_len, long_len);
263}
264
265void ListenSocketTester::TestServerSend() {
266  loop_->PostTask(FROM_HERE, NewRunnableMethod(
267      this, &ListenSocketTester::SendFromTester));
268  ASSERT_TRUE(NextAction(kDefaultTimeoutMs));
269  ASSERT_EQ(ACTION_SEND, last_action_.type());
270  // TODO(erikkay): Without this sleep, the recv seems to fail a small amount
271  // of the time.  I could fix this by making the socket blocking, but then
272  // this test might hang in the case of errors.  It would be nice to do
273  // something that felt more reliable here.
274  PlatformThread::Sleep(10);  // sleep for 10ms
275  const int buf_len = 200;
276  char buf[buf_len+1];
277  int recv_len;
278  do {
279    recv_len = HANDLE_EINTR(recv(test_socket_, buf, buf_len, 0));
280#if defined(OS_POSIX)
281  } while (recv_len == SOCKET_ERROR && errno == EINTR);
282#else
283  } while (false);
284#endif
285  ASSERT_NE(recv_len, SOCKET_ERROR);
286  buf[recv_len] = 0;
287  ASSERT_STREQ(buf, kHelloWorld);
288}
289
290
291class ListenSocketTest: public PlatformTest {
292 public:
293  ListenSocketTest() {
294    tester_ = NULL;
295  }
296
297  virtual void SetUp() {
298    PlatformTest::SetUp();
299    tester_ = new ListenSocketTester();
300    tester_->SetUp();
301  }
302
303  virtual void TearDown() {
304    PlatformTest::TearDown();
305    tester_->TearDown();
306    tester_ = NULL;
307  }
308
309  scoped_refptr<ListenSocketTester> tester_;
310};
311
312TEST_F(ListenSocketTest, ClientSend) {
313  tester_->TestClientSend();
314}
315
316TEST_F(ListenSocketTest, ClientSendLong) {
317  tester_->TestClientSendLong();
318}
319
320TEST_F(ListenSocketTest, ServerSend) {
321  tester_->TestServerSend();
322}
323