1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/socks_client_socket.h"
6
7#include "base/memory/scoped_ptr.h"
8#include "net/base/address_list.h"
9#include "net/base/net_log.h"
10#include "net/base/net_log_unittest.h"
11#include "net/base/test_completion_callback.h"
12#include "net/base/winsock_init.h"
13#include "net/dns/host_resolver.h"
14#include "net/dns/mock_host_resolver.h"
15#include "net/socket/client_socket_factory.h"
16#include "net/socket/socket_test_util.h"
17#include "net/socket/tcp_client_socket.h"
18#include "testing/gtest/include/gtest/gtest.h"
19#include "testing/platform_test.h"
20
21//-----------------------------------------------------------------------------
22
23namespace net {
24
25const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 };
26const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
27
28class SOCKSClientSocketTest : public PlatformTest {
29 public:
30  SOCKSClientSocketTest();
31  // Create a SOCKSClientSocket on top of a MockSocket.
32  scoped_ptr<SOCKSClientSocket> BuildMockSocket(
33      MockRead reads[], size_t reads_count,
34      MockWrite writes[], size_t writes_count,
35      HostResolver* host_resolver,
36      const std::string& hostname, int port,
37      NetLog* net_log);
38  virtual void SetUp();
39
40 protected:
41  scoped_ptr<SOCKSClientSocket> user_sock_;
42  AddressList address_list_;
43  // Filled in by BuildMockSocket() and owned by its return value
44  // (which |user_sock| is set to).
45  StreamSocket* tcp_sock_;
46  TestCompletionCallback callback_;
47  scoped_ptr<MockHostResolver> host_resolver_;
48  scoped_ptr<SocketDataProvider> data_;
49};
50
51SOCKSClientSocketTest::SOCKSClientSocketTest()
52  : host_resolver_(new MockHostResolver) {
53}
54
55// Set up platform before every test case
56void SOCKSClientSocketTest::SetUp() {
57  PlatformTest::SetUp();
58}
59
60scoped_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket(
61    MockRead reads[],
62    size_t reads_count,
63    MockWrite writes[],
64    size_t writes_count,
65    HostResolver* host_resolver,
66    const std::string& hostname,
67    int port,
68    NetLog* net_log) {
69
70  TestCompletionCallback callback;
71  data_.reset(new StaticSocketDataProvider(reads, reads_count,
72                                           writes, writes_count));
73  tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
74
75  int rv = tcp_sock_->Connect(callback.callback());
76  EXPECT_EQ(ERR_IO_PENDING, rv);
77  rv = callback.WaitForResult();
78  EXPECT_EQ(OK, rv);
79  EXPECT_TRUE(tcp_sock_->IsConnected());
80
81  scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle);
82  // |connection| takes ownership of |tcp_sock_|, but keep a
83  // non-owning pointer to it.
84  connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_));
85  return scoped_ptr<SOCKSClientSocket>(new SOCKSClientSocket(
86      connection.Pass(),
87      HostResolver::RequestInfo(HostPortPair(hostname, port)),
88      DEFAULT_PRIORITY,
89      host_resolver));
90}
91
92// Implementation of HostResolver that never completes its resolve request.
93// We use this in the test "DisconnectWhileHostResolveInProgress" to make
94// sure that the outstanding resolve request gets cancelled.
95class HangingHostResolverWithCancel : public HostResolver {
96 public:
97  HangingHostResolverWithCancel() : outstanding_request_(NULL) {}
98
99  virtual int Resolve(const RequestInfo& info,
100                      RequestPriority priority,
101                      AddressList* addresses,
102                      const CompletionCallback& callback,
103                      RequestHandle* out_req,
104                      const BoundNetLog& net_log) OVERRIDE {
105    DCHECK(addresses);
106    DCHECK_EQ(false, callback.is_null());
107    EXPECT_FALSE(HasOutstandingRequest());
108    outstanding_request_ = reinterpret_cast<RequestHandle>(1);
109    *out_req = outstanding_request_;
110    return ERR_IO_PENDING;
111  }
112
113  virtual int ResolveFromCache(const RequestInfo& info,
114                               AddressList* addresses,
115                               const BoundNetLog& net_log) OVERRIDE {
116    NOTIMPLEMENTED();
117    return ERR_UNEXPECTED;
118  }
119
120  virtual void CancelRequest(RequestHandle req) OVERRIDE {
121    EXPECT_TRUE(HasOutstandingRequest());
122    EXPECT_EQ(outstanding_request_, req);
123    outstanding_request_ = NULL;
124  }
125
126  bool HasOutstandingRequest() {
127    return outstanding_request_ != NULL;
128  }
129
130 private:
131  RequestHandle outstanding_request_;
132
133  DISALLOW_COPY_AND_ASSIGN(HangingHostResolverWithCancel);
134};
135
136// Tests a complete handshake and the disconnection.
137TEST_F(SOCKSClientSocketTest, CompleteHandshake) {
138  const std::string payload_write = "random data";
139  const std::string payload_read = "moar random data";
140
141  MockWrite data_writes[] = {
142      MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)),
143      MockWrite(ASYNC, payload_write.data(), payload_write.size()) };
144  MockRead data_reads[] = {
145      MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)),
146      MockRead(ASYNC, payload_read.data(), payload_read.size()) };
147  CapturingNetLog log;
148
149  user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
150                               data_writes, arraysize(data_writes),
151                               host_resolver_.get(),
152                               "localhost", 80,
153                               &log);
154
155  // At this state the TCP connection is completed but not the SOCKS handshake.
156  EXPECT_TRUE(tcp_sock_->IsConnected());
157  EXPECT_FALSE(user_sock_->IsConnected());
158
159  int rv = user_sock_->Connect(callback_.callback());
160  EXPECT_EQ(ERR_IO_PENDING, rv);
161
162  CapturingNetLog::CapturedEntryList entries;
163  log.GetEntries(&entries);
164  EXPECT_TRUE(
165      LogContainsBeginEvent(entries, 0, NetLog::TYPE_SOCKS_CONNECT));
166  EXPECT_FALSE(user_sock_->IsConnected());
167
168  rv = callback_.WaitForResult();
169  EXPECT_EQ(OK, rv);
170  EXPECT_TRUE(user_sock_->IsConnected());
171  log.GetEntries(&entries);
172  EXPECT_TRUE(LogContainsEndEvent(
173      entries, -1, NetLog::TYPE_SOCKS_CONNECT));
174
175  scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
176  memcpy(buffer->data(), payload_write.data(), payload_write.size());
177  rv = user_sock_->Write(
178      buffer.get(), payload_write.size(), callback_.callback());
179  EXPECT_EQ(ERR_IO_PENDING, rv);
180  rv = callback_.WaitForResult();
181  EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
182
183  buffer = new IOBuffer(payload_read.size());
184  rv =
185      user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
186  EXPECT_EQ(ERR_IO_PENDING, rv);
187  rv = callback_.WaitForResult();
188  EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
189  EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
190
191  user_sock_->Disconnect();
192  EXPECT_FALSE(tcp_sock_->IsConnected());
193  EXPECT_FALSE(user_sock_->IsConnected());
194}
195
196// List of responses from the socks server and the errors they should
197// throw up are tested here.
198TEST_F(SOCKSClientSocketTest, HandshakeFailures) {
199  const struct {
200    const char fail_reply[8];
201    Error fail_code;
202  } tests[] = {
203    // Failure of the server response code
204    {
205      { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 },
206      ERR_SOCKS_CONNECTION_FAILED,
207    },
208    // Failure of the null byte
209    {
210      { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 },
211      ERR_SOCKS_CONNECTION_FAILED,
212    },
213  };
214
215  //---------------------------------------
216
217  for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) {
218    MockWrite data_writes[] = {
219        MockWrite(SYNCHRONOUS, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
220    MockRead data_reads[] = {
221        MockRead(SYNCHRONOUS, tests[i].fail_reply,
222                 arraysize(tests[i].fail_reply)) };
223    CapturingNetLog log;
224
225    user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
226                                 data_writes, arraysize(data_writes),
227                                 host_resolver_.get(),
228                                 "localhost", 80,
229                                 &log);
230
231    int rv = user_sock_->Connect(callback_.callback());
232    EXPECT_EQ(ERR_IO_PENDING, rv);
233
234    CapturingNetLog::CapturedEntryList entries;
235    log.GetEntries(&entries);
236    EXPECT_TRUE(LogContainsBeginEvent(
237        entries, 0, NetLog::TYPE_SOCKS_CONNECT));
238
239    rv = callback_.WaitForResult();
240    EXPECT_EQ(tests[i].fail_code, rv);
241    EXPECT_FALSE(user_sock_->IsConnected());
242    EXPECT_TRUE(tcp_sock_->IsConnected());
243    log.GetEntries(&entries);
244    EXPECT_TRUE(LogContainsEndEvent(
245        entries, -1, NetLog::TYPE_SOCKS_CONNECT));
246  }
247}
248
249// Tests scenario when the server sends the handshake response in
250// more than one packet.
251TEST_F(SOCKSClientSocketTest, PartialServerReads) {
252  const char kSOCKSPartialReply1[] = { 0x00 };
253  const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 };
254
255  MockWrite data_writes[] = {
256      MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
257  MockRead data_reads[] = {
258      MockRead(ASYNC, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)),
259      MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) };
260  CapturingNetLog log;
261
262  user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
263                               data_writes, arraysize(data_writes),
264                               host_resolver_.get(),
265                               "localhost", 80,
266                               &log);
267
268  int rv = user_sock_->Connect(callback_.callback());
269  EXPECT_EQ(ERR_IO_PENDING, rv);
270  CapturingNetLog::CapturedEntryList entries;
271  log.GetEntries(&entries);
272  EXPECT_TRUE(LogContainsBeginEvent(
273      entries, 0, NetLog::TYPE_SOCKS_CONNECT));
274
275  rv = callback_.WaitForResult();
276  EXPECT_EQ(OK, rv);
277  EXPECT_TRUE(user_sock_->IsConnected());
278  log.GetEntries(&entries);
279  EXPECT_TRUE(LogContainsEndEvent(
280      entries, -1, NetLog::TYPE_SOCKS_CONNECT));
281}
282
283// Tests scenario when the client sends the handshake request in
284// more than one packet.
285TEST_F(SOCKSClientSocketTest, PartialClientWrites) {
286  const char kSOCKSPartialRequest1[] = { 0x04, 0x01 };
287  const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 };
288
289  MockWrite data_writes[] = {
290      MockWrite(ASYNC, arraysize(kSOCKSPartialRequest1)),
291      // simulate some empty writes
292      MockWrite(ASYNC, 0),
293      MockWrite(ASYNC, 0),
294      MockWrite(ASYNC, kSOCKSPartialRequest2,
295                arraysize(kSOCKSPartialRequest2)) };
296  MockRead data_reads[] = {
297      MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) };
298  CapturingNetLog log;
299
300  user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
301                               data_writes, arraysize(data_writes),
302                               host_resolver_.get(),
303                               "localhost", 80,
304                               &log);
305
306  int rv = user_sock_->Connect(callback_.callback());
307  EXPECT_EQ(ERR_IO_PENDING, rv);
308  CapturingNetLog::CapturedEntryList entries;
309  log.GetEntries(&entries);
310  EXPECT_TRUE(LogContainsBeginEvent(
311      entries, 0, NetLog::TYPE_SOCKS_CONNECT));
312
313  rv = callback_.WaitForResult();
314  EXPECT_EQ(OK, rv);
315  EXPECT_TRUE(user_sock_->IsConnected());
316  log.GetEntries(&entries);
317  EXPECT_TRUE(LogContainsEndEvent(
318      entries, -1, NetLog::TYPE_SOCKS_CONNECT));
319}
320
321// Tests the case when the server sends a smaller sized handshake data
322// and closes the connection.
323TEST_F(SOCKSClientSocketTest, FailedSocketRead) {
324  MockWrite data_writes[] = {
325      MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) };
326  MockRead data_reads[] = {
327      MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2),
328      // close connection unexpectedly
329      MockRead(SYNCHRONOUS, 0) };
330  CapturingNetLog log;
331
332  user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
333                               data_writes, arraysize(data_writes),
334                               host_resolver_.get(),
335                               "localhost", 80,
336                               &log);
337
338  int rv = user_sock_->Connect(callback_.callback());
339  EXPECT_EQ(ERR_IO_PENDING, rv);
340  CapturingNetLog::CapturedEntryList entries;
341  log.GetEntries(&entries);
342  EXPECT_TRUE(LogContainsBeginEvent(
343      entries, 0, NetLog::TYPE_SOCKS_CONNECT));
344
345  rv = callback_.WaitForResult();
346  EXPECT_EQ(ERR_CONNECTION_CLOSED, rv);
347  EXPECT_FALSE(user_sock_->IsConnected());
348  log.GetEntries(&entries);
349  EXPECT_TRUE(LogContainsEndEvent(
350      entries, -1, NetLog::TYPE_SOCKS_CONNECT));
351}
352
353// Tries to connect to an unknown hostname. Should fail rather than
354// falling back to SOCKS4a.
355TEST_F(SOCKSClientSocketTest, FailedDNS) {
356  const char hostname[] = "unresolved.ipv4.address";
357
358  host_resolver_->rules()->AddSimulatedFailure(hostname);
359
360  CapturingNetLog log;
361
362  user_sock_ = BuildMockSocket(NULL, 0,
363                               NULL, 0,
364                               host_resolver_.get(),
365                               hostname, 80,
366                               &log);
367
368  int rv = user_sock_->Connect(callback_.callback());
369  EXPECT_EQ(ERR_IO_PENDING, rv);
370  CapturingNetLog::CapturedEntryList entries;
371  log.GetEntries(&entries);
372  EXPECT_TRUE(LogContainsBeginEvent(
373      entries, 0, NetLog::TYPE_SOCKS_CONNECT));
374
375  rv = callback_.WaitForResult();
376  EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv);
377  EXPECT_FALSE(user_sock_->IsConnected());
378  log.GetEntries(&entries);
379  EXPECT_TRUE(LogContainsEndEvent(
380      entries, -1, NetLog::TYPE_SOCKS_CONNECT));
381}
382
383// Calls Disconnect() while a host resolve is in progress. The outstanding host
384// resolve should be cancelled.
385TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) {
386  scoped_ptr<HangingHostResolverWithCancel> hanging_resolver(
387    new HangingHostResolverWithCancel());
388
389  // Doesn't matter what the socket data is, we will never use it -- garbage.
390  MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, "", 0) };
391  MockRead data_reads[] = { MockRead(SYNCHRONOUS, "", 0) };
392
393  user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads),
394                               data_writes, arraysize(data_writes),
395                               hanging_resolver.get(),
396                               "foo", 80,
397                               NULL);
398
399  // Start connecting (will get stuck waiting for the host to resolve).
400  int rv = user_sock_->Connect(callback_.callback());
401  EXPECT_EQ(ERR_IO_PENDING, rv);
402
403  EXPECT_FALSE(user_sock_->IsConnected());
404  EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
405
406  // The host resolver should have received the resolve request.
407  EXPECT_TRUE(hanging_resolver->HasOutstandingRequest());
408
409  // Disconnect the SOCKS socket -- this should cancel the outstanding resolve.
410  user_sock_->Disconnect();
411
412  EXPECT_FALSE(hanging_resolver->HasOutstandingRequest());
413
414  EXPECT_FALSE(user_sock_->IsConnected());
415  EXPECT_FALSE(user_sock_->IsConnectedAndIdle());
416}
417
418// Tries to connect to an IPv6 IP.  Should fail, as SOCKS4 does not support
419// IPv6.
420TEST_F(SOCKSClientSocketTest, NoIPv6) {
421  const char kHostName[] = "::1";
422
423  user_sock_ = BuildMockSocket(NULL, 0,
424                               NULL, 0,
425                               host_resolver_.get(),
426                               kHostName, 80,
427                               NULL);
428
429  EXPECT_EQ(ERR_NAME_NOT_RESOLVED,
430            callback_.GetResult(user_sock_->Connect(callback_.callback())));
431}
432
433// Same as above, but with a real resolver, to protect against regressions.
434TEST_F(SOCKSClientSocketTest, NoIPv6RealResolver) {
435  const char kHostName[] = "::1";
436
437  scoped_ptr<HostResolver> host_resolver(
438      HostResolver::CreateSystemResolver(HostResolver::Options(), NULL));
439
440  user_sock_ = BuildMockSocket(NULL, 0,
441                               NULL, 0,
442                               host_resolver.get(),
443                               kHostName, 80,
444                               NULL);
445
446  EXPECT_EQ(ERR_NAME_NOT_RESOLVED,
447            callback_.GetResult(user_sock_->Connect(callback_.callback())));
448}
449
450}  // namespace net
451