1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "extensions/browser/api/socket/tcp_socket.h"
6
7#include "base/memory/scoped_ptr.h"
8#include "net/base/address_list.h"
9#include "net/base/completion_callback.h"
10#include "net/base/io_buffer.h"
11#include "net/base/net_errors.h"
12#include "net/base/rand_callback.h"
13#include "net/socket/tcp_client_socket.h"
14#include "net/socket/tcp_server_socket.h"
15#include "testing/gmock/include/gmock/gmock.h"
16
17using testing::_;
18using testing::DoAll;
19using testing::Return;
20using testing::SaveArg;
21
22namespace extensions {
23
24class MockTCPSocket : public net::TCPClientSocket {
25 public:
26  explicit MockTCPSocket(const net::AddressList& address_list)
27      : net::TCPClientSocket(address_list, NULL, net::NetLog::Source()) {
28  }
29
30  MOCK_METHOD3(Read, int(net::IOBuffer* buf, int buf_len,
31                         const net::CompletionCallback& callback));
32  MOCK_METHOD3(Write, int(net::IOBuffer* buf, int buf_len,
33                          const net::CompletionCallback& callback));
34  MOCK_METHOD2(SetKeepAlive, bool(bool enable, int delay));
35  MOCK_METHOD1(SetNoDelay, bool(bool no_delay));
36  virtual bool IsConnected() const OVERRIDE {
37    return true;
38  }
39
40 private:
41  DISALLOW_COPY_AND_ASSIGN(MockTCPSocket);
42};
43
44class MockTCPServerSocket : public net::TCPServerSocket {
45 public:
46  explicit MockTCPServerSocket()
47      : net::TCPServerSocket(NULL, net::NetLog::Source()) {
48  }
49  MOCK_METHOD2(Listen, int(const net::IPEndPoint& address, int backlog));
50  MOCK_METHOD2(Accept, int(scoped_ptr<net::StreamSocket>* socket,
51                            const net::CompletionCallback& callback));
52
53 private:
54  DISALLOW_COPY_AND_ASSIGN(MockTCPServerSocket);
55};
56
57class CompleteHandler {
58 public:
59  CompleteHandler() {}
60  MOCK_METHOD1(OnComplete, void(int result_code));
61  MOCK_METHOD2(OnReadComplete, void(int result_code,
62      scoped_refptr<net::IOBuffer> io_buffer));
63  MOCK_METHOD2(OnAccept, void(int, net::TCPClientSocket*));
64 private:
65  DISALLOW_COPY_AND_ASSIGN(CompleteHandler);
66};
67
68const std::string FAKE_ID = "abcdefghijklmnopqrst";
69
70TEST(SocketTest, TestTCPSocketRead) {
71  net::AddressList address_list;
72  MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list);
73  CompleteHandler handler;
74
75  scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting(
76      tcp_client_socket, FAKE_ID, true));
77
78  EXPECT_CALL(*tcp_client_socket, Read(_, _, _))
79      .Times(1);
80  EXPECT_CALL(handler, OnReadComplete(_, _))
81      .Times(1);
82
83  const int count = 512;
84  socket->Read(count, base::Bind(&CompleteHandler::OnReadComplete,
85        base::Unretained(&handler)));
86}
87
88TEST(SocketTest, TestTCPSocketWrite) {
89  net::AddressList address_list;
90  MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list);
91  CompleteHandler handler;
92
93  scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting(
94      tcp_client_socket, FAKE_ID, true));
95
96  net::CompletionCallback callback;
97  EXPECT_CALL(*tcp_client_socket, Write(_, _, _))
98      .Times(2)
99      .WillRepeatedly(testing::DoAll(SaveArg<2>(&callback),
100                                     Return(128)));
101  EXPECT_CALL(handler, OnComplete(_))
102      .Times(1);
103
104  scoped_refptr<net::IOBufferWithSize> io_buffer(
105      new net::IOBufferWithSize(256));
106  socket->Write(io_buffer.get(), io_buffer->size(),
107      base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler)));
108}
109
110TEST(SocketTest, TestTCPSocketBlockedWrite) {
111  net::AddressList address_list;
112  MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list);
113  CompleteHandler handler;
114
115  scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting(
116      tcp_client_socket, FAKE_ID, true));
117
118  net::CompletionCallback callback;
119  EXPECT_CALL(*tcp_client_socket, Write(_, _, _))
120      .Times(2)
121      .WillRepeatedly(testing::DoAll(SaveArg<2>(&callback),
122                                     Return(net::ERR_IO_PENDING)));
123  scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(42));
124  socket->Write(io_buffer.get(), io_buffer->size(),
125      base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler)));
126
127  // Good. Original call came back unable to complete. Now pretend the socket
128  // finished, and confirm that we passed the error back.
129  EXPECT_CALL(handler, OnComplete(42))
130      .Times(1);
131  callback.Run(40);
132  callback.Run(2);
133}
134
135TEST(SocketTest, TestTCPSocketBlockedWriteReentry) {
136  net::AddressList address_list;
137  MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list);
138  CompleteHandler handlers[5];
139
140  scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting(
141      tcp_client_socket, FAKE_ID, true));
142
143  net::CompletionCallback callback;
144  EXPECT_CALL(*tcp_client_socket, Write(_, _, _))
145      .Times(5)
146      .WillRepeatedly(testing::DoAll(SaveArg<2>(&callback),
147                                     Return(net::ERR_IO_PENDING)));
148  scoped_refptr<net::IOBufferWithSize> io_buffers[5];
149  int i;
150  for (i = 0; i < 5; i++) {
151    io_buffers[i] = new net::IOBufferWithSize(128 + i * 50);
152    scoped_refptr<net::IOBufferWithSize> io_buffer1(
153        new net::IOBufferWithSize(42));
154    socket->Write(io_buffers[i].get(), io_buffers[i]->size(),
155        base::Bind(&CompleteHandler::OnComplete,
156            base::Unretained(&handlers[i])));
157
158    EXPECT_CALL(handlers[i], OnComplete(io_buffers[i]->size()))
159        .Times(1);
160  }
161
162  for (i = 0; i < 5; i++) {
163    callback.Run(128 + i * 50);
164  }
165}
166
167TEST(SocketTest, TestTCPSocketSetNoDelay) {
168  net::AddressList address_list;
169  MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list);
170
171  scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting(
172      tcp_client_socket, FAKE_ID));
173
174  bool no_delay = false;
175  EXPECT_CALL(*tcp_client_socket, SetNoDelay(_))
176      .WillOnce(testing::DoAll(SaveArg<0>(&no_delay), Return(true)));
177  int result = socket->SetNoDelay(true);
178  EXPECT_TRUE(result);
179  EXPECT_TRUE(no_delay);
180
181  EXPECT_CALL(*tcp_client_socket, SetNoDelay(_))
182      .WillOnce(testing::DoAll(SaveArg<0>(&no_delay), Return(false)));
183
184  result = socket->SetNoDelay(false);
185  EXPECT_FALSE(result);
186  EXPECT_FALSE(no_delay);
187}
188
189TEST(SocketTest, TestTCPSocketSetKeepAlive) {
190  net::AddressList address_list;
191  MockTCPSocket* tcp_client_socket = new MockTCPSocket(address_list);
192
193  scoped_ptr<TCPSocket> socket(TCPSocket::CreateSocketForTesting(
194      tcp_client_socket, FAKE_ID));
195
196  bool enable = false;
197  int delay = 0;
198  EXPECT_CALL(*tcp_client_socket, SetKeepAlive(_, _))
199      .WillOnce(testing::DoAll(SaveArg<0>(&enable),
200                               SaveArg<1>(&delay),
201                               Return(true)));
202  int result = socket->SetKeepAlive(true, 4500);
203  EXPECT_TRUE(result);
204  EXPECT_TRUE(enable);
205  EXPECT_EQ(4500, delay);
206
207  EXPECT_CALL(*tcp_client_socket, SetKeepAlive(_, _))
208      .WillOnce(testing::DoAll(SaveArg<0>(&enable),
209                               SaveArg<1>(&delay),
210                               Return(false)));
211  result = socket->SetKeepAlive(false, 0);
212  EXPECT_FALSE(result);
213  EXPECT_FALSE(enable);
214  EXPECT_EQ(0, delay);
215}
216
217TEST(SocketTest, TestTCPServerSocketListenAccept) {
218  MockTCPServerSocket* tcp_server_socket = new MockTCPServerSocket();
219  CompleteHandler handler;
220
221  scoped_ptr<TCPSocket> socket(TCPSocket::CreateServerSocketForTesting(
222      tcp_server_socket, FAKE_ID));
223
224  EXPECT_CALL(*tcp_server_socket, Accept(_, _)).Times(1);
225  EXPECT_CALL(*tcp_server_socket, Listen(_, _)).Times(1);
226  EXPECT_CALL(handler, OnAccept(_, _));
227
228  std::string err_msg;
229  EXPECT_EQ(net::OK, socket->Listen("127.0.0.1", 9999, 10, &err_msg));
230  socket->Accept(base::Bind(&CompleteHandler::OnAccept,
231        base::Unretained(&handler)));
232}
233
234}  // namespace extensions
235