1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/http/http_stream_factory_impl.h"
6
7#include <string>
8
9#include "base/basictypes.h"
10#include "net/base/cert_verifier.h"
11#include "net/base/mock_host_resolver.h"
12#include "net/base/net_log.h"
13#include "net/base/ssl_config_service_defaults.h"
14#include "net/base/test_completion_callback.h"
15#include "net/http/http_auth_handler_factory.h"
16#include "net/http/http_network_session.h"
17#include "net/http/http_network_session_peer.h"
18#include "net/http/http_request_info.h"
19#include "net/proxy/proxy_info.h"
20#include "net/proxy/proxy_service.h"
21#include "net/socket/socket_test_util.h"
22#include "net/spdy/spdy_session.h"
23#include "net/spdy/spdy_session_pool.h"
24#include "testing/gtest/include/gtest/gtest.h"
25
26namespace net {
27
28namespace {
29
30class MockHttpStreamFactoryImpl : public HttpStreamFactoryImpl {
31 public:
32  MockHttpStreamFactoryImpl(HttpNetworkSession* session)
33      : HttpStreamFactoryImpl(session),
34        preconnect_done_(false),
35        waiting_for_preconnect_(false) {}
36
37
38  void WaitForPreconnects() {
39    while (!preconnect_done_) {
40      waiting_for_preconnect_ = true;
41      MessageLoop::current()->Run();
42      waiting_for_preconnect_ = false;
43    }
44  }
45
46 private:
47  // HttpStreamFactoryImpl methods.
48  virtual void OnPreconnectsCompleteInternal() {
49    preconnect_done_ = true;
50    if (waiting_for_preconnect_)
51      MessageLoop::current()->Quit();
52  }
53
54  bool preconnect_done_;
55  bool waiting_for_preconnect_;
56};
57
58struct SessionDependencies {
59  // Custom proxy service dependency.
60  explicit SessionDependencies(ProxyService* proxy_service)
61      : host_resolver(new MockHostResolver),
62        cert_verifier(new CertVerifier),
63        proxy_service(proxy_service),
64        ssl_config_service(new SSLConfigServiceDefaults),
65        http_auth_handler_factory(
66            HttpAuthHandlerFactory::CreateDefault(host_resolver.get())),
67        net_log(NULL) {}
68
69  scoped_ptr<MockHostResolverBase> host_resolver;
70  scoped_ptr<CertVerifier> cert_verifier;
71  scoped_refptr<ProxyService> proxy_service;
72  scoped_refptr<SSLConfigService> ssl_config_service;
73  MockClientSocketFactory socket_factory;
74  scoped_ptr<HttpAuthHandlerFactory> http_auth_handler_factory;
75  NetLog* net_log;
76};
77
78HttpNetworkSession* CreateSession(SessionDependencies* session_deps) {
79  HttpNetworkSession::Params params;
80  params.host_resolver = session_deps->host_resolver.get();
81  params.cert_verifier = session_deps->cert_verifier.get();
82  params.proxy_service = session_deps->proxy_service;
83  params.ssl_config_service = session_deps->ssl_config_service;
84  params.client_socket_factory = &session_deps->socket_factory;
85  params.http_auth_handler_factory =
86      session_deps->http_auth_handler_factory.get();
87  params.net_log = session_deps->net_log;
88  return new HttpNetworkSession(params);
89}
90
91struct TestCase {
92  int num_streams;
93  bool ssl;
94};
95
96TestCase kTests[] = {
97  { 1, false },
98  { 2, false },
99  { 1, true},
100  { 2, true},
101};
102
103void PreconnectHelper(const TestCase& test,
104                      HttpNetworkSession* session) {
105  HttpNetworkSessionPeer peer(session);
106  MockHttpStreamFactoryImpl* mock_factory =
107      new MockHttpStreamFactoryImpl(session);
108  peer.SetHttpStreamFactory(mock_factory);
109  SSLConfig ssl_config;
110  session->ssl_config_service()->GetSSLConfig(&ssl_config);
111
112  HttpRequestInfo request;
113  request.method = "GET";
114  request.url = test.ssl ?  GURL("https://www.google.com") :
115      GURL("http://www.google.com");
116  request.load_flags = 0;
117
118  ProxyInfo proxy_info;
119  TestCompletionCallback callback;
120
121  session->http_stream_factory()->PreconnectStreams(
122      test.num_streams, request, ssl_config, BoundNetLog());
123  mock_factory->WaitForPreconnects();
124};
125
126template<typename ParentPool>
127class CapturePreconnectsSocketPool : public ParentPool {
128 public:
129  CapturePreconnectsSocketPool(HostResolver* host_resolver,
130                               CertVerifier* cert_verifier);
131
132  int last_num_streams() const {
133    return last_num_streams_;
134  }
135
136  virtual int RequestSocket(const std::string& group_name,
137                            const void* socket_params,
138                            RequestPriority priority,
139                            ClientSocketHandle* handle,
140                            CompletionCallback* callback,
141                            const BoundNetLog& net_log) {
142    ADD_FAILURE();
143    return ERR_UNEXPECTED;
144  }
145
146  virtual void RequestSockets(const std::string& group_name,
147                              const void* socket_params,
148                              int num_sockets,
149                              const BoundNetLog& net_log) {
150    last_num_streams_ = num_sockets;
151  }
152
153  virtual void CancelRequest(const std::string& group_name,
154                             ClientSocketHandle* handle) {
155    ADD_FAILURE();
156  }
157  virtual void ReleaseSocket(const std::string& group_name,
158                             ClientSocket* socket,
159                             int id) {
160    ADD_FAILURE();
161  }
162  virtual void CloseIdleSockets() {
163    ADD_FAILURE();
164  }
165  virtual int IdleSocketCount() const {
166    ADD_FAILURE();
167    return 0;
168  }
169  virtual int IdleSocketCountInGroup(const std::string& group_name) const {
170    ADD_FAILURE();
171    return 0;
172  }
173  virtual LoadState GetLoadState(const std::string& group_name,
174                                 const ClientSocketHandle* handle) const {
175    ADD_FAILURE();
176    return LOAD_STATE_IDLE;
177  }
178  virtual base::TimeDelta ConnectionTimeout() const {
179    return base::TimeDelta();
180  }
181
182 private:
183  int last_num_streams_;
184};
185
186typedef CapturePreconnectsSocketPool<TransportClientSocketPool>
187CapturePreconnectsTransportSocketPool;
188typedef CapturePreconnectsSocketPool<HttpProxyClientSocketPool>
189CapturePreconnectsHttpProxySocketPool;
190typedef CapturePreconnectsSocketPool<SOCKSClientSocketPool>
191CapturePreconnectsSOCKSSocketPool;
192typedef CapturePreconnectsSocketPool<SSLClientSocketPool>
193CapturePreconnectsSSLSocketPool;
194
195template<typename ParentPool>
196CapturePreconnectsSocketPool<ParentPool>::CapturePreconnectsSocketPool(
197    HostResolver* host_resolver, CertVerifier* /* cert_verifier */)
198    : ParentPool(0, 0, NULL, host_resolver, NULL, NULL),
199      last_num_streams_(-1) {}
200
201template<>
202CapturePreconnectsHttpProxySocketPool::CapturePreconnectsSocketPool(
203    HostResolver* host_resolver, CertVerifier* /* cert_verifier */)
204    : HttpProxyClientSocketPool(0, 0, NULL, host_resolver, NULL, NULL, NULL),
205      last_num_streams_(-1) {}
206
207template<>
208CapturePreconnectsSSLSocketPool::CapturePreconnectsSocketPool(
209    HostResolver* host_resolver, CertVerifier* cert_verifier)
210    : SSLClientSocketPool(0, 0, NULL, host_resolver, cert_verifier, NULL, NULL,
211                          NULL, NULL, NULL, NULL, NULL, NULL, NULL),
212      last_num_streams_(-1) {}
213
214TEST(HttpStreamFactoryTest, PreconnectDirect) {
215  for (size_t i = 0; i < arraysize(kTests); ++i) {
216    SessionDependencies session_deps(ProxyService::CreateDirect());
217    scoped_refptr<HttpNetworkSession> session(CreateSession(&session_deps));
218    HttpNetworkSessionPeer peer(session);
219    CapturePreconnectsTransportSocketPool* transport_conn_pool =
220        new CapturePreconnectsTransportSocketPool(
221            session_deps.host_resolver.get(),
222            session_deps.cert_verifier.get());
223    peer.SetTransportSocketPool(transport_conn_pool);
224    CapturePreconnectsSSLSocketPool* ssl_conn_pool =
225        new CapturePreconnectsSSLSocketPool(
226            session_deps.host_resolver.get(),
227            session_deps.cert_verifier.get());
228    peer.SetSSLSocketPool(ssl_conn_pool);
229    PreconnectHelper(kTests[i], session);
230    if (kTests[i].ssl)
231      EXPECT_EQ(kTests[i].num_streams, ssl_conn_pool->last_num_streams());
232    else
233      EXPECT_EQ(kTests[i].num_streams, transport_conn_pool->last_num_streams());
234  }
235}
236
237TEST(HttpStreamFactoryTest, PreconnectHttpProxy) {
238  for (size_t i = 0; i < arraysize(kTests); ++i) {
239    SessionDependencies session_deps(ProxyService::CreateFixed("http_proxy"));
240    scoped_refptr<HttpNetworkSession> session(CreateSession(&session_deps));
241    HttpNetworkSessionPeer peer(session);
242    HostPortPair proxy_host("http_proxy", 80);
243    CapturePreconnectsHttpProxySocketPool* http_proxy_pool =
244        new CapturePreconnectsHttpProxySocketPool(
245            session_deps.host_resolver.get(),
246            session_deps.cert_verifier.get());
247    peer.SetSocketPoolForHTTPProxy(proxy_host, http_proxy_pool);
248    CapturePreconnectsSSLSocketPool* ssl_conn_pool =
249        new CapturePreconnectsSSLSocketPool(
250            session_deps.host_resolver.get(),
251            session_deps.cert_verifier.get());
252    peer.SetSocketPoolForSSLWithProxy(proxy_host, ssl_conn_pool);
253    PreconnectHelper(kTests[i], session);
254    if (kTests[i].ssl)
255      EXPECT_EQ(kTests[i].num_streams, ssl_conn_pool->last_num_streams());
256    else
257      EXPECT_EQ(kTests[i].num_streams, http_proxy_pool->last_num_streams());
258  }
259}
260
261TEST(HttpStreamFactoryTest, PreconnectSocksProxy) {
262  for (size_t i = 0; i < arraysize(kTests); ++i) {
263    SessionDependencies session_deps(
264        ProxyService::CreateFixed("socks4://socks_proxy:1080"));
265    scoped_refptr<HttpNetworkSession> session(CreateSession(&session_deps));
266    HttpNetworkSessionPeer peer(session);
267    HostPortPair proxy_host("socks_proxy", 1080);
268    CapturePreconnectsSOCKSSocketPool* socks_proxy_pool =
269        new CapturePreconnectsSOCKSSocketPool(
270            session_deps.host_resolver.get(),
271            session_deps.cert_verifier.get());
272    peer.SetSocketPoolForSOCKSProxy(proxy_host, socks_proxy_pool);
273    CapturePreconnectsSSLSocketPool* ssl_conn_pool =
274        new CapturePreconnectsSSLSocketPool(
275            session_deps.host_resolver.get(),
276            session_deps.cert_verifier.get());
277    peer.SetSocketPoolForSSLWithProxy(proxy_host, ssl_conn_pool);
278    PreconnectHelper(kTests[i], session);
279    if (kTests[i].ssl)
280      EXPECT_EQ(kTests[i].num_streams, ssl_conn_pool->last_num_streams());
281    else
282      EXPECT_EQ(kTests[i].num_streams, socks_proxy_pool->last_num_streams());
283  }
284}
285
286TEST(HttpStreamFactoryTest, PreconnectDirectWithExistingSpdySession) {
287  for (size_t i = 0; i < arraysize(kTests); ++i) {
288    SessionDependencies session_deps(ProxyService::CreateDirect());
289    scoped_refptr<HttpNetworkSession> session(CreateSession(&session_deps));
290    HttpNetworkSessionPeer peer(session);
291
292    // Set an existing SpdySession in the pool.
293    HostPortPair host_port_pair("www.google.com", 443);
294    HostPortProxyPair pair(host_port_pair, ProxyServer::Direct());
295    scoped_refptr<SpdySession> spdy_session =
296        session->spdy_session_pool()->Get(pair, BoundNetLog());
297
298    CapturePreconnectsTransportSocketPool* transport_conn_pool =
299        new CapturePreconnectsTransportSocketPool(
300            session_deps.host_resolver.get(),
301            session_deps.cert_verifier.get());
302    peer.SetTransportSocketPool(transport_conn_pool);
303    CapturePreconnectsSSLSocketPool* ssl_conn_pool =
304        new CapturePreconnectsSSLSocketPool(
305            session_deps.host_resolver.get(),
306            session_deps.cert_verifier.get());
307    peer.SetSSLSocketPool(ssl_conn_pool);
308    PreconnectHelper(kTests[i], session);
309    // We shouldn't be preconnecting if we have an existing session, which is
310    // the case for https://www.google.com.
311    if (kTests[i].ssl)
312      EXPECT_EQ(-1, ssl_conn_pool->last_num_streams());
313    else
314      EXPECT_EQ(kTests[i].num_streams, transport_conn_pool->last_num_streams());
315  }
316}
317
318}  // namespace
319
320}  // namespace net
321