1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "content/browser/renderer_host/websocket_dispatcher_host.h"
6
7#include <vector>
8
9#include "base/bind.h"
10#include "base/bind_helpers.h"
11#include "base/memory/ref_counted.h"
12#include "content/browser/renderer_host/websocket_host.h"
13#include "content/common/websocket.h"
14#include "content/common/websocket_messages.h"
15#include "ipc/ipc_message.h"
16#include "testing/gtest/include/gtest/gtest.h"
17#include "url/gurl.h"
18#include "url/origin.h"
19
20namespace content {
21namespace {
22
23// This number is unlikely to occur by chance.
24static const int kMagicRenderProcessId = 506116062;
25
26// A mock of WebsocketHost which records received messages.
27class MockWebSocketHost : public WebSocketHost {
28 public:
29  MockWebSocketHost(int routing_id,
30                    WebSocketDispatcherHost* dispatcher,
31                    net::URLRequestContext* url_request_context)
32      : WebSocketHost(routing_id, dispatcher, url_request_context) {
33  }
34
35  virtual ~MockWebSocketHost() {}
36
37  virtual bool OnMessageReceived(const IPC::Message& message) OVERRIDE{
38    received_messages_.push_back(message);
39    return true;
40  }
41
42  std::vector<IPC::Message> received_messages_;
43};
44
45class WebSocketDispatcherHostTest : public ::testing::Test {
46 public:
47  WebSocketDispatcherHostTest() {
48    dispatcher_host_ = new WebSocketDispatcherHost(
49        kMagicRenderProcessId,
50        base::Bind(&WebSocketDispatcherHostTest::OnGetRequestContext,
51                   base::Unretained(this)),
52        base::Bind(&WebSocketDispatcherHostTest::CreateWebSocketHost,
53                   base::Unretained(this)));
54  }
55
56  virtual ~WebSocketDispatcherHostTest() {}
57
58 protected:
59  scoped_refptr<WebSocketDispatcherHost> dispatcher_host_;
60
61  // Stores allocated MockWebSocketHost instances. Doesn't take ownership of
62  // them.
63  std::vector<MockWebSocketHost*> mock_hosts_;
64
65 private:
66  net::URLRequestContext* OnGetRequestContext() {
67    return NULL;
68  }
69
70  WebSocketHost* CreateWebSocketHost(int routing_id) {
71    MockWebSocketHost* host =
72        new MockWebSocketHost(routing_id, dispatcher_host_.get(), NULL);
73    mock_hosts_.push_back(host);
74    return host;
75  }
76};
77
78TEST_F(WebSocketDispatcherHostTest, Construct) {
79  // Do nothing.
80}
81
82TEST_F(WebSocketDispatcherHostTest, UnrelatedMessage) {
83  IPC::Message message;
84  EXPECT_FALSE(dispatcher_host_->OnMessageReceived(message));
85}
86
87TEST_F(WebSocketDispatcherHostTest, RenderProcessIdGetter) {
88  EXPECT_EQ(kMagicRenderProcessId, dispatcher_host_->render_process_id());
89}
90
91TEST_F(WebSocketDispatcherHostTest, AddChannelRequest) {
92  int routing_id = 123;
93  GURL socket_url("ws://example.com/test");
94  std::vector<std::string> requested_protocols;
95  requested_protocols.push_back("hello");
96  url::Origin origin("http://example.com/test");
97  int render_frame_id = -2;
98  WebSocketHostMsg_AddChannelRequest message(
99      routing_id, socket_url, requested_protocols, origin, render_frame_id);
100
101  ASSERT_TRUE(dispatcher_host_->OnMessageReceived(message));
102
103  ASSERT_EQ(1U, mock_hosts_.size());
104  MockWebSocketHost* host = mock_hosts_[0];
105
106  ASSERT_EQ(1U, host->received_messages_.size());
107  const IPC::Message& forwarded_message = host->received_messages_[0];
108  EXPECT_EQ(WebSocketHostMsg_AddChannelRequest::ID, forwarded_message.type());
109  EXPECT_EQ(routing_id, forwarded_message.routing_id());
110}
111
112TEST_F(WebSocketDispatcherHostTest, SendFrameButNoHostYet) {
113  int routing_id = 123;
114  std::vector<char> data;
115  WebSocketMsg_SendFrame message(
116      routing_id, true, WEB_SOCKET_MESSAGE_TYPE_TEXT, data);
117
118  // Expected to be ignored.
119  EXPECT_TRUE(dispatcher_host_->OnMessageReceived(message));
120
121  EXPECT_EQ(0U, mock_hosts_.size());
122}
123
124TEST_F(WebSocketDispatcherHostTest, SendFrame) {
125  int routing_id = 123;
126
127  GURL socket_url("ws://example.com/test");
128  std::vector<std::string> requested_protocols;
129  requested_protocols.push_back("hello");
130  url::Origin origin("http://example.com/test");
131  int render_frame_id = -2;
132  WebSocketHostMsg_AddChannelRequest add_channel_message(
133      routing_id, socket_url, requested_protocols, origin, render_frame_id);
134
135  ASSERT_TRUE(dispatcher_host_->OnMessageReceived(add_channel_message));
136
137  std::vector<char> data;
138  WebSocketMsg_SendFrame send_frame_message(
139      routing_id, true, WEB_SOCKET_MESSAGE_TYPE_TEXT, data);
140
141  EXPECT_TRUE(dispatcher_host_->OnMessageReceived(send_frame_message));
142
143  ASSERT_EQ(1U, mock_hosts_.size());
144  MockWebSocketHost* host = mock_hosts_[0];
145
146  ASSERT_EQ(2U, host->received_messages_.size());
147  {
148    const IPC::Message& forwarded_message = host->received_messages_[0];
149    EXPECT_EQ(WebSocketHostMsg_AddChannelRequest::ID, forwarded_message.type());
150    EXPECT_EQ(routing_id, forwarded_message.routing_id());
151  }
152  {
153    const IPC::Message& forwarded_message = host->received_messages_[1];
154    EXPECT_EQ(WebSocketMsg_SendFrame::ID, forwarded_message.type());
155    EXPECT_EQ(routing_id, forwarded_message.routing_id());
156  }
157}
158
159}  // namespace
160}  // namespace content
161