1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/socket/socks_client_socket.h" 6 7#include "base/basictypes.h" 8#include "base/bind.h" 9#include "base/callback_helpers.h" 10#include "base/compiler_specific.h" 11#include "base/sys_byteorder.h" 12#include "net/base/io_buffer.h" 13#include "net/base/net_log.h" 14#include "net/base/net_util.h" 15#include "net/socket/client_socket_handle.h" 16 17namespace net { 18 19// Every SOCKS server requests a user-id from the client. It is optional 20// and we send an empty string. 21static const char kEmptyUserId[] = ""; 22 23// For SOCKS4, the client sends 8 bytes plus the size of the user-id. 24static const unsigned int kWriteHeaderSize = 8; 25 26// For SOCKS4 the server sends 8 bytes for acknowledgement. 27static const unsigned int kReadHeaderSize = 8; 28 29// Server Response codes for SOCKS. 30static const uint8 kServerResponseOk = 0x5A; 31static const uint8 kServerResponseRejected = 0x5B; 32static const uint8 kServerResponseNotReachable = 0x5C; 33static const uint8 kServerResponseMismatchedUserId = 0x5D; 34 35static const uint8 kSOCKSVersion4 = 0x04; 36static const uint8 kSOCKSStreamRequest = 0x01; 37 38// A struct holding the essential details of the SOCKS4 Server Request. 39// The port in the header is stored in network byte order. 40struct SOCKS4ServerRequest { 41 uint8 version; 42 uint8 command; 43 uint16 nw_port; 44 uint8 ip[4]; 45}; 46COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize, 47 socks4_server_request_struct_wrong_size); 48 49// A struct holding details of the SOCKS4 Server Response. 50struct SOCKS4ServerResponse { 51 uint8 reserved_null; 52 uint8 code; 53 uint16 port; 54 uint8 ip[4]; 55}; 56COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize, 57 socks4_server_response_struct_wrong_size); 58 59SOCKSClientSocket::SOCKSClientSocket( 60 scoped_ptr<ClientSocketHandle> transport_socket, 61 const HostResolver::RequestInfo& req_info, 62 RequestPriority priority, 63 HostResolver* host_resolver) 64 : transport_(transport_socket.Pass()), 65 next_state_(STATE_NONE), 66 completed_handshake_(false), 67 bytes_sent_(0), 68 bytes_received_(0), 69 was_ever_used_(false), 70 host_resolver_(host_resolver), 71 host_request_info_(req_info), 72 priority_(priority), 73 net_log_(transport_->socket()->NetLog()) {} 74 75SOCKSClientSocket::~SOCKSClientSocket() { 76 Disconnect(); 77} 78 79int SOCKSClientSocket::Connect(const CompletionCallback& callback) { 80 DCHECK(transport_.get()); 81 DCHECK(transport_->socket()); 82 DCHECK_EQ(STATE_NONE, next_state_); 83 DCHECK(user_callback_.is_null()); 84 85 // If already connected, then just return OK. 86 if (completed_handshake_) 87 return OK; 88 89 next_state_ = STATE_RESOLVE_HOST; 90 91 net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT); 92 93 int rv = DoLoop(OK); 94 if (rv == ERR_IO_PENDING) { 95 user_callback_ = callback; 96 } else { 97 net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); 98 } 99 return rv; 100} 101 102void SOCKSClientSocket::Disconnect() { 103 completed_handshake_ = false; 104 host_resolver_.Cancel(); 105 transport_->socket()->Disconnect(); 106 107 // Reset other states to make sure they aren't mistakenly used later. 108 // These are the states initialized by Connect(). 109 next_state_ = STATE_NONE; 110 user_callback_.Reset(); 111} 112 113bool SOCKSClientSocket::IsConnected() const { 114 return completed_handshake_ && transport_->socket()->IsConnected(); 115} 116 117bool SOCKSClientSocket::IsConnectedAndIdle() const { 118 return completed_handshake_ && transport_->socket()->IsConnectedAndIdle(); 119} 120 121const BoundNetLog& SOCKSClientSocket::NetLog() const { 122 return net_log_; 123} 124 125void SOCKSClientSocket::SetSubresourceSpeculation() { 126 if (transport_.get() && transport_->socket()) { 127 transport_->socket()->SetSubresourceSpeculation(); 128 } else { 129 NOTREACHED(); 130 } 131} 132 133void SOCKSClientSocket::SetOmniboxSpeculation() { 134 if (transport_.get() && transport_->socket()) { 135 transport_->socket()->SetOmniboxSpeculation(); 136 } else { 137 NOTREACHED(); 138 } 139} 140 141bool SOCKSClientSocket::WasEverUsed() const { 142 return was_ever_used_; 143} 144 145bool SOCKSClientSocket::UsingTCPFastOpen() const { 146 if (transport_.get() && transport_->socket()) { 147 return transport_->socket()->UsingTCPFastOpen(); 148 } 149 NOTREACHED(); 150 return false; 151} 152 153bool SOCKSClientSocket::WasNpnNegotiated() const { 154 if (transport_.get() && transport_->socket()) { 155 return transport_->socket()->WasNpnNegotiated(); 156 } 157 NOTREACHED(); 158 return false; 159} 160 161NextProto SOCKSClientSocket::GetNegotiatedProtocol() const { 162 if (transport_.get() && transport_->socket()) { 163 return transport_->socket()->GetNegotiatedProtocol(); 164 } 165 NOTREACHED(); 166 return kProtoUnknown; 167} 168 169bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) { 170 if (transport_.get() && transport_->socket()) { 171 return transport_->socket()->GetSSLInfo(ssl_info); 172 } 173 NOTREACHED(); 174 return false; 175 176} 177 178// Read is called by the transport layer above to read. This can only be done 179// if the SOCKS handshake is complete. 180int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len, 181 const CompletionCallback& callback) { 182 DCHECK(completed_handshake_); 183 DCHECK_EQ(STATE_NONE, next_state_); 184 DCHECK(user_callback_.is_null()); 185 DCHECK(!callback.is_null()); 186 187 int rv = transport_->socket()->Read( 188 buf, buf_len, 189 base::Bind(&SOCKSClientSocket::OnReadWriteComplete, 190 base::Unretained(this), callback)); 191 if (rv > 0) 192 was_ever_used_ = true; 193 return rv; 194} 195 196// Write is called by the transport layer. This can only be done if the 197// SOCKS handshake is complete. 198int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len, 199 const CompletionCallback& callback) { 200 DCHECK(completed_handshake_); 201 DCHECK_EQ(STATE_NONE, next_state_); 202 DCHECK(user_callback_.is_null()); 203 DCHECK(!callback.is_null()); 204 205 int rv = transport_->socket()->Write( 206 buf, buf_len, 207 base::Bind(&SOCKSClientSocket::OnReadWriteComplete, 208 base::Unretained(this), callback)); 209 if (rv > 0) 210 was_ever_used_ = true; 211 return rv; 212} 213 214int SOCKSClientSocket::SetReceiveBufferSize(int32 size) { 215 return transport_->socket()->SetReceiveBufferSize(size); 216} 217 218int SOCKSClientSocket::SetSendBufferSize(int32 size) { 219 return transport_->socket()->SetSendBufferSize(size); 220} 221 222void SOCKSClientSocket::DoCallback(int result) { 223 DCHECK_NE(ERR_IO_PENDING, result); 224 DCHECK(!user_callback_.is_null()); 225 226 // Since Run() may result in Read being called, 227 // clear user_callback_ up front. 228 DVLOG(1) << "Finished setting up SOCKS handshake"; 229 base::ResetAndReturn(&user_callback_).Run(result); 230} 231 232void SOCKSClientSocket::OnIOComplete(int result) { 233 DCHECK_NE(STATE_NONE, next_state_); 234 int rv = DoLoop(result); 235 if (rv != ERR_IO_PENDING) { 236 net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); 237 DoCallback(rv); 238 } 239} 240 241void SOCKSClientSocket::OnReadWriteComplete(const CompletionCallback& callback, 242 int result) { 243 DCHECK_NE(ERR_IO_PENDING, result); 244 DCHECK(!callback.is_null()); 245 246 if (result > 0) 247 was_ever_used_ = true; 248 callback.Run(result); 249} 250 251int SOCKSClientSocket::DoLoop(int last_io_result) { 252 DCHECK_NE(next_state_, STATE_NONE); 253 int rv = last_io_result; 254 do { 255 State state = next_state_; 256 next_state_ = STATE_NONE; 257 switch (state) { 258 case STATE_RESOLVE_HOST: 259 DCHECK_EQ(OK, rv); 260 rv = DoResolveHost(); 261 break; 262 case STATE_RESOLVE_HOST_COMPLETE: 263 rv = DoResolveHostComplete(rv); 264 break; 265 case STATE_HANDSHAKE_WRITE: 266 DCHECK_EQ(OK, rv); 267 rv = DoHandshakeWrite(); 268 break; 269 case STATE_HANDSHAKE_WRITE_COMPLETE: 270 rv = DoHandshakeWriteComplete(rv); 271 break; 272 case STATE_HANDSHAKE_READ: 273 DCHECK_EQ(OK, rv); 274 rv = DoHandshakeRead(); 275 break; 276 case STATE_HANDSHAKE_READ_COMPLETE: 277 rv = DoHandshakeReadComplete(rv); 278 break; 279 default: 280 NOTREACHED() << "bad state"; 281 rv = ERR_UNEXPECTED; 282 break; 283 } 284 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); 285 return rv; 286} 287 288int SOCKSClientSocket::DoResolveHost() { 289 next_state_ = STATE_RESOLVE_HOST_COMPLETE; 290 // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4 291 // addresses for the target host. 292 host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4); 293 return host_resolver_.Resolve( 294 host_request_info_, 295 priority_, 296 &addresses_, 297 base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)), 298 net_log_); 299} 300 301int SOCKSClientSocket::DoResolveHostComplete(int result) { 302 if (result != OK) { 303 // Resolving the hostname failed; fail the request rather than automatically 304 // falling back to SOCKS4a (since it can be confusing to see invalid IP 305 // addresses being sent to the SOCKS4 server when it doesn't support 4A.) 306 return result; 307 } 308 309 next_state_ = STATE_HANDSHAKE_WRITE; 310 return OK; 311} 312 313// Builds the buffer that is to be sent to the server. 314const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const { 315 SOCKS4ServerRequest request; 316 request.version = kSOCKSVersion4; 317 request.command = kSOCKSStreamRequest; 318 request.nw_port = base::HostToNet16(host_request_info_.port()); 319 320 DCHECK(!addresses_.empty()); 321 const IPEndPoint& endpoint = addresses_.front(); 322 323 // We disabled IPv6 results when resolving the hostname, so none of the 324 // results in the list will be IPv6. 325 // TODO(eroman): we only ever use the first address in the list. It would be 326 // more robust to try all the IP addresses we have before 327 // failing the connect attempt. 328 CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily()); 329 CHECK_LE(endpoint.address().size(), sizeof(request.ip)); 330 memcpy(&request.ip, &endpoint.address()[0], endpoint.address().size()); 331 332 DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort(); 333 334 std::string handshake_data(reinterpret_cast<char*>(&request), 335 sizeof(request)); 336 handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId)); 337 338 return handshake_data; 339} 340 341// Writes the SOCKS handshake data to the underlying socket connection. 342int SOCKSClientSocket::DoHandshakeWrite() { 343 next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; 344 345 if (buffer_.empty()) { 346 buffer_ = BuildHandshakeWriteBuffer(); 347 bytes_sent_ = 0; 348 } 349 350 int handshake_buf_len = buffer_.size() - bytes_sent_; 351 DCHECK_GT(handshake_buf_len, 0); 352 handshake_buf_ = new IOBuffer(handshake_buf_len); 353 memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], 354 handshake_buf_len); 355 return transport_->socket()->Write( 356 handshake_buf_.get(), 357 handshake_buf_len, 358 base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this))); 359} 360 361int SOCKSClientSocket::DoHandshakeWriteComplete(int result) { 362 if (result < 0) 363 return result; 364 365 // We ignore the case when result is 0, since the underlying Write 366 // may return spurious writes while waiting on the socket. 367 368 bytes_sent_ += result; 369 if (bytes_sent_ == buffer_.size()) { 370 next_state_ = STATE_HANDSHAKE_READ; 371 buffer_.clear(); 372 } else if (bytes_sent_ < buffer_.size()) { 373 next_state_ = STATE_HANDSHAKE_WRITE; 374 } else { 375 return ERR_UNEXPECTED; 376 } 377 378 return OK; 379} 380 381int SOCKSClientSocket::DoHandshakeRead() { 382 next_state_ = STATE_HANDSHAKE_READ_COMPLETE; 383 384 if (buffer_.empty()) { 385 bytes_received_ = 0; 386 } 387 388 int handshake_buf_len = kReadHeaderSize - bytes_received_; 389 handshake_buf_ = new IOBuffer(handshake_buf_len); 390 return transport_->socket()->Read( 391 handshake_buf_.get(), 392 handshake_buf_len, 393 base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this))); 394} 395 396int SOCKSClientSocket::DoHandshakeReadComplete(int result) { 397 if (result < 0) 398 return result; 399 400 // The underlying socket closed unexpectedly. 401 if (result == 0) 402 return ERR_CONNECTION_CLOSED; 403 404 if (bytes_received_ + result > kReadHeaderSize) { 405 // TODO(eroman): Describe failure in NetLog. 406 return ERR_SOCKS_CONNECTION_FAILED; 407 } 408 409 buffer_.append(handshake_buf_->data(), result); 410 bytes_received_ += result; 411 if (bytes_received_ < kReadHeaderSize) { 412 next_state_ = STATE_HANDSHAKE_READ; 413 return OK; 414 } 415 416 const SOCKS4ServerResponse* response = 417 reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data()); 418 419 if (response->reserved_null != 0x00) { 420 LOG(ERROR) << "Unknown response from SOCKS server."; 421 return ERR_SOCKS_CONNECTION_FAILED; 422 } 423 424 switch (response->code) { 425 case kServerResponseOk: 426 completed_handshake_ = true; 427 return OK; 428 case kServerResponseRejected: 429 LOG(ERROR) << "SOCKS request rejected or failed"; 430 return ERR_SOCKS_CONNECTION_FAILED; 431 case kServerResponseNotReachable: 432 LOG(ERROR) << "SOCKS request failed because client is not running " 433 << "identd (or not reachable from the server)"; 434 return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE; 435 case kServerResponseMismatchedUserId: 436 LOG(ERROR) << "SOCKS request failed because client's identd could " 437 << "not confirm the user ID string in the request"; 438 return ERR_SOCKS_CONNECTION_FAILED; 439 default: 440 LOG(ERROR) << "SOCKS server sent unknown response"; 441 return ERR_SOCKS_CONNECTION_FAILED; 442 } 443 444 // Note: we ignore the last 6 bytes as specified by the SOCKS protocol 445} 446 447int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const { 448 return transport_->socket()->GetPeerAddress(address); 449} 450 451int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const { 452 return transport_->socket()->GetLocalAddress(address); 453} 454 455} // namespace net 456