1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/socket/socks5_client_socket.h" 6 7#include <algorithm> 8#include <iterator> 9#include <map> 10 11#include "base/sys_byteorder.h" 12#include "net/base/address_list.h" 13#include "net/base/net_log.h" 14#include "net/base/net_log_unittest.h" 15#include "net/base/test_completion_callback.h" 16#include "net/base/winsock_init.h" 17#include "net/dns/mock_host_resolver.h" 18#include "net/socket/client_socket_factory.h" 19#include "net/socket/socket_test_util.h" 20#include "net/socket/tcp_client_socket.h" 21#include "testing/gtest/include/gtest/gtest.h" 22#include "testing/platform_test.h" 23 24//----------------------------------------------------------------------------- 25 26namespace net { 27 28namespace { 29 30// Base class to test SOCKS5ClientSocket 31class SOCKS5ClientSocketTest : public PlatformTest { 32 public: 33 SOCKS5ClientSocketTest(); 34 // Create a SOCKSClientSocket on top of a MockSocket. 35 SOCKS5ClientSocket* BuildMockSocket(MockRead reads[], 36 size_t reads_count, 37 MockWrite writes[], 38 size_t writes_count, 39 const std::string& hostname, 40 int port, 41 NetLog* net_log); 42 43 virtual void SetUp(); 44 45 protected: 46 const uint16 kNwPort; 47 CapturingNetLog net_log_; 48 scoped_ptr<SOCKS5ClientSocket> user_sock_; 49 AddressList address_list_; 50 StreamSocket* tcp_sock_; 51 TestCompletionCallback callback_; 52 scoped_ptr<MockHostResolver> host_resolver_; 53 scoped_ptr<SocketDataProvider> data_; 54 55 private: 56 DISALLOW_COPY_AND_ASSIGN(SOCKS5ClientSocketTest); 57}; 58 59SOCKS5ClientSocketTest::SOCKS5ClientSocketTest() 60 : kNwPort(base::HostToNet16(80)), 61 host_resolver_(new MockHostResolver) { 62} 63 64// Set up platform before every test case 65void SOCKS5ClientSocketTest::SetUp() { 66 PlatformTest::SetUp(); 67 68 // Resolve the "localhost" AddressList used by the TCP connection to connect. 69 HostResolver::RequestInfo info(HostPortPair("www.socks-proxy.com", 1080)); 70 TestCompletionCallback callback; 71 int rv = host_resolver_->Resolve(info, &address_list_, callback.callback(), 72 NULL, BoundNetLog()); 73 ASSERT_EQ(ERR_IO_PENDING, rv); 74 rv = callback.WaitForResult(); 75 ASSERT_EQ(OK, rv); 76} 77 78SOCKS5ClientSocket* SOCKS5ClientSocketTest::BuildMockSocket( 79 MockRead reads[], 80 size_t reads_count, 81 MockWrite writes[], 82 size_t writes_count, 83 const std::string& hostname, 84 int port, 85 NetLog* net_log) { 86 TestCompletionCallback callback; 87 data_.reset(new StaticSocketDataProvider(reads, reads_count, 88 writes, writes_count)); 89 tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get()); 90 91 int rv = tcp_sock_->Connect(callback.callback()); 92 EXPECT_EQ(ERR_IO_PENDING, rv); 93 rv = callback.WaitForResult(); 94 EXPECT_EQ(OK, rv); 95 EXPECT_TRUE(tcp_sock_->IsConnected()); 96 97 return new SOCKS5ClientSocket(tcp_sock_, 98 HostResolver::RequestInfo(HostPortPair(hostname, port))); 99} 100 101// Tests a complete SOCKS5 handshake and the disconnection. 102TEST_F(SOCKS5ClientSocketTest, CompleteHandshake) { 103 const std::string payload_write = "random data"; 104 const std::string payload_read = "moar random data"; 105 106 const char kOkRequest[] = { 107 0x05, // Version 108 0x01, // Command (CONNECT) 109 0x00, // Reserved. 110 0x03, // Address type (DOMAINNAME). 111 0x09, // Length of domain (9) 112 // Domain string: 113 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't', 114 0x00, 0x50, // 16-bit port (80) 115 }; 116 117 MockWrite data_writes[] = { 118 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), 119 MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)), 120 MockWrite(ASYNC, payload_write.data(), payload_write.size()) }; 121 MockRead data_reads[] = { 122 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), 123 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength), 124 MockRead(ASYNC, payload_read.data(), payload_read.size()) }; 125 126 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 127 data_writes, arraysize(data_writes), 128 "localhost", 80, &net_log_)); 129 130 // At this state the TCP connection is completed but not the SOCKS handshake. 131 EXPECT_TRUE(tcp_sock_->IsConnected()); 132 EXPECT_FALSE(user_sock_->IsConnected()); 133 134 int rv = user_sock_->Connect(callback_.callback()); 135 EXPECT_EQ(ERR_IO_PENDING, rv); 136 EXPECT_FALSE(user_sock_->IsConnected()); 137 138 CapturingNetLog::CapturedEntryList net_log_entries; 139 net_log_.GetEntries(&net_log_entries); 140 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, 141 NetLog::TYPE_SOCKS5_CONNECT)); 142 143 rv = callback_.WaitForResult(); 144 145 EXPECT_EQ(OK, rv); 146 EXPECT_TRUE(user_sock_->IsConnected()); 147 148 net_log_.GetEntries(&net_log_entries); 149 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, 150 NetLog::TYPE_SOCKS5_CONNECT)); 151 152 scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size())); 153 memcpy(buffer->data(), payload_write.data(), payload_write.size()); 154 rv = user_sock_->Write( 155 buffer.get(), payload_write.size(), callback_.callback()); 156 EXPECT_EQ(ERR_IO_PENDING, rv); 157 rv = callback_.WaitForResult(); 158 EXPECT_EQ(static_cast<int>(payload_write.size()), rv); 159 160 buffer = new IOBuffer(payload_read.size()); 161 rv = 162 user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback()); 163 EXPECT_EQ(ERR_IO_PENDING, rv); 164 rv = callback_.WaitForResult(); 165 EXPECT_EQ(static_cast<int>(payload_read.size()), rv); 166 EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size())); 167 168 user_sock_->Disconnect(); 169 EXPECT_FALSE(tcp_sock_->IsConnected()); 170 EXPECT_FALSE(user_sock_->IsConnected()); 171} 172 173// Test that you can call Connect() again after having called Disconnect(). 174TEST_F(SOCKS5ClientSocketTest, ConnectAndDisconnectTwice) { 175 const std::string hostname = "my-host-name"; 176 const char kSOCKS5DomainRequest[] = { 177 0x05, // VER 178 0x01, // CMD 179 0x00, // RSV 180 0x03, // ATYPE 181 }; 182 183 std::string request(kSOCKS5DomainRequest, arraysize(kSOCKS5DomainRequest)); 184 request.push_back(hostname.size()); 185 request.append(hostname); 186 request.append(reinterpret_cast<const char*>(&kNwPort), sizeof(kNwPort)); 187 188 for (int i = 0; i < 2; ++i) { 189 MockWrite data_writes[] = { 190 MockWrite(SYNCHRONOUS, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), 191 MockWrite(SYNCHRONOUS, request.data(), request.size()) 192 }; 193 MockRead data_reads[] = { 194 MockRead(SYNCHRONOUS, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), 195 MockRead(SYNCHRONOUS, kSOCKS5OkResponse, kSOCKS5OkResponseLength) 196 }; 197 198 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 199 data_writes, arraysize(data_writes), 200 hostname, 80, NULL)); 201 202 int rv = user_sock_->Connect(callback_.callback()); 203 EXPECT_EQ(OK, rv); 204 EXPECT_TRUE(user_sock_->IsConnected()); 205 206 user_sock_->Disconnect(); 207 EXPECT_FALSE(user_sock_->IsConnected()); 208 } 209} 210 211// Test that we fail trying to connect to a hosname longer than 255 bytes. 212TEST_F(SOCKS5ClientSocketTest, LargeHostNameFails) { 213 // Create a string of length 256, where each character is 'x'. 214 std::string large_host_name; 215 std::fill_n(std::back_inserter(large_host_name), 256, 'x'); 216 217 // Create a SOCKS socket, with mock transport socket. 218 MockWrite data_writes[] = {MockWrite()}; 219 MockRead data_reads[] = {MockRead()}; 220 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 221 data_writes, arraysize(data_writes), 222 large_host_name, 80, NULL)); 223 224 // Try to connect -- should fail (without having read/written anything to 225 // the transport socket first) because the hostname is too long. 226 TestCompletionCallback callback; 227 int rv = user_sock_->Connect(callback.callback()); 228 EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv); 229} 230 231TEST_F(SOCKS5ClientSocketTest, PartialReadWrites) { 232 const std::string hostname = "www.google.com"; 233 234 const char kOkRequest[] = { 235 0x05, // Version 236 0x01, // Command (CONNECT) 237 0x00, // Reserved. 238 0x03, // Address type (DOMAINNAME). 239 0x0E, // Length of domain (14) 240 // Domain string: 241 'w', 'w', 'w', '.', 'g', 'o', 'o', 'g', 'l', 'e', '.', 'c', 'o', 'm', 242 0x00, 0x50, // 16-bit port (80) 243 }; 244 245 // Test for partial greet request write 246 { 247 const char partial1[] = { 0x05, 0x01 }; 248 const char partial2[] = { 0x00 }; 249 MockWrite data_writes[] = { 250 MockWrite(ASYNC, arraysize(partial1)), 251 MockWrite(ASYNC, partial2, arraysize(partial2)), 252 MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) }; 253 MockRead data_reads[] = { 254 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), 255 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; 256 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 257 data_writes, arraysize(data_writes), 258 hostname, 80, &net_log_)); 259 int rv = user_sock_->Connect(callback_.callback()); 260 EXPECT_EQ(ERR_IO_PENDING, rv); 261 262 CapturingNetLog::CapturedEntryList net_log_entries; 263 net_log_.GetEntries(&net_log_entries); 264 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, 265 NetLog::TYPE_SOCKS5_CONNECT)); 266 267 rv = callback_.WaitForResult(); 268 EXPECT_EQ(OK, rv); 269 EXPECT_TRUE(user_sock_->IsConnected()); 270 271 net_log_.GetEntries(&net_log_entries); 272 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, 273 NetLog::TYPE_SOCKS5_CONNECT)); 274 } 275 276 // Test for partial greet response read 277 { 278 const char partial1[] = { 0x05 }; 279 const char partial2[] = { 0x00 }; 280 MockWrite data_writes[] = { 281 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), 282 MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) }; 283 MockRead data_reads[] = { 284 MockRead(ASYNC, partial1, arraysize(partial1)), 285 MockRead(ASYNC, partial2, arraysize(partial2)), 286 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; 287 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 288 data_writes, arraysize(data_writes), 289 hostname, 80, &net_log_)); 290 int rv = user_sock_->Connect(callback_.callback()); 291 EXPECT_EQ(ERR_IO_PENDING, rv); 292 293 CapturingNetLog::CapturedEntryList net_log_entries; 294 net_log_.GetEntries(&net_log_entries); 295 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, 296 NetLog::TYPE_SOCKS5_CONNECT)); 297 rv = callback_.WaitForResult(); 298 EXPECT_EQ(OK, rv); 299 EXPECT_TRUE(user_sock_->IsConnected()); 300 net_log_.GetEntries(&net_log_entries); 301 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, 302 NetLog::TYPE_SOCKS5_CONNECT)); 303 } 304 305 // Test for partial handshake request write. 306 { 307 const int kSplitPoint = 3; // Break handshake write into two parts. 308 MockWrite data_writes[] = { 309 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), 310 MockWrite(ASYNC, kOkRequest, kSplitPoint), 311 MockWrite(ASYNC, kOkRequest + kSplitPoint, 312 arraysize(kOkRequest) - kSplitPoint) 313 }; 314 MockRead data_reads[] = { 315 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), 316 MockRead(ASYNC, kSOCKS5OkResponse, kSOCKS5OkResponseLength) }; 317 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 318 data_writes, arraysize(data_writes), 319 hostname, 80, &net_log_)); 320 int rv = user_sock_->Connect(callback_.callback()); 321 EXPECT_EQ(ERR_IO_PENDING, rv); 322 CapturingNetLog::CapturedEntryList net_log_entries; 323 net_log_.GetEntries(&net_log_entries); 324 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, 325 NetLog::TYPE_SOCKS5_CONNECT)); 326 rv = callback_.WaitForResult(); 327 EXPECT_EQ(OK, rv); 328 EXPECT_TRUE(user_sock_->IsConnected()); 329 net_log_.GetEntries(&net_log_entries); 330 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, 331 NetLog::TYPE_SOCKS5_CONNECT)); 332 } 333 334 // Test for partial handshake response read 335 { 336 const int kSplitPoint = 6; // Break the handshake read into two parts. 337 MockWrite data_writes[] = { 338 MockWrite(ASYNC, kSOCKS5GreetRequest, kSOCKS5GreetRequestLength), 339 MockWrite(ASYNC, kOkRequest, arraysize(kOkRequest)) 340 }; 341 MockRead data_reads[] = { 342 MockRead(ASYNC, kSOCKS5GreetResponse, kSOCKS5GreetResponseLength), 343 MockRead(ASYNC, kSOCKS5OkResponse, kSplitPoint), 344 MockRead(ASYNC, kSOCKS5OkResponse + kSplitPoint, 345 kSOCKS5OkResponseLength - kSplitPoint) 346 }; 347 348 user_sock_.reset(BuildMockSocket(data_reads, arraysize(data_reads), 349 data_writes, arraysize(data_writes), 350 hostname, 80, &net_log_)); 351 int rv = user_sock_->Connect(callback_.callback()); 352 EXPECT_EQ(ERR_IO_PENDING, rv); 353 CapturingNetLog::CapturedEntryList net_log_entries; 354 net_log_.GetEntries(&net_log_entries); 355 EXPECT_TRUE(LogContainsBeginEvent(net_log_entries, 0, 356 NetLog::TYPE_SOCKS5_CONNECT)); 357 rv = callback_.WaitForResult(); 358 EXPECT_EQ(OK, rv); 359 EXPECT_TRUE(user_sock_->IsConnected()); 360 net_log_.GetEntries(&net_log_entries); 361 EXPECT_TRUE(LogContainsEndEvent(net_log_entries, -1, 362 NetLog::TYPE_SOCKS5_CONNECT)); 363 } 364} 365 366} // namespace 367 368} // namespace net 369