1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/socks_client_socket_pool.h"
6
7#include "base/callback.h"
8#include "base/compiler_specific.h"
9#include "base/time/time.h"
10#include "net/base/load_timing_info.h"
11#include "net/base/load_timing_info_test_util.h"
12#include "net/base/net_errors.h"
13#include "net/base/test_completion_callback.h"
14#include "net/dns/mock_host_resolver.h"
15#include "net/socket/client_socket_factory.h"
16#include "net/socket/client_socket_handle.h"
17#include "net/socket/client_socket_pool_histograms.h"
18#include "net/socket/socket_test_util.h"
19#include "testing/gtest/include/gtest/gtest.h"
20
21namespace net {
22
23namespace {
24
25const int kMaxSockets = 32;
26const int kMaxSocketsPerGroup = 6;
27
28// Make sure |handle|'s load times are set correctly.  Only connect times should
29// be set.
30void TestLoadTimingInfo(const ClientSocketHandle& handle) {
31  LoadTimingInfo load_timing_info;
32  EXPECT_TRUE(handle.GetLoadTimingInfo(false, &load_timing_info));
33
34  // None of these tests use a NetLog.
35  EXPECT_EQ(NetLog::Source::kInvalidId, load_timing_info.socket_log_id);
36
37  EXPECT_FALSE(load_timing_info.socket_reused);
38
39  ExpectConnectTimingHasTimes(load_timing_info.connect_timing,
40                              CONNECT_TIMING_HAS_CONNECT_TIMES_ONLY);
41  ExpectLoadTimingHasOnlyConnectionTimes(load_timing_info);
42}
43
44class SOCKSClientSocketPoolTest : public testing::Test {
45 protected:
46  class SOCKS5MockData {
47   public:
48    explicit SOCKS5MockData(IoMode mode) {
49      writes_.reset(new MockWrite[3]);
50      writes_[0] = MockWrite(mode, kSOCKS5GreetRequest,
51                             kSOCKS5GreetRequestLength);
52      writes_[1] = MockWrite(mode, kSOCKS5OkRequest, kSOCKS5OkRequestLength);
53      writes_[2] = MockWrite(mode, 0);
54
55      reads_.reset(new MockRead[3]);
56      reads_[0] = MockRead(mode, kSOCKS5GreetResponse,
57                           kSOCKS5GreetResponseLength);
58      reads_[1] = MockRead(mode, kSOCKS5OkResponse, kSOCKS5OkResponseLength);
59      reads_[2] = MockRead(mode, 0);
60
61      data_.reset(new StaticSocketDataProvider(reads_.get(), 3,
62                                               writes_.get(), 3));
63    }
64
65    SocketDataProvider* data_provider() { return data_.get(); }
66
67   private:
68    scoped_ptr<StaticSocketDataProvider> data_;
69    scoped_ptr<MockWrite[]> writes_;
70    scoped_ptr<MockRead[]> reads_;
71  };
72
73  SOCKSClientSocketPoolTest()
74      : ignored_transport_socket_params_(new TransportSocketParams(
75          HostPortPair("proxy", 80), MEDIUM, false, false,
76          OnHostResolutionCallback())),
77        transport_histograms_("MockTCP"),
78        transport_socket_pool_(
79            kMaxSockets, kMaxSocketsPerGroup,
80            &transport_histograms_,
81            &transport_client_socket_factory_),
82        ignored_socket_params_(new SOCKSSocketParams(
83            ignored_transport_socket_params_, true, HostPortPair("host", 80),
84            MEDIUM)),
85        socks_histograms_("SOCKSUnitTest"),
86        pool_(kMaxSockets, kMaxSocketsPerGroup,
87              &socks_histograms_,
88              NULL,
89              &transport_socket_pool_,
90              NULL) {
91  }
92
93  virtual ~SOCKSClientSocketPoolTest() {}
94
95  int StartRequest(const std::string& group_name, RequestPriority priority) {
96    return test_base_.StartRequestUsingPool(
97        &pool_, group_name, priority, ignored_socket_params_);
98  }
99
100  int GetOrderOfRequest(size_t index) const {
101    return test_base_.GetOrderOfRequest(index);
102  }
103
104  ScopedVector<TestSocketRequest>* requests() { return test_base_.requests(); }
105
106  scoped_refptr<TransportSocketParams> ignored_transport_socket_params_;
107  ClientSocketPoolHistograms transport_histograms_;
108  MockClientSocketFactory transport_client_socket_factory_;
109  MockTransportClientSocketPool transport_socket_pool_;
110
111  scoped_refptr<SOCKSSocketParams> ignored_socket_params_;
112  ClientSocketPoolHistograms socks_histograms_;
113  SOCKSClientSocketPool pool_;
114  ClientSocketPoolTest test_base_;
115};
116
117TEST_F(SOCKSClientSocketPoolTest, Simple) {
118  SOCKS5MockData data(SYNCHRONOUS);
119  data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK));
120  transport_client_socket_factory_.AddSocketDataProvider(data.data_provider());
121
122  ClientSocketHandle handle;
123  int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(),
124                       &pool_, BoundNetLog());
125  EXPECT_EQ(OK, rv);
126  EXPECT_TRUE(handle.is_initialized());
127  EXPECT_TRUE(handle.socket());
128  TestLoadTimingInfo(handle);
129}
130
131TEST_F(SOCKSClientSocketPoolTest, Async) {
132  SOCKS5MockData data(ASYNC);
133  transport_client_socket_factory_.AddSocketDataProvider(data.data_provider());
134
135  TestCompletionCallback callback;
136  ClientSocketHandle handle;
137  int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(),
138                       &pool_, BoundNetLog());
139  EXPECT_EQ(ERR_IO_PENDING, rv);
140  EXPECT_FALSE(handle.is_initialized());
141  EXPECT_FALSE(handle.socket());
142
143  EXPECT_EQ(OK, callback.WaitForResult());
144  EXPECT_TRUE(handle.is_initialized());
145  EXPECT_TRUE(handle.socket());
146  TestLoadTimingInfo(handle);
147}
148
149TEST_F(SOCKSClientSocketPoolTest, TransportConnectError) {
150  StaticSocketDataProvider socket_data;
151  socket_data.set_connect_data(MockConnect(SYNCHRONOUS,
152                                           ERR_CONNECTION_REFUSED));
153  transport_client_socket_factory_.AddSocketDataProvider(&socket_data);
154
155  ClientSocketHandle handle;
156  int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(),
157                       &pool_, BoundNetLog());
158  EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, rv);
159  EXPECT_FALSE(handle.is_initialized());
160  EXPECT_FALSE(handle.socket());
161}
162
163TEST_F(SOCKSClientSocketPoolTest, AsyncTransportConnectError) {
164  StaticSocketDataProvider socket_data;
165  socket_data.set_connect_data(MockConnect(ASYNC, ERR_CONNECTION_REFUSED));
166  transport_client_socket_factory_.AddSocketDataProvider(&socket_data);
167
168  TestCompletionCallback callback;
169  ClientSocketHandle handle;
170  int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(),
171                       &pool_, BoundNetLog());
172  EXPECT_EQ(ERR_IO_PENDING, rv);
173  EXPECT_FALSE(handle.is_initialized());
174  EXPECT_FALSE(handle.socket());
175
176  EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, callback.WaitForResult());
177  EXPECT_FALSE(handle.is_initialized());
178  EXPECT_FALSE(handle.socket());
179}
180
181TEST_F(SOCKSClientSocketPoolTest, SOCKSConnectError) {
182  MockRead failed_read[] = {
183    MockRead(SYNCHRONOUS, 0),
184  };
185  StaticSocketDataProvider socket_data(
186      failed_read, arraysize(failed_read), NULL, 0);
187  socket_data.set_connect_data(MockConnect(SYNCHRONOUS, OK));
188  transport_client_socket_factory_.AddSocketDataProvider(&socket_data);
189
190  ClientSocketHandle handle;
191  EXPECT_EQ(0, transport_socket_pool_.release_count());
192  int rv = handle.Init("a", ignored_socket_params_, LOW, CompletionCallback(),
193                       &pool_, BoundNetLog());
194  EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv);
195  EXPECT_FALSE(handle.is_initialized());
196  EXPECT_FALSE(handle.socket());
197  EXPECT_EQ(1, transport_socket_pool_.release_count());
198}
199
200TEST_F(SOCKSClientSocketPoolTest, AsyncSOCKSConnectError) {
201  MockRead failed_read[] = {
202    MockRead(ASYNC, 0),
203  };
204  StaticSocketDataProvider socket_data(
205        failed_read, arraysize(failed_read), NULL, 0);
206  socket_data.set_connect_data(MockConnect(SYNCHRONOUS, OK));
207  transport_client_socket_factory_.AddSocketDataProvider(&socket_data);
208
209  TestCompletionCallback callback;
210  ClientSocketHandle handle;
211  EXPECT_EQ(0, transport_socket_pool_.release_count());
212  int rv = handle.Init("a", ignored_socket_params_, LOW, callback.callback(),
213                       &pool_, BoundNetLog());
214  EXPECT_EQ(ERR_IO_PENDING, rv);
215  EXPECT_FALSE(handle.is_initialized());
216  EXPECT_FALSE(handle.socket());
217
218  EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, callback.WaitForResult());
219  EXPECT_FALSE(handle.is_initialized());
220  EXPECT_FALSE(handle.socket());
221  EXPECT_EQ(1, transport_socket_pool_.release_count());
222}
223
224TEST_F(SOCKSClientSocketPoolTest, CancelDuringTransportConnect) {
225  SOCKS5MockData data(SYNCHRONOUS);
226  transport_client_socket_factory_.AddSocketDataProvider(data.data_provider());
227  // We need two connections because the pool base lets one cancelled
228  // connect job proceed for potential future use.
229  SOCKS5MockData data2(SYNCHRONOUS);
230  transport_client_socket_factory_.AddSocketDataProvider(data2.data_provider());
231
232  EXPECT_EQ(0, transport_socket_pool_.cancel_count());
233  int rv = StartRequest("a", LOW);
234  EXPECT_EQ(ERR_IO_PENDING, rv);
235
236  rv = StartRequest("a", LOW);
237  EXPECT_EQ(ERR_IO_PENDING, rv);
238
239  pool_.CancelRequest("a", (*requests())[0]->handle());
240  pool_.CancelRequest("a", (*requests())[1]->handle());
241  // Requests in the connect phase don't actually get cancelled.
242  EXPECT_EQ(0, transport_socket_pool_.cancel_count());
243
244  // Now wait for the TCP sockets to connect.
245  base::MessageLoop::current()->RunUntilIdle();
246
247  EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(1));
248  EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(2));
249  EXPECT_EQ(0, transport_socket_pool_.cancel_count());
250  EXPECT_EQ(2, pool_.IdleSocketCount());
251
252  (*requests())[0]->handle()->Reset();
253  (*requests())[1]->handle()->Reset();
254}
255
256TEST_F(SOCKSClientSocketPoolTest, CancelDuringSOCKSConnect) {
257  SOCKS5MockData data(ASYNC);
258  data.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK));
259  transport_client_socket_factory_.AddSocketDataProvider(data.data_provider());
260  // We need two connections because the pool base lets one cancelled
261  // connect job proceed for potential future use.
262  SOCKS5MockData data2(ASYNC);
263  data2.data_provider()->set_connect_data(MockConnect(SYNCHRONOUS, OK));
264  transport_client_socket_factory_.AddSocketDataProvider(data2.data_provider());
265
266  EXPECT_EQ(0, transport_socket_pool_.cancel_count());
267  EXPECT_EQ(0, transport_socket_pool_.release_count());
268  int rv = StartRequest("a", LOW);
269  EXPECT_EQ(ERR_IO_PENDING, rv);
270
271  rv = StartRequest("a", LOW);
272  EXPECT_EQ(ERR_IO_PENDING, rv);
273
274  pool_.CancelRequest("a", (*requests())[0]->handle());
275  pool_.CancelRequest("a", (*requests())[1]->handle());
276  EXPECT_EQ(0, transport_socket_pool_.cancel_count());
277  // Requests in the connect phase don't actually get cancelled.
278  EXPECT_EQ(0, transport_socket_pool_.release_count());
279
280  // Now wait for the async data to reach the SOCKS connect jobs.
281  base::MessageLoop::current()->RunUntilIdle();
282
283  EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(1));
284  EXPECT_EQ(ClientSocketPoolTest::kRequestNotFound, GetOrderOfRequest(2));
285  EXPECT_EQ(0, transport_socket_pool_.cancel_count());
286  EXPECT_EQ(0, transport_socket_pool_.release_count());
287  EXPECT_EQ(2, pool_.IdleSocketCount());
288
289  (*requests())[0]->handle()->Reset();
290  (*requests())[1]->handle()->Reset();
291}
292
293// It would be nice to also test the timeouts in SOCKSClientSocketPool.
294
295}  // namespace
296
297}  // namespace net
298