ssl_client_socket_pool.h revision 5f1c94371a64b3196d4be9466099bb892df9b88e
1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_
6#define NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_
7
8#include <map>
9#include <string>
10#include <vector>
11
12#include "base/memory/ref_counted.h"
13#include "base/memory/scoped_ptr.h"
14#include "base/time/time.h"
15#include "net/base/privacy_mode.h"
16#include "net/dns/host_resolver.h"
17#include "net/http/http_response_info.h"
18#include "net/socket/client_socket_pool.h"
19#include "net/socket/client_socket_pool_base.h"
20#include "net/socket/client_socket_pool_histograms.h"
21#include "net/socket/ssl_client_socket.h"
22#include "net/ssl/ssl_config_service.h"
23
24namespace net {
25
26class CertVerifier;
27class ClientSocketFactory;
28class ConnectJobFactory;
29class CTVerifier;
30class HostPortPair;
31class HttpProxyClientSocketPool;
32class HttpProxySocketParams;
33class SOCKSClientSocketPool;
34class SOCKSSocketParams;
35class SSLClientSocket;
36class TransportClientSocketPool;
37class TransportSecurityState;
38class TransportSocketParams;
39
40class NET_EXPORT_PRIVATE SSLSocketParams
41    : public base::RefCounted<SSLSocketParams> {
42 public:
43  enum ConnectionType { DIRECT, SOCKS_PROXY, HTTP_PROXY };
44
45  // Exactly one of |direct_params|, |socks_proxy_params|, and
46  // |http_proxy_params| must be non-NULL.
47  SSLSocketParams(
48      const scoped_refptr<TransportSocketParams>& direct_params,
49      const scoped_refptr<SOCKSSocketParams>& socks_proxy_params,
50      const scoped_refptr<HttpProxySocketParams>& http_proxy_params,
51      const HostPortPair& host_and_port,
52      const SSLConfig& ssl_config,
53      PrivacyMode privacy_mode,
54      int load_flags,
55      bool force_spdy_over_ssl,
56      bool want_spdy_over_npn);
57
58  // Returns the type of the underlying connection.
59  ConnectionType GetConnectionType() const;
60
61  // Must be called only when GetConnectionType() returns DIRECT.
62  const scoped_refptr<TransportSocketParams>&
63      GetDirectConnectionParams() const;
64
65  // Must be called only when GetConnectionType() returns SOCKS_PROXY.
66  const scoped_refptr<SOCKSSocketParams>&
67      GetSocksProxyConnectionParams() const;
68
69  // Must be called only when GetConnectionType() returns HTTP_PROXY.
70  const scoped_refptr<HttpProxySocketParams>&
71      GetHttpProxyConnectionParams() const;
72
73  const HostPortPair& host_and_port() const { return host_and_port_; }
74  const SSLConfig& ssl_config() const { return ssl_config_; }
75  PrivacyMode privacy_mode() const { return privacy_mode_; }
76  int load_flags() const { return load_flags_; }
77  bool force_spdy_over_ssl() const { return force_spdy_over_ssl_; }
78  bool want_spdy_over_npn() const { return want_spdy_over_npn_; }
79  bool ignore_limits() const { return ignore_limits_; }
80
81 private:
82  friend class base::RefCounted<SSLSocketParams>;
83  ~SSLSocketParams();
84
85  const scoped_refptr<TransportSocketParams> direct_params_;
86  const scoped_refptr<SOCKSSocketParams> socks_proxy_params_;
87  const scoped_refptr<HttpProxySocketParams> http_proxy_params_;
88  const HostPortPair host_and_port_;
89  const SSLConfig ssl_config_;
90  const PrivacyMode privacy_mode_;
91  const int load_flags_;
92  const bool force_spdy_over_ssl_;
93  const bool want_spdy_over_npn_;
94  bool ignore_limits_;
95
96  DISALLOW_COPY_AND_ASSIGN(SSLSocketParams);
97};
98
99// SSLConnectJobMessenger handles communication between concurrent
100// SSLConnectJobs that share the same SSL session cache key.
101//
102// SSLConnectJobMessengers tell the session cache when a certain
103// connection should be monitored for success or failure, and
104// tell SSLConnectJobs when to pause or resume their connections.
105class SSLConnectJobMessenger {
106 public:
107  struct SocketAndCallback {
108    SocketAndCallback(SSLClientSocket* ssl_socket,
109                      const base::Closure& job_resumption_callback);
110    ~SocketAndCallback();
111
112    SSLClientSocket* socket;
113    base::Closure callback;
114  };
115
116  typedef std::vector<SocketAndCallback> SSLPendingSocketsAndCallbacks;
117
118  SSLConnectJobMessenger();
119  ~SSLConnectJobMessenger();
120
121  // Removes |socket| from the set of sockets being monitored. This
122  // guarantees that |job_resumption_callback| will not be called for
123  // the socket.
124  void RemovePendingSocket(SSLClientSocket* ssl_socket);
125
126  // Returns true if |ssl_socket|'s Connect() method should be called.
127  bool CanProceed(SSLClientSocket* ssl_socket);
128
129  // Configures the SSLConnectJobMessenger to begin monitoring |ssl_socket|'s
130  // connection status. After a successful connection, or an error,
131  // the messenger will determine which sockets that have been added
132  // via AddPendingSocket() to allow to proceed.
133  void MonitorConnectionResult(SSLClientSocket* ssl_socket);
134
135  // Adds |socket| to the list of sockets waiting to Connect(). When
136  // the messenger has determined that it's an appropriate time for |socket|
137  // to connect, it will asynchronously invoke |callback|.
138  //
139  // Note: It is an error to call AddPendingSocket() without having first
140  // called MonitorConnectionResult() and configuring a socket that WILL
141  // have Connect() called on it.
142  void AddPendingSocket(SSLClientSocket* ssl_socket,
143                        const base::Closure& callback);
144
145 private:
146  // Processes pending callbacks when a socket completes its SSL handshake --
147  // either successfully or unsuccessfully.
148  void OnSSLHandshakeCompleted();
149
150  // Runs all callbacks stored in |pending_sockets_and_callbacks_|.
151  void RunAllCallbacks(
152      const SSLPendingSocketsAndCallbacks& pending_socket_and_callbacks);
153
154  base::WeakPtrFactory<SSLConnectJobMessenger> weak_factory_;
155
156  SSLPendingSocketsAndCallbacks pending_sockets_and_callbacks_;
157  // Note: this field is a vector to allow for future design changes. Currently,
158  // this vector should only ever have one entry.
159  std::vector<SSLClientSocket*> connecting_sockets_;
160};
161
162// SSLConnectJob handles the SSL handshake after setting up the underlying
163// connection as specified in the params.
164class SSLConnectJob : public ConnectJob {
165 public:
166  // Note: the SSLConnectJob does not own |messenger| so it must outlive the
167  // job.
168  SSLConnectJob(const std::string& group_name,
169                RequestPriority priority,
170                const scoped_refptr<SSLSocketParams>& params,
171                const base::TimeDelta& timeout_duration,
172                TransportClientSocketPool* transport_pool,
173                SOCKSClientSocketPool* socks_pool,
174                HttpProxyClientSocketPool* http_proxy_pool,
175                ClientSocketFactory* client_socket_factory,
176                HostResolver* host_resolver,
177                const SSLClientSocketContext& context,
178                SSLConnectJobMessenger* messenger,
179                Delegate* delegate,
180                NetLog* net_log);
181  virtual ~SSLConnectJob();
182
183  // ConnectJob methods.
184  virtual LoadState GetLoadState() const OVERRIDE;
185
186  virtual void GetAdditionalErrorState(ClientSocketHandle * handle) OVERRIDE;
187
188 private:
189  enum State {
190    STATE_TRANSPORT_CONNECT,
191    STATE_TRANSPORT_CONNECT_COMPLETE,
192    STATE_SOCKS_CONNECT,
193    STATE_SOCKS_CONNECT_COMPLETE,
194    STATE_TUNNEL_CONNECT,
195    STATE_TUNNEL_CONNECT_COMPLETE,
196    STATE_CREATE_SSL_SOCKET,
197    STATE_CHECK_FOR_RESUME,
198    STATE_SSL_CONNECT,
199    STATE_SSL_CONNECT_COMPLETE,
200    STATE_NONE,
201  };
202
203  void OnIOComplete(int result);
204
205  // Runs the state transition loop.
206  int DoLoop(int result);
207
208  int DoTransportConnect();
209  int DoTransportConnectComplete(int result);
210  int DoSOCKSConnect();
211  int DoSOCKSConnectComplete(int result);
212  int DoTunnelConnect();
213  int DoTunnelConnectComplete(int result);
214  int DoCreateSSLSocket();
215  int DoCheckForResume();
216  int DoSSLConnect();
217  int DoSSLConnectComplete(int result);
218
219  // Tells a waiting SSLConnectJob to resume its SSL connection.
220  void ResumeSSLConnection();
221
222  // Returns the initial state for the state machine based on the
223  // |connection_type|.
224  static State GetInitialState(SSLSocketParams::ConnectionType connection_type);
225
226  // Starts the SSL connection process.  Returns OK on success and
227  // ERR_IO_PENDING if it cannot immediately service the request.
228  // Otherwise, it returns a net error code.
229  virtual int ConnectInternal() OVERRIDE;
230
231  scoped_refptr<SSLSocketParams> params_;
232  TransportClientSocketPool* const transport_pool_;
233  SOCKSClientSocketPool* const socks_pool_;
234  HttpProxyClientSocketPool* const http_proxy_pool_;
235  ClientSocketFactory* const client_socket_factory_;
236  HostResolver* const host_resolver_;
237
238  const SSLClientSocketContext context_;
239
240  State next_state_;
241  CompletionCallback io_callback_;
242  scoped_ptr<ClientSocketHandle> transport_socket_handle_;
243  scoped_ptr<SSLClientSocket> ssl_socket_;
244
245  SSLConnectJobMessenger* messenger_;
246  HttpResponseInfo error_response_info_;
247
248  base::WeakPtrFactory<SSLConnectJob> weak_factory_;
249
250  DISALLOW_COPY_AND_ASSIGN(SSLConnectJob);
251};
252
253class NET_EXPORT_PRIVATE SSLClientSocketPool
254    : public ClientSocketPool,
255      public HigherLayeredPool,
256      public SSLConfigService::Observer {
257 public:
258  typedef SSLSocketParams SocketParams;
259
260  // Only the pools that will be used are required. i.e. if you never
261  // try to create an SSL over SOCKS socket, |socks_pool| may be NULL.
262  SSLClientSocketPool(int max_sockets,
263                      int max_sockets_per_group,
264                      ClientSocketPoolHistograms* histograms,
265                      HostResolver* host_resolver,
266                      CertVerifier* cert_verifier,
267                      ChannelIDService* channel_id_service,
268                      TransportSecurityState* transport_security_state,
269                      CTVerifier* cert_transparency_verifier,
270                      const std::string& ssl_session_cache_shard,
271                      ClientSocketFactory* client_socket_factory,
272                      TransportClientSocketPool* transport_pool,
273                      SOCKSClientSocketPool* socks_pool,
274                      HttpProxyClientSocketPool* http_proxy_pool,
275                      SSLConfigService* ssl_config_service,
276                      bool enable_ssl_connect_job_waiting,
277                      NetLog* net_log);
278
279  virtual ~SSLClientSocketPool();
280
281  // ClientSocketPool implementation.
282  virtual int RequestSocket(const std::string& group_name,
283                            const void* connect_params,
284                            RequestPriority priority,
285                            ClientSocketHandle* handle,
286                            const CompletionCallback& callback,
287                            const BoundNetLog& net_log) OVERRIDE;
288
289  virtual void RequestSockets(const std::string& group_name,
290                              const void* params,
291                              int num_sockets,
292                              const BoundNetLog& net_log) OVERRIDE;
293
294  virtual void CancelRequest(const std::string& group_name,
295                             ClientSocketHandle* handle) OVERRIDE;
296
297  virtual void ReleaseSocket(const std::string& group_name,
298                             scoped_ptr<StreamSocket> socket,
299                             int id) OVERRIDE;
300
301  virtual void FlushWithError(int error) OVERRIDE;
302
303  virtual void CloseIdleSockets() OVERRIDE;
304
305  virtual int IdleSocketCount() const OVERRIDE;
306
307  virtual int IdleSocketCountInGroup(
308      const std::string& group_name) const OVERRIDE;
309
310  virtual LoadState GetLoadState(
311      const std::string& group_name,
312      const ClientSocketHandle* handle) const OVERRIDE;
313
314  virtual base::DictionaryValue* GetInfoAsValue(
315      const std::string& name,
316      const std::string& type,
317      bool include_nested_pools) const OVERRIDE;
318
319  virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
320
321  virtual ClientSocketPoolHistograms* histograms() const OVERRIDE;
322
323  // LowerLayeredPool implementation.
324  virtual bool IsStalled() const OVERRIDE;
325
326  virtual void AddHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
327
328  virtual void RemoveHigherLayeredPool(HigherLayeredPool* higher_pool) OVERRIDE;
329
330  // HigherLayeredPool implementation.
331  virtual bool CloseOneIdleConnection() OVERRIDE;
332
333 private:
334  typedef ClientSocketPoolBase<SSLSocketParams> PoolBase;
335
336  // SSLConfigService::Observer implementation.
337
338  // When the user changes the SSL config, we flush all idle sockets so they
339  // won't get re-used.
340  virtual void OnSSLConfigChanged() OVERRIDE;
341
342  class SSLConnectJobFactory : public PoolBase::ConnectJobFactory {
343   public:
344    SSLConnectJobFactory(TransportClientSocketPool* transport_pool,
345                         SOCKSClientSocketPool* socks_pool,
346                         HttpProxyClientSocketPool* http_proxy_pool,
347                         ClientSocketFactory* client_socket_factory,
348                         HostResolver* host_resolver,
349                         const SSLClientSocketContext& context,
350                         bool enable_ssl_connect_job_waiting,
351                         NetLog* net_log);
352
353    virtual ~SSLConnectJobFactory();
354
355    // ClientSocketPoolBase::ConnectJobFactory methods.
356    virtual scoped_ptr<ConnectJob> NewConnectJob(
357        const std::string& group_name,
358        const PoolBase::Request& request,
359        ConnectJob::Delegate* delegate) const OVERRIDE;
360
361    virtual base::TimeDelta ConnectionTimeout() const OVERRIDE;
362
363   private:
364    // Maps SSLConnectJob cache keys to SSLConnectJobMessenger objects.
365    typedef std::map<std::string, SSLConnectJobMessenger*> MessengerMap;
366
367    TransportClientSocketPool* const transport_pool_;
368    SOCKSClientSocketPool* const socks_pool_;
369    HttpProxyClientSocketPool* const http_proxy_pool_;
370    ClientSocketFactory* const client_socket_factory_;
371    HostResolver* const host_resolver_;
372    const SSLClientSocketContext context_;
373    base::TimeDelta timeout_;
374    bool enable_ssl_connect_job_waiting_;
375    NetLog* net_log_;
376    // |messenger_map_| is currently a pointer so that an element can be
377    // added to it inside of the const method NewConnectJob. In the future,
378    // elements will be added in a different method.
379    // TODO(mshelley) Change this to a non-pointer.
380    scoped_ptr<MessengerMap> messenger_map_;
381
382    DISALLOW_COPY_AND_ASSIGN(SSLConnectJobFactory);
383  };
384
385  TransportClientSocketPool* const transport_pool_;
386  SOCKSClientSocketPool* const socks_pool_;
387  HttpProxyClientSocketPool* const http_proxy_pool_;
388  PoolBase base_;
389  const scoped_refptr<SSLConfigService> ssl_config_service_;
390
391  DISALLOW_COPY_AND_ASSIGN(SSLClientSocketPool);
392};
393
394}  // namespace net
395
396#endif  // NET_SOCKET_SSL_CLIENT_SOCKET_POOL_H_
397