1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "content/browser/renderer_host/websocket_dispatcher_host.h"
6
7#include <algorithm>
8#include <vector>
9
10#include "base/bind.h"
11#include "base/bind_helpers.h"
12#include "base/memory/ref_counted.h"
13#include "base/memory/weak_ptr.h"
14#include "content/browser/renderer_host/websocket_host.h"
15#include "content/common/websocket.h"
16#include "content/common/websocket_messages.h"
17#include "ipc/ipc_message.h"
18#include "testing/gtest/include/gtest/gtest.h"
19#include "url/gurl.h"
20#include "url/origin.h"
21
22namespace content {
23namespace {
24
25// This number is unlikely to occur by chance.
26static const int kMagicRenderProcessId = 506116062;
27
28class WebSocketDispatcherHostTest;
29
30// A mock of WebsocketHost which records received messages.
31class MockWebSocketHost : public WebSocketHost {
32 public:
33  MockWebSocketHost(int routing_id,
34                    WebSocketDispatcherHost* dispatcher,
35                    net::URLRequestContext* url_request_context,
36                    WebSocketDispatcherHostTest* owner);
37
38  virtual ~MockWebSocketHost() {}
39
40  virtual bool OnMessageReceived(const IPC::Message& message) OVERRIDE {
41    received_messages_.push_back(message);
42    return true;
43  }
44
45  virtual void GoAway() OVERRIDE;
46
47  std::vector<IPC::Message> received_messages_;
48  base::WeakPtr<WebSocketDispatcherHostTest> owner_;
49};
50
51class WebSocketDispatcherHostTest : public ::testing::Test {
52 public:
53  WebSocketDispatcherHostTest()
54      : weak_ptr_factory_(this) {
55    dispatcher_host_ = new WebSocketDispatcherHost(
56        kMagicRenderProcessId,
57        base::Bind(&WebSocketDispatcherHostTest::OnGetRequestContext,
58                   base::Unretained(this)),
59        base::Bind(&WebSocketDispatcherHostTest::CreateWebSocketHost,
60                   base::Unretained(this)));
61  }
62
63  virtual ~WebSocketDispatcherHostTest() {
64    // We need to invalidate the issued WeakPtrs at the beginning of the
65    // destructor in order not to access destructed member variables.
66    weak_ptr_factory_.InvalidateWeakPtrs();
67  }
68
69  void GoAway(int routing_id) {
70    gone_hosts_.push_back(routing_id);
71  }
72
73  base::WeakPtr<WebSocketDispatcherHostTest> GetWeakPtr() {
74    return weak_ptr_factory_.GetWeakPtr();
75  }
76
77 protected:
78  scoped_refptr<WebSocketDispatcherHost> dispatcher_host_;
79
80  // Stores allocated MockWebSocketHost instances. Doesn't take ownership of
81  // them.
82  std::vector<MockWebSocketHost*> mock_hosts_;
83  std::vector<int> gone_hosts_;
84
85  base::WeakPtrFactory<WebSocketDispatcherHostTest> weak_ptr_factory_;
86
87 private:
88  net::URLRequestContext* OnGetRequestContext() {
89    return NULL;
90  }
91
92  WebSocketHost* CreateWebSocketHost(int routing_id) {
93    MockWebSocketHost* host =
94        new MockWebSocketHost(routing_id, dispatcher_host_.get(), NULL, this);
95    mock_hosts_.push_back(host);
96    return host;
97  }
98};
99
100MockWebSocketHost::MockWebSocketHost(
101    int routing_id,
102    WebSocketDispatcherHost* dispatcher,
103    net::URLRequestContext* url_request_context,
104    WebSocketDispatcherHostTest* owner)
105    : WebSocketHost(routing_id, dispatcher, url_request_context),
106      owner_(owner->GetWeakPtr()) {}
107
108void MockWebSocketHost::GoAway() {
109  if (owner_)
110    owner_->GoAway(routing_id());
111}
112
113TEST_F(WebSocketDispatcherHostTest, Construct) {
114  // Do nothing.
115}
116
117TEST_F(WebSocketDispatcherHostTest, UnrelatedMessage) {
118  IPC::Message message;
119  EXPECT_FALSE(dispatcher_host_->OnMessageReceived(message));
120}
121
122TEST_F(WebSocketDispatcherHostTest, RenderProcessIdGetter) {
123  EXPECT_EQ(kMagicRenderProcessId, dispatcher_host_->render_process_id());
124}
125
126TEST_F(WebSocketDispatcherHostTest, AddChannelRequest) {
127  int routing_id = 123;
128  GURL socket_url("ws://example.com/test");
129  std::vector<std::string> requested_protocols;
130  requested_protocols.push_back("hello");
131  url::Origin origin("http://example.com/test");
132  int render_frame_id = -2;
133  WebSocketHostMsg_AddChannelRequest message(
134      routing_id, socket_url, requested_protocols, origin, render_frame_id);
135
136  ASSERT_TRUE(dispatcher_host_->OnMessageReceived(message));
137
138  ASSERT_EQ(1U, mock_hosts_.size());
139  MockWebSocketHost* host = mock_hosts_[0];
140
141  ASSERT_EQ(1U, host->received_messages_.size());
142  const IPC::Message& forwarded_message = host->received_messages_[0];
143  EXPECT_EQ(WebSocketHostMsg_AddChannelRequest::ID, forwarded_message.type());
144  EXPECT_EQ(routing_id, forwarded_message.routing_id());
145}
146
147TEST_F(WebSocketDispatcherHostTest, SendFrameButNoHostYet) {
148  int routing_id = 123;
149  std::vector<char> data;
150  WebSocketMsg_SendFrame message(
151      routing_id, true, WEB_SOCKET_MESSAGE_TYPE_TEXT, data);
152
153  // Expected to be ignored.
154  EXPECT_TRUE(dispatcher_host_->OnMessageReceived(message));
155
156  EXPECT_EQ(0U, mock_hosts_.size());
157}
158
159TEST_F(WebSocketDispatcherHostTest, SendFrame) {
160  int routing_id = 123;
161
162  GURL socket_url("ws://example.com/test");
163  std::vector<std::string> requested_protocols;
164  requested_protocols.push_back("hello");
165  url::Origin origin("http://example.com/test");
166  int render_frame_id = -2;
167  WebSocketHostMsg_AddChannelRequest add_channel_message(
168      routing_id, socket_url, requested_protocols, origin, render_frame_id);
169
170  ASSERT_TRUE(dispatcher_host_->OnMessageReceived(add_channel_message));
171
172  std::vector<char> data;
173  WebSocketMsg_SendFrame send_frame_message(
174      routing_id, true, WEB_SOCKET_MESSAGE_TYPE_TEXT, data);
175
176  EXPECT_TRUE(dispatcher_host_->OnMessageReceived(send_frame_message));
177
178  ASSERT_EQ(1U, mock_hosts_.size());
179  MockWebSocketHost* host = mock_hosts_[0];
180
181  ASSERT_EQ(2U, host->received_messages_.size());
182  {
183    const IPC::Message& forwarded_message = host->received_messages_[0];
184    EXPECT_EQ(WebSocketHostMsg_AddChannelRequest::ID, forwarded_message.type());
185    EXPECT_EQ(routing_id, forwarded_message.routing_id());
186  }
187  {
188    const IPC::Message& forwarded_message = host->received_messages_[1];
189    EXPECT_EQ(WebSocketMsg_SendFrame::ID, forwarded_message.type());
190    EXPECT_EQ(routing_id, forwarded_message.routing_id());
191  }
192}
193
194TEST_F(WebSocketDispatcherHostTest, Destruct) {
195  WebSocketHostMsg_AddChannelRequest message1(
196      123, GURL("ws://example.com/test"), std::vector<std::string>(),
197      url::Origin("http://example.com"), -1);
198  WebSocketHostMsg_AddChannelRequest message2(
199      456, GURL("ws://example.com/test2"), std::vector<std::string>(),
200      url::Origin("http://example.com"), -1);
201
202  ASSERT_TRUE(dispatcher_host_->OnMessageReceived(message1));
203  ASSERT_TRUE(dispatcher_host_->OnMessageReceived(message2));
204
205  ASSERT_EQ(2u, mock_hosts_.size());
206
207  mock_hosts_.clear();
208  dispatcher_host_ = NULL;
209
210  ASSERT_EQ(2u, gone_hosts_.size());
211  // The gone_hosts_ ordering is not predictable because it depends on the
212  // hash_map ordering.
213  std::sort(gone_hosts_.begin(), gone_hosts_.end());
214  EXPECT_EQ(123, gone_hosts_[0]);
215  EXPECT_EQ(456, gone_hosts_[1]);
216}
217
218}  // namespace
219}  // namespace content
220