socks_client_socket.cc revision c2e0dbddbe15c98d52c4786dac06cb8952a8ae6d
1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/socket/socks_client_socket.h" 6 7#include "base/basictypes.h" 8#include "base/bind.h" 9#include "base/compiler_specific.h" 10#include "base/sys_byteorder.h" 11#include "net/base/io_buffer.h" 12#include "net/base/net_log.h" 13#include "net/base/net_util.h" 14#include "net/socket/client_socket_handle.h" 15 16namespace net { 17 18// Every SOCKS server requests a user-id from the client. It is optional 19// and we send an empty string. 20static const char kEmptyUserId[] = ""; 21 22// For SOCKS4, the client sends 8 bytes plus the size of the user-id. 23static const unsigned int kWriteHeaderSize = 8; 24 25// For SOCKS4 the server sends 8 bytes for acknowledgement. 26static const unsigned int kReadHeaderSize = 8; 27 28// Server Response codes for SOCKS. 29static const uint8 kServerResponseOk = 0x5A; 30static const uint8 kServerResponseRejected = 0x5B; 31static const uint8 kServerResponseNotReachable = 0x5C; 32static const uint8 kServerResponseMismatchedUserId = 0x5D; 33 34static const uint8 kSOCKSVersion4 = 0x04; 35static const uint8 kSOCKSStreamRequest = 0x01; 36 37// A struct holding the essential details of the SOCKS4 Server Request. 38// The port in the header is stored in network byte order. 39struct SOCKS4ServerRequest { 40 uint8 version; 41 uint8 command; 42 uint16 nw_port; 43 uint8 ip[4]; 44}; 45COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize, 46 socks4_server_request_struct_wrong_size); 47 48// A struct holding details of the SOCKS4 Server Response. 49struct SOCKS4ServerResponse { 50 uint8 reserved_null; 51 uint8 code; 52 uint16 port; 53 uint8 ip[4]; 54}; 55COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize, 56 socks4_server_response_struct_wrong_size); 57 58SOCKSClientSocket::SOCKSClientSocket(ClientSocketHandle* transport_socket, 59 const HostResolver::RequestInfo& req_info, 60 HostResolver* host_resolver) 61 : transport_(transport_socket), 62 next_state_(STATE_NONE), 63 completed_handshake_(false), 64 bytes_sent_(0), 65 bytes_received_(0), 66 host_resolver_(host_resolver), 67 host_request_info_(req_info), 68 net_log_(transport_socket->socket()->NetLog()) { 69} 70 71SOCKSClientSocket::SOCKSClientSocket(StreamSocket* transport_socket, 72 const HostResolver::RequestInfo& req_info, 73 HostResolver* host_resolver) 74 : transport_(new ClientSocketHandle()), 75 next_state_(STATE_NONE), 76 completed_handshake_(false), 77 bytes_sent_(0), 78 bytes_received_(0), 79 host_resolver_(host_resolver), 80 host_request_info_(req_info), 81 net_log_(transport_socket->NetLog()) { 82 transport_->set_socket(transport_socket); 83} 84 85SOCKSClientSocket::~SOCKSClientSocket() { 86 Disconnect(); 87} 88 89int SOCKSClientSocket::Connect(const CompletionCallback& callback) { 90 DCHECK(transport_.get()); 91 DCHECK(transport_->socket()); 92 DCHECK_EQ(STATE_NONE, next_state_); 93 DCHECK(user_callback_.is_null()); 94 95 // If already connected, then just return OK. 96 if (completed_handshake_) 97 return OK; 98 99 next_state_ = STATE_RESOLVE_HOST; 100 101 net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT); 102 103 int rv = DoLoop(OK); 104 if (rv == ERR_IO_PENDING) { 105 user_callback_ = callback; 106 } else { 107 net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); 108 } 109 return rv; 110} 111 112void SOCKSClientSocket::Disconnect() { 113 completed_handshake_ = false; 114 host_resolver_.Cancel(); 115 transport_->socket()->Disconnect(); 116 117 // Reset other states to make sure they aren't mistakenly used later. 118 // These are the states initialized by Connect(). 119 next_state_ = STATE_NONE; 120 user_callback_.Reset(); 121} 122 123bool SOCKSClientSocket::IsConnected() const { 124 return completed_handshake_ && transport_->socket()->IsConnected(); 125} 126 127bool SOCKSClientSocket::IsConnectedAndIdle() const { 128 return completed_handshake_ && transport_->socket()->IsConnectedAndIdle(); 129} 130 131const BoundNetLog& SOCKSClientSocket::NetLog() const { 132 return net_log_; 133} 134 135void SOCKSClientSocket::SetSubresourceSpeculation() { 136 if (transport_.get() && transport_->socket()) { 137 transport_->socket()->SetSubresourceSpeculation(); 138 } else { 139 NOTREACHED(); 140 } 141} 142 143void SOCKSClientSocket::SetOmniboxSpeculation() { 144 if (transport_.get() && transport_->socket()) { 145 transport_->socket()->SetOmniboxSpeculation(); 146 } else { 147 NOTREACHED(); 148 } 149} 150 151bool SOCKSClientSocket::WasEverUsed() const { 152 if (transport_.get() && transport_->socket()) { 153 return transport_->socket()->WasEverUsed(); 154 } 155 NOTREACHED(); 156 return false; 157} 158 159bool SOCKSClientSocket::UsingTCPFastOpen() const { 160 if (transport_.get() && transport_->socket()) { 161 return transport_->socket()->UsingTCPFastOpen(); 162 } 163 NOTREACHED(); 164 return false; 165} 166 167bool SOCKSClientSocket::WasNpnNegotiated() const { 168 if (transport_.get() && transport_->socket()) { 169 return transport_->socket()->WasNpnNegotiated(); 170 } 171 NOTREACHED(); 172 return false; 173} 174 175NextProto SOCKSClientSocket::GetNegotiatedProtocol() const { 176 if (transport_.get() && transport_->socket()) { 177 return transport_->socket()->GetNegotiatedProtocol(); 178 } 179 NOTREACHED(); 180 return kProtoUnknown; 181} 182 183bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) { 184 if (transport_.get() && transport_->socket()) { 185 return transport_->socket()->GetSSLInfo(ssl_info); 186 } 187 NOTREACHED(); 188 return false; 189 190} 191 192// Read is called by the transport layer above to read. This can only be done 193// if the SOCKS handshake is complete. 194int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len, 195 const CompletionCallback& callback) { 196 DCHECK(completed_handshake_); 197 DCHECK_EQ(STATE_NONE, next_state_); 198 DCHECK(user_callback_.is_null()); 199 200 return transport_->socket()->Read(buf, buf_len, callback); 201} 202 203// Write is called by the transport layer. This can only be done if the 204// SOCKS handshake is complete. 205int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len, 206 const CompletionCallback& callback) { 207 DCHECK(completed_handshake_); 208 DCHECK_EQ(STATE_NONE, next_state_); 209 DCHECK(user_callback_.is_null()); 210 211 return transport_->socket()->Write(buf, buf_len, callback); 212} 213 214bool SOCKSClientSocket::SetReceiveBufferSize(int32 size) { 215 return transport_->socket()->SetReceiveBufferSize(size); 216} 217 218bool SOCKSClientSocket::SetSendBufferSize(int32 size) { 219 return transport_->socket()->SetSendBufferSize(size); 220} 221 222void SOCKSClientSocket::DoCallback(int result) { 223 DCHECK_NE(ERR_IO_PENDING, result); 224 DCHECK(!user_callback_.is_null()); 225 226 // Since Run() may result in Read being called, 227 // clear user_callback_ up front. 228 CompletionCallback c = user_callback_; 229 user_callback_.Reset(); 230 DVLOG(1) << "Finished setting up SOCKS handshake"; 231 c.Run(result); 232} 233 234void SOCKSClientSocket::OnIOComplete(int result) { 235 DCHECK_NE(STATE_NONE, next_state_); 236 int rv = DoLoop(result); 237 if (rv != ERR_IO_PENDING) { 238 net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv); 239 DoCallback(rv); 240 } 241} 242 243int SOCKSClientSocket::DoLoop(int last_io_result) { 244 DCHECK_NE(next_state_, STATE_NONE); 245 int rv = last_io_result; 246 do { 247 State state = next_state_; 248 next_state_ = STATE_NONE; 249 switch (state) { 250 case STATE_RESOLVE_HOST: 251 DCHECK_EQ(OK, rv); 252 rv = DoResolveHost(); 253 break; 254 case STATE_RESOLVE_HOST_COMPLETE: 255 rv = DoResolveHostComplete(rv); 256 break; 257 case STATE_HANDSHAKE_WRITE: 258 DCHECK_EQ(OK, rv); 259 rv = DoHandshakeWrite(); 260 break; 261 case STATE_HANDSHAKE_WRITE_COMPLETE: 262 rv = DoHandshakeWriteComplete(rv); 263 break; 264 case STATE_HANDSHAKE_READ: 265 DCHECK_EQ(OK, rv); 266 rv = DoHandshakeRead(); 267 break; 268 case STATE_HANDSHAKE_READ_COMPLETE: 269 rv = DoHandshakeReadComplete(rv); 270 break; 271 default: 272 NOTREACHED() << "bad state"; 273 rv = ERR_UNEXPECTED; 274 break; 275 } 276 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); 277 return rv; 278} 279 280int SOCKSClientSocket::DoResolveHost() { 281 next_state_ = STATE_RESOLVE_HOST_COMPLETE; 282 // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4 283 // addresses for the target host. 284 host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4); 285 return host_resolver_.Resolve( 286 host_request_info_, &addresses_, 287 base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)), 288 net_log_); 289} 290 291int SOCKSClientSocket::DoResolveHostComplete(int result) { 292 if (result != OK) { 293 // Resolving the hostname failed; fail the request rather than automatically 294 // falling back to SOCKS4a (since it can be confusing to see invalid IP 295 // addresses being sent to the SOCKS4 server when it doesn't support 4A.) 296 return result; 297 } 298 299 next_state_ = STATE_HANDSHAKE_WRITE; 300 return OK; 301} 302 303// Builds the buffer that is to be sent to the server. 304const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const { 305 SOCKS4ServerRequest request; 306 request.version = kSOCKSVersion4; 307 request.command = kSOCKSStreamRequest; 308 request.nw_port = base::HostToNet16(host_request_info_.port()); 309 310 DCHECK(!addresses_.empty()); 311 const IPEndPoint& endpoint = addresses_.front(); 312 313 // We disabled IPv6 results when resolving the hostname, so none of the 314 // results in the list will be IPv6. 315 // TODO(eroman): we only ever use the first address in the list. It would be 316 // more robust to try all the IP addresses we have before 317 // failing the connect attempt. 318 CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily()); 319 CHECK_LE(endpoint.address().size(), sizeof(request.ip)); 320 memcpy(&request.ip, &endpoint.address()[0], endpoint.address().size()); 321 322 DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort(); 323 324 std::string handshake_data(reinterpret_cast<char*>(&request), 325 sizeof(request)); 326 handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId)); 327 328 return handshake_data; 329} 330 331// Writes the SOCKS handshake data to the underlying socket connection. 332int SOCKSClientSocket::DoHandshakeWrite() { 333 next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE; 334 335 if (buffer_.empty()) { 336 buffer_ = BuildHandshakeWriteBuffer(); 337 bytes_sent_ = 0; 338 } 339 340 int handshake_buf_len = buffer_.size() - bytes_sent_; 341 DCHECK_GT(handshake_buf_len, 0); 342 handshake_buf_ = new IOBuffer(handshake_buf_len); 343 memcpy(handshake_buf_->data(), &buffer_[bytes_sent_], 344 handshake_buf_len); 345 return transport_->socket()->Write( 346 handshake_buf_, handshake_buf_len, 347 base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this))); 348} 349 350int SOCKSClientSocket::DoHandshakeWriteComplete(int result) { 351 if (result < 0) 352 return result; 353 354 // We ignore the case when result is 0, since the underlying Write 355 // may return spurious writes while waiting on the socket. 356 357 bytes_sent_ += result; 358 if (bytes_sent_ == buffer_.size()) { 359 next_state_ = STATE_HANDSHAKE_READ; 360 buffer_.clear(); 361 } else if (bytes_sent_ < buffer_.size()) { 362 next_state_ = STATE_HANDSHAKE_WRITE; 363 } else { 364 return ERR_UNEXPECTED; 365 } 366 367 return OK; 368} 369 370int SOCKSClientSocket::DoHandshakeRead() { 371 next_state_ = STATE_HANDSHAKE_READ_COMPLETE; 372 373 if (buffer_.empty()) { 374 bytes_received_ = 0; 375 } 376 377 int handshake_buf_len = kReadHeaderSize - bytes_received_; 378 handshake_buf_ = new IOBuffer(handshake_buf_len); 379 return transport_->socket()->Read(handshake_buf_, handshake_buf_len, 380 base::Bind(&SOCKSClientSocket::OnIOComplete, 381 base::Unretained(this))); 382} 383 384int SOCKSClientSocket::DoHandshakeReadComplete(int result) { 385 if (result < 0) 386 return result; 387 388 // The underlying socket closed unexpectedly. 389 if (result == 0) 390 return ERR_CONNECTION_CLOSED; 391 392 if (bytes_received_ + result > kReadHeaderSize) { 393 // TODO(eroman): Describe failure in NetLog. 394 return ERR_SOCKS_CONNECTION_FAILED; 395 } 396 397 buffer_.append(handshake_buf_->data(), result); 398 bytes_received_ += result; 399 if (bytes_received_ < kReadHeaderSize) { 400 next_state_ = STATE_HANDSHAKE_READ; 401 return OK; 402 } 403 404 const SOCKS4ServerResponse* response = 405 reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data()); 406 407 if (response->reserved_null != 0x00) { 408 LOG(ERROR) << "Unknown response from SOCKS server."; 409 return ERR_SOCKS_CONNECTION_FAILED; 410 } 411 412 switch (response->code) { 413 case kServerResponseOk: 414 completed_handshake_ = true; 415 return OK; 416 case kServerResponseRejected: 417 LOG(ERROR) << "SOCKS request rejected or failed"; 418 return ERR_SOCKS_CONNECTION_FAILED; 419 case kServerResponseNotReachable: 420 LOG(ERROR) << "SOCKS request failed because client is not running " 421 << "identd (or not reachable from the server)"; 422 return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE; 423 case kServerResponseMismatchedUserId: 424 LOG(ERROR) << "SOCKS request failed because client's identd could " 425 << "not confirm the user ID string in the request"; 426 return ERR_SOCKS_CONNECTION_FAILED; 427 default: 428 LOG(ERROR) << "SOCKS server sent unknown response"; 429 return ERR_SOCKS_CONNECTION_FAILED; 430 } 431 432 // Note: we ignore the last 6 bytes as specified by the SOCKS protocol 433} 434 435int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const { 436 return transport_->socket()->GetPeerAddress(address); 437} 438 439int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const { 440 return transport_->socket()->GetLocalAddress(address); 441} 442 443} // namespace net 444