1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/socks_client_socket_pool.h"
6
7#include "base/time.h"
8#include "base/values.h"
9#include "googleurl/src/gurl.h"
10#include "net/base/net_errors.h"
11#include "net/socket/client_socket_factory.h"
12#include "net/socket/client_socket_handle.h"
13#include "net/socket/client_socket_pool_base.h"
14#include "net/socket/socks5_client_socket.h"
15#include "net/socket/socks_client_socket.h"
16#include "net/socket/transport_client_socket_pool.h"
17
18namespace net {
19
20SOCKSSocketParams::SOCKSSocketParams(
21    const scoped_refptr<TransportSocketParams>& proxy_server,
22    bool socks_v5,
23    const HostPortPair& host_port_pair,
24    RequestPriority priority,
25    const GURL& referrer)
26    : transport_params_(proxy_server),
27      destination_(host_port_pair),
28      socks_v5_(socks_v5) {
29  if (transport_params_)
30    ignore_limits_ = transport_params_->ignore_limits();
31  else
32    ignore_limits_ = false;
33  // The referrer is used by the DNS prefetch system to correlate resolutions
34  // with the page that triggered them. It doesn't impact the actual addresses
35  // that we resolve to.
36  destination_.set_referrer(referrer);
37  destination_.set_priority(priority);
38}
39
40#ifdef ANDROID
41bool SOCKSSocketParams::getUID(uid_t *uid) const {
42  if (transport_params_)
43    return transport_params_->getUID(uid);
44  else
45    return false;
46}
47
48void SOCKSSocketParams::setUID(uid_t uid) {
49  if (transport_params_)
50    return transport_params_->setUID(uid);
51}
52#endif
53
54SOCKSSocketParams::~SOCKSSocketParams() {}
55
56// SOCKSConnectJobs will time out after this many seconds.  Note this is on
57// top of the timeout for the transport socket.
58static const int kSOCKSConnectJobTimeoutInSeconds = 30;
59
60SOCKSConnectJob::SOCKSConnectJob(
61    const std::string& group_name,
62    const scoped_refptr<SOCKSSocketParams>& socks_params,
63    const base::TimeDelta& timeout_duration,
64    TransportClientSocketPool* transport_pool,
65    HostResolver* host_resolver,
66    Delegate* delegate,
67    NetLog* net_log)
68    : ConnectJob(group_name, timeout_duration, delegate,
69                 BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)),
70      socks_params_(socks_params),
71      transport_pool_(transport_pool),
72      resolver_(host_resolver),
73      ALLOW_THIS_IN_INITIALIZER_LIST(
74          callback_(this, &SOCKSConnectJob::OnIOComplete)) {
75}
76
77SOCKSConnectJob::~SOCKSConnectJob() {
78  // We don't worry about cancelling the tcp socket since the destructor in
79  // scoped_ptr<ClientSocketHandle> transport_socket_handle_ will take care of
80  // it.
81}
82
83LoadState SOCKSConnectJob::GetLoadState() const {
84  switch (next_state_) {
85    case STATE_TRANSPORT_CONNECT:
86    case STATE_TRANSPORT_CONNECT_COMPLETE:
87      return transport_socket_handle_->GetLoadState();
88    case STATE_SOCKS_CONNECT:
89    case STATE_SOCKS_CONNECT_COMPLETE:
90      return LOAD_STATE_CONNECTING;
91    default:
92      NOTREACHED();
93      return LOAD_STATE_IDLE;
94  }
95}
96
97void SOCKSConnectJob::OnIOComplete(int result) {
98  int rv = DoLoop(result);
99  if (rv != ERR_IO_PENDING)
100    NotifyDelegateOfCompletion(rv);  // Deletes |this|
101}
102
103int SOCKSConnectJob::DoLoop(int result) {
104  DCHECK_NE(next_state_, STATE_NONE);
105
106  int rv = result;
107  do {
108    State state = next_state_;
109    next_state_ = STATE_NONE;
110    switch (state) {
111      case STATE_TRANSPORT_CONNECT:
112        DCHECK_EQ(OK, rv);
113        rv = DoTransportConnect();
114        break;
115      case STATE_TRANSPORT_CONNECT_COMPLETE:
116        rv = DoTransportConnectComplete(rv);
117        break;
118      case STATE_SOCKS_CONNECT:
119        DCHECK_EQ(OK, rv);
120        rv = DoSOCKSConnect();
121        break;
122      case STATE_SOCKS_CONNECT_COMPLETE:
123        rv = DoSOCKSConnectComplete(rv);
124        break;
125      default:
126        NOTREACHED() << "bad state";
127        rv = ERR_FAILED;
128        break;
129    }
130  } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE);
131
132  return rv;
133}
134
135int SOCKSConnectJob::DoTransportConnect() {
136  next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE;
137  transport_socket_handle_.reset(new ClientSocketHandle());
138  return transport_socket_handle_->Init(group_name(),
139                                        socks_params_->transport_params(),
140                                        socks_params_->destination().priority(),
141                                        &callback_,
142                                        transport_pool_,
143                                        net_log());
144}
145
146int SOCKSConnectJob::DoTransportConnectComplete(int result) {
147  if (result != OK)
148    return ERR_PROXY_CONNECTION_FAILED;
149
150  // Reset the timer to just the length of time allowed for SOCKS handshake
151  // so that a fast TCP connection plus a slow SOCKS failure doesn't take
152  // longer to timeout than it should.
153  ResetTimer(base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds));
154  next_state_ = STATE_SOCKS_CONNECT;
155  return result;
156}
157
158int SOCKSConnectJob::DoSOCKSConnect() {
159  next_state_ = STATE_SOCKS_CONNECT_COMPLETE;
160
161  // Add a SOCKS connection on top of the tcp socket.
162  if (socks_params_->is_socks_v5()) {
163    socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.release(),
164                                         socks_params_->destination()));
165  } else {
166    socket_.reset(new SOCKSClientSocket(transport_socket_handle_.release(),
167                                        socks_params_->destination(),
168                                        resolver_));
169  }
170
171#ifdef ANDROID
172  uid_t calling_uid = 0;
173  bool valid_uid = socks_params_->transport_params()->getUID(&calling_uid);
174#endif
175
176  return socket_->Connect(&callback_
177#ifdef ANDROID
178                          , socks_params_->ignore_limits()
179                          , valid_uid
180                          , calling_uid
181#endif
182                         );
183}
184
185int SOCKSConnectJob::DoSOCKSConnectComplete(int result) {
186  if (result != OK) {
187    socket_->Disconnect();
188    return result;
189  }
190
191  set_socket(socket_.release());
192  return result;
193}
194
195int SOCKSConnectJob::ConnectInternal() {
196  next_state_ = STATE_TRANSPORT_CONNECT;
197  return DoLoop(OK);
198}
199
200ConnectJob* SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob(
201    const std::string& group_name,
202    const PoolBase::Request& request,
203    ConnectJob::Delegate* delegate) const {
204  return new SOCKSConnectJob(group_name,
205                             request.params(),
206                             ConnectionTimeout(),
207                             transport_pool_,
208                             host_resolver_,
209                             delegate,
210                             net_log_);
211}
212
213base::TimeDelta
214SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const {
215  return transport_pool_->ConnectionTimeout() +
216      base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds);
217}
218
219SOCKSClientSocketPool::SOCKSClientSocketPool(
220    int max_sockets,
221    int max_sockets_per_group,
222    ClientSocketPoolHistograms* histograms,
223    HostResolver* host_resolver,
224    TransportClientSocketPool* transport_pool,
225    NetLog* net_log)
226    : transport_pool_(transport_pool),
227      base_(max_sockets, max_sockets_per_group, histograms,
228            base::TimeDelta::FromSeconds(
229                ClientSocketPool::unused_idle_socket_timeout()),
230            base::TimeDelta::FromSeconds(kUsedIdleSocketTimeout),
231            new SOCKSConnectJobFactory(transport_pool,
232                                       host_resolver,
233                                       net_log)) {
234}
235
236SOCKSClientSocketPool::~SOCKSClientSocketPool() {}
237
238int SOCKSClientSocketPool::RequestSocket(const std::string& group_name,
239                                         const void* socket_params,
240                                         RequestPriority priority,
241                                         ClientSocketHandle* handle,
242                                         CompletionCallback* callback,
243                                         const BoundNetLog& net_log) {
244  const scoped_refptr<SOCKSSocketParams>* casted_socket_params =
245      static_cast<const scoped_refptr<SOCKSSocketParams>*>(socket_params);
246
247  return base_.RequestSocket(group_name, *casted_socket_params, priority,
248                             handle, callback, net_log);
249}
250
251void SOCKSClientSocketPool::RequestSockets(
252    const std::string& group_name,
253    const void* params,
254    int num_sockets,
255    const BoundNetLog& net_log) {
256  const scoped_refptr<SOCKSSocketParams>* casted_params =
257      static_cast<const scoped_refptr<SOCKSSocketParams>*>(params);
258
259  base_.RequestSockets(group_name, *casted_params, num_sockets, net_log);
260}
261
262void SOCKSClientSocketPool::CancelRequest(const std::string& group_name,
263                                          ClientSocketHandle* handle) {
264  base_.CancelRequest(group_name, handle);
265}
266
267void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
268                                          ClientSocket* socket, int id) {
269  base_.ReleaseSocket(group_name, socket, id);
270}
271
272void SOCKSClientSocketPool::Flush() {
273  base_.Flush();
274}
275
276void SOCKSClientSocketPool::CloseIdleSockets() {
277  base_.CloseIdleSockets();
278}
279
280int SOCKSClientSocketPool::IdleSocketCount() const {
281  return base_.idle_socket_count();
282}
283
284int SOCKSClientSocketPool::IdleSocketCountInGroup(
285    const std::string& group_name) const {
286  return base_.IdleSocketCountInGroup(group_name);
287}
288
289LoadState SOCKSClientSocketPool::GetLoadState(
290    const std::string& group_name, const ClientSocketHandle* handle) const {
291  return base_.GetLoadState(group_name, handle);
292}
293
294DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue(
295    const std::string& name,
296    const std::string& type,
297    bool include_nested_pools) const {
298  DictionaryValue* dict = base_.GetInfoAsValue(name, type);
299  if (include_nested_pools) {
300    ListValue* list = new ListValue();
301    list->Append(transport_pool_->GetInfoAsValue("transport_socket_pool",
302                                                 "transport_socket_pool",
303                                                 false));
304    dict->Set("nested_pools", list);
305  }
306  return dict;
307}
308
309base::TimeDelta SOCKSClientSocketPool::ConnectionTimeout() const {
310  return base_.ConnectionTimeout();
311}
312
313ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const {
314  return base_.histograms();
315};
316
317}  // namespace net
318