1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/socket/socks_client_socket_pool.h" 6 7#include "base/bind.h" 8#include "base/bind_helpers.h" 9#include "base/time/time.h" 10#include "base/values.h" 11#include "net/base/net_errors.h" 12#include "net/socket/client_socket_factory.h" 13#include "net/socket/client_socket_handle.h" 14#include "net/socket/client_socket_pool_base.h" 15#include "net/socket/socks5_client_socket.h" 16#include "net/socket/socks_client_socket.h" 17#include "net/socket/transport_client_socket_pool.h" 18 19namespace net { 20 21SOCKSSocketParams::SOCKSSocketParams( 22 const scoped_refptr<TransportSocketParams>& proxy_server, 23 bool socks_v5, 24 const HostPortPair& host_port_pair) 25 : transport_params_(proxy_server), 26 destination_(host_port_pair), 27 socks_v5_(socks_v5) { 28 if (transport_params_.get()) 29 ignore_limits_ = transport_params_->ignore_limits(); 30 else 31 ignore_limits_ = false; 32} 33 34SOCKSSocketParams::~SOCKSSocketParams() {} 35 36// SOCKSConnectJobs will time out after this many seconds. Note this is on 37// top of the timeout for the transport socket. 38static const int kSOCKSConnectJobTimeoutInSeconds = 30; 39 40SOCKSConnectJob::SOCKSConnectJob( 41 const std::string& group_name, 42 RequestPriority priority, 43 const scoped_refptr<SOCKSSocketParams>& socks_params, 44 const base::TimeDelta& timeout_duration, 45 TransportClientSocketPool* transport_pool, 46 HostResolver* host_resolver, 47 Delegate* delegate, 48 NetLog* net_log) 49 : ConnectJob(group_name, timeout_duration, priority, delegate, 50 BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), 51 socks_params_(socks_params), 52 transport_pool_(transport_pool), 53 resolver_(host_resolver), 54 callback_(base::Bind(&SOCKSConnectJob::OnIOComplete, 55 base::Unretained(this))) { 56} 57 58SOCKSConnectJob::~SOCKSConnectJob() { 59 // We don't worry about cancelling the tcp socket since the destructor in 60 // scoped_ptr<ClientSocketHandle> transport_socket_handle_ will take care of 61 // it. 62} 63 64LoadState SOCKSConnectJob::GetLoadState() const { 65 switch (next_state_) { 66 case STATE_TRANSPORT_CONNECT: 67 case STATE_TRANSPORT_CONNECT_COMPLETE: 68 return transport_socket_handle_->GetLoadState(); 69 case STATE_SOCKS_CONNECT: 70 case STATE_SOCKS_CONNECT_COMPLETE: 71 return LOAD_STATE_CONNECTING; 72 default: 73 NOTREACHED(); 74 return LOAD_STATE_IDLE; 75 } 76} 77 78void SOCKSConnectJob::OnIOComplete(int result) { 79 int rv = DoLoop(result); 80 if (rv != ERR_IO_PENDING) 81 NotifyDelegateOfCompletion(rv); // Deletes |this| 82} 83 84int SOCKSConnectJob::DoLoop(int result) { 85 DCHECK_NE(next_state_, STATE_NONE); 86 87 int rv = result; 88 do { 89 State state = next_state_; 90 next_state_ = STATE_NONE; 91 switch (state) { 92 case STATE_TRANSPORT_CONNECT: 93 DCHECK_EQ(OK, rv); 94 rv = DoTransportConnect(); 95 break; 96 case STATE_TRANSPORT_CONNECT_COMPLETE: 97 rv = DoTransportConnectComplete(rv); 98 break; 99 case STATE_SOCKS_CONNECT: 100 DCHECK_EQ(OK, rv); 101 rv = DoSOCKSConnect(); 102 break; 103 case STATE_SOCKS_CONNECT_COMPLETE: 104 rv = DoSOCKSConnectComplete(rv); 105 break; 106 default: 107 NOTREACHED() << "bad state"; 108 rv = ERR_FAILED; 109 break; 110 } 111 } while (rv != ERR_IO_PENDING && next_state_ != STATE_NONE); 112 113 return rv; 114} 115 116int SOCKSConnectJob::DoTransportConnect() { 117 next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; 118 transport_socket_handle_.reset(new ClientSocketHandle()); 119 return transport_socket_handle_->Init(group_name(), 120 socks_params_->transport_params(), 121 priority(), 122 callback_, 123 transport_pool_, 124 net_log()); 125} 126 127int SOCKSConnectJob::DoTransportConnectComplete(int result) { 128 if (result != OK) 129 return ERR_PROXY_CONNECTION_FAILED; 130 131 // Reset the timer to just the length of time allowed for SOCKS handshake 132 // so that a fast TCP connection plus a slow SOCKS failure doesn't take 133 // longer to timeout than it should. 134 ResetTimer(base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds)); 135 next_state_ = STATE_SOCKS_CONNECT; 136 return result; 137} 138 139int SOCKSConnectJob::DoSOCKSConnect() { 140 next_state_ = STATE_SOCKS_CONNECT_COMPLETE; 141 142 // Add a SOCKS connection on top of the tcp socket. 143 if (socks_params_->is_socks_v5()) { 144 socket_.reset(new SOCKS5ClientSocket(transport_socket_handle_.Pass(), 145 socks_params_->destination())); 146 } else { 147 socket_.reset(new SOCKSClientSocket(transport_socket_handle_.Pass(), 148 socks_params_->destination(), 149 priority(), 150 resolver_)); 151 } 152 return socket_->Connect( 153 base::Bind(&SOCKSConnectJob::OnIOComplete, base::Unretained(this))); 154} 155 156int SOCKSConnectJob::DoSOCKSConnectComplete(int result) { 157 if (result != OK) { 158 socket_->Disconnect(); 159 return result; 160 } 161 162 SetSocket(socket_.Pass()); 163 return result; 164} 165 166int SOCKSConnectJob::ConnectInternal() { 167 next_state_ = STATE_TRANSPORT_CONNECT; 168 return DoLoop(OK); 169} 170 171scoped_ptr<ConnectJob> 172SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob( 173 const std::string& group_name, 174 const PoolBase::Request& request, 175 ConnectJob::Delegate* delegate) const { 176 return scoped_ptr<ConnectJob>(new SOCKSConnectJob(group_name, 177 request.priority(), 178 request.params(), 179 ConnectionTimeout(), 180 transport_pool_, 181 host_resolver_, 182 delegate, 183 net_log_)); 184} 185 186base::TimeDelta 187SOCKSClientSocketPool::SOCKSConnectJobFactory::ConnectionTimeout() const { 188 return transport_pool_->ConnectionTimeout() + 189 base::TimeDelta::FromSeconds(kSOCKSConnectJobTimeoutInSeconds); 190} 191 192SOCKSClientSocketPool::SOCKSClientSocketPool( 193 int max_sockets, 194 int max_sockets_per_group, 195 ClientSocketPoolHistograms* histograms, 196 HostResolver* host_resolver, 197 TransportClientSocketPool* transport_pool, 198 NetLog* net_log) 199 : transport_pool_(transport_pool), 200 base_(this, max_sockets, max_sockets_per_group, histograms, 201 ClientSocketPool::unused_idle_socket_timeout(), 202 ClientSocketPool::used_idle_socket_timeout(), 203 new SOCKSConnectJobFactory(transport_pool, 204 host_resolver, 205 net_log)) { 206 // We should always have a |transport_pool_| except in unit tests. 207 if (transport_pool_) 208 base_.AddLowerLayeredPool(transport_pool_); 209} 210 211SOCKSClientSocketPool::~SOCKSClientSocketPool() { 212} 213 214int SOCKSClientSocketPool::RequestSocket( 215 const std::string& group_name, const void* socket_params, 216 RequestPriority priority, ClientSocketHandle* handle, 217 const CompletionCallback& callback, const BoundNetLog& net_log) { 218 const scoped_refptr<SOCKSSocketParams>* casted_socket_params = 219 static_cast<const scoped_refptr<SOCKSSocketParams>*>(socket_params); 220 221 return base_.RequestSocket(group_name, *casted_socket_params, priority, 222 handle, callback, net_log); 223} 224 225void SOCKSClientSocketPool::RequestSockets( 226 const std::string& group_name, 227 const void* params, 228 int num_sockets, 229 const BoundNetLog& net_log) { 230 const scoped_refptr<SOCKSSocketParams>* casted_params = 231 static_cast<const scoped_refptr<SOCKSSocketParams>*>(params); 232 233 base_.RequestSockets(group_name, *casted_params, num_sockets, net_log); 234} 235 236void SOCKSClientSocketPool::CancelRequest(const std::string& group_name, 237 ClientSocketHandle* handle) { 238 base_.CancelRequest(group_name, handle); 239} 240 241void SOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, 242 scoped_ptr<StreamSocket> socket, 243 int id) { 244 base_.ReleaseSocket(group_name, socket.Pass(), id); 245} 246 247void SOCKSClientSocketPool::FlushWithError(int error) { 248 base_.FlushWithError(error); 249} 250 251void SOCKSClientSocketPool::CloseIdleSockets() { 252 base_.CloseIdleSockets(); 253} 254 255int SOCKSClientSocketPool::IdleSocketCount() const { 256 return base_.idle_socket_count(); 257} 258 259int SOCKSClientSocketPool::IdleSocketCountInGroup( 260 const std::string& group_name) const { 261 return base_.IdleSocketCountInGroup(group_name); 262} 263 264LoadState SOCKSClientSocketPool::GetLoadState( 265 const std::string& group_name, const ClientSocketHandle* handle) const { 266 return base_.GetLoadState(group_name, handle); 267} 268 269base::DictionaryValue* SOCKSClientSocketPool::GetInfoAsValue( 270 const std::string& name, 271 const std::string& type, 272 bool include_nested_pools) const { 273 base::DictionaryValue* dict = base_.GetInfoAsValue(name, type); 274 if (include_nested_pools) { 275 base::ListValue* list = new base::ListValue(); 276 list->Append(transport_pool_->GetInfoAsValue("transport_socket_pool", 277 "transport_socket_pool", 278 false)); 279 dict->Set("nested_pools", list); 280 } 281 return dict; 282} 283 284base::TimeDelta SOCKSClientSocketPool::ConnectionTimeout() const { 285 return base_.ConnectionTimeout(); 286} 287 288ClientSocketPoolHistograms* SOCKSClientSocketPool::histograms() const { 289 return base_.histograms(); 290}; 291 292bool SOCKSClientSocketPool::IsStalled() const { 293 return base_.IsStalled(); 294} 295 296void SOCKSClientSocketPool::AddHigherLayeredPool( 297 HigherLayeredPool* higher_pool) { 298 base_.AddHigherLayeredPool(higher_pool); 299} 300 301void SOCKSClientSocketPool::RemoveHigherLayeredPool( 302 HigherLayeredPool* higher_pool) { 303 base_.RemoveHigherLayeredPool(higher_pool); 304} 305 306bool SOCKSClientSocketPool::CloseOneIdleConnection() { 307 if (base_.CloseOneIdleSocket()) 308 return true; 309 return base_.CloseOneIdleConnectionInHigherLayeredPool(); 310} 311 312} // namespace net 313