1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/socks_client_socket_pool.h"
6
7#include "base/callback.h"
8#include "base/compiler_specific.h"
9#include "base/time/time.h"
10#include "net/base/load_timing_info.h"
11#include "net/base/load_timing_info_test_util.h"
12#include "net/base/net_errors.h"
13#include "net/base/test_completion_callback.h"
14#include "net/dns/mock_host_resolver.h"
15#include "net/socket/client_socket_factory.h"
16#include "net/socket/client_socket_handle.h"
17#include "net/socket/client_socket_pool_histograms.h"
18#include "net/socket/socket_test_util.h"
19#include "testing/gtest/include/gtest/gtest.h"
20
21namespace net {
22
23namespace {
24
25const int kMaxSockets = 32;
26const int kMaxSocketsPerGroup = 6;
27
28// Make sure |handle|'s load times are set correctly.  Only connect times should
29// be set.
30void TestLoadTimingInfo(const ClientSocketHandle& handle) {
31  LoadTimingInfo load_timing_info;
32  EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));
33
34  // None of these tests use a NetLog.
35  EXPECT_EQ(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
36
37  EXPECT_FALSE(load_timing_info.socket_reused);
38
39  ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
40                              CONNECT_TIMING_HAS_CONNECT_TIMES_ONLY);
41  ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
42}
43
44
45scoped_refptr<TransportSocketParams> CreateProxyHostParams() {
46  return new TransportSocketParams(
47      HostPortPair("proxy", 80), false, false, OnHostResolutionCallback(),
48      TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT);
49}
50
51scoped_refptr<SOCKSSocketParams> CreateSOCKSv4Params() {
52  return new SOCKSSocketParams(
53      CreateProxyHostParams(), false /* socks_v5 */,
54      HostPortPair("host", 80));
55}
56
57scoped_refptr<SOCKSSocketParams> CreateSOCKSv5Params() {
58  return new SOCKSSocketParams(
59      CreateProxyHostParams(), true /* socks_v5 */,
60      HostPortPair("host", 80));
61}
62
63class SOCKSClientSocketPoolTest : public testing::Test {
64 protected:
65  class SOCKS5MockData {
66   public:
67    explicit SOCKS5MockData(IoMode mode) {
68      writes_.reset(new MockWrite[3]);
69      writes_[0] = MockWrite(mode, kSOCKS5GreetRequest,
70                             kSOCKS5GreetRequestLength);
71      writes_[1] = MockWrite(mode, kSOCKS5OkRequest, kSOCKS5OkRequestLength);
72      writes_[2] = MockWrite(mode, 0);
73
74      reads_.reset(new MockRead[3]);
75      reads_[0] = MockRead(mode, kSOCKS5GreetResponse,
76                           kSOCKS5GreetResponseLength);
77      reads_[1] = MockRead(mode, kSOCKS5OkResponse, kSOCKS5OkResponseLength);
78      reads_[2] = MockRead(mode, 0);
79
80      data_.reset(new StaticSocketDataProvider(reads_.get(), 3,
81                                               writes_.get(), 3));
82    }
83
84    SocketDataProvider* data_provider() { return data_.get(); }
85
86   private:
87    scoped_ptr<StaticSocketDataProvider> data_;
88    scoped_ptr<MockWrite[]> writes_;
89    scoped_ptr<MockRead[]> reads_;
90  };
91
92  SOCKSClientSocketPoolTest()
93      : transport_histograms_("MockTCP"),
94        transport_socket_pool_(
95            kMaxSockets, kMaxSocketsPerGroup,
96            &transport_histograms_,
97            &transport_client_socket_factory_),
98        socks_histograms_("SOCKSUnitTest"),
99        pool_(kMaxSockets, kMaxSocketsPerGroup,
100              &socks_histograms_,
101              &host_resolver_,
102              &transport_socket_pool_,
103              NULL) {
104  }
105
106  virtual ~SOCKSClientSocketPoolTest() {}
107
108  int StartRequestV5(const std::string& group_name, RequestPriority priority) {
109    return test_base_.StartRequestUsingPool(
110        &pool_, group_name, priority, CreateSOCKSv5Params());
111  }
112
113  int GetOrderOfRequest(size_t index) const {
114    return test_base_.GetOrderOfRequest(index);
115  }
116
117  ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); }
118
119  ClientSocketPoolHistograms transport_histograms_;
120  MockClientSocketFactory transport_client_socket_factory_;
121  MockTransportClientSocketPool transport_socket_pool_;
122
123  ClientSocketPoolHistograms socks_histograms_;
124  MockHostResolver host_resolver_;
125  SOCKSClientSocketPool pool_;
126  ClientSocketPoolTest test_base_;
127};
128
129TEST_F(SOCKSClientSocketPoolTest, Simple) {
130  SOCKS5MockData data(SYNCHRONOUS);
131  data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK));
132  transport_client_socket_factory_.AddSocketDataProvider(data.data_provider());
133
134  ClientSocketHandle handle;
135  int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, CompletionCallback(),
136                       &pool_, BoundNetLog());
137  EXPECT_EQ(OK, rv);
138  EXPECT_TRUE(handle.is_initialized());
139  EXPECT_TRUE(handle.socket());
140  TestLoadTimingInfo(handle);
141}
142
143// Make sure that SOCKSConnectJob passes on its priority to its
144// socket request on Init.
145TEST_F(SOCKSClientSocketPoolTest, SetSocketRequestPriorityOnInit) {
146  for (int i = MINIMUM_PRIORITY; i <= MAXIMUM_PRIORITY; ++i) {
147    RequestPriority priority = static_cast<RequestPriority>(i);
148    SOCKS5MockData data(SYNCHRONOUS);
149    data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK));
150    transport_client_socket_factory_.AddSocketDataProvider(
151        data.data_provider());
152
153    ClientSocketHandle handle;
154    EXPECT_EQ(OK,
155              handle.Init("a", CreateSOCKSv5Params(), priority,
156                          CompletionCallback(), &pool_, BoundNetLog()));
157    EXPECT_EQ(priority, transport_socket_pool_.last_request_priority());
158    handle.socket()->Disconnect();
159  }
160}
161
162// Make sure that SOCKSConnectJob passes on its priority to its
163// HostResolver request (for non-SOCKS5) on Init.
164TEST_F(SOCKSClientSocketPoolTest, SetResolvePriorityOnInit) {
165  for (int i = MINIMUM_PRIORITY; i <= MAXIMUM_PRIORITY; ++i) {
166    RequestPriority priority = static_cast<RequestPriority>(i);
167    SOCKS5MockData data(SYNCHRONOUS);
168    data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK));
169    transport_client_socket_factory_.AddSocketDataProvider(
170        data.data_provider());
171
172    ClientSocketHandle handle;
173    EXPECT_EQ(ERR_IO_PENDING,
174              handle.Init("a", CreateSOCKSv4Params(), priority,
175                          CompletionCallback(), &pool_, BoundNetLog()));
176    EXPECT_EQ(priority, transport_socket_pool_.last_request_priority());
177    EXPECT_EQ(priority, host_resolver_.last_request_priority());
178    EXPECT_TRUE(handle.socket() == NULL);
179  }
180}
181
182TEST_F(SOCKSClientSocketPoolTest, Async) {
183  SOCKS5MockData data(ASYNC);
184  transport_client_socket_factory_.AddSocketDataProvider(data.data_provider());
185
186  TestCompletionCallback callback;
187  ClientSocketHandle handle;
188  int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, callback.callback(),
189                       &pool_, BoundNetLog());
190  EXPECT_EQ(ERR_IO_PENDING, rv);
191  EXPECT_FALSE(handle.is_initialized());
192  EXPECT_FALSE(handle.socket());
193
194  EXPECT_EQ(OK, callback.WaitForResult());
195  EXPECT_TRUE(handle.is_initialized());
196  EXPECT_TRUE(handle.socket());
197  TestLoadTimingInfo(handle);
198}
199
200TEST_F(SOCKSClientSocketPoolTest, TransportConnectError) {
201  StaticSocketDataProvider socket_data;
202  socket_data.set_connect_data(MockConnect(SYNCHRONOUS,
203                                           ERR_CONNECTION_REFUSED));
204  transport_client_socket_factory_.AddSocketDataProvider(&socket_data);
205
206  ClientSocketHandle handle;
207  int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, CompletionCallback(),
208                       &pool_, BoundNetLog());
209  EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, rv);
210  EXPECT_FALSE(handle.is_initialized());
211  EXPECT_FALSE(handle.socket());
212}
213
214TEST_F(SOCKSClientSocketPoolTest, AsyncTransportConnectError) {
215  StaticSocketDataProvider socket_data;
216  socket_data.set_connect_data(MockConnect(ASYNC, ERR_CONNECTION_REFUSED));
217  transport_client_socket_factory_.AddSocketDataProvider(&socket_data);
218
219  TestCompletionCallback callback;
220  ClientSocketHandle handle;
221  int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, callback.callback(),
222                       &pool_, BoundNetLog());
223  EXPECT_EQ(ERR_IO_PENDING, rv);
224  EXPECT_FALSE(handle.is_initialized());
225  EXPECT_FALSE(handle.socket());
226
227  EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, callback.WaitForResult());
228  EXPECT_FALSE(handle.is_initialized());
229  EXPECT_FALSE(handle.socket());
230}
231
232TEST_F(SOCKSClientSocketPoolTest, SOCKSConnectError) {
233  MockRead failed_read[] = {
234    MockRead(SYNCHRONOUS, 0),
235  };
236  StaticSocketDataProvider socket_data(
237      failed_read, arraysize(failed_read), NULL, 0);
238  socket_data.set_connect_data(MockConnect(SYNCHRONOUS, OK));
239  transport_client_socket_factory_.AddSocketDataProvider(&socket_data);
240
241  ClientSocketHandle handle;
242  EXPECT_EQ(0, transport_socket_pool_.release_count());
243  int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, CompletionCallback(),
244                       &pool_, BoundNetLog());
245  EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv);
246  EXPECT_FALSE(handle.is_initialized());
247  EXPECT_FALSE(handle.socket());
248  EXPECT_EQ(1, transport_socket_pool_.release_count());
249}
250
251TEST_F(SOCKSClientSocketPoolTest, AsyncSOCKSConnectError) {
252  MockRead failed_read[] = {
253    MockRead(ASYNC, 0),
254  };
255  StaticSocketDataProvider socket_data(
256        failed_read, arraysize(failed_read), NULL, 0);
257  socket_data.set_connect_data(MockConnect(SYNCHRONOUS, OK));
258  transport_client_socket_factory_.AddSocketDataProvider(&socket_data);
259
260  TestCompletionCallback callback;
261  ClientSocketHandle handle;
262  EXPECT_EQ(0, transport_socket_pool_.release_count());
263  int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, callback.callback(),
264                       &pool_, BoundNetLog());
265  EXPECT_EQ(ERR_IO_PENDING, rv);
266  EXPECT_FALSE(handle.is_initialized());
267  EXPECT_FALSE(handle.socket());
268
269  EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, callback.WaitForResult());
270  EXPECT_FALSE(handle.is_initialized());
271  EXPECT_FALSE(handle.socket());
272  EXPECT_EQ(1, transport_socket_pool_.release_count());
273}
274
275TEST_F(SOCKSClientSocketPoolTest, CancelDuringTransportConnect) {
276  SOCKS5MockData data(SYNCHRONOUS);
277  transport_client_socket_factory_.AddSocketDataProvider(data.data_provider());
278  // We need two connections because the pool base lets one cancelled
279  // connect job proceed for potential future use.
280  SOCKS5MockData data2(SYNCHRONOUS);
281  transport_client_socket_factory_.AddSocketDataProvider(data2.data_provider());
282
283  EXPECT_EQ(0, transport_socket_pool_.cancel_count());
284  int rv = StartRequestV5("a", LOW);
285  EXPECT_EQ(ERR_IO_PENDING, rv);
286
287  rv = StartRequestV5("a", LOW);
288  EXPECT_EQ(ERR_IO_PENDING, rv);
289
290  pool_.CancelRequest("a", (*requests())[0]->handle());
291  pool_.CancelRequest("a", (*requests())[1]->handle());
292  // Requests in the connect phase don't actually get cancelled.
293  EXPECT_EQ(0, transport_socket_pool_.cancel_count());
294
295  // Now wait for the TCP sockets to connect.
296  base::MessageLoop::current()->RunUntilIdle();
297
298  EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(1));
299  EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(2));
300  EXPECT_EQ(0, transport_socket_pool_.cancel_count());
301  EXPECT_EQ(2, pool_.IdleSocketCount());
302
303  (*requests())[0]->handle()->Reset();
304  (*requests())[1]->handle()->Reset();
305}
306
307TEST_F(SOCKSClientSocketPoolTest, CancelDuringSOCKSConnect) {
308  SOCKS5MockData data(ASYNC);
309  data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK));
310  transport_client_socket_factory_.AddSocketDataProvider(data.data_provider());
311  // We need two connections because the pool base lets one cancelled
312  // connect job proceed for potential future use.
313  SOCKS5MockData data2(ASYNC);
314  data2.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK));
315  transport_client_socket_factory_.AddSocketDataProvider(data2.data_provider());
316
317  EXPECT_EQ(0, transport_socket_pool_.cancel_count());
318  EXPECT_EQ(0, transport_socket_pool_.release_count());
319  int rv = StartRequestV5("a", LOW);
320  EXPECT_EQ(ERR_IO_PENDING, rv);
321
322  rv = StartRequestV5("a", LOW);
323  EXPECT_EQ(ERR_IO_PENDING, rv);
324
325  pool_.CancelRequest("a", (*requests())[0]->handle());
326  pool_.CancelRequest("a", (*requests())[1]->handle());
327  EXPECT_EQ(0, transport_socket_pool_.cancel_count());
328  // Requests in the connect phase don't actually get cancelled.
329  EXPECT_EQ(0, transport_socket_pool_.release_count());
330
331  // Now wait for the async data to reach the SOCKS connect jobs.
332  base::MessageLoop::current()->RunUntilIdle();
333
334  EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(1));
335  EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(2));
336  EXPECT_EQ(0, transport_socket_pool_.cancel_count());
337  EXPECT_EQ(0, transport_socket_pool_.release_count());
338  EXPECT_EQ(2, pool_.IdleSocketCount());
339
340  (*requests())[0]->handle()->Reset();
341  (*requests())[1]->handle()->Reset();
342}
343
344// It would be nice to also test the timeouts in SOCKSClientSocketPool.
345
346}  // namespace
347
348}  // namespace net
349