1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_
6#define NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_
7
8#include <string>
9
10#include "base/memory/ref_counted.h"
11#include "base/memory/scoped_ptr.h"
12#include "base/time/time.h"
13#include "net/base/privacy_mode.h"
14#include "net/dns/host_resolver.h"
15#include "net/http/http_response_info.h"
16#include "net/socket/client_socket_pool.h"
17#include "net/socket/client_socket_pool_base.h"
18#include "net/socket/client_socket_pool_histograms.h"
19#include "net/socket/ssl_client_socket.h"
20#include "net/ssl/ssl_config_service.h"
21
22namespace net {
23
24class CertVerifier;
25class ClientSocketFactory;
26class ConnectJobFactory;
27class CTVerifier;
28class HostPortPair;
29class HttpProxyClientSocketPool;
30class HttpProxySocketParams;
31class SOCKSClientSocketPool;
32class SOCKSSocketParams;
33class SSLClientSocket;
34class TransportClientSocketPool;
35class TransportSecurityState;
36class TransportSocketParams;
37
38class NET_EXPORT_PRIVATE SSLSocketParams
39    : public base::RefCounted<SSLSocketParams> {
40 public:
41  enum ConnectionType { DIRECT, SOCKS_PROXY, HTTP_PROXY };
42
43  // Exactly one of |direct_params|, |socks_proxy_params|, and
44  // |http_proxy_params| must be non-NULL.
45  SSLSocketParams(
46      const scoped_refptr<TransportSocketParams>& direct_params,
47      const scoped_refptr<SOCKSSocketParams>& socks_proxy_params,
48      const scoped_refptr<HttpProxySocketParams>& http_proxy_params,
49      const HostPortPair& host_and_port,
50      const SSLConfig& ssl_config,
51      PrivacyMode privacy_mode,
52      int load_flags,
53      bool force_spdy_over_ssl,
54      bool want_spdy_over_npn);
55
56  // Returns the type of the underlying connection.
57  ConnectionType GetConnectionType() const;
58
59  // Must be called only when GetConnectionType() returns DIRECT.
60  const scoped_refptr<TransportSocketParams>&
61      GetDirectConnectionParams() const;
62
63  // Must be called only when GetConnectionType() returns SOCKS_PROXY.
64  const scoped_refptr<SOCKSSocketParams>&
65      GetSocksProxyConnectionParams() const;
66
67  // Must be called only when GetConnectionType() returns HTTP_PROXY.
68  const scoped_refptr<HttpProxySocketParams>&
69      GetHttpProxyConnectionParams() const;
70
71  const HostPortPair& host_and_port() const { return host_and_port_; }
72  const SSLConfig& ssl_config() const { return ssl_config_; }
73  PrivacyMode privacy_mode() const { return privacy_mode_; }
74  int load_flags() const { return load_flags_; }
75  bool force_spdy_over_ssl() const { return force_spdy_over_ssl_; }
76  bool want_spdy_over_npn() const { return want_spdy_over_npn_; }
77  bool ignore_limits() const { return ignore_limits_; }
78
79 private:
80  friend class base::RefCounted<SSLSocketParams>;
81  ~SSLSocketParams();
82
83  const scoped_refptr<TransportSocketParams> direct_params_;
84  const scoped_refptr<SOCKSSocketParams> socks_proxy_params_;
85  const scoped_refptr<HttpProxySocketParams> http_proxy_params_;
86  const HostPortPair host_and_port_;
87  const SSLConfig ssl_config_;
88  const PrivacyMode privacy_mode_;
89  const int load_flags_;
90  const bool force_spdy_over_ssl_;
91  const bool want_spdy_over_npn_;
92  bool ignore_limits_;
93
94  DISALLOW_COPY_AND_ASSIGN(SSLSocketParams);
95};
96
97// SSLConnectJob handles the SSL handshake after setting up the underlying
98// connection as specified in the params.
99class SSLConnectJob : public ConnectJob {
100 public:
101  SSLConnectJob(
102      const std::string& group_name,
103      RequestPriority priority,
104      const scoped_refptr<SSLSocketParams>& params,
105      const base::TimeDelta& timeout_duration,
106      TransportClientSocketPool* transport_pool,
107      SOCKSClientSocketPool* socks_pool,
108      HttpProxyClientSocketPool* http_proxy_pool,
109      ClientSocketFactory* client_socket_factory,
110      HostResolver* host_resolver,
111      const SSLClientSocketContext& context,
112      Delegate* delegate,
113      NetLog* net_log);
114  virtual ~SSLConnectJob();
115
116  // ConnectJob methods.
117  virtual LoadState GetLoadState() const OVERRIDE;
118
119  virtual void GetAdditionalErrorState(ClientSocketHandle * handle) OVERRIDE;
120
121 private:
122  enum State {
123    STATE_TRANSPORT_CONNECT,
124    STATE_TRANSPORT_CONNECT_COMPLETE,
125    STATE_SOCKS_CONNECT,
126    STATE_SOCKS_CONNECT_COMPLETE,
127    STATE_TUNNEL_CONNECT,
128    STATE_TUNNEL_CONNECT_COMPLETE,
129    STATE_SSL_CONNECT,
130    STATE_SSL_CONNECT_COMPLETE,
131    STATE_NONE,
132  };
133
134  void OnIOComplete(int result);
135
136  // Runs the state transition loop.
137  int DoLoop(int result);
138
139  int DoTransportConnect();
140  int DoTransportConnectComplete(int result);
141  int DoSOCKSConnect();
142  int DoSOCKSConnectComplete(int result);
143  int DoTunnelConnect();
144  int DoTunnelConnectComplete(int result);
145  int DoSSLConnect();
146  int DoSSLConnectComplete(int result);
147
148  // Returns the initial state for the state machine based on the
149  // |connection_type|.
150  static State GetInitialState(SSLSocketParams::ConnectionType connection_type);
151
152  // Starts the SSL connection process.  Returns OK on success and
153  // ERR_IO_PENDING if it cannot immediately service the request.
154  // Otherwise, it returns a net error code.
155  virtual int ConnectInternal() OVERRIDE;
156
157  scoped_refptr<SSLSocketParams> params_;
158  TransportClientSocketPool* const transport_pool_;
159  SOCKSClientSocketPool* const socks_pool_;
160  HttpProxyClientSocketPool* const http_proxy_pool_;
161  ClientSocketFactory* const client_socket_factory_;
162  HostResolver* const host_resolver_;
163
164  const SSLClientSocketContext context_;
165
166  State next_state_;
167  CompletionCallback callback_;
168  scoped_ptr<ClientSocketHandle> transport_socket_handle_;
169  scoped_ptr<SSLClientSocket> ssl_socket_;
170
171  HttpResponseInfo error_response_info_;
172
173  DISALLOW_COPY_AND_ASSIGN(SSLConnectJob);
174};
175
176class NET_EXPORT_PRIVATE SSLClientSocketPool
177    : public ClientSocketPool,
178      public HigherLayeredPool,
179      public SSLConfigService::Observer {
180 public:
181  typedef SSLSocketParams SocketParams;
182
183  // Only the pools that will be used are required. i.e. if you never
184  // try to create an SSL over SOCKS socket, |socks_pool| may be NULL.
185  SSLClientSocketPool(
186      int max_sockets,
187      int max_sockets_per_group,
188      ClientSocketPoolHistograms* histograms,
189      HostResolver* host_resolver,
190      CertVerifier* cert_verifier,
191      ServerBoundCertService* server_bound_cert_service,
192      TransportSecurityState* transport_security_state,
193      CTVerifier* cert_transparency_verifier,
194      const std::string& ssl_session_cache_shard,
195      ClientSocketFactory* client_socket_factory,
196      TransportClientSocketPool* transport_pool,
197      SOCKSClientSocketPool* socks_pool,
198      HttpProxyClientSocketPool* http_proxy_pool,
199      SSLConfigService* ssl_config_service,
200      NetLog* net_log);
201
202  virtual ~SSLClientSocketPool();
203
204  // ClientSocketPool implementation.
205  virtual int RequestSocket(const std::string& group_name,
206                            const void* connect_params,
207                            RequestPriority priority,
208                            ClientSocketHandle* handle,
209                            const CompletionCallback& callback,
210                            const BoundNetLog& net_log) OVERRIDE;
211
212  virtual void RequestSockets(const std::string& group_name,
213                              const void* params,
214                              int num_sockets,
215                              const BoundNetLog& net_log) OVERRIDE;
216
217  virtual void CancelRequest(const std::string& group_name,
218                             ClientSocketHandle* handle) OVERRIDE;
219
220  virtual void ReleaseSocket(const std::string& group_name,
221                             scoped_ptr<StreamSocket> socket,
222                             int id) OVERRIDE;
223
224  virtual void FlushWithError(int error) OVERRIDE;
225
226  virtual void CloseIdleSockets() OVERRIDE;
227
228  virtual int IdleSocketCount() const OVERRIDE;
229
230  virtual int IdleSocketCountInGroup(
231      const std::string& group_name) const OVERRIDE;
232
233  virtual LoadState GetLoadState(
234      const std::string& group_name,
235      const ClientSocketHandle* handle) const OVERRIDE;
236
237  virtual base::DictionaryValue* GetInfoAsValue(
238      const std::string& name,
239      const std::string& type,
240      bool include_nested_pools) const OVERRIDE;
241
242  virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
243
244  virtual ClientSocketPoolHistograms* histograms() const OVERRIDE;
245
246  // LowerLayeredPool implementation.
247  virtual bool IsStalled() const OVERRIDE;
248
249  virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
250
251  virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
252
253  // HigherLayeredPool implementation.
254  virtual bool CloseOneIdleConnection() OVERRIDE;
255
256 private:
257  typedef ClientSocketPoolBase<SSLSocketParams> PoolBase;
258
259  // SSLConfigService::Observer implementation.
260
261  // When the user changes the SSL config, we flush all idle sockets so they
262  // won't get re-used.
263  virtual void OnSSLConfigChanged() OVERRIDE;
264
265  class SSLConnectJobFactory : public PoolBase::ConnectJobFactory {
266   public:
267    SSLConnectJobFactory(
268        TransportClientSocketPool* transport_pool,
269        SOCKSClientSocketPool* socks_pool,
270        HttpProxyClientSocketPool* http_proxy_pool,
271        ClientSocketFactory* client_socket_factory,
272        HostResolver* host_resolver,
273        const SSLClientSocketContext& context,
274        NetLog* net_log);
275
276    virtual ~SSLConnectJobFactory() {}
277
278    // ClientSocketPoolBase::ConnectJobFactory methods.
279    virtual scoped_ptr<ConnectJob> NewConnectJob(
280        const std::string& group_name,
281        const PoolBase::Request& request,
282        ConnectJob::Delegate* delegate) const OVERRIDE;
283
284    virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
285
286   private:
287    TransportClientSocketPool* const transport_pool_;
288    SOCKSClientSocketPool* const socks_pool_;
289    HttpProxyClientSocketPool* const http_proxy_pool_;
290    ClientSocketFactory* const client_socket_factory_;
291    HostResolver* const host_resolver_;
292    const SSLClientSocketContext context_;
293    base::TimeDelta timeout_;
294    NetLog* net_log_;
295
296    DISALLOW_COPY_AND_ASSIGN(SSLConnectJobFactory);
297  };
298
299  TransportClientSocketPool* const transport_pool_;
300  SOCKSClientSocketPool* const socks_pool_;
301  HttpProxyClientSocketPool* const http_proxy_pool_;
302  PoolBase base_;
303  const scoped_refptr<SSLConfigService> ssl_config_service_;
304
305  DISALLOW_COPY_AND_ASSIGN(SSLClientSocketPool);
306};
307
308}  // namespace net
309
310#endif  // NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_
311