1// Copyright 2013 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "content/browser/renderer_host/websocket_dispatcher_host.h" 6 7#include <algorithm> 8#include <vector> 9 10#include "base/bind.h" 11#include "base/bind_helpers.h" 12#include "base/memory/ref_counted.h" 13#include "base/memory/weak_ptr.h" 14#include "content/browser/renderer_host/websocket_host.h" 15#include "content/common/websocket.h" 16#include "content/common/websocket_messages.h" 17#include "ipc/ipc_message.h" 18#include "testing/gtest/include/gtest/gtest.h" 19#include "url/gurl.h" 20#include "url/origin.h" 21 22namespace content { 23namespace { 24 25// This number is unlikely to occur by chance. 26static const int kMagicRenderProcessId = 506116062; 27 28class WebSocketDispatcherHostTest; 29 30// A mock of WebsocketHost which records received messages. 31class MockWebSocketHost : public WebSocketHost { 32 public: 33 MockWebSocketHost(int routing_id, 34 WebSocketDispatcherHost* dispatcher, 35 net::URLRequestContext* url_request_context, 36 WebSocketDispatcherHostTest* owner); 37 38 virtual ~MockWebSocketHost() {} 39 40 virtual bool OnMessageReceived(const IPC::Message& message) OVERRIDE { 41 received_messages_.push_back(message); 42 return true; 43 } 44 45 virtual void GoAway() OVERRIDE; 46 47 std::vector<IPC::Message> received_messages_; 48 base::WeakPtr<WebSocketDispatcherHostTest> owner_; 49}; 50 51class WebSocketDispatcherHostTest : public ::testing::Test { 52 public: 53 WebSocketDispatcherHostTest() 54 : weak_ptr_factory_(this) { 55 dispatcher_host_ = new WebSocketDispatcherHost( 56 kMagicRenderProcessId, 57 base::Bind(&WebSocketDispatcherHostTest::OnGetRequestContext, 58 base::Unretained(this)), 59 base::Bind(&WebSocketDispatcherHostTest::CreateWebSocketHost, 60 base::Unretained(this))); 61 } 62 63 virtual ~WebSocketDispatcherHostTest() { 64 // We need to invalidate the issued WeakPtrs at the beginning of the 65 // destructor in order not to access destructed member variables. 66 weak_ptr_factory_.InvalidateWeakPtrs(); 67 } 68 69 void GoAway(int routing_id) { 70 gone_hosts_.push_back(routing_id); 71 } 72 73 base::WeakPtr<WebSocketDispatcherHostTest> GetWeakPtr() { 74 return weak_ptr_factory_.GetWeakPtr(); 75 } 76 77 protected: 78 scoped_refptr<WebSocketDispatcherHost> dispatcher_host_; 79 80 // Stores allocated MockWebSocketHost instances. Doesn't take ownership of 81 // them. 82 std::vector<MockWebSocketHost*> mock_hosts_; 83 std::vector<int> gone_hosts_; 84 85 base::WeakPtrFactory<WebSocketDispatcherHostTest> weak_ptr_factory_; 86 87 private: 88 net::URLRequestContext* OnGetRequestContext() { 89 return NULL; 90 } 91 92 WebSocketHost* CreateWebSocketHost(int routing_id) { 93 MockWebSocketHost* host = 94 new MockWebSocketHost(routing_id, dispatcher_host_.get(), NULL, this); 95 mock_hosts_.push_back(host); 96 return host; 97 } 98}; 99 100MockWebSocketHost::MockWebSocketHost( 101 int routing_id, 102 WebSocketDispatcherHost* dispatcher, 103 net::URLRequestContext* url_request_context, 104 WebSocketDispatcherHostTest* owner) 105 : WebSocketHost(routing_id, dispatcher, url_request_context), 106 owner_(owner->GetWeakPtr()) {} 107 108void MockWebSocketHost::GoAway() { 109 if (owner_) 110 owner_->GoAway(routing_id()); 111} 112 113TEST_F(WebSocketDispatcherHostTest, Construct) { 114 // Do nothing. 115} 116 117TEST_F(WebSocketDispatcherHostTest, UnrelatedMessage) { 118 IPC::Message message; 119 EXPECT_FALSE(dispatcher_host_->OnMessageReceived(message)); 120} 121 122TEST_F(WebSocketDispatcherHostTest, RenderProcessIdGetter) { 123 EXPECT_EQ(kMagicRenderProcessId, dispatcher_host_->render_process_id()); 124} 125 126TEST_F(WebSocketDispatcherHostTest, AddChannelRequest) { 127 int routing_id = 123; 128 GURL socket_url("ws://example.com/test"); 129 std::vector<std::string> requested_protocols; 130 requested_protocols.push_back("hello"); 131 url::Origin origin("http://example.com/test"); 132 int render_frame_id = -2; 133 WebSocketHostMsg_AddChannelRequest message( 134 routing_id, socket_url, requested_protocols, origin, render_frame_id); 135 136 ASSERT_TRUE(dispatcher_host_->OnMessageReceived(message)); 137 138 ASSERT_EQ(1U, mock_hosts_.size()); 139 MockWebSocketHost* host = mock_hosts_[0]; 140 141 ASSERT_EQ(1U, host->received_messages_.size()); 142 const IPC::Message& forwarded_message = host->received_messages_[0]; 143 EXPECT_EQ(WebSocketHostMsg_AddChannelRequest::ID, forwarded_message.type()); 144 EXPECT_EQ(routing_id, forwarded_message.routing_id()); 145} 146 147TEST_F(WebSocketDispatcherHostTest, SendFrameButNoHostYet) { 148 int routing_id = 123; 149 std::vector<char> data; 150 WebSocketMsg_SendFrame message( 151 routing_id, true, WEB_SOCKET_MESSAGE_TYPE_TEXT, data); 152 153 // Expected to be ignored. 154 EXPECT_TRUE(dispatcher_host_->OnMessageReceived(message)); 155 156 EXPECT_EQ(0U, mock_hosts_.size()); 157} 158 159TEST_F(WebSocketDispatcherHostTest, SendFrame) { 160 int routing_id = 123; 161 162 GURL socket_url("ws://example.com/test"); 163 std::vector<std::string> requested_protocols; 164 requested_protocols.push_back("hello"); 165 url::Origin origin("http://example.com/test"); 166 int render_frame_id = -2; 167 WebSocketHostMsg_AddChannelRequest add_channel_message( 168 routing_id, socket_url, requested_protocols, origin, render_frame_id); 169 170 ASSERT_TRUE(dispatcher_host_->OnMessageReceived(add_channel_message)); 171 172 std::vector<char> data; 173 WebSocketMsg_SendFrame send_frame_message( 174 routing_id, true, WEB_SOCKET_MESSAGE_TYPE_TEXT, data); 175 176 EXPECT_TRUE(dispatcher_host_->OnMessageReceived(send_frame_message)); 177 178 ASSERT_EQ(1U, mock_hosts_.size()); 179 MockWebSocketHost* host = mock_hosts_[0]; 180 181 ASSERT_EQ(2U, host->received_messages_.size()); 182 { 183 const IPC::Message& forwarded_message = host->received_messages_[0]; 184 EXPECT_EQ(WebSocketHostMsg_AddChannelRequest::ID, forwarded_message.type()); 185 EXPECT_EQ(routing_id, forwarded_message.routing_id()); 186 } 187 { 188 const IPC::Message& forwarded_message = host->received_messages_[1]; 189 EXPECT_EQ(WebSocketMsg_SendFrame::ID, forwarded_message.type()); 190 EXPECT_EQ(routing_id, forwarded_message.routing_id()); 191 } 192} 193 194TEST_F(WebSocketDispatcherHostTest, Destruct) { 195 WebSocketHostMsg_AddChannelRequest message1( 196 123, GURL("ws://example.com/test"), std::vector<std::string>(), 197 url::Origin("http://example.com"), -1); 198 WebSocketHostMsg_AddChannelRequest message2( 199 456, GURL("ws://example.com/test2"), std::vector<std::string>(), 200 url::Origin("http://example.com"), -1); 201 202 ASSERT_TRUE(dispatcher_host_->OnMessageReceived(message1)); 203 ASSERT_TRUE(dispatcher_host_->OnMessageReceived(message2)); 204 205 ASSERT_EQ(2u, mock_hosts_.size()); 206 207 mock_hosts_.clear(); 208 dispatcher_host_ = NULL; 209 210 ASSERT_EQ(2u, gone_hosts_.size()); 211 // The gone_hosts_ ordering is not predictable because it depends on the 212 // hash_map ordering. 213 std::sort(gone_hosts_.begin(), gone_hosts_.end()); 214 EXPECT_EQ(123, gone_hosts_[0]); 215 EXPECT_EQ(456, gone_hosts_[1]); 216} 217 218} // namespace 219} // namespace content 220