1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/tcp_client_socket.h"
6
7#include "base/basictypes.h"
8#include "base/memory/ref_counted.h"
9#include "base/memory/scoped_ptr.h"
10#include "net/base/address_list.h"
11#include "net/base/io_buffer.h"
12#include "net/base/net_errors.h"
13#include "net/base/net_log.h"
14#include "net/base/net_log_unittest.h"
15#include "net/base/test_completion_callback.h"
16#include "net/base/winsock_init.h"
17#include "net/dns/mock_host_resolver.h"
18#include "net/socket/client_socket_factory.h"
19#include "net/socket/tcp_listen_socket.h"
20#include "testing/gtest/include/gtest/gtest.h"
21#include "testing/platform_test.h"
22
23namespace net {
24
25namespace {
26
27const char kServerReply[] = "HTTP/1.1 404 Not Found";
28
29enum ClientSocketTestTypes {
30  TCP,
31  SCTP
32};
33
34}  // namespace
35
36class TransportClientSocketTest
37    : public StreamListenSocket::Delegate,
38      public ::testing::TestWithParam<ClientSocketTestTypes> {
39 public:
40  TransportClientSocketTest()
41      : listen_port_(0),
42        socket_factory_(ClientSocketFactory::GetDefaultFactory()),
43        close_server_socket_on_next_send_(false) {
44  }
45
46  virtual ~TransportClientSocketTest() {
47  }
48
49  // Implement StreamListenSocket::Delegate methods
50  virtual void DidAccept(StreamListenSocket* server,
51                         scoped_ptr<StreamListenSocket> connection) OVERRIDE {
52    connected_sock_.reset(
53        static_cast<TCPListenSocket*>(connection.release()));
54  }
55  virtual void DidRead(StreamListenSocket*, const char* str, int len) OVERRIDE {
56    // TODO(dkegel): this might not be long enough to tickle some bugs.
57    connected_sock_->Send(kServerReply, arraysize(kServerReply) - 1,
58                          false /* Don't append line feed */);
59    if (close_server_socket_on_next_send_)
60      CloseServerSocket();
61  }
62  virtual void DidClose(StreamListenSocket* sock) OVERRIDE {}
63
64  // Testcase hooks
65  virtual void SetUp();
66
67  void CloseServerSocket() {
68    // delete the connected_sock_, which will close it.
69    connected_sock_.reset();
70  }
71
72  void PauseServerReads() {
73    connected_sock_->PauseReads();
74  }
75
76  void ResumeServerReads() {
77    connected_sock_->ResumeReads();
78  }
79
80  int DrainClientSocket(IOBuffer* buf,
81                        uint32 buf_len,
82                        uint32 bytes_to_read,
83                        TestCompletionCallback* callback);
84
85  void SendClientRequest();
86
87  void set_close_server_socket_on_next_send(bool close) {
88    close_server_socket_on_next_send_ = close;
89  }
90
91 protected:
92  int listen_port_;
93  CapturingNetLog net_log_;
94  ClientSocketFactory* const socket_factory_;
95  scoped_ptr<StreamSocket> sock_;
96
97 private:
98  scoped_ptr<TCPListenSocket> listen_sock_;
99  scoped_ptr<TCPListenSocket> connected_sock_;
100  bool close_server_socket_on_next_send_;
101};
102
103void TransportClientSocketTest::SetUp() {
104  ::testing::TestWithParam<ClientSocketTestTypes>::SetUp();
105
106  // Find a free port to listen on
107  scoped_ptr<TCPListenSocket> sock;
108  int port;
109  // Range of ports to listen on.  Shouldn't need to try many.
110  const int kMinPort = 10100;
111  const int kMaxPort = 10200;
112#if defined(OS_WIN)
113  EnsureWinsockInit();
114#endif
115  for (port = kMinPort; port < kMaxPort; port++) {
116    sock = TCPListenSocket::CreateAndListen("127.0.0.1", port, this);
117    if (sock.get())
118      break;
119  }
120  ASSERT_TRUE(sock.get() != NULL);
121  listen_sock_ = sock.Pass();
122  listen_port_ = port;
123
124  AddressList addr;
125  // MockHostResolver resolves everything to 127.0.0.1.
126  scoped_ptr<HostResolver> resolver(new MockHostResolver());
127  HostResolver::RequestInfo info(HostPortPair("localhost", listen_port_));
128  TestCompletionCallback callback;
129  int rv = resolver->Resolve(
130      info, DEFAULT_PRIORITY, &addr, callback.callback(), NULL, BoundNetLog());
131  CHECK_EQ(ERR_IO_PENDING, rv);
132  rv = callback.WaitForResult();
133  CHECK_EQ(rv, OK);
134  sock_ =
135      socket_factory_->CreateTransportClientSocket(addr,
136                                                   &net_log_,
137                                                   NetLog::Source());
138}
139
140int TransportClientSocketTest::DrainClientSocket(
141    IOBuffer* buf, uint32 buf_len,
142    uint32 bytes_to_read, TestCompletionCallback* callback) {
143  int rv = OK;
144  uint32 bytes_read = 0;
145
146  while (bytes_read < bytes_to_read) {
147    rv = sock_->Read(buf, buf_len, callback->callback());
148    EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
149
150    if (rv == ERR_IO_PENDING)
151      rv = callback->WaitForResult();
152
153    EXPECT_GE(rv, 0);
154    bytes_read += rv;
155  }
156
157  return static_cast<int>(bytes_read);
158}
159
160void TransportClientSocketTest::SendClientRequest() {
161  const char request_text[] = "GET / HTTP/1.0\r\n\r\n";
162  scoped_refptr<IOBuffer> request_buffer(
163      new IOBuffer(arraysize(request_text) - 1));
164  TestCompletionCallback callback;
165  int rv;
166
167  memcpy(request_buffer->data(), request_text, arraysize(request_text) - 1);
168  rv = sock_->Write(
169      request_buffer.get(), arraysize(request_text) - 1, callback.callback());
170  EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
171
172  if (rv == ERR_IO_PENDING)
173    rv = callback.WaitForResult();
174  EXPECT_EQ(rv, static_cast<int>(arraysize(request_text) - 1));
175}
176
177// TODO(leighton):  Add SCTP to this list when it is ready.
178INSTANTIATE_TEST_CASE_P(StreamSocket,
179                        TransportClientSocketTest,
180                        ::testing::Values(TCP));
181
182TEST_P(TransportClientSocketTest, Connect) {
183  TestCompletionCallback callback;
184  EXPECT_FALSE(sock_->IsConnected());
185
186  int rv = sock_->Connect(callback.callback());
187
188  net::CapturingNetLog::CapturedEntryList net_log_entries;
189  net_log_.GetEntries(&net_log_entries);
190  EXPECT_TRUE(net::LogContainsBeginEvent(
191      net_log_entries, 0, net::NetLog::TYPE_SOCKET_ALIVE));
192  EXPECT_TRUE(net::LogContainsBeginEvent(
193      net_log_entries, 1, net::NetLog::TYPE_TCP_CONNECT));
194  if (rv != OK) {
195    ASSERT_EQ(rv, ERR_IO_PENDING);
196    rv = callback.WaitForResult();
197    EXPECT_EQ(rv, OK);
198  }
199
200  EXPECT_TRUE(sock_->IsConnected());
201  net_log_.GetEntries(&net_log_entries);
202  EXPECT_TRUE(net::LogContainsEndEvent(
203      net_log_entries, -1, net::NetLog::TYPE_TCP_CONNECT));
204
205  sock_->Disconnect();
206  EXPECT_FALSE(sock_->IsConnected());
207}
208
209TEST_P(TransportClientSocketTest, IsConnected) {
210  scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
211  TestCompletionCallback callback;
212  uint32 bytes_read;
213
214  EXPECT_FALSE(sock_->IsConnected());
215  EXPECT_FALSE(sock_->IsConnectedAndIdle());
216  int rv = sock_->Connect(callback.callback());
217  if (rv != OK) {
218    ASSERT_EQ(rv, ERR_IO_PENDING);
219    rv = callback.WaitForResult();
220    EXPECT_EQ(rv, OK);
221  }
222  EXPECT_TRUE(sock_->IsConnected());
223  EXPECT_TRUE(sock_->IsConnectedAndIdle());
224
225  // Send the request and wait for the server to respond.
226  SendClientRequest();
227
228  // Drain a single byte so we know we've received some data.
229  bytes_read = DrainClientSocket(buf.get(), 1, 1, &callback);
230  ASSERT_EQ(bytes_read, 1u);
231
232  // Socket should be considered connected, but not idle, due to
233  // pending data.
234  EXPECT_TRUE(sock_->IsConnected());
235  EXPECT_FALSE(sock_->IsConnectedAndIdle());
236
237  bytes_read = DrainClientSocket(
238      buf.get(), 4096, arraysize(kServerReply) - 2, &callback);
239  ASSERT_EQ(bytes_read, arraysize(kServerReply) - 2);
240
241  // After draining the data, the socket should be back to connected
242  // and idle.
243  EXPECT_TRUE(sock_->IsConnected());
244  EXPECT_TRUE(sock_->IsConnectedAndIdle());
245
246  // This time close the server socket immediately after the server response.
247  set_close_server_socket_on_next_send(true);
248  SendClientRequest();
249
250  bytes_read = DrainClientSocket(buf.get(), 1, 1, &callback);
251  ASSERT_EQ(bytes_read, 1u);
252
253  // As above because of data.
254  EXPECT_TRUE(sock_->IsConnected());
255  EXPECT_FALSE(sock_->IsConnectedAndIdle());
256
257  bytes_read = DrainClientSocket(
258      buf.get(), 4096, arraysize(kServerReply) - 2, &callback);
259  ASSERT_EQ(bytes_read, arraysize(kServerReply) - 2);
260
261  // Once the data is drained, the socket should now be seen as not
262  // connected.
263  if (sock_->IsConnected()) {
264    // In the unlikely event that the server's connection closure is not
265    // processed in time, wait for the connection to be closed.
266    rv = sock_->Read(buf.get(), 4096, callback.callback());
267    EXPECT_EQ(0, callback.GetResult(rv));
268    EXPECT_FALSE(sock_->IsConnected());
269  }
270  EXPECT_FALSE(sock_->IsConnectedAndIdle());
271}
272
273TEST_P(TransportClientSocketTest, Read) {
274  TestCompletionCallback callback;
275  int rv = sock_->Connect(callback.callback());
276  if (rv != OK) {
277    ASSERT_EQ(rv, ERR_IO_PENDING);
278
279    rv = callback.WaitForResult();
280    EXPECT_EQ(rv, OK);
281  }
282  SendClientRequest();
283
284  scoped_refptr<IOBuffer> buf(new IOBuffer(4096));
285  uint32 bytes_read = DrainClientSocket(
286      buf.get(), 4096, arraysize(kServerReply) - 1, &callback);
287  ASSERT_EQ(bytes_read, arraysize(kServerReply) - 1);
288
289  // All data has been read now.  Read once more to force an ERR_IO_PENDING, and
290  // then close the server socket, and note the close.
291
292  rv = sock_->Read(buf.get(), 4096, callback.callback());
293  ASSERT_EQ(ERR_IO_PENDING, rv);
294  CloseServerSocket();
295  EXPECT_EQ(0, callback.WaitForResult());
296}
297
298TEST_P(TransportClientSocketTest, Read_SmallChunks) {
299  TestCompletionCallback callback;
300  int rv = sock_->Connect(callback.callback());
301  if (rv != OK) {
302    ASSERT_EQ(rv, ERR_IO_PENDING);
303
304    rv = callback.WaitForResult();
305    EXPECT_EQ(rv, OK);
306  }
307  SendClientRequest();
308
309  scoped_refptr<IOBuffer> buf(new IOBuffer(1));
310  uint32 bytes_read = 0;
311  while (bytes_read < arraysize(kServerReply) - 1) {
312    rv = sock_->Read(buf.get(), 1, callback.callback());
313    EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
314
315    if (rv == ERR_IO_PENDING)
316      rv = callback.WaitForResult();
317
318    ASSERT_EQ(1, rv);
319    bytes_read += rv;
320  }
321
322  // All data has been read now.  Read once more to force an ERR_IO_PENDING, and
323  // then close the server socket, and note the close.
324
325  rv = sock_->Read(buf.get(), 1, callback.callback());
326  ASSERT_EQ(ERR_IO_PENDING, rv);
327  CloseServerSocket();
328  EXPECT_EQ(0, callback.WaitForResult());
329}
330
331TEST_P(TransportClientSocketTest, Read_Interrupted) {
332  TestCompletionCallback callback;
333  int rv = sock_->Connect(callback.callback());
334  if (rv != OK) {
335    ASSERT_EQ(ERR_IO_PENDING, rv);
336
337    rv = callback.WaitForResult();
338    EXPECT_EQ(rv, OK);
339  }
340  SendClientRequest();
341
342  // Do a partial read and then exit.  This test should not crash!
343  scoped_refptr<IOBuffer> buf(new IOBuffer(16));
344  rv = sock_->Read(buf.get(), 16, callback.callback());
345  EXPECT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
346
347  if (rv == ERR_IO_PENDING)
348    rv = callback.WaitForResult();
349
350  EXPECT_NE(0, rv);
351}
352
353TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_ReadFirst) {
354  TestCompletionCallback callback;
355  int rv = sock_->Connect(callback.callback());
356  if (rv != OK) {
357    ASSERT_EQ(rv, ERR_IO_PENDING);
358
359    rv = callback.WaitForResult();
360    EXPECT_EQ(rv, OK);
361  }
362
363  // Read first.  There's no data, so it should return ERR_IO_PENDING.
364  const int kBufLen = 4096;
365  scoped_refptr<IOBuffer> buf(new IOBuffer(kBufLen));
366  rv = sock_->Read(buf.get(), kBufLen, callback.callback());
367  EXPECT_EQ(ERR_IO_PENDING, rv);
368
369  PauseServerReads();
370  const int kWriteBufLen = 64 * 1024;
371  scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kWriteBufLen));
372  char* request_data = request_buffer->data();
373  memset(request_data, 'A', kWriteBufLen);
374  TestCompletionCallback write_callback;
375
376  while (true) {
377    rv = sock_->Write(
378        request_buffer.get(), kWriteBufLen, write_callback.callback());
379    ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
380
381    if (rv == ERR_IO_PENDING) {
382      ResumeServerReads();
383      rv = write_callback.WaitForResult();
384      break;
385    }
386  }
387
388  // At this point, both read and write have returned ERR_IO_PENDING, and the
389  // write callback has executed.  We wait for the read callback to run now to
390  // make sure that the socket can handle full duplex communications.
391
392  rv = callback.WaitForResult();
393  EXPECT_GE(rv, 0);
394}
395
396TEST_P(TransportClientSocketTest, DISABLED_FullDuplex_WriteFirst) {
397  TestCompletionCallback callback;
398  int rv = sock_->Connect(callback.callback());
399  if (rv != OK) {
400    ASSERT_EQ(ERR_IO_PENDING, rv);
401
402    rv = callback.WaitForResult();
403    EXPECT_EQ(OK, rv);
404  }
405
406  PauseServerReads();
407  const int kWriteBufLen = 64 * 1024;
408  scoped_refptr<IOBuffer> request_buffer(new IOBuffer(kWriteBufLen));
409  char* request_data = request_buffer->data();
410  memset(request_data, 'A', kWriteBufLen);
411  TestCompletionCallback write_callback;
412
413  while (true) {
414    rv = sock_->Write(
415        request_buffer.get(), kWriteBufLen, write_callback.callback());
416    ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
417
418    if (rv == ERR_IO_PENDING)
419      break;
420  }
421
422  // Now we have the Write() blocked on ERR_IO_PENDING.  It's time to force the
423  // Read() to block on ERR_IO_PENDING too.
424
425  const int kBufLen = 4096;
426  scoped_refptr<IOBuffer> buf(new IOBuffer(kBufLen));
427  while (true) {
428    rv = sock_->Read(buf.get(), kBufLen, callback.callback());
429    ASSERT_TRUE(rv >= 0 || rv == ERR_IO_PENDING);
430    if (rv == ERR_IO_PENDING)
431      break;
432  }
433
434  // At this point, both read and write have returned ERR_IO_PENDING.  Now we
435  // run the write and read callbacks to make sure they can handle full duplex
436  // communications.
437
438  ResumeServerReads();
439  rv = write_callback.WaitForResult();
440  EXPECT_GE(rv, 0);
441
442  // It's possible the read is blocked because it's already read all the data.
443  // Close the server socket, so there will at least be a 0-byte read.
444  CloseServerSocket();
445
446  rv = callback.WaitForResult();
447  EXPECT_GE(rv, 0);
448}
449
450}  // namespace net
451