socks_client_socket.cc revision c2e0dbddbe15c98d52c4786dac06cb8952a8ae6d
1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/socks_client_socket.h"
6
7#include "base/basictypes.h"
8#include "base/bind.h"
9#include "base/compiler_specific.h"
10#include "base/sys_byteorder.h"
11#include "net/base/io_buffer.h"
12#include "net/base/net_log.h"
13#include "net/base/net_util.h"
14#include "net/socket/client_socket_handle.h"
15
16namespace net {
17
18// Every SOCKS server requests a user-id from the client. It is optional
19// and we send an empty string.
20static const char kEmptyUserId[] = "";
21
22// For SOCKS4, the client sends 8 bytes  plus the size of the user-id.
23static const unsigned int kWriteHeaderSize = 8;
24
25// For SOCKS4 the server sends 8 bytes for acknowledgement.
26static const unsigned int kReadHeaderSize = 8;
27
28// Server Response codes for SOCKS.
29static const uint8 kServerResponseOk  = 0x5A;
30static const uint8 kServerResponseRejected = 0x5B;
31static const uint8 kServerResponseNotReachable = 0x5C;
32static const uint8 kServerResponseMismatchedUserId = 0x5D;
33
34static const uint8 kSOCKSVersion4 = 0x04;
35static const uint8 kSOCKSStreamRequest = 0x01;
36
37// A struct holding the essential details of the SOCKS4 Server Request.
38// The port in the header is stored in network byte order.
39struct SOCKS4ServerRequest {
40  uint8 version;
41  uint8 command;
42  uint16 nw_port;
43  uint8 ip[4];
44};
45COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize,
46               socks4_server_request_struct_wrong_size);
47
48// A struct holding details of the SOCKS4 Server Response.
49struct SOCKS4ServerResponse {
50  uint8 reserved_null;
51  uint8 code;
52  uint16 port;
53  uint8 ip[4];
54};
55COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
56               socks4_server_response_struct_wrong_size);
57
58SOCKSClientSocket::SOCKSClientSocket(ClientSocketHandle* transport_socket,
59                                     const HostResolver::RequestInfo& req_info,
60                                     HostResolver* host_resolver)
61    : transport_(transport_socket),
62      next_state_(STATE_NONE),
63      completed_handshake_(false),
64      bytes_sent_(0),
65      bytes_received_(0),
66      host_resolver_(host_resolver),
67      host_request_info_(req_info),
68      net_log_(transport_socket->socket()->NetLog()) {
69}
70
71SOCKSClientSocket::SOCKSClientSocket(StreamSocket* transport_socket,
72                                     const HostResolver::RequestInfo& req_info,
73                                     HostResolver* host_resolver)
74    : transport_(new ClientSocketHandle()),
75      next_state_(STATE_NONE),
76      completed_handshake_(false),
77      bytes_sent_(0),
78      bytes_received_(0),
79      host_resolver_(host_resolver),
80      host_request_info_(req_info),
81      net_log_(transport_socket->NetLog()) {
82  transport_->set_socket(transport_socket);
83}
84
85SOCKSClientSocket::~SOCKSClientSocket() {
86  Disconnect();
87}
88
89int SOCKSClientSocket::Connect(const CompletionCallback& callback) {
90  DCHECK(transport_.get());
91  DCHECK(transport_->socket());
92  DCHECK_EQ(STATE_NONE, next_state_);
93  DCHECK(user_callback_.is_null());
94
95  // If already connected, then just return OK.
96  if (completed_handshake_)
97    return OK;
98
99  next_state_ = STATE_RESOLVE_HOST;
100
101  net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT);
102
103  int rv = DoLoop(OK);
104  if (rv == ERR_IO_PENDING) {
105    user_callback_ = callback;
106  } else {
107    net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
108  }
109  return rv;
110}
111
112void SOCKSClientSocket::Disconnect() {
113  completed_handshake_ = false;
114  host_resolver_.Cancel();
115  transport_->socket()->Disconnect();
116
117  // Reset other states to make sure they aren't mistakenly used later.
118  // These are the states initialized by Connect().
119  next_state_ = STATE_NONE;
120  user_callback_.Reset();
121}
122
123bool SOCKSClientSocket::IsConnected() const {
124  return completed_handshake_ && transport_->socket()->IsConnected();
125}
126
127bool SOCKSClientSocket::IsConnectedAndIdle() const {
128  return completed_handshake_ && transport_->socket()->IsConnectedAndIdle();
129}
130
131const BoundNetLog& SOCKSClientSocket::NetLog() const {
132  return net_log_;
133}
134
135void SOCKSClientSocket::SetSubresourceSpeculation() {
136  if (transport_.get() && transport_->socket()) {
137    transport_->socket()->SetSubresourceSpeculation();
138  } else {
139    NOTREACHED();
140  }
141}
142
143void SOCKSClientSocket::SetOmniboxSpeculation() {
144  if (transport_.get() && transport_->socket()) {
145    transport_->socket()->SetOmniboxSpeculation();
146  } else {
147    NOTREACHED();
148  }
149}
150
151bool SOCKSClientSocket::WasEverUsed() const {
152  if (transport_.get() && transport_->socket()) {
153    return transport_->socket()->WasEverUsed();
154  }
155  NOTREACHED();
156  return false;
157}
158
159bool SOCKSClientSocket::UsingTCPFastOpen() const {
160  if (transport_.get() && transport_->socket()) {
161    return transport_->socket()->UsingTCPFastOpen();
162  }
163  NOTREACHED();
164  return false;
165}
166
167bool SOCKSClientSocket::WasNpnNegotiated() const {
168  if (transport_.get() && transport_->socket()) {
169    return transport_->socket()->WasNpnNegotiated();
170  }
171  NOTREACHED();
172  return false;
173}
174
175NextProto SOCKSClientSocket::GetNegotiatedProtocol() const {
176  if (transport_.get() && transport_->socket()) {
177    return transport_->socket()->GetNegotiatedProtocol();
178  }
179  NOTREACHED();
180  return kProtoUnknown;
181}
182
183bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
184  if (transport_.get() && transport_->socket()) {
185    return transport_->socket()->GetSSLInfo(ssl_info);
186  }
187  NOTREACHED();
188  return false;
189
190}
191
192// Read is called by the transport layer above to read. This can only be done
193// if the SOCKS handshake is complete.
194int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len,
195                            const CompletionCallback& callback) {
196  DCHECK(completed_handshake_);
197  DCHECK_EQ(STATE_NONE, next_state_);
198  DCHECK(user_callback_.is_null());
199
200  return transport_->socket()->Read(buf, buf_len, callback);
201}
202
203// Write is called by the transport layer. This can only be done if the
204// SOCKS handshake is complete.
205int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len,
206                             const CompletionCallback& callback) {
207  DCHECK(completed_handshake_);
208  DCHECK_EQ(STATE_NONE, next_state_);
209  DCHECK(user_callback_.is_null());
210
211  return transport_->socket()->Write(buf, buf_len, callback);
212}
213
214bool SOCKSClientSocket::SetReceiveBufferSize(int32 size) {
215  return transport_->socket()->SetReceiveBufferSize(size);
216}
217
218bool SOCKSClientSocket::SetSendBufferSize(int32 size) {
219  return transport_->socket()->SetSendBufferSize(size);
220}
221
222void SOCKSClientSocket::DoCallback(int result) {
223  DCHECK_NE(ERR_IO_PENDING, result);
224  DCHECK(!user_callback_.is_null());
225
226  // Since Run() may result in Read being called,
227  // clear user_callback_ up front.
228  CompletionCallback c = user_callback_;
229  user_callback_.Reset();
230  DVLOG(1) << "Finished setting up SOCKS handshake";
231  c.Run(result);
232}
233
234void SOCKSClientSocket::OnIOComplete(int result) {
235  DCHECK_NE(STATE_NONE, next_state_);
236  int rv = DoLoop(result);
237  if (rv != ERR_IO_PENDING) {
238    net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
239    DoCallback(rv);
240  }
241}
242
243int SOCKSClientSocket::DoLoop(int last_io_result) {
244  DCHECK_NE(next_state_, STATE_NONE);
245  int rv = last_io_result;
246  do {
247    State state = next_state_;
248    next_state_ = STATE_NONE;
249    switch (state) {
250      case STATE_RESOLVE_HOST:
251        DCHECK_EQ(OK, rv);
252        rv = DoResolveHost();
253        break;
254      case STATE_RESOLVE_HOST_COMPLETE:
255        rv = DoResolveHostComplete(rv);
256        break;
257      case STATE_HANDSHAKE_WRITE:
258        DCHECK_EQ(OK, rv);
259        rv = DoHandshakeWrite();
260        break;
261      case STATE_HANDSHAKE_WRITE_COMPLETE:
262        rv = DoHandshakeWriteComplete(rv);
263        break;
264      case STATE_HANDSHAKE_READ:
265        DCHECK_EQ(OK, rv);
266        rv = DoHandshakeRead();
267        break;
268      case STATE_HANDSHAKE_READ_COMPLETE:
269        rv = DoHandshakeReadComplete(rv);
270        break;
271      default:
272        NOTREACHED() << "bad state";
273        rv = ERR_UNEXPECTED;
274        break;
275    }
276  } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
277  return rv;
278}
279
280int SOCKSClientSocket::DoResolveHost() {
281  next_state_ = STATE_RESOLVE_HOST_COMPLETE;
282  // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4
283  // addresses for the target host.
284  host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4);
285  return host_resolver_.Resolve(
286      host_request_info_, &addresses_,
287      base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)),
288      net_log_);
289}
290
291int SOCKSClientSocket::DoResolveHostComplete(int result) {
292  if (result != OK) {
293    // Resolving the hostname failed; fail the request rather than automatically
294    // falling back to SOCKS4a (since it can be confusing to see invalid IP
295    // addresses being sent to the SOCKS4 server when it doesn't support 4A.)
296    return result;
297  }
298
299  next_state_ = STATE_HANDSHAKE_WRITE;
300  return OK;
301}
302
303// Builds the buffer that is to be sent to the server.
304const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const {
305  SOCKS4ServerRequest request;
306  request.version = kSOCKSVersion4;
307  request.command = kSOCKSStreamRequest;
308  request.nw_port = base::HostToNet16(host_request_info_.port());
309
310  DCHECK(!addresses_.empty());
311  const IPEndPoint& endpoint = addresses_.front();
312
313  // We disabled IPv6 results when resolving the hostname, so none of the
314  // results in the list will be IPv6.
315  // TODO(eroman): we only ever use the first address in the list. It would be
316  //               more robust to try all the IP addresses we have before
317  //               failing the connect attempt.
318  CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily());
319  CHECK_LE(endpoint.address().size(), sizeof(request.ip));
320  memcpy(&request.ip, &endpoint.address()[0], endpoint.address().size());
321
322  DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort();
323
324  std::string handshake_data(reinterpret_cast<char*>(&request),
325                             sizeof(request));
326  handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId));
327
328  return handshake_data;
329}
330
331// Writes the SOCKS handshake data to the underlying socket connection.
332int SOCKSClientSocket::DoHandshakeWrite() {
333  next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
334
335  if (buffer_.empty()) {
336    buffer_ = BuildHandshakeWriteBuffer();
337    bytes_sent_ = 0;
338  }
339
340  int handshake_buf_len = buffer_.size() - bytes_sent_;
341  DCHECK_GT(handshake_buf_len, 0);
342  handshake_buf_ = new IOBuffer(handshake_buf_len);
343  memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
344         handshake_buf_len);
345  return transport_->socket()->Write(
346      handshake_buf_, handshake_buf_len,
347      base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
348}
349
350int SOCKSClientSocket::DoHandshakeWriteComplete(int result) {
351  if (result < 0)
352    return result;
353
354  // We ignore the case when result is 0, since the underlying Write
355  // may return spurious writes while waiting on the socket.
356
357  bytes_sent_ += result;
358  if (bytes_sent_ == buffer_.size()) {
359    next_state_ = STATE_HANDSHAKE_READ;
360    buffer_.clear();
361  } else if (bytes_sent_ < buffer_.size()) {
362    next_state_ = STATE_HANDSHAKE_WRITE;
363  } else {
364    return ERR_UNEXPECTED;
365  }
366
367  return OK;
368}
369
370int SOCKSClientSocket::DoHandshakeRead() {
371  next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
372
373  if (buffer_.empty()) {
374    bytes_received_ = 0;
375  }
376
377  int handshake_buf_len = kReadHeaderSize - bytes_received_;
378  handshake_buf_ = new IOBuffer(handshake_buf_len);
379  return transport_->socket()->Read(handshake_buf_, handshake_buf_len,
380                                    base::Bind(&SOCKSClientSocket::OnIOComplete,
381                                               base::Unretained(this)));
382}
383
384int SOCKSClientSocket::DoHandshakeReadComplete(int result) {
385  if (result < 0)
386    return result;
387
388  // The underlying socket closed unexpectedly.
389  if (result == 0)
390    return ERR_CONNECTION_CLOSED;
391
392  if (bytes_received_ + result > kReadHeaderSize) {
393    // TODO(eroman): Describe failure in NetLog.
394    return ERR_SOCKS_CONNECTION_FAILED;
395  }
396
397  buffer_.append(handshake_buf_->data(), result);
398  bytes_received_ += result;
399  if (bytes_received_ < kReadHeaderSize) {
400    next_state_ = STATE_HANDSHAKE_READ;
401    return OK;
402  }
403
404  const SOCKS4ServerResponse* response =
405      reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data());
406
407  if (response->reserved_null != 0x00) {
408    LOG(ERROR) << "Unknown response from SOCKS server.";
409    return ERR_SOCKS_CONNECTION_FAILED;
410  }
411
412  switch (response->code) {
413    case kServerResponseOk:
414      completed_handshake_ = true;
415      return OK;
416    case kServerResponseRejected:
417      LOG(ERROR) << "SOCKS request rejected or failed";
418      return ERR_SOCKS_CONNECTION_FAILED;
419    case kServerResponseNotReachable:
420      LOG(ERROR) << "SOCKS request failed because client is not running "
421                 << "identd (or not reachable from the server)";
422      return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE;
423    case kServerResponseMismatchedUserId:
424      LOG(ERROR) << "SOCKS request failed because client's identd could "
425                 << "not confirm the user ID string in the request";
426      return ERR_SOCKS_CONNECTION_FAILED;
427    default:
428      LOG(ERROR) << "SOCKS server sent unknown response";
429      return ERR_SOCKS_CONNECTION_FAILED;
430  }
431
432  // Note: we ignore the last 6 bytes as specified by the SOCKS protocol
433}
434
435int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const {
436  return transport_->socket()->GetPeerAddress(address);
437}
438
439int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const {
440  return transport_->socket()->GetLocalAddress(address);
441}
442
443}  // namespace net
444