1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/socks5_client_socket.h"
6
7#include <algorithm>
8#include <iterator>
9#include <map>
10
11#include "base/sys_byteorder.h"
12#include "net/base/address_list.h"
13#include "net/base/net_log.h"
14#include "net/base/net_log_unittest.h"
15#include "net/base/test_completion_callback.h"
16#include "net/base/winsock_init.h"
17#include "net/dns/mock_host_resolver.h"
18#include "net/socket/client_socket_factory.h"
19#include "net/socket/socket_test_util.h"
20#include "net/socket/tcp_client_socket.h"
21#include "testing/gtest/include/gtest/gtest.h"
22#include "testing/platform_test.h"
23
24//-----------------------------------------------------------------------------
25
26namespace net {
27
28namespace {
29
30// Base class to test SOCKS5ClientSocket
31class SOCKS5ClientSocketTest : public PlatformTest {
32 public:
33  SOCKS5ClientSocketTest();
34  // Create a SOCKSClientSocket on top of a MockSocket.
35  SOCKS5ClientSocket* BuildMockSocket(MockRead reads[],
36                                      size_t reads_count,
37                                      MockWrite writes[],
38                                      size_t writes_count,
39                                      const std::string& hostname,
40                                      int port,
41                                      NetLog* net_log);
42
43  virtual void SetUp();
44
45 protected:
46  const uint16 kNwPort;
47  CapturingNetLog net_log_;
48  scoped_ptr<SOCKS5ClientSocket> user_sock_;
49  AddressList address_list_;
50  StreamSocket* tcp_sock_;
51  TestCompletionCallback callback_;
52  scoped_ptr<MockHostResolver> host_resolver_;
53  scoped_ptr<SocketDataProvider> data_;
54
55 private:
56  DISALLOW_COPY_AND_ASSIGN(SOCKS5ClientSocketTest);
57};
58
59SOCKS5ClientSocketTest::SOCKS5ClientSocketTest()
60  : kNwPort(base::HostToNet16(80)),
61    host_resolver_(new MockHostResolver) {
62}
63
64// Set up platform before every test case
65void SOCKS5ClientSocketTest::SetUp() {
66  PlatformTest::SetUp();
67
68  // Resolve the "localhost" AddressList used by the TCP connection to connect.
69  HostResolver::RequestInfo info(HostPortPair("www.socks-proxy.com", 1080));
70  TestCompletionCallback callback;
71  int rv = host_resolver_->Resolve(info, &address_list_, callback.callback(),
72                                   NULL, BoundNetLog());
73  ASSERT_EQ(ERR_IO_PENDING, rv);
74  rv = callback.WaitForResult();
75  ASSERT_EQ(OK, rv);
76}
77
78SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket(
79    MockRead reads[],
80    size_t reads_count,
81    MockWrite writes[],
82    size_t writes_count,
83    const std::string& hostname,
84    int port,
85    NetLog* net_log) {
86  TestCompletionCallback callback;
87  data_.reset(new StaticSocketDataProvider(reads, reads_count,
88                                           writes, writes_count));
89  tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get());
90
91  int rv = tcp_sock_->Connect(callback.callback());
92  EXPECT_EQ(ERR_IO_PENDING, rv);
93  rv = callback.WaitForResult();
94  EXPECT_EQ(OK, rv);
95  EXPECT_TRUE(tcp_sock_->IsConnected());
96
97  return new SOCKS5ClientSocket(tcp_sock_,
98      HostResolver::RequestInfo(HostPortPair(hostname, port)));
99}
100
101// Tests a complete SOCKS5 handshake and the disconnection.
102TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) {
103  const std::string payload_write = "random data";
104  const std::string payload_read = "moar random data";
105
106  const char kOkRequest[] = {
107    0x05,  // Version
108    0x01,  // Command (CONNECT)
109    0x00,  // Reserved.
110    0x03,  // Address type (DOMAINNAME).
111    0x09,  // Length of domain (9)
112    // Domain string:
113    'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't',
114    0x00, 0x50,  // 16-bit port (80)
115  };
116
117  MockWrite data_writes[] = {
118      MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
119      MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)),
120      MockWrite(ASYNC, payload_write.data(), payload_write.size()) };
121  MockRead data_reads[] = {
122      MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
123      MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength),
124      MockRead(ASYNC, payload_read.data(), payload_read.size()) };
125
126  user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
127                                   data_writes, arraysize(data_writes),
128                                   "localhost", 80, &net_log_));
129
130  // At this state the TCP connection is completed but not the SOCKS handshake.
131  EXPECT_TRUE(tcp_sock_->IsConnected());
132  EXPECT_FALSE(user_sock_->IsConnected());
133
134  int rv = user_sock_->Connect(callback_.callback());
135  EXPECT_EQ(ERR_IO_PENDING, rv);
136  EXPECT_FALSE(user_sock_->IsConnected());
137
138  CapturingNetLog::CapturedEntryList net_log_entries;
139  net_log_.GetEntries(&net_log_entries);
140  EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
141                                    NetLog::TYPE_SOCKS5_CONNECT));
142
143  rv = callback_.WaitForResult();
144
145  EXPECT_EQ(OK, rv);
146  EXPECT_TRUE(user_sock_->IsConnected());
147
148  net_log_.GetEntries(&net_log_entries);
149  EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
150                                  NetLog::TYPE_SOCKS5_CONNECT));
151
152  scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size()));
153  memcpy(buffer->data(), payload_write.data(), payload_write.size());
154  rv = user_sock_->Write(
155      buffer.get(), payload_write.size(), callback_.callback());
156  EXPECT_EQ(ERR_IO_PENDING, rv);
157  rv = callback_.WaitForResult();
158  EXPECT_EQ(static_cast<int>(payload_write.size()), rv);
159
160  buffer = new IOBuffer(payload_read.size());
161  rv =
162      user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback());
163  EXPECT_EQ(ERR_IO_PENDING, rv);
164  rv = callback_.WaitForResult();
165  EXPECT_EQ(static_cast<int>(payload_read.size()), rv);
166  EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size()));
167
168  user_sock_->Disconnect();
169  EXPECT_FALSE(tcp_sock_->IsConnected());
170  EXPECT_FALSE(user_sock_->IsConnected());
171}
172
173// Test that you can call Connect() again after having called Disconnect().
174TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) {
175  const std::string hostname = "my-host-name";
176  const char kSOCKS5DomainRequest[] = {
177      0x05,  // VER
178      0x01,  // CMD
179      0x00,  // RSV
180      0x03,  // ATYPE
181  };
182
183  std::string request(kSOCKS5DomainRequest, arraysize(kSOCKS5DomainRequest));
184  request.push_back(hostname.size());
185  request.append(hostname);
186  request.append(reinterpret_cast<const char*>(&kNwPort), sizeof(kNwPort));
187
188  for (int i = 0; i < 2; ++i) {
189    MockWrite data_writes[] = {
190        MockWrite(SYNCHRONOUS, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
191        MockWrite(SYNCHRONOUS, request.data(), request.size())
192    };
193    MockRead data_reads[] = {
194        MockRead(SYNCHRONOUS, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
195        MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength)
196    };
197
198    user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
199                                     data_writes, arraysize(data_writes),
200                                     hostname, 80, NULL));
201
202    int rv = user_sock_->Connect(callback_.callback());
203    EXPECT_EQ(OK, rv);
204    EXPECT_TRUE(user_sock_->IsConnected());
205
206    user_sock_->Disconnect();
207    EXPECT_FALSE(user_sock_->IsConnected());
208  }
209}
210
211// Test that we fail trying to connect to a hosname longer than 255 bytes.
212TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) {
213  // Create a string of length 256, where each character is 'x'.
214  std::string large_host_name;
215  std::fill_n(std::back_inserter(large_host_name), 256, 'x');
216
217  // Create a SOCKS socket, with mock transport socket.
218  MockWrite data_writes[] = {MockWrite()};
219  MockRead data_reads[] = {MockRead()};
220  user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
221                                   data_writes, arraysize(data_writes),
222                                   large_host_name, 80, NULL));
223
224  // Try to connect -- should fail (without having read/written anything to
225  // the transport socket first) because the hostname is too long.
226  TestCompletionCallback callback;
227  int rv = user_sock_->Connect(callback.callback());
228  EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv);
229}
230
231TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) {
232  const std::string hostname = "www.google.com";
233
234  const char kOkRequest[] = {
235    0x05,  // Version
236    0x01,  // Command (CONNECT)
237    0x00,  // Reserved.
238    0x03,  // Address type (DOMAINNAME).
239    0x0E,  // Length of domain (14)
240    // Domain string:
241    'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm',
242    0x00, 0x50,  // 16-bit port (80)
243  };
244
245  // Test for partial greet request write
246  {
247    const char partial1[] = { 0x05, 0x01 };
248    const char partial2[] = { 0x00 };
249    MockWrite data_writes[] = {
250        MockWrite(ASYNC, arraysize(partial1)),
251        MockWrite(ASYNC, partial2, arraysize(partial2)),
252        MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) };
253    MockRead data_reads[] = {
254        MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
255        MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
256    user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
257                                     data_writes, arraysize(data_writes),
258                                     hostname, 80, &net_log_));
259    int rv = user_sock_->Connect(callback_.callback());
260    EXPECT_EQ(ERR_IO_PENDING, rv);
261
262    CapturingNetLog::CapturedEntryList net_log_entries;
263    net_log_.GetEntries(&net_log_entries);
264    EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
265                NetLog::TYPE_SOCKS5_CONNECT));
266
267    rv = callback_.WaitForResult();
268    EXPECT_EQ(OK, rv);
269    EXPECT_TRUE(user_sock_->IsConnected());
270
271    net_log_.GetEntries(&net_log_entries);
272    EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
273                NetLog::TYPE_SOCKS5_CONNECT));
274  }
275
276  // Test for partial greet response read
277  {
278    const char partial1[] = { 0x05 };
279    const char partial2[] = { 0x00 };
280    MockWrite data_writes[] = {
281        MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
282        MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) };
283    MockRead data_reads[] = {
284        MockRead(ASYNC, partial1, arraysize(partial1)),
285        MockRead(ASYNC, partial2, arraysize(partial2)),
286        MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
287    user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
288                                     data_writes, arraysize(data_writes),
289                                     hostname, 80, &net_log_));
290    int rv = user_sock_->Connect(callback_.callback());
291    EXPECT_EQ(ERR_IO_PENDING, rv);
292
293    CapturingNetLog::CapturedEntryList net_log_entries;
294    net_log_.GetEntries(&net_log_entries);
295    EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
296                                      NetLog::TYPE_SOCKS5_CONNECT));
297    rv = callback_.WaitForResult();
298    EXPECT_EQ(OK, rv);
299    EXPECT_TRUE(user_sock_->IsConnected());
300    net_log_.GetEntries(&net_log_entries);
301    EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
302                                    NetLog::TYPE_SOCKS5_CONNECT));
303  }
304
305  // Test for partial handshake request write.
306  {
307    const int kSplitPoint = 3;  // Break handshake write into two parts.
308    MockWrite data_writes[] = {
309        MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
310        MockWrite(ASYNC, kOkRequest, kSplitPoint),
311        MockWrite(ASYNC, kOkRequest + kSplitPoint,
312                  arraysize(kOkRequest) - kSplitPoint)
313    };
314    MockRead data_reads[] = {
315        MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
316        MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) };
317    user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
318                                     data_writes, arraysize(data_writes),
319                                     hostname, 80, &net_log_));
320    int rv = user_sock_->Connect(callback_.callback());
321    EXPECT_EQ(ERR_IO_PENDING, rv);
322    CapturingNetLog::CapturedEntryList net_log_entries;
323    net_log_.GetEntries(&net_log_entries);
324    EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
325                                      NetLog::TYPE_SOCKS5_CONNECT));
326    rv = callback_.WaitForResult();
327    EXPECT_EQ(OK, rv);
328    EXPECT_TRUE(user_sock_->IsConnected());
329    net_log_.GetEntries(&net_log_entries);
330    EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
331                                    NetLog::TYPE_SOCKS5_CONNECT));
332  }
333
334  // Test for partial handshake response read
335  {
336    const int kSplitPoint = 6;  // Break the handshake read into two parts.
337    MockWrite data_writes[] = {
338        MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength),
339        MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest))
340    };
341    MockRead data_reads[] = {
342        MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength),
343        MockRead(ASYNC, kSOCKS5OkResponse, kSplitPoint),
344        MockRead(ASYNC, kSOCKS5OkResponse + kSplitPoint,
345                 kSOCKS5OkResponseLength - kSplitPoint)
346    };
347
348    user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads),
349                                     data_writes, arraysize(data_writes),
350                                     hostname, 80, &net_log_));
351    int rv = user_sock_->Connect(callback_.callback());
352    EXPECT_EQ(ERR_IO_PENDING, rv);
353    CapturingNetLog::CapturedEntryList net_log_entries;
354    net_log_.GetEntries(&net_log_entries);
355    EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0,
356                                      NetLog::TYPE_SOCKS5_CONNECT));
357    rv = callback_.WaitForResult();
358    EXPECT_EQ(OK, rv);
359    EXPECT_TRUE(user_sock_->IsConnected());
360    net_log_.GetEntries(&net_log_entries);
361    EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1,
362                                    NetLog::TYPE_SOCKS5_CONNECT));
363  }
364}
365
366}  // namespace
367
368}  // namespace net
369