1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/socks_client_socket.h"
6
7#include "net/base/address_list.h"
8#include "net/base/net_log.h"
9#include "net/base/net_log_unittest.h"
10#include "net/base/mock_host_resolver.h"
11#include "net/base/test_completion_callback.h"
12#include "net/base/winsock_init.h"
13#include "net/socket/client_socket_factory.h"
14#include "net/socket/tcp_client_socket.h"
15#include "net/socket/socket_test_util.h"
16#include "testing/gtest/include/gtest/gtest.h"
17#include "testing/platform_test.h"
18
19//-----------------------------------------------------------------------------
20
21namespace net {
22
23const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 };
24const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
25
26class SOCKSClientSocketTest : public PlatformTest {
27 public:
28  SOCKSClientSocketTest();
29  // Create a SOCKSClientSocket on top of a MockSocket.
30  SOCKSClientSocket* BuildMockSocket(MockRead reads[], size_t reads_count,
31                                     MockWrite writes[], size_t writes_count,
32                                     HostResolver* host_resolver,
33                                     const std::string& hostname, int port,
34                                     NetLog* net_log);
35  virtual void SetUp();
36
37 protected:
38  scoped_ptr<SOCKSClientSocket> user_sock_;
39  AddressList address_list_;
40  ClientSocket* tcp_sock_;
41  TestCompletionCallback callback_;
42  scoped_ptr<MockHostResolver> host_resolver_;
43  scoped_ptr<SocketDataProvider> data_;
44};
45
46SOCKSClientSocketTest::SOCKSClientSocketTest()
47  : host_resolver_(new MockHostResolver) {
48}
49
50// Set up platform before every test case
51void SOCKSClientSocketTest::SetUp() {
52  PlatformTest::SetUp();
53}
54
55SOCKSClientSocket* SOCKSClientSocketTest::BuildMockSocket(
56    MockRead reads[],
57    size_t reads_count,
58    MockWrite writes[],
59    size_t writes_count,
60    HostResolver* host_resolver,
61    const std::string& hostname,
62    int port,
63    NetLog* net_log) {
64
65  TestCompletionCallback callback;
66  data_.reset(new StaticSocketDataProvider(reads, reads_count,
67                                           writes, writes_count));
68  tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
69
70  int rv = tcp_sock_->Connect(&callback);
71  EXPECT_EQ(ERR_IO_PENDING, rv);
72  rv = callback.WaitForResult();
73  EXPECT_EQ(OK, rv);
74  EXPECT_TRUE(tcp_sock_->IsConnected());
75
76  return new SOCKSClientSocket(tcp_sock_,
77      HostResolver::RequestInfo(HostPortPair(hostname, port)),
78      host_resolver);
79}
80
81// Implementation of HostResolver that never completes its resolve request.
82// We use this in the test "DisconnectWhileHostResolveInProgress" to make
83// sure that the outstanding resolve request gets cancelled.
84class HangingHostResolver : public HostResolver {
85 public:
86  HangingHostResolver() : outstanding_request_(NULL) {}
87
88  virtual int Resolve(const RequestInfo& info,
89                      AddressList* addresses,
90                      CompletionCallback* callback,
91                      RequestHandle* out_req,
92                      const BoundNetLog& net_log) {
93    EXPECT_FALSE(HasOutstandingRequest());
94    outstanding_request_ = reinterpret_cast<RequestHandle>(1);
95    *out_req = outstanding_request_;
96    return ERR_IO_PENDING;
97  }
98
99  virtual void CancelRequest(RequestHandle req) {
100    EXPECT_TRUE(HasOutstandingRequest());
101    EXPECT_EQ(outstanding_request_, req);
102    outstanding_request_ = NULL;
103  }
104
105  virtual void AddObserver(Observer* observer) {}
106  virtual void RemoveObserver(Observer* observer) {}
107
108  bool HasOutstandingRequest() {
109    return outstanding_request_ != NULL;
110  }
111
112 private:
113  RequestHandle outstanding_request_;
114
115  DISALLOW_COPY_AND_ASSIGN(HangingHostResolver);
116};
117
118// Tests a complete handshake and the disconnection.
119TEST_F(SOCKSClientSocketTest, CompleteHandshake) {
120  const std::string payload_write = "random data";
121  const std::string payload_read = "moar random data";
122
123  MockWrite data_writes[] = {
124      MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)),
125      MockWrite(true, payload_write.data(), payload_write.size()) };
126  MockRead data_reads[] = {
127      MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)),
128      MockRead(true, payload_read.data(), payload_read.size()) };
129  CapturingNetLog log(CapturingNetLog::kUnbounded);
130
131  user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
132                                   data_writes, arraysize(data_writes),
133                                   host_resolver_.get(),
134                                   "localhost", 80,
135                                   &log));
136
137  // At this state the TCP connection is completed but not the SOCKS handshake.
138  EXPECT_TRUE(tcp_sock_->IsConnected());
139  EXPECT_FALSE(user_sock_->IsConnected());
140
141  int rv = user_sock_->Connect(&callback_);
142  EXPECT_EQ(ERR_IO_PENDING, rv);
143
144  net::CapturingNetLog::EntryList entries;
145  log.GetEntries(&entries);
146  EXPECT_TRUE(
147      LogContainsBeginEvent(entries, 0, NetLog::TYPE_SOCKS_CONNECT));
148  EXPECT_FALSE(user_sock_->IsConnected());
149
150  rv = callback_.WaitForResult();
151  EXPECT_EQ(OK, rv);
152  EXPECT_TRUE(user_sock_->IsConnected());
153  log.GetEntries(&entries);
154  EXPECT_TRUE(LogContainsEndEvent(
155      entries, -1, NetLog::TYPE_SOCKS_CONNECT));
156
157  scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
158  memcpy(buffer->data(), payload_write.data(), payload_write.size());
159  rv = user_sock_->Write(buffer, payload_write.size(), &callback_);
160  EXPECT_EQ(ERR_IO_PENDING, rv);
161  rv = callback_.WaitForResult();
162  EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
163
164  buffer = new IOBuffer(payload_read.size());
165  rv = user_sock_->Read(buffer, payload_read.size(), &callback_);
166  EXPECT_EQ(ERR_IO_PENDING, rv);
167  rv = callback_.WaitForResult();
168  EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
169  EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
170
171  user_sock_->Disconnect();
172  EXPECT_FALSE(tcp_sock_->IsConnected());
173  EXPECT_FALSE(user_sock_->IsConnected());
174}
175
176// List of responses from the socks server and the errors they should
177// throw up are tested here.
178TEST_F(SOCKSClientSocketTest, HandshakeFailures) {
179  const struct {
180    const char fail_reply[8];
181    Error fail_code;
182  } tests[] = {
183    // Failure of the server response code
184    {
185      { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 },
186      ERR_SOCKS_CONNECTION_FAILED,
187    },
188    // Failure of the null byte
189    {
190      { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 },
191      ERR_SOCKS_CONNECTION_FAILED,
192    },
193  };
194
195  //---------------------------------------
196
197  for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) {
198    MockWrite data_writes[] = {
199        MockWrite(false, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
200    MockRead data_reads[] = {
201        MockRead(false, tests[i].fail_reply, arraysize(tests[i].fail_reply)) };
202    CapturingNetLog log(CapturingNetLog::kUnbounded);
203
204    user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
205                                     data_writes, arraysize(data_writes),
206                                     host_resolver_.get(),
207                                     "localhost", 80,
208                                     &log));
209
210    int rv = user_sock_->Connect(&callback_);
211    EXPECT_EQ(ERR_IO_PENDING, rv);
212
213    net::CapturingNetLog::EntryList entries;
214    log.GetEntries(&entries);
215    EXPECT_TRUE(LogContainsBeginEvent(
216        entries, 0, NetLog::TYPE_SOCKS_CONNECT));
217
218    rv = callback_.WaitForResult();
219    EXPECT_EQ(tests[i].fail_code, rv);
220    EXPECT_FALSE(user_sock_->IsConnected());
221    EXPECT_TRUE(tcp_sock_->IsConnected());
222    log.GetEntries(&entries);
223    EXPECT_TRUE(LogContainsEndEvent(
224        entries, -1, NetLog::TYPE_SOCKS_CONNECT));
225  }
226}
227
228// Tests scenario when the server sends the handshake response in
229// more than one packet.
230TEST_F(SOCKSClientSocketTest, PartialServerReads) {
231  const char kSOCKSPartialReply1[] = { 0x00 };
232  const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
233
234  MockWrite data_writes[] = {
235      MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
236  MockRead data_reads[] = {
237      MockRead(true, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)),
238      MockRead(true, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) };
239  CapturingNetLog log(CapturingNetLog::kUnbounded);
240
241  user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
242                                   data_writes, arraysize(data_writes),
243                                   host_resolver_.get(),
244                                   "localhost", 80,
245                                   &log));
246
247  int rv = user_sock_->Connect(&callback_);
248  EXPECT_EQ(ERR_IO_PENDING, rv);
249  net::CapturingNetLog::EntryList entries;
250  log.GetEntries(&entries);
251  EXPECT_TRUE(LogContainsBeginEvent(
252      entries, 0, NetLog::TYPE_SOCKS_CONNECT));
253
254  rv = callback_.WaitForResult();
255  EXPECT_EQ(OK, rv);
256  EXPECT_TRUE(user_sock_->IsConnected());
257  log.GetEntries(&entries);
258  EXPECT_TRUE(LogContainsEndEvent(
259      entries, -1, NetLog::TYPE_SOCKS_CONNECT));
260}
261
262// Tests scenario when the client sends the handshake request in
263// more than one packet.
264TEST_F(SOCKSClientSocketTest, PartialClientWrites) {
265  const char kSOCKSPartialRequest1[] = { 0x04, 0x01 };
266  const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 };
267
268  MockWrite data_writes[] = {
269      MockWrite(true, arraysize(kSOCKSPartialRequest1)),
270      // simulate some empty writes
271      MockWrite(true, 0),
272      MockWrite(true, 0),
273      MockWrite(true, kSOCKSPartialRequest2,
274                arraysize(kSOCKSPartialRequest2)) };
275  MockRead data_reads[] = {
276      MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply)) };
277  CapturingNetLog log(CapturingNetLog::kUnbounded);
278
279  user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
280                                   data_writes, arraysize(data_writes),
281                                   host_resolver_.get(),
282                                   "localhost", 80,
283                                   &log));
284
285  int rv = user_sock_->Connect(&callback_);
286  EXPECT_EQ(ERR_IO_PENDING, rv);
287  net::CapturingNetLog::EntryList entries;
288  log.GetEntries(&entries);
289  EXPECT_TRUE(LogContainsBeginEvent(
290      entries, 0, NetLog::TYPE_SOCKS_CONNECT));
291
292  rv = callback_.WaitForResult();
293  EXPECT_EQ(OK, rv);
294  EXPECT_TRUE(user_sock_->IsConnected());
295  log.GetEntries(&entries);
296  EXPECT_TRUE(LogContainsEndEvent(
297      entries, -1, NetLog::TYPE_SOCKS_CONNECT));
298}
299
300// Tests the case when the server sends a smaller sized handshake data
301// and closes the connection.
302TEST_F(SOCKSClientSocketTest, FailedSocketRead) {
303  MockWrite data_writes[] = {
304      MockWrite(true, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
305  MockRead data_reads[] = {
306      MockRead(true, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2),
307      // close connection unexpectedly
308      MockRead(false, 0) };
309  CapturingNetLog log(CapturingNetLog::kUnbounded);
310
311  user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
312                                   data_writes, arraysize(data_writes),
313                                   host_resolver_.get(),
314                                   "localhost", 80,
315                                   &log));
316
317  int rv = user_sock_->Connect(&callback_);
318  EXPECT_EQ(ERR_IO_PENDING, rv);
319  net::CapturingNetLog::EntryList entries;
320  log.GetEntries(&entries);
321  EXPECT_TRUE(LogContainsBeginEvent(
322      entries, 0, NetLog::TYPE_SOCKS_CONNECT));
323
324  rv = callback_.WaitForResult();
325  EXPECT_EQ(ERR_CONNECTION_CLOSED, rv);
326  EXPECT_FALSE(user_sock_->IsConnected());
327  log.GetEntries(&entries);
328  EXPECT_TRUE(LogContainsEndEvent(
329      entries, -1, NetLog::TYPE_SOCKS_CONNECT));
330}
331
332// Tries to connect to an unknown hostname. Should fail rather than
333// falling back to SOCKS4a.
334TEST_F(SOCKSClientSocketTest, FailedDNS) {
335  const char hostname[] = "unresolved.ipv4.address";
336
337  host_resolver_->rules()->AddSimulatedFailure(hostname);
338
339  CapturingNetLog log(CapturingNetLog::kUnbounded);
340
341  user_sock_.reset(BuildMockSocket(NULL, 0,
342                                   NULL, 0,
343                                   host_resolver_.get(),
344                                   hostname, 80,
345                                   &log));
346
347  int rv = user_sock_->Connect(&callback_);
348  EXPECT_EQ(ERR_IO_PENDING, rv);
349  net::CapturingNetLog::EntryList entries;
350  log.GetEntries(&entries);
351  EXPECT_TRUE(LogContainsBeginEvent(
352      entries, 0, NetLog::TYPE_SOCKS_CONNECT));
353
354  rv = callback_.WaitForResult();
355  EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv);
356  EXPECT_FALSE(user_sock_->IsConnected());
357  log.GetEntries(&entries);
358  EXPECT_TRUE(LogContainsEndEvent(
359      entries, -1, NetLog::TYPE_SOCKS_CONNECT));
360}
361
362// Calls Disconnect() while a host resolve is in progress. The outstanding host
363// resolve should be cancelled.
364TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) {
365  scoped_ptr<HangingHostResolver> hanging_resolver(new HangingHostResolver());
366
367  // Doesn't matter what the socket data is, we will never use it -- garbage.
368  MockWrite data_writes[] = { MockWrite(false, "", 0) };
369  MockRead data_reads[] = { MockRead(false, "", 0) };
370
371  user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
372                                   data_writes, arraysize(data_writes),
373                                   hanging_resolver.get(),
374                                   "foo", 80,
375                                   NULL));
376
377  // Start connecting (will get stuck waiting for the host to resolve).
378  int rv = user_sock_->Connect(&callback_);
379  EXPECT_EQ(ERR_IO_PENDING, rv);
380
381  EXPECT_FALSE(user_sock_->IsConnected());
382  EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
383
384  // The host resolver should have received the resolve request.
385  EXPECT_TRUE(hanging_resolver->HasOutstandingRequest());
386
387  // Disconnect the SOCKS socket -- this should cancel the outstanding resolve.
388  user_sock_->Disconnect();
389
390  EXPECT_FALSE(hanging_resolver->HasOutstandingRequest());
391
392  EXPECT_FALSE(user_sock_->IsConnected());
393  EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
394}
395
396}  // namespace net
397