1// Copyright 2013 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/websockets/websocket_handshake_stream_create_helper.h" 6 7#include <string> 8#include <vector> 9 10#include "net/base/completion_callback.h" 11#include "net/base/net_errors.h" 12#include "net/http/http_request_headers.h" 13#include "net/http/http_request_info.h" 14#include "net/http/http_response_headers.h" 15#include "net/http/http_response_info.h" 16#include "net/socket/client_socket_handle.h" 17#include "net/socket/socket_test_util.h" 18#include "net/websockets/websocket_basic_handshake_stream.h" 19#include "net/websockets/websocket_stream.h" 20#include "net/websockets/websocket_test_util.h" 21#include "testing/gtest/include/gtest/gtest.h" 22#include "url/gurl.h" 23 24namespace net { 25namespace { 26 27// This class encapsulates the details of creating a mock ClientSocketHandle. 28class MockClientSocketHandleFactory { 29 public: 30 MockClientSocketHandleFactory() 31 : histograms_("a"), 32 pool_(1, 1, &histograms_, socket_factory_maker_.factory()) {} 33 34 // The created socket expects |expect_written| to be written to the socket, 35 // and will respond with |return_to_read|. The test will fail if the expected 36 // text is not written, or if all the bytes are not read. 37 scoped_ptr<ClientSocketHandle> CreateClientSocketHandle( 38 const std::string& expect_written, 39 const std::string& return_to_read) { 40 socket_factory_maker_.SetExpectations(expect_written, return_to_read); 41 scoped_ptr<ClientSocketHandle> socket_handle(new ClientSocketHandle); 42 socket_handle->Init( 43 "a", 44 scoped_refptr<MockTransportSocketParams>(), 45 MEDIUM, 46 CompletionCallback(), 47 &pool_, 48 BoundNetLog()); 49 return socket_handle.Pass(); 50 } 51 52 private: 53 WebSocketDeterministicMockClientSocketFactoryMaker socket_factory_maker_; 54 ClientSocketPoolHistograms histograms_; 55 MockTransportClientSocketPool pool_; 56 57 DISALLOW_COPY_AND_ASSIGN(MockClientSocketHandleFactory); 58}; 59 60class TestConnectDelegate : public WebSocketStream::ConnectDelegate { 61 public: 62 virtual ~TestConnectDelegate() {} 63 64 virtual void OnSuccess(scoped_ptr<WebSocketStream> stream) OVERRIDE {} 65 virtual void OnFailure(const std::string& failure_message) OVERRIDE {} 66 virtual void OnStartOpeningHandshake( 67 scoped_ptr<WebSocketHandshakeRequestInfo> request) OVERRIDE {} 68 virtual void OnFinishOpeningHandshake( 69 scoped_ptr<WebSocketHandshakeResponseInfo> response) OVERRIDE {} 70 virtual void OnSSLCertificateError( 71 scoped_ptr<WebSocketEventInterface::SSLErrorCallbacks> 72 ssl_error_callbacks, 73 const SSLInfo& ssl_info, 74 bool fatal) OVERRIDE {} 75}; 76 77class WebSocketHandshakeStreamCreateHelperTest : public ::testing::Test { 78 protected: 79 scoped_ptr<WebSocketStream> CreateAndInitializeStream( 80 const std::string& socket_url, 81 const std::string& socket_path, 82 const std::vector<std::string>& sub_protocols, 83 const std::string& origin, 84 const std::string& extra_request_headers, 85 const std::string& extra_response_headers) { 86 WebSocketHandshakeStreamCreateHelper create_helper(&connect_delegate_, 87 sub_protocols); 88 create_helper.set_failure_message(&failure_message_); 89 90 scoped_ptr<ClientSocketHandle> socket_handle = 91 socket_handle_factory_.CreateClientSocketHandle( 92 WebSocketStandardRequest( 93 socket_path, origin, extra_request_headers), 94 WebSocketStandardResponse(extra_response_headers)); 95 96 scoped_ptr<WebSocketHandshakeStreamBase> handshake( 97 create_helper.CreateBasicStream(socket_handle.Pass(), false)); 98 99 // If in future the implementation type returned by CreateBasicStream() 100 // changes, this static_cast will be wrong. However, in that case the test 101 // will fail and AddressSanitizer should identify the issue. 102 static_cast<WebSocketBasicHandshakeStream*>(handshake.get()) 103 ->SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ=="); 104 105 HttpRequestInfo request_info; 106 request_info.url = GURL(socket_url); 107 request_info.method = "GET"; 108 request_info.load_flags = LOAD_DISABLE_CACHE | LOAD_DO_NOT_PROMPT_FOR_LOGIN; 109 int rv = handshake->InitializeStream( 110 &request_info, DEFAULT_PRIORITY, BoundNetLog(), CompletionCallback()); 111 EXPECT_EQ(OK, rv); 112 113 HttpRequestHeaders headers; 114 headers.SetHeader("Host", "localhost"); 115 headers.SetHeader("Connection", "Upgrade"); 116 headers.SetHeader("Pragma", "no-cache"); 117 headers.SetHeader("Cache-Control", "no-cache"); 118 headers.SetHeader("Upgrade", "websocket"); 119 headers.SetHeader("Origin", origin); 120 headers.SetHeader("Sec-WebSocket-Version", "13"); 121 headers.SetHeader("User-Agent", ""); 122 headers.SetHeader("Accept-Encoding", "gzip, deflate"); 123 headers.SetHeader("Accept-Language", "en-us,fr"); 124 125 HttpResponseInfo response; 126 TestCompletionCallback dummy; 127 128 rv = handshake->SendRequest(headers, &response, dummy.callback()); 129 130 EXPECT_EQ(OK, rv); 131 132 rv = handshake->ReadResponseHeaders(dummy.callback()); 133 EXPECT_EQ(OK, rv); 134 EXPECT_EQ(101, response.headers->response_code()); 135 EXPECT_TRUE(response.headers->HasHeaderValue("Connection", "Upgrade")); 136 EXPECT_TRUE(response.headers->HasHeaderValue("Upgrade", "websocket")); 137 return handshake->Upgrade(); 138 } 139 140 MockClientSocketHandleFactory socket_handle_factory_; 141 TestConnectDelegate connect_delegate_; 142 std::string failure_message_; 143}; 144 145// Confirm that the basic case works as expected. 146TEST_F(WebSocketHandshakeStreamCreateHelperTest, BasicStream) { 147 scoped_ptr<WebSocketStream> stream = 148 CreateAndInitializeStream("ws://localhost/", "/", 149 std::vector<std::string>(), "http://localhost/", 150 "", ""); 151 EXPECT_EQ("", stream->GetExtensions()); 152 EXPECT_EQ("", stream->GetSubProtocol()); 153} 154 155// Verify that the sub-protocols are passed through. 156TEST_F(WebSocketHandshakeStreamCreateHelperTest, SubProtocols) { 157 std::vector<std::string> sub_protocols; 158 sub_protocols.push_back("chat"); 159 sub_protocols.push_back("superchat"); 160 scoped_ptr<WebSocketStream> stream = 161 CreateAndInitializeStream("ws://localhost/", 162 "/", 163 sub_protocols, 164 "http://localhost/", 165 "Sec-WebSocket-Protocol: chat, superchat\r\n", 166 "Sec-WebSocket-Protocol: superchat\r\n"); 167 EXPECT_EQ("superchat", stream->GetSubProtocol()); 168} 169 170// Verify that extension name is available. Bad extension names are tested in 171// websocket_stream_test.cc. 172TEST_F(WebSocketHandshakeStreamCreateHelperTest, Extensions) { 173 scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream( 174 "ws://localhost/", 175 "/", 176 std::vector<std::string>(), 177 "http://localhost/", 178 "", 179 "Sec-WebSocket-Extensions: permessage-deflate\r\n"); 180 EXPECT_EQ("permessage-deflate", stream->GetExtensions()); 181} 182 183// Verify that extension parameters are available. Bad parameters are tested in 184// websocket_stream_test.cc. 185TEST_F(WebSocketHandshakeStreamCreateHelperTest, ExtensionParameters) { 186 scoped_ptr<WebSocketStream> stream = CreateAndInitializeStream( 187 "ws://localhost/", 188 "/", 189 std::vector<std::string>(), 190 "http://localhost/", 191 "", 192 "Sec-WebSocket-Extensions: permessage-deflate;" 193 " client_max_window_bits=14; server_max_window_bits=14;" 194 " server_no_context_takeover; client_no_context_takeover\r\n"); 195 196 EXPECT_EQ( 197 "permessage-deflate;" 198 " client_max_window_bits=14; server_max_window_bits=14;" 199 " server_no_context_takeover; client_no_context_takeover", 200 stream->GetExtensions()); 201} 202 203} // namespace 204} // namespace net 205