1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/websockets/websocket_handshake_stream_create_helper.h"
6
7#include <string>
8#include <vector>
9
10#include "net/base/completion_callback.h"
11#include "net/base/net_errors.h"
12#include "net/http/http_request_headers.h"
13#include "net/http/http_request_info.h"
14#include "net/http/http_response_headers.h"
15#include "net/http/http_response_info.h"
16#include "net/socket/client_socket_handle.h"
17#include "net/socket/socket_test_util.h"
18#include "net/websockets/websocket_basic_handshake_stream.h"
19#include "net/websockets/websocket_stream.h"
20#include "net/websockets/websocket_test_util.h"
21#include "testing/gtest/include/gtest/gtest.h"
22#include "url/gurl.h"
23
24namespace net {
25namespace {
26
27// This class encapsulates the details of creating a mock ClientSocketHandle.
28class MockClientSocketHandleFactory {
29 public:
30  MockClientSocketHandleFactory()
31      : histograms_("a"),
32        pool_(1, 1, &histograms_, socket_factory_maker_.factory()) {}
33
34  // The created socket expects |expect_written| to be written to the socket,
35  // and will respond with |return_to_read|. The test will fail if the expected
36  // text is not written, or if all the bytes are not read.
37  scoped_ptr<ClientSocketHandle> CreateClientSocketHandle(
38      const std::string& expect_written,
39      const std::string& return_to_read) {
40    socket_factory_maker_.SetExpectations(expect_written, return_to_read);
41    scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle);
42    socket_handle->Init(
43        "a",
44        scoped_refptr<MockTransportSocketParams>(),
45        MEDIUM,
46        CompletionCallback(),
47        &pool_,
48        BoundNetLog());
49    return socket_handle.Pass();
50  }
51
52 private:
53  WebSocketDeterministicMockClientSocketFactoryMaker socket_factory_maker_;
54  ClientSocketPoolHistograms histograms_;
55  MockTransportClientSocketPool pool_;
56
57  DISALLOW_COPY_AND_ASSIGN(MockClientSocketHandleFactory);
58};
59
60class TestConnectDelegate : public WebSocketStream::ConnectDelegate {
61 public:
62  virtual ~TestConnectDelegate() {}
63
64  virtual void OnSuccess(scoped_ptr<WebSocketStream> stream) OVERRIDE {}
65  virtual void OnFailure(const std::string& failure_message) OVERRIDE {}
66  virtual void OnStartOpeningHandshake(
67      scoped_ptr<WebSocketHandshakeRequestInfo> request) OVERRIDE {}
68  virtual void OnFinishOpeningHandshake(
69      scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE {}
70  virtual void OnSSLCertificateError(
71      scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks>
72          ssl_error_callbacks,
73      const SSLInfo& ssl_info,
74      bool fatal) OVERRIDE {}
75};
76
77class WebSocketHandshakeStreamCreateHelperTest : public ::testing::Test {
78 protected:
79  scoped_ptr<WebSocketStream> CreateAndInitializeStream(
80      const std::string& socket_url,
81      const std::string& socket_path,
82      const std::vector<std::string>& sub_protocols,
83      const std::string& origin,
84      const std::string& extra_request_headers,
85      const std::string& extra_response_headers) {
86    WebSocketHandshakeStreamCreateHelper create_helper(&connect_delegate_,
87                                                       sub_protocols);
88    create_helper.set_failure_message(&failure_message_);
89
90    scoped_ptr<ClientSocketHandle> socket_handle =
91        socket_handle_factory_.CreateClientSocketHandle(
92            WebSocketStandardRequest(
93                socket_path, origin, extra_request_headers),
94            WebSocketStandardResponse(extra_response_headers));
95
96    scoped_ptr<WebSocketHandshakeStreamBase> handshake(
97        create_helper.CreateBasicStream(socket_handle.Pass(), false));
98
99    // If in future the implementation type returned by CreateBasicStream()
100    // changes, this static_cast will be wrong. However, in that case the test
101    // will fail and AddressSanitizer should identify the issue.
102    static_cast<WebSocketBasicHandshakeStream*>(handshake.get())
103        ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ==");
104
105    HttpRequestInfo request_info;
106    request_info.url = GURL(socket_url);
107    request_info.method = "GET";
108    request_info.load_flags = LOAD_DISABLE_CACHE | LOAD_DO_NOT_PROMPT_FOR_LOGIN;
109    int rv = handshake->InitializeStream(
110        &request_info, DEFAULT_PRIORITY, BoundNetLog(), CompletionCallback());
111    EXPECT_EQ(OK, rv);
112
113    HttpRequestHeaders headers;
114    headers.SetHeader("Host", "localhost");
115    headers.SetHeader("Connection", "Upgrade");
116    headers.SetHeader("Pragma", "no-cache");
117    headers.SetHeader("Cache-Control", "no-cache");
118    headers.SetHeader("Upgrade", "websocket");
119    headers.SetHeader("Origin", origin);
120    headers.SetHeader("Sec-WebSocket-Version", "13");
121    headers.SetHeader("User-Agent", "");
122    headers.SetHeader("Accept-Encoding", "gzip, deflate");
123    headers.SetHeader("Accept-Language", "en-us,fr");
124
125    HttpResponseInfo response;
126    TestCompletionCallback dummy;
127
128    rv = handshake->SendRequest(headers, &response, dummy.callback());
129
130    EXPECT_EQ(OK, rv);
131
132    rv = handshake->ReadResponseHeaders(dummy.callback());
133    EXPECT_EQ(OK, rv);
134    EXPECT_EQ(101, response.headers->response_code());
135    EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade"));
136    EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket"));
137    return handshake->Upgrade();
138  }
139
140  MockClientSocketHandleFactory socket_handle_factory_;
141  TestConnectDelegate connect_delegate_;
142  std::string failure_message_;
143};
144
145// Confirm that the basic case works as expected.
146TEST_F(WebSocketHandshakeStreamCreateHelperTest, BasicStream) {
147  scoped_ptr<WebSocketStream> stream =
148      CreateAndInitializeStream("ws://localhost/", "/",
149                                std::vector<std::string>(), "http://localhost/",
150                                "", "");
151  EXPECT_EQ("", stream->GetExtensions());
152  EXPECT_EQ("", stream->GetSubProtocol());
153}
154
155// Verify that the sub-protocols are passed through.
156TEST_F(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) {
157  std::vector<std::string> sub_protocols;
158  sub_protocols.push_back("chat");
159  sub_protocols.push_back("superchat");
160  scoped_ptr<WebSocketStream> stream =
161      CreateAndInitializeStream("ws://localhost/",
162                                "/",
163                                sub_protocols,
164                                "http://localhost/",
165                                "Sec-WebSocket-Protocol: chat, superchat\r\n",
166                                "Sec-WebSocket-Protocol: superchat\r\n");
167  EXPECT_EQ("superchat", stream->GetSubProtocol());
168}
169
170// Verify that extension name is available. Bad extension names are tested in
171// websocket_stream_test.cc.
172TEST_F(WebSocketHandshakeStreamCreateHelperTest, Extensions) {
173  scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream(
174      "ws://localhost/",
175      "/",
176      std::vector<std::string>(),
177      "http://localhost/",
178      "",
179      "Sec-WebSocket-Extensions: permessage-deflate\r\n");
180  EXPECT_EQ("permessage-deflate", stream->GetExtensions());
181}
182
183// Verify that extension parameters are available. Bad parameters are tested in
184// websocket_stream_test.cc.
185TEST_F(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) {
186  scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream(
187      "ws://localhost/",
188      "/",
189      std::vector<std::string>(),
190      "http://localhost/",
191      "",
192      "Sec-WebSocket-Extensions: permessage-deflate;"
193      " client_max_window_bits=14; server_max_window_bits=14;"
194      " server_no_context_takeover; client_no_context_takeover\r\n");
195
196  EXPECT_EQ(
197      "permessage-deflate;"
198      " client_max_window_bits=14; server_max_window_bits=14;"
199      " server_no_context_takeover; client_no_context_takeover",
200      stream->GetExtensions());
201}
202
203}  // namespace
204}  // namespace net
205