1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/socks5_client_socket.h"
6
7#include "base/basictypes.h"
8#include "base/callback_helpers.h"
9#include "base/compiler_specific.h"
10#include "base/debug/trace_event.h"
11#include "base/format_macros.h"
12#include "base/strings/string_util.h"
13#include "base/sys_byteorder.h"
14#include "net/base/io_buffer.h"
15#include "net/base/net_log.h"
16#include "net/base/net_util.h"
17#include "net/socket/client_socket_handle.h"
18
19namespace net {
20
21const unsigned int SOCKS5ClientSocket::kGreetReadHeaderSize = 2;
22const unsigned int SOCKS5ClientSocket::kWriteHeaderSize = 10;
23const unsigned int SOCKS5ClientSocket::kReadHeaderSize = 5;
24const uint8 SOCKS5ClientSocket::kSOCKS5Version = 0x05;
25const uint8 SOCKS5ClientSocket::kTunnelCommand = 0x01;
26const uint8 SOCKS5ClientSocket::kNullByte = 0x00;
27
28COMPILE_ASSERT(sizeof(struct in_addr) == 4, incorrect_system_size_of_IPv4);
29COMPILE_ASSERT(sizeof(struct in6_addr) == 16, incorrect_system_size_of_IPv6);
30
31SOCKS5ClientSocket::SOCKS5ClientSocket(
32    scoped_ptr<ClientSocketHandle> transport_socket,
33    const HostResolver::RequestInfo& req_info)
34    : io_callback_(base::Bind(&SOCKS5ClientSocket::OnIOComplete,
35                              base::Unretained(this))),
36      transport_(transport_socket.Pass()),
37      next_state_(STATE_NONE),
38      completed_handshake_(false),
39      bytes_sent_(0),
40      bytes_received_(0),
41      read_header_size(kReadHeaderSize),
42      was_ever_used_(false),
43      host_request_info_(req_info),
44      net_log_(transport_->socket()->NetLog()) {
45}
46
47SOCKS5ClientSocket::~SOCKS5ClientSocket() {
48  Disconnect();
49}
50
51int SOCKS5ClientSocket::Connect(const CompletionCallback& callback) {
52  DCHECK(transport_.get());
53  DCHECK(transport_->socket());
54  DCHECK_EQ(STATE_NONE, next_state_);
55  DCHECK(user_callback_.is_null());
56
57  // If already connected, then just return OK.
58  if (completed_handshake_)
59    return OK;
60
61  net_log_.BeginEvent(NetLog::TYPE_SOCKS5_CONNECT);
62
63  next_state_ = STATE_GREET_WRITE;
64  buffer_.clear();
65
66  int rv = DoLoop(OK);
67  if (rv == ERR_IO_PENDING) {
68    user_callback_ = callback;
69  } else {
70    net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_CONNECT, rv);
71  }
72  return rv;
73}
74
75void SOCKS5ClientSocket::Disconnect() {
76  completed_handshake_ = false;
77  transport_->socket()->Disconnect();
78
79  // Reset other states to make sure they aren't mistakenly used later.
80  // These are the states initialized by Connect().
81  next_state_ = STATE_NONE;
82  user_callback_.Reset();
83}
84
85bool SOCKS5ClientSocket::IsConnected() const {
86  return completed_handshake_ && transport_->socket()->IsConnected();
87}
88
89bool SOCKS5ClientSocket::IsConnectedAndIdle() const {
90  return completed_handshake_ && transport_->socket()->IsConnectedAndIdle();
91}
92
93const BoundNetLog& SOCKS5ClientSocket::NetLog() const {
94  return net_log_;
95}
96
97void SOCKS5ClientSocket::SetSubresourceSpeculation() {
98  if (transport_.get() && transport_->socket()) {
99    transport_->socket()->SetSubresourceSpeculation();
100  } else {
101    NOTREACHED();
102  }
103}
104
105void SOCKS5ClientSocket::SetOmniboxSpeculation() {
106  if (transport_.get() && transport_->socket()) {
107    transport_->socket()->SetOmniboxSpeculation();
108  } else {
109    NOTREACHED();
110  }
111}
112
113bool SOCKS5ClientSocket::WasEverUsed() const {
114  return was_ever_used_;
115}
116
117bool SOCKS5ClientSocket::UsingTCPFastOpen() const {
118  if (transport_.get() && transport_->socket()) {
119    return transport_->socket()->UsingTCPFastOpen();
120  }
121  NOTREACHED();
122  return false;
123}
124
125bool SOCKS5ClientSocket::WasNpnNegotiated() const {
126  if (transport_.get() && transport_->socket()) {
127    return transport_->socket()->WasNpnNegotiated();
128  }
129  NOTREACHED();
130  return false;
131}
132
133NextProto SOCKS5ClientSocket::GetNegotiatedProtocol() const {
134  if (transport_.get() && transport_->socket()) {
135    return transport_->socket()->GetNegotiatedProtocol();
136  }
137  NOTREACHED();
138  return kProtoUnknown;
139}
140
141bool SOCKS5ClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
142  if (transport_.get() && transport_->socket()) {
143    return transport_->socket()->GetSSLInfo(ssl_info);
144  }
145  NOTREACHED();
146  return false;
147
148}
149
150// Read is called by the transport layer above to read. This can only be done
151// if the SOCKS handshake is complete.
152int SOCKS5ClientSocket::Read(IOBuffer* buf, int buf_len,
153                             const CompletionCallback& callback) {
154  DCHECK(completed_handshake_);
155  DCHECK_EQ(STATE_NONE, next_state_);
156  DCHECK(user_callback_.is_null());
157  DCHECK(!callback.is_null());
158
159  int rv = transport_->socket()->Read(
160      buf, buf_len,
161      base::Bind(&SOCKS5ClientSocket::OnReadWriteComplete,
162                 base::Unretained(this), callback));
163  if (rv > 0)
164    was_ever_used_ = true;
165  return rv;
166}
167
168// Write is called by the transport layer. This can only be done if the
169// SOCKS handshake is complete.
170int SOCKS5ClientSocket::Write(IOBuffer* buf, int buf_len,
171                              const CompletionCallback& callback) {
172  DCHECK(completed_handshake_);
173  DCHECK_EQ(STATE_NONE, next_state_);
174  DCHECK(user_callback_.is_null());
175  DCHECK(!callback.is_null());
176
177  int rv = transport_->socket()->Write(
178      buf, buf_len,
179      base::Bind(&SOCKS5ClientSocket::OnReadWriteComplete,
180                 base::Unretained(this), callback));
181  if (rv > 0)
182    was_ever_used_ = true;
183  return rv;
184}
185
186int SOCKS5ClientSocket::SetReceiveBufferSize(int32 size) {
187  return transport_->socket()->SetReceiveBufferSize(size);
188}
189
190int SOCKS5ClientSocket::SetSendBufferSize(int32 size) {
191  return transport_->socket()->SetSendBufferSize(size);
192}
193
194void SOCKS5ClientSocket::DoCallback(int result) {
195  DCHECK_NE(ERR_IO_PENDING, result);
196  DCHECK(!user_callback_.is_null());
197
198  // Since Run() may result in Read being called,
199  // clear user_callback_ up front.
200  base::ResetAndReturn(&user_callback_).Run(result);
201}
202
203void SOCKS5ClientSocket::OnIOComplete(int result) {
204  DCHECK_NE(STATE_NONE, next_state_);
205  int rv = DoLoop(result);
206  if (rv != ERR_IO_PENDING) {
207    net_log_.EndEvent(NetLog::TYPE_SOCKS5_CONNECT);
208    DoCallback(rv);
209  }
210}
211
212void SOCKS5ClientSocket::OnReadWriteComplete(const CompletionCallback& callback,
213                                             int result) {
214  DCHECK_NE(ERR_IO_PENDING, result);
215  DCHECK(!callback.is_null());
216
217  if (result > 0)
218    was_ever_used_ = true;
219  callback.Run(result);
220}
221
222int SOCKS5ClientSocket::DoLoop(int last_io_result) {
223  DCHECK_NE(next_state_, STATE_NONE);
224  int rv = last_io_result;
225  do {
226    State state = next_state_;
227    next_state_ = STATE_NONE;
228    switch (state) {
229      case STATE_GREET_WRITE:
230        DCHECK_EQ(OK, rv);
231        net_log_.BeginEvent(NetLog::TYPE_SOCKS5_GREET_WRITE);
232        rv = DoGreetWrite();
233        break;
234      case STATE_GREET_WRITE_COMPLETE:
235        rv = DoGreetWriteComplete(rv);
236        net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_GREET_WRITE, rv);
237        break;
238      case STATE_GREET_READ:
239        DCHECK_EQ(OK, rv);
240        net_log_.BeginEvent(NetLog::TYPE_SOCKS5_GREET_READ);
241        rv = DoGreetRead();
242        break;
243      case STATE_GREET_READ_COMPLETE:
244        rv = DoGreetReadComplete(rv);
245        net_log_.EndEventWithNetErrorCode(NetLog::TYPE_SOCKS5_GREET_READ, rv);
246        break;
247      case STATE_HANDSHAKE_WRITE:
248        DCHECK_EQ(OK, rv);
249        net_log_.BeginEvent(NetLog::TYPE_SOCKS5_HANDSHAKE_WRITE);
250        rv = DoHandshakeWrite();
251        break;
252      case STATE_HANDSHAKE_WRITE_COMPLETE:
253        rv = DoHandshakeWriteComplete(rv);
254        net_log_.EndEventWithNetErrorCode(
255            NetLog::TYPE_SOCKS5_HANDSHAKE_WRITE, rv);
256        break;
257      case STATE_HANDSHAKE_READ:
258        DCHECK_EQ(OK, rv);
259        net_log_.BeginEvent(NetLog::TYPE_SOCKS5_HANDSHAKE_READ);
260        rv = DoHandshakeRead();
261        break;
262      case STATE_HANDSHAKE_READ_COMPLETE:
263        rv = DoHandshakeReadComplete(rv);
264        net_log_.EndEventWithNetErrorCode(
265            NetLog::TYPE_SOCKS5_HANDSHAKE_READ, rv);
266        break;
267      default:
268        NOTREACHED() << "bad state";
269        rv = ERR_UNEXPECTED;
270        break;
271    }
272  } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
273  return rv;
274}
275
276const char kSOCKS5GreetWriteData[] = { 0x05, 0x01, 0x00 };  // no authentication
277
278int SOCKS5ClientSocket::DoGreetWrite() {
279  // Since we only have 1 byte to send the hostname length in, if the
280  // URL has a hostname longer than 255 characters we can't send it.
281  if (0xFF < host_request_info_.hostname().size()) {
282    net_log_.AddEvent(NetLog::TYPE_SOCKS_HOSTNAME_TOO_BIG);
283    return ERR_SOCKS_CONNECTION_FAILED;
284  }
285
286  if (buffer_.empty()) {
287    buffer_ = std::string(kSOCKS5GreetWriteData,
288                          arraysize(kSOCKS5GreetWriteData));
289    bytes_sent_ = 0;
290  }
291
292  next_state_ = STATE_GREET_WRITE_COMPLETE;
293  size_t handshake_buf_len = buffer_.size() - bytes_sent_;
294  handshake_buf_ = new IOBuffer(handshake_buf_len);
295  memcpy(handshake_buf_->data(), &buffer_.data()[bytes_sent_],
296         handshake_buf_len);
297  return transport_->socket()
298      ->Write(handshake_buf_.get(), handshake_buf_len, io_callback_);
299}
300
301int SOCKS5ClientSocket::DoGreetWriteComplete(int result) {
302  if (result < 0)
303    return result;
304
305  bytes_sent_ += result;
306  if (bytes_sent_ == buffer_.size()) {
307    buffer_.clear();
308    bytes_received_ = 0;
309    next_state_ = STATE_GREET_READ;
310  } else {
311    next_state_ = STATE_GREET_WRITE;
312  }
313  return OK;
314}
315
316int SOCKS5ClientSocket::DoGreetRead() {
317  next_state_ = STATE_GREET_READ_COMPLETE;
318  size_t handshake_buf_len = kGreetReadHeaderSize - bytes_received_;
319  handshake_buf_ = new IOBuffer(handshake_buf_len);
320  return transport_->socket()
321      ->Read(handshake_buf_.get(), handshake_buf_len, io_callback_);
322}
323
324int SOCKS5ClientSocket::DoGreetReadComplete(int result) {
325  if (result < 0)
326    return result;
327
328  if (result == 0) {
329    net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTEDLY_CLOSED_DURING_GREETING);
330    return ERR_SOCKS_CONNECTION_FAILED;
331  }
332
333  bytes_received_ += result;
334  buffer_.append(handshake_buf_->data(), result);
335  if (bytes_received_ < kGreetReadHeaderSize) {
336    next_state_ = STATE_GREET_READ;
337    return OK;
338  }
339
340  // Got the greet data.
341  if (buffer_[0] != kSOCKS5Version) {
342    net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTED_VERSION,
343                      NetLog::IntegerCallback("version", buffer_[0]));
344    return ERR_SOCKS_CONNECTION_FAILED;
345  }
346  if (buffer_[1] != 0x00) {
347    net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTED_AUTH,
348                      NetLog::IntegerCallback("method", buffer_[1]));
349    return ERR_SOCKS_CONNECTION_FAILED;
350  }
351
352  buffer_.clear();
353  next_state_ = STATE_HANDSHAKE_WRITE;
354  return OK;
355}
356
357int SOCKS5ClientSocket::BuildHandshakeWriteBuffer(std::string* handshake)
358    const {
359  DCHECK(handshake->empty());
360
361  handshake->push_back(kSOCKS5Version);
362  handshake->push_back(kTunnelCommand);  // Connect command
363  handshake->push_back(kNullByte);  // Reserved null
364
365  handshake->push_back(kEndPointDomain);  // The type of the address.
366
367  DCHECK_GE(static_cast<size_t>(0xFF), host_request_info_.hostname().size());
368
369  // First add the size of the hostname, followed by the hostname.
370  handshake->push_back(static_cast<unsigned char>(
371      host_request_info_.hostname().size()));
372  handshake->append(host_request_info_.hostname());
373
374  uint16 nw_port = base::HostToNet16(host_request_info_.port());
375  handshake->append(reinterpret_cast<char*>(&nw_port), sizeof(nw_port));
376  return OK;
377}
378
379// Writes the SOCKS handshake data to the underlying socket connection.
380int SOCKS5ClientSocket::DoHandshakeWrite() {
381  next_state_ = STATE_HANDSHAKE_WRITE_COMPLETE;
382
383  if (buffer_.empty()) {
384    int rv = BuildHandshakeWriteBuffer(&buffer_);
385    if (rv != OK)
386      return rv;
387    bytes_sent_ = 0;
388  }
389
390  int handshake_buf_len = buffer_.size() - bytes_sent_;
391  DCHECK_LT(0, handshake_buf_len);
392  handshake_buf_ = new IOBuffer(handshake_buf_len);
393  memcpy(handshake_buf_->data(), &buffer_[bytes_sent_],
394         handshake_buf_len);
395  return transport_->socket()
396      ->Write(handshake_buf_.get(), handshake_buf_len, io_callback_);
397}
398
399int SOCKS5ClientSocket::DoHandshakeWriteComplete(int result) {
400  if (result < 0)
401    return result;
402
403  // We ignore the case when result is 0, since the underlying Write
404  // may return spurious writes while waiting on the socket.
405
406  bytes_sent_ += result;
407  if (bytes_sent_ == buffer_.size()) {
408    next_state_ = STATE_HANDSHAKE_READ;
409    buffer_.clear();
410  } else if (bytes_sent_ < buffer_.size()) {
411    next_state_ = STATE_HANDSHAKE_WRITE;
412  } else {
413    NOTREACHED();
414  }
415
416  return OK;
417}
418
419int SOCKS5ClientSocket::DoHandshakeRead() {
420  next_state_ = STATE_HANDSHAKE_READ_COMPLETE;
421
422  if (buffer_.empty()) {
423    bytes_received_ = 0;
424    read_header_size = kReadHeaderSize;
425  }
426
427  int handshake_buf_len = read_header_size - bytes_received_;
428  handshake_buf_ = new IOBuffer(handshake_buf_len);
429  return transport_->socket()
430      ->Read(handshake_buf_.get(), handshake_buf_len, io_callback_);
431}
432
433int SOCKS5ClientSocket::DoHandshakeReadComplete(int result) {
434  if (result < 0)
435    return result;
436
437  // The underlying socket closed unexpectedly.
438  if (result == 0) {
439    net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTEDLY_CLOSED_DURING_HANDSHAKE);
440    return ERR_SOCKS_CONNECTION_FAILED;
441  }
442
443  buffer_.append(handshake_buf_->data(), result);
444  bytes_received_ += result;
445
446  // When the first few bytes are read, check how many more are required
447  // and accordingly increase them
448  if (bytes_received_ == kReadHeaderSize) {
449    if (buffer_[0] != kSOCKS5Version || buffer_[2] != kNullByte) {
450      net_log_.AddEvent(NetLog::TYPE_SOCKS_UNEXPECTED_VERSION,
451                        NetLog::IntegerCallback("version", buffer_[0]));
452      return ERR_SOCKS_CONNECTION_FAILED;
453    }
454    if (buffer_[1] != 0x00) {
455      net_log_.AddEvent(NetLog::TYPE_SOCKS_SERVER_ERROR,
456                        NetLog::IntegerCallback("error_code", buffer_[1]));
457      return ERR_SOCKS_CONNECTION_FAILED;
458    }
459
460    // We check the type of IP/Domain the server returns and accordingly
461    // increase the size of the response. For domains, we need to read the
462    // size of the domain, so the initial request size is upto the domain
463    // size. Since for IPv4/IPv6 the size is fixed and hence no 'size' is
464    // read, we substract 1 byte from the additional request size.
465    SocksEndPointAddressType address_type =
466        static_cast<SocksEndPointAddressType>(buffer_[3]);
467    if (address_type == kEndPointDomain)
468      read_header_size += static_cast<uint8>(buffer_[4]);
469    else if (address_type == kEndPointResolvedIPv4)
470      read_header_size += sizeof(struct in_addr) - 1;
471    else if (address_type == kEndPointResolvedIPv6)
472      read_header_size += sizeof(struct in6_addr) - 1;
473    else {
474      net_log_.AddEvent(NetLog::TYPE_SOCKS_UNKNOWN_ADDRESS_TYPE,
475                        NetLog::IntegerCallback("address_type", buffer_[3]));
476      return ERR_SOCKS_CONNECTION_FAILED;
477    }
478
479    read_header_size += 2;  // for the port.
480    next_state_ = STATE_HANDSHAKE_READ;
481    return OK;
482  }
483
484  // When the final bytes are read, setup handshake. We ignore the rest
485  // of the response since they represent the SOCKSv5 endpoint and have
486  // no use when doing a tunnel connection.
487  if (bytes_received_ == read_header_size) {
488    completed_handshake_ = true;
489    buffer_.clear();
490    next_state_ = STATE_NONE;
491    return OK;
492  }
493
494  next_state_ = STATE_HANDSHAKE_READ;
495  return OK;
496}
497
498int SOCKS5ClientSocket::GetPeerAddress(IPEndPoint* address) const {
499  return transport_->socket()->GetPeerAddress(address);
500}
501
502int SOCKS5ClientSocket::GetLocalAddress(IPEndPoint* address) const {
503  return transport_->socket()->GetLocalAddress(address);
504}
505
506}  // namespace net
507