1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_
6#define NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_
7
8#include <string>
9
10#include "base/memory/ref_counted.h"
11#include "base/memory/scoped_ptr.h"
12#include "base/time/time.h"
13#include "net/base/privacy_mode.h"
14#include "net/dns/host_resolver.h"
15#include "net/http/http_response_info.h"
16#include "net/proxy/proxy_server.h"
17#include "net/socket/client_socket_pool.h"
18#include "net/socket/client_socket_pool_base.h"
19#include "net/socket/client_socket_pool_histograms.h"
20#include "net/socket/ssl_client_socket.h"
21#include "net/ssl/ssl_config_service.h"
22
23namespace net {
24
25class CertVerifier;
26class ClientSocketFactory;
27class ConnectJobFactory;
28class HostPortPair;
29class HttpProxyClientSocketPool;
30class HttpProxySocketParams;
31class SOCKSClientSocketPool;
32class SOCKSSocketParams;
33class SSLClientSocket;
34class TransportClientSocketPool;
35class TransportSecurityState;
36class TransportSocketParams;
37
38// SSLSocketParams only needs the socket params for the transport socket
39// that will be used (denoted by |proxy|).
40class NET_EXPORT_PRIVATE SSLSocketParams
41    : public base::RefCounted<SSLSocketParams> {
42 public:
43  SSLSocketParams(const scoped_refptr<TransportSocketParams>& transport_params,
44                  const scoped_refptr<SOCKSSocketParams>& socks_params,
45                  const scoped_refptr<HttpProxySocketParams>& http_proxy_params,
46                  ProxyServer::Scheme proxy,
47                  const HostPortPair& host_and_port,
48                  const SSLConfig& ssl_config,
49                  PrivacyMode privacy_mode,
50                  int load_flags,
51                  bool force_spdy_over_ssl,
52                  bool want_spdy_over_npn);
53
54  const scoped_refptr<TransportSocketParams>& transport_params() {
55      return transport_params_;
56  }
57  const scoped_refptr<HttpProxySocketParams>& http_proxy_params() {
58    return http_proxy_params_;
59  }
60  const scoped_refptr<SOCKSSocketParams>& socks_params() {
61    return socks_params_;
62  }
63  ProxyServer::Scheme proxy() const { return proxy_; }
64  const HostPortPair& host_and_port() const { return host_and_port_; }
65  const SSLConfig& ssl_config() const { return ssl_config_; }
66  PrivacyMode privacy_mode() const { return privacy_mode_; }
67  int load_flags() const { return load_flags_; }
68  bool force_spdy_over_ssl() const { return force_spdy_over_ssl_; }
69  bool want_spdy_over_npn() const { return want_spdy_over_npn_; }
70  bool ignore_limits() const { return ignore_limits_; }
71
72 private:
73  friend class base::RefCounted<SSLSocketParams>;
74  ~SSLSocketParams();
75
76  const scoped_refptr<TransportSocketParams> transport_params_;
77  const scoped_refptr<HttpProxySocketParams> http_proxy_params_;
78  const scoped_refptr<SOCKSSocketParams> socks_params_;
79  const ProxyServer::Scheme proxy_;
80  const HostPortPair host_and_port_;
81  const SSLConfig ssl_config_;
82  const PrivacyMode privacy_mode_;
83  const int load_flags_;
84  const bool force_spdy_over_ssl_;
85  const bool want_spdy_over_npn_;
86  bool ignore_limits_;
87
88  DISALLOW_COPY_AND_ASSIGN(SSLSocketParams);
89};
90
91// SSLConnectJob handles the SSL handshake after setting up the underlying
92// connection as specified in the params.
93class SSLConnectJob : public ConnectJob {
94 public:
95  SSLConnectJob(
96      const std::string& group_name,
97      const scoped_refptr<SSLSocketParams>& params,
98      const base::TimeDelta& timeout_duration,
99      TransportClientSocketPool* transport_pool,
100      SOCKSClientSocketPool* socks_pool,
101      HttpProxyClientSocketPool* http_proxy_pool,
102      ClientSocketFactory* client_socket_factory,
103      HostResolver* host_resolver,
104      const SSLClientSocketContext& context,
105      Delegate* delegate,
106      NetLog* net_log);
107  virtual ~SSLConnectJob();
108
109  // ConnectJob methods.
110  virtual LoadState GetLoadState() const OVERRIDE;
111
112  virtual void GetAdditionalErrorState(ClientSocketHandle * handle) OVERRIDE;
113
114 private:
115  enum State {
116    STATE_TRANSPORT_CONNECT,
117    STATE_TRANSPORT_CONNECT_COMPLETE,
118    STATE_SOCKS_CONNECT,
119    STATE_SOCKS_CONNECT_COMPLETE,
120    STATE_TUNNEL_CONNECT,
121    STATE_TUNNEL_CONNECT_COMPLETE,
122    STATE_SSL_CONNECT,
123    STATE_SSL_CONNECT_COMPLETE,
124    STATE_NONE,
125  };
126
127  void OnIOComplete(int result);
128
129  // Runs the state transition loop.
130  int DoLoop(int result);
131
132  int DoTransportConnect();
133  int DoTransportConnectComplete(int result);
134  int DoSOCKSConnect();
135  int DoSOCKSConnectComplete(int result);
136  int DoTunnelConnect();
137  int DoTunnelConnectComplete(int result);
138  int DoSSLConnect();
139  int DoSSLConnectComplete(int result);
140
141  // Starts the SSL connection process.  Returns OK on success and
142  // ERR_IO_PENDING if it cannot immediately service the request.
143  // Otherwise, it returns a net error code.
144  virtual int ConnectInternal() OVERRIDE;
145
146  scoped_refptr<SSLSocketParams> params_;
147  TransportClientSocketPool* const transport_pool_;
148  SOCKSClientSocketPool* const socks_pool_;
149  HttpProxyClientSocketPool* const http_proxy_pool_;
150  ClientSocketFactory* const client_socket_factory_;
151  HostResolver* const host_resolver_;
152
153  const SSLClientSocketContext context_;
154
155  State next_state_;
156  CompletionCallback callback_;
157  scoped_ptr<ClientSocketHandle> transport_socket_handle_;
158  scoped_ptr<SSLClientSocket> ssl_socket_;
159
160  HttpResponseInfo error_response_info_;
161
162  DISALLOW_COPY_AND_ASSIGN(SSLConnectJob);
163};
164
165class NET_EXPORT_PRIVATE SSLClientSocketPool
166    : public ClientSocketPool,
167      public LayeredPool,
168      public SSLConfigService::Observer {
169 public:
170  // Only the pools that will be used are required. i.e. if you never
171  // try to create an SSL over SOCKS socket, |socks_pool| may be NULL.
172  SSLClientSocketPool(
173      int max_sockets,
174      int max_sockets_per_group,
175      ClientSocketPoolHistograms* histograms,
176      HostResolver* host_resolver,
177      CertVerifier* cert_verifier,
178      ServerBoundCertService* server_bound_cert_service,
179      TransportSecurityState* transport_security_state,
180      const std::string& ssl_session_cache_shard,
181      ClientSocketFactory* client_socket_factory,
182      TransportClientSocketPool* transport_pool,
183      SOCKSClientSocketPool* socks_pool,
184      HttpProxyClientSocketPool* http_proxy_pool,
185      SSLConfigService* ssl_config_service,
186      NetLog* net_log);
187
188  virtual ~SSLClientSocketPool();
189
190  // ClientSocketPool implementation.
191  virtual int RequestSocket(const std::string& group_name,
192                            const void* connect_params,
193                            RequestPriority priority,
194                            ClientSocketHandle* handle,
195                            const CompletionCallback& callback,
196                            const BoundNetLog& net_log) OVERRIDE;
197
198  virtual void RequestSockets(const std::string& group_name,
199                              const void* params,
200                              int num_sockets,
201                              const BoundNetLog& net_log) OVERRIDE;
202
203  virtual void CancelRequest(const std::string& group_name,
204                             ClientSocketHandle* handle) OVERRIDE;
205
206  virtual void ReleaseSocket(const std::string& group_name,
207                             StreamSocket* socket,
208                             int id) OVERRIDE;
209
210  virtual void FlushWithError(int error) OVERRIDE;
211
212  virtual bool IsStalled() const OVERRIDE;
213
214  virtual void CloseIdleSockets() OVERRIDE;
215
216  virtual int IdleSocketCount() const OVERRIDE;
217
218  virtual int IdleSocketCountInGroup(
219      const std::string& group_name) const OVERRIDE;
220
221  virtual LoadState GetLoadState(
222      const std::string& group_name,
223      const ClientSocketHandle* handle) const OVERRIDE;
224
225  virtual void AddLayeredPool(LayeredPool* layered_pool) OVERRIDE;
226
227  virtual void RemoveLayeredPool(LayeredPool* layered_pool) OVERRIDE;
228
229  virtual base::DictionaryValue* GetInfoAsValue(
230      const std::string& name,
231      const std::string& type,
232      bool include_nested_pools) const OVERRIDE;
233
234  virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
235
236  virtual ClientSocketPoolHistograms* histograms() const OVERRIDE;
237
238  // LayeredPool implementation.
239  virtual bool CloseOneIdleConnection() OVERRIDE;
240
241 private:
242  typedef ClientSocketPoolBase<SSLSocketParams> PoolBase;
243
244  // SSLConfigService::Observer implementation.
245
246  // When the user changes the SSL config, we flush all idle sockets so they
247  // won't get re-used.
248  virtual void OnSSLConfigChanged() OVERRIDE;
249
250  class SSLConnectJobFactory : public PoolBase::ConnectJobFactory {
251   public:
252    SSLConnectJobFactory(
253        TransportClientSocketPool* transport_pool,
254        SOCKSClientSocketPool* socks_pool,
255        HttpProxyClientSocketPool* http_proxy_pool,
256        ClientSocketFactory* client_socket_factory,
257        HostResolver* host_resolver,
258        const SSLClientSocketContext& context,
259        NetLog* net_log);
260
261    virtual ~SSLConnectJobFactory() {}
262
263    // ClientSocketPoolBase::ConnectJobFactory methods.
264    virtual ConnectJob* NewConnectJob(
265        const std::string& group_name,
266        const PoolBase::Request& request,
267        ConnectJob::Delegate* delegate) const OVERRIDE;
268
269    virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
270
271   private:
272    TransportClientSocketPool* const transport_pool_;
273    SOCKSClientSocketPool* const socks_pool_;
274    HttpProxyClientSocketPool* const http_proxy_pool_;
275    ClientSocketFactory* const client_socket_factory_;
276    HostResolver* const host_resolver_;
277    const SSLClientSocketContext context_;
278    base::TimeDelta timeout_;
279    NetLog* net_log_;
280
281    DISALLOW_COPY_AND_ASSIGN(SSLConnectJobFactory);
282  };
283
284  TransportClientSocketPool* const transport_pool_;
285  SOCKSClientSocketPool* const socks_pool_;
286  HttpProxyClientSocketPool* const http_proxy_pool_;
287  PoolBase base_;
288  const scoped_refptr<SSLConfigService> ssl_config_service_;
289
290  DISALLOW_COPY_AND_ASSIGN(SSLClientSocketPool);
291};
292
293REGISTER_SOCKET_PARAMS_FOR_POOL(SSLClientSocketPool, SSLSocketParams);
294
295}  // namespace net
296
297#endif  // NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_
298