1// Copyright 2013 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "content/browser/renderer_host/websocket_dispatcher_host.h" 6 7#include <vector> 8 9#include "base/bind.h" 10#include "base/bind_helpers.h" 11#include "base/memory/ref_counted.h" 12#include "content/browser/renderer_host/websocket_host.h" 13#include "content/common/websocket.h" 14#include "content/common/websocket_messages.h" 15#include "ipc/ipc_message.h" 16#include "testing/gtest/include/gtest/gtest.h" 17#include "url/gurl.h" 18#include "url/origin.h" 19 20namespace content { 21namespace { 22 23// This number is unlikely to occur by chance. 24static const int kMagicRenderProcessId = 506116062; 25 26// A mock of WebsocketHost which records received messages. 27class MockWebSocketHost : public WebSocketHost { 28 public: 29 MockWebSocketHost(int routing_id, 30 WebSocketDispatcherHost* dispatcher, 31 net::URLRequestContext* url_request_context) 32 : WebSocketHost(routing_id, dispatcher, url_request_context) { 33 } 34 35 virtual ~MockWebSocketHost() {} 36 37 virtual bool OnMessageReceived(const IPC::Message& message) OVERRIDE{ 38 received_messages_.push_back(message); 39 return true; 40 } 41 42 std::vector<IPC::Message> received_messages_; 43}; 44 45class WebSocketDispatcherHostTest : public ::testing::Test { 46 public: 47 WebSocketDispatcherHostTest() { 48 dispatcher_host_ = new WebSocketDispatcherHost( 49 kMagicRenderProcessId, 50 base::Bind(&WebSocketDispatcherHostTest::OnGetRequestContext, 51 base::Unretained(this)), 52 base::Bind(&WebSocketDispatcherHostTest::CreateWebSocketHost, 53 base::Unretained(this))); 54 } 55 56 virtual ~WebSocketDispatcherHostTest() {} 57 58 protected: 59 scoped_refptr<WebSocketDispatcherHost> dispatcher_host_; 60 61 // Stores allocated MockWebSocketHost instances. Doesn't take ownership of 62 // them. 63 std::vector<MockWebSocketHost*> mock_hosts_; 64 65 private: 66 net::URLRequestContext* OnGetRequestContext() { 67 return NULL; 68 } 69 70 WebSocketHost* CreateWebSocketHost(int routing_id) { 71 MockWebSocketHost* host = 72 new MockWebSocketHost(routing_id, dispatcher_host_.get(), NULL); 73 mock_hosts_.push_back(host); 74 return host; 75 } 76}; 77 78TEST_F(WebSocketDispatcherHostTest, Construct) { 79 // Do nothing. 80} 81 82TEST_F(WebSocketDispatcherHostTest, UnrelatedMessage) { 83 IPC::Message message; 84 EXPECT_FALSE(dispatcher_host_->OnMessageReceived(message)); 85} 86 87TEST_F(WebSocketDispatcherHostTest, RenderProcessIdGetter) { 88 EXPECT_EQ(kMagicRenderProcessId, dispatcher_host_->render_process_id()); 89} 90 91TEST_F(WebSocketDispatcherHostTest, AddChannelRequest) { 92 int routing_id = 123; 93 GURL socket_url("ws://example.com/test"); 94 std::vector<std::string> requested_protocols; 95 requested_protocols.push_back("hello"); 96 url::Origin origin("http://example.com/test"); 97 int render_frame_id = -2; 98 WebSocketHostMsg_AddChannelRequest message( 99 routing_id, socket_url, requested_protocols, origin, render_frame_id); 100 101 ASSERT_TRUE(dispatcher_host_->OnMessageReceived(message)); 102 103 ASSERT_EQ(1U, mock_hosts_.size()); 104 MockWebSocketHost* host = mock_hosts_[0]; 105 106 ASSERT_EQ(1U, host->received_messages_.size()); 107 const IPC::Message& forwarded_message = host->received_messages_[0]; 108 EXPECT_EQ(WebSocketHostMsg_AddChannelRequest::ID, forwarded_message.type()); 109 EXPECT_EQ(routing_id, forwarded_message.routing_id()); 110} 111 112TEST_F(WebSocketDispatcherHostTest, SendFrameButNoHostYet) { 113 int routing_id = 123; 114 std::vector<char> data; 115 WebSocketMsg_SendFrame message( 116 routing_id, true, WEB_SOCKET_MESSAGE_TYPE_TEXT, data); 117 118 // Expected to be ignored. 119 EXPECT_TRUE(dispatcher_host_->OnMessageReceived(message)); 120 121 EXPECT_EQ(0U, mock_hosts_.size()); 122} 123 124TEST_F(WebSocketDispatcherHostTest, SendFrame) { 125 int routing_id = 123; 126 127 GURL socket_url("ws://example.com/test"); 128 std::vector<std::string> requested_protocols; 129 requested_protocols.push_back("hello"); 130 url::Origin origin("http://example.com/test"); 131 int render_frame_id = -2; 132 WebSocketHostMsg_AddChannelRequest add_channel_message( 133 routing_id, socket_url, requested_protocols, origin, render_frame_id); 134 135 ASSERT_TRUE(dispatcher_host_->OnMessageReceived(add_channel_message)); 136 137 std::vector<char> data; 138 WebSocketMsg_SendFrame send_frame_message( 139 routing_id, true, WEB_SOCKET_MESSAGE_TYPE_TEXT, data); 140 141 EXPECT_TRUE(dispatcher_host_->OnMessageReceived(send_frame_message)); 142 143 ASSERT_EQ(1U, mock_hosts_.size()); 144 MockWebSocketHost* host = mock_hosts_[0]; 145 146 ASSERT_EQ(2U, host->received_messages_.size()); 147 { 148 const IPC::Message& forwarded_message = host->received_messages_[0]; 149 EXPECT_EQ(WebSocketHostMsg_AddChannelRequest::ID, forwarded_message.type()); 150 EXPECT_EQ(routing_id, forwarded_message.routing_id()); 151 } 152 { 153 const IPC::Message& forwarded_message = host->received_messages_[1]; 154 EXPECT_EQ(WebSocketMsg_SendFrame::ID, forwarded_message.type()); 155 EXPECT_EQ(routing_id, forwarded_message.routing_id()); 156 } 157} 158 159} // namespace 160} // namespace content 161