1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/socks_client_socket.h"
6
7#include "base/basictypes.h"
8#include "base/bind.h"
9#include "base/callback_helpers.h"
10#include "base/compiler_specific.h"
11#include "base/sys_byteorder.h"
12#include "net/base/io_buffer.h"
13#include "net/base/net_log.h"
14#include "net/base/net_util.h"
15#include "net/socket/client_socket_handle.h"
16
17namespace net {
18
19// Every SOCKS server requests a user-id from the client. It is optional
20// and we send an empty string.
21static const char kEmptyUserId[] = "";
22
23// For SOCKS4, the client sends 8 bytes  plus the size of the user-id.
24static const unsigned int kWriteHeaderSize = 8;
25
26// For SOCKS4 the server sends 8 bytes for acknowledgement.
27static const unsigned int kReadHeaderSize = 8;
28
29// Server Response codes for SOCKS.
30static const uint8 kServerResponseOk  = 0x5A;
31static const uint8 kServerResponseRejected = 0x5B;
32static const uint8 kServerResponseNotReachable = 0x5C;
33static const uint8 kServerResponseMismatchedUserId = 0x5D;
34
35static const uint8 kSOCKSVersion4 = 0x04;
36static const uint8 kSOCKSStreamRequest = 0x01;
37
38// A struct holding the essential details of the SOCKS4 Server Request.
39// The port in the header is stored in network byte order.
40struct SOCKS4ServerRequest {
41  uint8 version;
42  uint8 command;
43  uint16 nw_port;
44  uint8 ip[4];
45};
46COMPILE_ASSERT(sizeof(SOCKS4ServerRequest) == kWriteHeaderSize,
47               socks4_server_request_struct_wrong_size);
48
49// A struct holding details of the SOCKS4 Server Response.
50struct SOCKS4ServerResponse {
51  uint8 reserved_null;
52  uint8 code;
53  uint16 port;
54  uint8 ip[4];
55};
56COMPILE_ASSERT(sizeof(SOCKS4ServerResponse) == kReadHeaderSize,
57               socks4_server_response_struct_wrong_size);
58
59SOCKSClientSocket::SOCKSClientSocket(
60    scoped_ptr<ClientSocketHandle> transport_socket,
61    const HostResolver::RequestInfo& req_info,
62    RequestPriority priority,
63    HostResolver* host_resolver)
64    : transport_(transport_socket.Pass()),
65      next_state_(STATE_NONE),
66      completed_handshake_(false),
67      bytes_sent_(0),
68      bytes_received_(0),
69      was_ever_used_(false),
70      host_resolver_(host_resolver),
71      host_request_info_(req_info),
72      priority_(priority),
73      net_log_(transport_->socket()->NetLog()) {}
74
75SOCKSClientSocket::~SOCKSClientSocket() {
76  Disconnect();
77}
78
79int SOCKSClientSocket::Connect(const CompletionCallback& callback) {
80  DCHECK(transport_.get());
81  DCHECK(transport_->socket());
82  DCHECK_EQ(STATE_NONE, next_state_);
83  DCHECK(user_callback_.is_null());
84
85  // If already connected, then just return OK.
86  if (completed_handshake_)
87    return OK;
88
89  next_state_ = STATE_RESOLVE_HOST;
90
91  net_log_.BeginEvent(NetLog::TYPE_SOCKS_CONNECT);
92
93  int rv = DoLoop(OK);
94  if (rv == ERR_IO_PENDING) {
95    user_callback_ = callback;
96  } else {
97    net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
98  }
99  return rv;
100}
101
102void SOCKSClientSocket::Disconnect() {
103  completed_handshake_ = false;
104  host_resolver_.Cancel();
105  transport_->socket()->Disconnect();
106
107  // Reset other states to make sure they aren't mistakenly used later.
108  // These are the states initialized by Connect().
109  next_state_ = STATE_NONE;
110  user_callback_.Reset();
111}
112
113bool SOCKSClientSocket::IsConnected() const {
114  return completed_handshake_ && transport_->socket()->IsConnected();
115}
116
117bool SOCKSClientSocket::IsConnectedAndIdle() const {
118  return completed_handshake_ && transport_->socket()->IsConnectedAndIdle();
119}
120
121const BoundNetLog& SOCKSClientSocket::NetLog() const {
122  return net_log_;
123}
124
125void SOCKSClientSocket::SetSubresourceSpeculation() {
126  if (transport_.get() && transport_->socket()) {
127    transport_->socket()->SetSubresourceSpeculation();
128  } else {
129    NOTREACHED();
130  }
131}
132
133void SOCKSClientSocket::SetOmniboxSpeculation() {
134  if (transport_.get() && transport_->socket()) {
135    transport_->socket()->SetOmniboxSpeculation();
136  } else {
137    NOTREACHED();
138  }
139}
140
141bool SOCKSClientSocket::WasEverUsed() const {
142  return was_ever_used_;
143}
144
145bool SOCKSClientSocket::UsingTCPFastOpen() const {
146  if (transport_.get() && transport_->socket()) {
147    return transport_->socket()->UsingTCPFastOpen();
148  }
149  NOTREACHED();
150  return false;
151}
152
153bool SOCKSClientSocket::WasNpnNegotiated() const {
154  if (transport_.get() && transport_->socket()) {
155    return transport_->socket()->WasNpnNegotiated();
156  }
157  NOTREACHED();
158  return false;
159}
160
161NextProto SOCKSClientSocket::GetNegotiatedProtocol() const {
162  if (transport_.get() && transport_->socket()) {
163    return transport_->socket()->GetNegotiatedProtocol();
164  }
165  NOTREACHED();
166  return kProtoUnknown;
167}
168
169bool SOCKSClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
170  if (transport_.get() && transport_->socket()) {
171    return transport_->socket()->GetSSLInfo(ssl_info);
172  }
173  NOTREACHED();
174  return false;
175
176}
177
178// Read is called by the transport layer above to read. This can only be done
179// if the SOCKS handshake is complete.
180int SOCKSClientSocket::Read(IOBuffer* buf, int buf_len,
181                            const CompletionCallback& callback) {
182  DCHECK(completed_handshake_);
183  DCHECK_EQ(STATE_NONE, next_state_);
184  DCHECK(user_callback_.is_null());
185  DCHECK(!callback.is_null());
186
187  int rv = transport_->socket()->Read(
188      buf, buf_len,
189      base::Bind(&SOCKSClientSocket::OnReadWriteComplete,
190                 base::Unretained(this), callback));
191  if (rv > 0)
192    was_ever_used_ = true;
193  return rv;
194}
195
196// Write is called by the transport layer. This can only be done if the
197// SOCKS handshake is complete.
198int SOCKSClientSocket::Write(IOBuffer* buf, int buf_len,
199                             const CompletionCallback& callback) {
200  DCHECK(completed_handshake_);
201  DCHECK_EQ(STATE_NONE, next_state_);
202  DCHECK(user_callback_.is_null());
203  DCHECK(!callback.is_null());
204
205  int rv = transport_->socket()->Write(
206      buf, buf_len,
207      base::Bind(&SOCKSClientSocket::OnReadWriteComplete,
208                 base::Unretained(this), callback));
209  if (rv > 0)
210    was_ever_used_ = true;
211  return rv;
212}
213
214int SOCKSClientSocket::SetReceiveBufferSize(int32 size) {
215  return transport_->socket()->SetReceiveBufferSize(size);
216}
217
218int SOCKSClientSocket::SetSendBufferSize(int32 size) {
219  return transport_->socket()->SetSendBufferSize(size);
220}
221
222void SOCKSClientSocket::DoCallback(int result) {
223  DCHECK_NE(ERR_IO_PENDING, result);
224  DCHECK(!user_callback_.is_null());
225
226  // Since Run() may result in Read being called,
227  // clear user_callback_ up front.
228  DVLOG(1) << "Finished setting up SOCKS handshake";
229  base::ResetAndReturn(&user_callback_).Run(result);
230}
231
232void SOCKSClientSocket::OnIOComplete(int result) {
233  DCHECK_NE(STATE_NONE, next_state_);
234  int rv = DoLoop(result);
235  if (rv != ERR_IO_PENDING) {
236    net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS_CONNECT, rv);
237    DoCallback(rv);
238  }
239}
240
241void SOCKSClientSocket::OnReadWriteComplete(const CompletionCallback& callback,
242                                            int result) {
243  DCHECK_NE(ERR_IO_PENDING, result);
244  DCHECK(!callback.is_null());
245
246  if (result > 0)
247    was_ever_used_ = true;
248  callback.Run(result);
249}
250
251int SOCKSClientSocket::DoLoop(int last_io_result) {
252  DCHECK_NE(next_state_, STATE_NONE);
253  int rv = last_io_result;
254  do {
255    State state = next_state_;
256    next_state_ = STATE_NONE;
257    switch (state) {
258      case STATE_RESOLVE_HOST:
259        DCHECK_EQ(OK, rv);
260        rv = DoResolveHost();
261        break;
262      case STATE_RESOLVE_HOST_COMPLETE:
263        rv = DoResolveHostComplete(rv);
264        break;
265      case STATE_HANDSHAKE_WRITE:
266        DCHECK_EQ(OK, rv);
267        rv = DoHandshakeWrite();
268        break;
269      case STATE_HANDSHAKE_WRITE_COMPLETE:
270        rv = DoHandshakeWriteComplete(rv);
271        break;
272      case STATE_HANDSHAKE_READ:
273        DCHECK_EQ(OK, rv);
274        rv = DoHandshakeRead();
275        break;
276      case STATE_HANDSHAKE_READ_COMPLETE:
277        rv = DoHandshakeReadComplete(rv);
278        break;
279      default:
280        NOTREACHED() << "bad state";
281        rv = ERR_UNEXPECTED;
282        break;
283    }
284  } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
285  return rv;
286}
287
288int SOCKSClientSocket::DoResolveHost() {
289  next_state_ = STATE_RESOLVE_HOST_COMPLETE;
290  // SOCKS4 only supports IPv4 addresses, so only try getting the IPv4
291  // addresses for the target host.
292  host_request_info_.set_address_family(ADDRESS_FAMILY_IPV4);
293  return host_resolver_.Resolve(
294      host_request_info_,
295      priority_,
296      &addresses_,
297      base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)),
298      net_log_);
299}
300
301int SOCKSClientSocket::DoResolveHostComplete(int result) {
302  if (result != OK) {
303    // Resolving the hostname failed; fail the request rather than automatically
304    // falling back to SOCKS4a (since it can be confusing to see invalid IP
305    // addresses being sent to the SOCKS4 server when it doesn't support 4A.)
306    return result;
307  }
308
309  next_state_ = STATE_HANDSHAKE_WRITE;
310  return OK;
311}
312
313// Builds the buffer that is to be sent to the server.
314const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const {
315  SOCKS4ServerRequest request;
316  request.version = kSOCKSVersion4;
317  request.command = kSOCKSStreamRequest;
318  request.nw_port = base::HostToNet16(host_request_info_.port());
319
320  DCHECK(!addresses_.empty());
321  const IPEndPoint& endpoint = addresses_.front();
322
323  // We disabled IPv6 results when resolving the hostname, so none of the
324  // results in the list will be IPv6.
325  // TODO(eroman): we only ever use the first address in the list. It would be
326  //               more robust to try all the IP addresses we have before
327  //               failing the connect attempt.
328  CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily());
329  CHECK_LE(endpoint.address().size(), sizeof(request.ip));
330  memcpy(&request.ip, &endpoint.address()[0], endpoint.address().size());
331
332  DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort();
333
334  std::string handshake_data(reinterpret_cast<char*>(&request),
335                             sizeof(request));
336  handshake_data.append(kEmptyUserId, arraysize(kEmptyUserId));
337
338  return handshake_data;
339}
340
341// Writes the SOCKS handshake data to the underlying socket connection.
342int SOCKSClientSocket::DoHandshakeWrite() {
343  next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
344
345  if (buffer_.empty()) {
346    buffer_ = BuildHandshakeWriteBuffer();
347    bytes_sent_ = 0;
348  }
349
350  int handshake_buf_len = buffer_.size() - bytes_sent_;
351  DCHECK_GT(handshake_buf_len, 0);
352  handshake_buf_ = new IOBuffer(handshake_buf_len);
353  memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
354         handshake_buf_len);
355  return transport_->socket()->Write(
356      handshake_buf_.get(),
357      handshake_buf_len,
358      base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
359}
360
361int SOCKSClientSocket::DoHandshakeWriteComplete(int result) {
362  if (result < 0)
363    return result;
364
365  // We ignore the case when result is 0, since the underlying Write
366  // may return spurious writes while waiting on the socket.
367
368  bytes_sent_ += result;
369  if (bytes_sent_ == buffer_.size()) {
370    next_state_ = STATE_HANDSHAKE_READ;
371    buffer_.clear();
372  } else if (bytes_sent_ < buffer_.size()) {
373    next_state_ = STATE_HANDSHAKE_WRITE;
374  } else {
375    return ERR_UNEXPECTED;
376  }
377
378  return OK;
379}
380
381int SOCKSClientSocket::DoHandshakeRead() {
382  next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
383
384  if (buffer_.empty()) {
385    bytes_received_ = 0;
386  }
387
388  int handshake_buf_len = kReadHeaderSize - bytes_received_;
389  handshake_buf_ = new IOBuffer(handshake_buf_len);
390  return transport_->socket()->Read(
391      handshake_buf_.get(),
392      handshake_buf_len,
393      base::Bind(&SOCKSClientSocket::OnIOComplete, base::Unretained(this)));
394}
395
396int SOCKSClientSocket::DoHandshakeReadComplete(int result) {
397  if (result < 0)
398    return result;
399
400  // The underlying socket closed unexpectedly.
401  if (result == 0)
402    return ERR_CONNECTION_CLOSED;
403
404  if (bytes_received_ + result > kReadHeaderSize) {
405    // TODO(eroman): Describe failure in NetLog.
406    return ERR_SOCKS_CONNECTION_FAILED;
407  }
408
409  buffer_.append(handshake_buf_->data(), result);
410  bytes_received_ += result;
411  if (bytes_received_ < kReadHeaderSize) {
412    next_state_ = STATE_HANDSHAKE_READ;
413    return OK;
414  }
415
416  const SOCKS4ServerResponse* response =
417      reinterpret_cast<const SOCKS4ServerResponse*>(buffer_.data());
418
419  if (response->reserved_null != 0x00) {
420    LOG(ERROR) << "Unknown response from SOCKS server.";
421    return ERR_SOCKS_CONNECTION_FAILED;
422  }
423
424  switch (response->code) {
425    case kServerResponseOk:
426      completed_handshake_ = true;
427      return OK;
428    case kServerResponseRejected:
429      LOG(ERROR) << "SOCKS request rejected or failed";
430      return ERR_SOCKS_CONNECTION_FAILED;
431    case kServerResponseNotReachable:
432      LOG(ERROR) << "SOCKS request failed because client is not running "
433                 << "identd (or not reachable from the server)";
434      return ERR_SOCKS_CONNECTION_HOST_UNREACHABLE;
435    case kServerResponseMismatchedUserId:
436      LOG(ERROR) << "SOCKS request failed because client's identd could "
437                 << "not confirm the user ID string in the request";
438      return ERR_SOCKS_CONNECTION_FAILED;
439    default:
440      LOG(ERROR) << "SOCKS server sent unknown response";
441      return ERR_SOCKS_CONNECTION_FAILED;
442  }
443
444  // Note: we ignore the last 6 bytes as specified by the SOCKS protocol
445}
446
447int SOCKSClientSocket::GetPeerAddress(IPEndPoint* address) const {
448  return transport_->socket()->GetPeerAddress(address);
449}
450
451int SOCKSClientSocket::GetLocalAddress(IPEndPoint* address) const {
452  return transport_->socket()->GetLocalAddress(address);
453}
454
455}  // namespace net
456