1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/socket/socks_client_socket.h" 6 7#include "base/memory/scoped_ptr.h" 8#include "net/base/address_list.h" 9#include "net/base/net_log.h" 10#include "net/base/net_log_unittest.h" 11#include "net/base/test_completion_callback.h" 12#include "net/base/winsock_init.h" 13#include "net/dns/host_resolver.h" 14#include "net/dns/mock_host_resolver.h" 15#include "net/socket/client_socket_factory.h" 16#include "net/socket/socket_test_util.h" 17#include "net/socket/tcp_client_socket.h" 18#include "testing/gtest/include/gtest/gtest.h" 19#include "testing/platform_test.h" 20 21//----------------------------------------------------------------------------- 22 23namespace net { 24 25const char kSOCKSOkRequest[] = { 0x04, 0x01, 0x00, 0x50, 127, 0, 0, 1, 0 }; 26const char kSOCKSOkReply[] = { 0x00, 0x5A, 0x00, 0x00, 0, 0, 0, 0 }; 27 28class SOCKSClientSocketTest : public PlatformTest { 29 public: 30 SOCKSClientSocketTest(); 31 // Create a SOCKSClientSocket on top of a MockSocket. 32 scoped_ptr<SOCKSClientSocket> BuildMockSocket( 33 MockRead reads[], size_t reads_count, 34 MockWrite writes[], size_t writes_count, 35 HostResolver* host_resolver, 36 const std::string& hostname, int port, 37 NetLog* net_log); 38 virtual void SetUp(); 39 40 protected: 41 scoped_ptr<SOCKSClientSocket> user_sock_; 42 AddressList address_list_; 43 // Filled in by BuildMockSocket() and owned by its return value 44 // (which |user_sock| is set to). 45 StreamSocket* tcp_sock_; 46 TestCompletionCallback callback_; 47 scoped_ptr<MockHostResolver> host_resolver_; 48 scoped_ptr<SocketDataProvider> data_; 49}; 50 51SOCKSClientSocketTest::SOCKSClientSocketTest() 52 : host_resolver_(new MockHostResolver) { 53} 54 55// Set up platform before every test case 56void SOCKSClientSocketTest::SetUp() { 57 PlatformTest::SetUp(); 58} 59 60scoped_ptr<SOCKSClientSocket> SOCKSClientSocketTest::BuildMockSocket( 61 MockRead reads[], 62 size_t reads_count, 63 MockWrite writes[], 64 size_t writes_count, 65 HostResolver* host_resolver, 66 const std::string& hostname, 67 int port, 68 NetLog* net_log) { 69 70 TestCompletionCallback callback; 71 data_.reset(new StaticSocketDataProvider(reads, reads_count, 72 writes, writes_count)); 73 tcp_sock_ = new MockTCPClientSocket(address_list_, net_log, data_.get()); 74 75 int rv = tcp_sock_->Connect(callback.callback()); 76 EXPECT_EQ(ERR_IO_PENDING, rv); 77 rv = callback.WaitForResult(); 78 EXPECT_EQ(OK, rv); 79 EXPECT_TRUE(tcp_sock_->IsConnected()); 80 81 scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); 82 // |connection| takes ownership of |tcp_sock_|, but keep a 83 // non-owning pointer to it. 84 connection->SetSocket(scoped_ptr<StreamSocket>(tcp_sock_)); 85 return scoped_ptr<SOCKSClientSocket>(new SOCKSClientSocket( 86 connection.Pass(), 87 HostResolver::RequestInfo(HostPortPair(hostname, port)), 88 DEFAULT_PRIORITY, 89 host_resolver)); 90} 91 92// Implementation of HostResolver that never completes its resolve request. 93// We use this in the test "DisconnectWhileHostResolveInProgress" to make 94// sure that the outstanding resolve request gets cancelled. 95class HangingHostResolverWithCancel : public HostResolver { 96 public: 97 HangingHostResolverWithCancel() : outstanding_request_(NULL) {} 98 99 virtual int Resolve(const RequestInfo& info, 100 RequestPriority priority, 101 AddressList* addresses, 102 const CompletionCallback& callback, 103 RequestHandle* out_req, 104 const BoundNetLog& net_log) OVERRIDE { 105 DCHECK(addresses); 106 DCHECK_EQ(false, callback.is_null()); 107 EXPECT_FALSE(HasOutstandingRequest()); 108 outstanding_request_ = reinterpret_cast<RequestHandle>(1); 109 *out_req = outstanding_request_; 110 return ERR_IO_PENDING; 111 } 112 113 virtual int ResolveFromCache(const RequestInfo& info, 114 AddressList* addresses, 115 const BoundNetLog& net_log) OVERRIDE { 116 NOTIMPLEMENTED(); 117 return ERR_UNEXPECTED; 118 } 119 120 virtual void CancelRequest(RequestHandle req) OVERRIDE { 121 EXPECT_TRUE(HasOutstandingRequest()); 122 EXPECT_EQ(outstanding_request_, req); 123 outstanding_request_ = NULL; 124 } 125 126 bool HasOutstandingRequest() { 127 return outstanding_request_ != NULL; 128 } 129 130 private: 131 RequestHandle outstanding_request_; 132 133 DISALLOW_COPY_AND_ASSIGN(HangingHostResolverWithCancel); 134}; 135 136// Tests a complete handshake and the disconnection. 137TEST_F(SOCKSClientSocketTest, CompleteHandshake) { 138 const std::string payload_write = "random data"; 139 const std::string payload_read = "moar random data"; 140 141 MockWrite data_writes[] = { 142 MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)), 143 MockWrite(ASYNC, payload_write.data(), payload_write.size()) }; 144 MockRead data_reads[] = { 145 MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)), 146 MockRead(ASYNC, payload_read.data(), payload_read.size()) }; 147 CapturingNetLog log; 148 149 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 150 data_writes, arraysize(data_writes), 151 host_resolver_.get(), 152 "localhost", 80, 153 &log); 154 155 // At this state the TCP connection is completed but not the SOCKS handshake. 156 EXPECT_TRUE(tcp_sock_->IsConnected()); 157 EXPECT_FALSE(user_sock_->IsConnected()); 158 159 int rv = user_sock_->Connect(callback_.callback()); 160 EXPECT_EQ(ERR_IO_PENDING, rv); 161 162 CapturingNetLog::CapturedEntryList entries; 163 log.GetEntries(&entries); 164 EXPECT_TRUE( 165 LogContainsBeginEvent(entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 166 EXPECT_FALSE(user_sock_->IsConnected()); 167 168 rv = callback_.WaitForResult(); 169 EXPECT_EQ(OK, rv); 170 EXPECT_TRUE(user_sock_->IsConnected()); 171 log.GetEntries(&entries); 172 EXPECT_TRUE(LogContainsEndEvent( 173 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 174 175 scoped_refptr<IOBuffer> buffer(new IOBuffer(payload_write.size())); 176 memcpy(buffer->data(), payload_write.data(), payload_write.size()); 177 rv = user_sock_->Write( 178 buffer.get(), payload_write.size(), callback_.callback()); 179 EXPECT_EQ(ERR_IO_PENDING, rv); 180 rv = callback_.WaitForResult(); 181 EXPECT_EQ(static_cast<int>(payload_write.size()), rv); 182 183 buffer = new IOBuffer(payload_read.size()); 184 rv = 185 user_sock_->Read(buffer.get(), payload_read.size(), callback_.callback()); 186 EXPECT_EQ(ERR_IO_PENDING, rv); 187 rv = callback_.WaitForResult(); 188 EXPECT_EQ(static_cast<int>(payload_read.size()), rv); 189 EXPECT_EQ(payload_read, std::string(buffer->data(), payload_read.size())); 190 191 user_sock_->Disconnect(); 192 EXPECT_FALSE(tcp_sock_->IsConnected()); 193 EXPECT_FALSE(user_sock_->IsConnected()); 194} 195 196// List of responses from the socks server and the errors they should 197// throw up are tested here. 198TEST_F(SOCKSClientSocketTest, HandshakeFailures) { 199 const struct { 200 const char fail_reply[8]; 201 Error fail_code; 202 } tests[] = { 203 // Failure of the server response code 204 { 205 { 0x01, 0x5A, 0x00, 0x00, 0, 0, 0, 0 }, 206 ERR_SOCKS_CONNECTION_FAILED, 207 }, 208 // Failure of the null byte 209 { 210 { 0x00, 0x5B, 0x00, 0x00, 0, 0, 0, 0 }, 211 ERR_SOCKS_CONNECTION_FAILED, 212 }, 213 }; 214 215 //--------------------------------------- 216 217 for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); ++i) { 218 MockWrite data_writes[] = { 219 MockWrite(SYNCHRONOUS, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; 220 MockRead data_reads[] = { 221 MockRead(SYNCHRONOUS, tests[i].fail_reply, 222 arraysize(tests[i].fail_reply)) }; 223 CapturingNetLog log; 224 225 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 226 data_writes, arraysize(data_writes), 227 host_resolver_.get(), 228 "localhost", 80, 229 &log); 230 231 int rv = user_sock_->Connect(callback_.callback()); 232 EXPECT_EQ(ERR_IO_PENDING, rv); 233 234 CapturingNetLog::CapturedEntryList entries; 235 log.GetEntries(&entries); 236 EXPECT_TRUE(LogContainsBeginEvent( 237 entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 238 239 rv = callback_.WaitForResult(); 240 EXPECT_EQ(tests[i].fail_code, rv); 241 EXPECT_FALSE(user_sock_->IsConnected()); 242 EXPECT_TRUE(tcp_sock_->IsConnected()); 243 log.GetEntries(&entries); 244 EXPECT_TRUE(LogContainsEndEvent( 245 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 246 } 247} 248 249// Tests scenario when the server sends the handshake response in 250// more than one packet. 251TEST_F(SOCKSClientSocketTest, PartialServerReads) { 252 const char kSOCKSPartialReply1[] = { 0x00 }; 253 const char kSOCKSPartialReply2[] = { 0x5A, 0x00, 0x00, 0, 0, 0, 0 }; 254 255 MockWrite data_writes[] = { 256 MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; 257 MockRead data_reads[] = { 258 MockRead(ASYNC, kSOCKSPartialReply1, arraysize(kSOCKSPartialReply1)), 259 MockRead(ASYNC, kSOCKSPartialReply2, arraysize(kSOCKSPartialReply2)) }; 260 CapturingNetLog log; 261 262 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 263 data_writes, arraysize(data_writes), 264 host_resolver_.get(), 265 "localhost", 80, 266 &log); 267 268 int rv = user_sock_->Connect(callback_.callback()); 269 EXPECT_EQ(ERR_IO_PENDING, rv); 270 CapturingNetLog::CapturedEntryList entries; 271 log.GetEntries(&entries); 272 EXPECT_TRUE(LogContainsBeginEvent( 273 entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 274 275 rv = callback_.WaitForResult(); 276 EXPECT_EQ(OK, rv); 277 EXPECT_TRUE(user_sock_->IsConnected()); 278 log.GetEntries(&entries); 279 EXPECT_TRUE(LogContainsEndEvent( 280 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 281} 282 283// Tests scenario when the client sends the handshake request in 284// more than one packet. 285TEST_F(SOCKSClientSocketTest, PartialClientWrites) { 286 const char kSOCKSPartialRequest1[] = { 0x04, 0x01 }; 287 const char kSOCKSPartialRequest2[] = { 0x00, 0x50, 127, 0, 0, 1, 0 }; 288 289 MockWrite data_writes[] = { 290 MockWrite(ASYNC, arraysize(kSOCKSPartialRequest1)), 291 // simulate some empty writes 292 MockWrite(ASYNC, 0), 293 MockWrite(ASYNC, 0), 294 MockWrite(ASYNC, kSOCKSPartialRequest2, 295 arraysize(kSOCKSPartialRequest2)) }; 296 MockRead data_reads[] = { 297 MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply)) }; 298 CapturingNetLog log; 299 300 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 301 data_writes, arraysize(data_writes), 302 host_resolver_.get(), 303 "localhost", 80, 304 &log); 305 306 int rv = user_sock_->Connect(callback_.callback()); 307 EXPECT_EQ(ERR_IO_PENDING, rv); 308 CapturingNetLog::CapturedEntryList entries; 309 log.GetEntries(&entries); 310 EXPECT_TRUE(LogContainsBeginEvent( 311 entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 312 313 rv = callback_.WaitForResult(); 314 EXPECT_EQ(OK, rv); 315 EXPECT_TRUE(user_sock_->IsConnected()); 316 log.GetEntries(&entries); 317 EXPECT_TRUE(LogContainsEndEvent( 318 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 319} 320 321// Tests the case when the server sends a smaller sized handshake data 322// and closes the connection. 323TEST_F(SOCKSClientSocketTest, FailedSocketRead) { 324 MockWrite data_writes[] = { 325 MockWrite(ASYNC, kSOCKSOkRequest, arraysize(kSOCKSOkRequest)) }; 326 MockRead data_reads[] = { 327 MockRead(ASYNC, kSOCKSOkReply, arraysize(kSOCKSOkReply) - 2), 328 // close connection unexpectedly 329 MockRead(SYNCHRONOUS, 0) }; 330 CapturingNetLog log; 331 332 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 333 data_writes, arraysize(data_writes), 334 host_resolver_.get(), 335 "localhost", 80, 336 &log); 337 338 int rv = user_sock_->Connect(callback_.callback()); 339 EXPECT_EQ(ERR_IO_PENDING, rv); 340 CapturingNetLog::CapturedEntryList entries; 341 log.GetEntries(&entries); 342 EXPECT_TRUE(LogContainsBeginEvent( 343 entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 344 345 rv = callback_.WaitForResult(); 346 EXPECT_EQ(ERR_CONNECTION_CLOSED, rv); 347 EXPECT_FALSE(user_sock_->IsConnected()); 348 log.GetEntries(&entries); 349 EXPECT_TRUE(LogContainsEndEvent( 350 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 351} 352 353// Tries to connect to an unknown hostname. Should fail rather than 354// falling back to SOCKS4a. 355TEST_F(SOCKSClientSocketTest, FailedDNS) { 356 const char hostname[] = "unresolved.ipv4.address"; 357 358 host_resolver_->rules()->AddSimulatedFailure(hostname); 359 360 CapturingNetLog log; 361 362 user_sock_ = BuildMockSocket(NULL, 0, 363 NULL, 0, 364 host_resolver_.get(), 365 hostname, 80, 366 &log); 367 368 int rv = user_sock_->Connect(callback_.callback()); 369 EXPECT_EQ(ERR_IO_PENDING, rv); 370 CapturingNetLog::CapturedEntryList entries; 371 log.GetEntries(&entries); 372 EXPECT_TRUE(LogContainsBeginEvent( 373 entries, 0, NetLog::TYPE_SOCKS_CONNECT)); 374 375 rv = callback_.WaitForResult(); 376 EXPECT_EQ(ERR_NAME_NOT_RESOLVED, rv); 377 EXPECT_FALSE(user_sock_->IsConnected()); 378 log.GetEntries(&entries); 379 EXPECT_TRUE(LogContainsEndEvent( 380 entries, -1, NetLog::TYPE_SOCKS_CONNECT)); 381} 382 383// Calls Disconnect() while a host resolve is in progress. The outstanding host 384// resolve should be cancelled. 385TEST_F(SOCKSClientSocketTest, DisconnectWhileHostResolveInProgress) { 386 scoped_ptr<HangingHostResolverWithCancel> hanging_resolver( 387 new HangingHostResolverWithCancel()); 388 389 // Doesn't matter what the socket data is, we will never use it -- garbage. 390 MockWrite data_writes[] = { MockWrite(SYNCHRONOUS, "", 0) }; 391 MockRead data_reads[] = { MockRead(SYNCHRONOUS, "", 0) }; 392 393 user_sock_ = BuildMockSocket(data_reads, arraysize(data_reads), 394 data_writes, arraysize(data_writes), 395 hanging_resolver.get(), 396 "foo", 80, 397 NULL); 398 399 // Start connecting (will get stuck waiting for the host to resolve). 400 int rv = user_sock_->Connect(callback_.callback()); 401 EXPECT_EQ(ERR_IO_PENDING, rv); 402 403 EXPECT_FALSE(user_sock_->IsConnected()); 404 EXPECT_FALSE(user_sock_->IsConnectedAndIdle()); 405 406 // The host resolver should have received the resolve request. 407 EXPECT_TRUE(hanging_resolver->HasOutstandingRequest()); 408 409 // Disconnect the SOCKS socket -- this should cancel the outstanding resolve. 410 user_sock_->Disconnect(); 411 412 EXPECT_FALSE(hanging_resolver->HasOutstandingRequest()); 413 414 EXPECT_FALSE(user_sock_->IsConnected()); 415 EXPECT_FALSE(user_sock_->IsConnectedAndIdle()); 416} 417 418// Tries to connect to an IPv6 IP. Should fail, as SOCKS4 does not support 419// IPv6. 420TEST_F(SOCKSClientSocketTest, NoIPv6) { 421 const char kHostName[] = "::1"; 422 423 user_sock_ = BuildMockSocket(NULL, 0, 424 NULL, 0, 425 host_resolver_.get(), 426 kHostName, 80, 427 NULL); 428 429 EXPECT_EQ(ERR_NAME_NOT_RESOLVED, 430 callback_.GetResult(user_sock_->Connect(callback_.callback()))); 431} 432 433// Same as above, but with a real resolver, to protect against regressions. 434TEST_F(SOCKSClientSocketTest, NoIPv6RealResolver) { 435 const char kHostName[] = "::1"; 436 437 scoped_ptr<HostResolver> host_resolver( 438 HostResolver::CreateSystemResolver(HostResolver::Options(), NULL)); 439 440 user_sock_ = BuildMockSocket(NULL, 0, 441 NULL, 0, 442 host_resolver.get(), 443 kHostName, 80, 444 NULL); 445 446 EXPECT_EQ(ERR_NAME_NOT_RESOLVED, 447 callback_.GetResult(user_sock_->Connect(callback_.callback()))); 448} 449 450} // namespace net 451