1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <string>
6#include <vector>
7
8#include "base/bind.h"
9#include "base/compiler_specific.h"
10#include "base/location.h"
11#include "base/memory/ref_counted.h"
12#include "base/memory/scoped_ptr.h"
13#include "base/message_loop/message_loop.h"
14#include "base/message_loop/message_loop_proxy.h"
15#include "base/run_loop.h"
16#include "base/single_thread_task_runner.h"
17#include "base/threading/thread.h"
18#include "base/time/time.h"
19#include "chrome/test/chromedriver/net/test_http_server.h"
20#include "chrome/test/chromedriver/net/websocket.h"
21#include "net/url_request/url_request_test_util.h"
22#include "testing/gtest/include/gtest/gtest.h"
23#include "url/gurl.h"
24
25namespace {
26
27void OnConnectFinished(base::RunLoop* run_loop, int* save_error, int error) {
28  *save_error = error;
29  run_loop->Quit();
30}
31
32void RunPending(base::MessageLoop* loop) {
33  base::RunLoop run_loop;
34  loop->PostTask(FROM_HERE, run_loop.QuitClosure());
35  run_loop.Run();
36}
37
38class Listener : public WebSocketListener {
39 public:
40  explicit Listener(const std::vector<std::string>& messages)
41      : messages_(messages) {}
42
43  virtual ~Listener() {
44    EXPECT_TRUE(messages_.empty());
45  }
46
47  virtual void OnMessageReceived(const std::string& message) OVERRIDE {
48    ASSERT_TRUE(messages_.size());
49    EXPECT_EQ(messages_[0], message);
50    messages_.erase(messages_.begin());
51    if (messages_.empty())
52      base::MessageLoop::current()->Quit();
53  }
54
55  virtual void OnClose() OVERRIDE {
56    EXPECT_TRUE(false);
57  }
58
59 private:
60  std::vector<std::string> messages_;
61};
62
63class CloseListener : public WebSocketListener {
64 public:
65  explicit CloseListener(base::RunLoop* run_loop)
66      : run_loop_(run_loop) {}
67
68  virtual ~CloseListener() {
69    EXPECT_FALSE(run_loop_);
70  }
71
72  virtual void OnMessageReceived(const std::string& message) OVERRIDE {}
73
74  virtual void OnClose() OVERRIDE {
75    EXPECT_TRUE(run_loop_);
76    if (run_loop_)
77      run_loop_->Quit();
78    run_loop_ = NULL;
79  }
80
81 private:
82  base::RunLoop* run_loop_;
83};
84
85class WebSocketTest : public testing::Test {
86 public:
87  WebSocketTest() {}
88  virtual ~WebSocketTest() {}
89
90  virtual void SetUp() OVERRIDE {
91    ASSERT_TRUE(server_.Start());
92  }
93
94  virtual void TearDown() OVERRIDE {
95    server_.Stop();
96  }
97
98 protected:
99  scoped_ptr<WebSocket> CreateWebSocket(const GURL& url,
100                                        WebSocketListener* listener) {
101    int error;
102    scoped_ptr<WebSocket> sock(new WebSocket(url, listener));
103    base::RunLoop run_loop;
104    sock->Connect(base::Bind(&OnConnectFinished, &run_loop, &error));
105    loop_.PostDelayedTask(
106        FROM_HERE, run_loop.QuitClosure(),
107        base::TimeDelta::FromSeconds(10));
108    run_loop.Run();
109    if (error == net::OK)
110      return sock.Pass();
111    return scoped_ptr<WebSocket>();
112  }
113
114  scoped_ptr<WebSocket> CreateConnectedWebSocket(WebSocketListener* listener) {
115    return CreateWebSocket(server_.web_socket_url(), listener);
116  }
117
118  void SendReceive(const std::vector<std::string>& messages) {
119    Listener listener(messages);
120    scoped_ptr<WebSocket> sock(CreateConnectedWebSocket(&listener));
121    ASSERT_TRUE(sock);
122    for (size_t i = 0; i < messages.size(); ++i) {
123      ASSERT_TRUE(sock->Send(messages[i]));
124    }
125    base::RunLoop run_loop;
126    loop_.PostDelayedTask(
127        FROM_HERE, run_loop.QuitClosure(),
128        base::TimeDelta::FromSeconds(10));
129    run_loop.Run();
130  }
131
132  base::MessageLoopForIO loop_;
133  TestHttpServer server_;
134};
135
136}  // namespace
137
138TEST_F(WebSocketTest, CreateDestroy) {
139  CloseListener listener(NULL);
140  WebSocket sock(GURL("ws://127.0.0.1:2222"), &listener);
141}
142
143TEST_F(WebSocketTest, Connect) {
144  CloseListener listener(NULL);
145  ASSERT_TRUE(CreateWebSocket(server_.web_socket_url(), &listener));
146  RunPending(&loop_);
147  ASSERT_TRUE(server_.WaitForConnectionsToClose());
148}
149
150TEST_F(WebSocketTest, ConnectNoServer) {
151  CloseListener listener(NULL);
152  ASSERT_FALSE(CreateWebSocket(GURL("ws://127.0.0.1:33333"), NULL));
153}
154
155TEST_F(WebSocketTest, Connect404) {
156  server_.SetRequestAction(TestHttpServer::kNotFound);
157  CloseListener listener(NULL);
158  ASSERT_FALSE(CreateWebSocket(server_.web_socket_url(), NULL));
159  RunPending(&loop_);
160  ASSERT_TRUE(server_.WaitForConnectionsToClose());
161}
162
163TEST_F(WebSocketTest, ConnectServerClosesConn) {
164  server_.SetRequestAction(TestHttpServer::kClose);
165  CloseListener listener(NULL);
166  ASSERT_FALSE(CreateWebSocket(server_.web_socket_url(), &listener));
167}
168
169TEST_F(WebSocketTest, CloseOnReceive) {
170  server_.SetMessageAction(TestHttpServer::kCloseOnMessage);
171  base::RunLoop run_loop;
172  CloseListener listener(&run_loop);
173  scoped_ptr<WebSocket> sock(CreateConnectedWebSocket(&listener));
174  ASSERT_TRUE(sock);
175  ASSERT_TRUE(sock->Send("hi"));
176  loop_.PostDelayedTask(
177      FROM_HERE, run_loop.QuitClosure(),
178      base::TimeDelta::FromSeconds(10));
179  run_loop.Run();
180}
181
182TEST_F(WebSocketTest, CloseOnSend) {
183  base::RunLoop run_loop;
184  CloseListener listener(&run_loop);
185  scoped_ptr<WebSocket> sock(CreateConnectedWebSocket(&listener));
186  ASSERT_TRUE(sock);
187  server_.Stop();
188
189  sock->Send("hi");
190  loop_.PostDelayedTask(
191      FROM_HERE, run_loop.QuitClosure(),
192      base::TimeDelta::FromSeconds(10));
193  run_loop.Run();
194  ASSERT_FALSE(sock->Send("hi"));
195}
196
197TEST_F(WebSocketTest, SendReceive) {
198  std::vector<std::string> messages;
199  messages.push_back("hello");
200  SendReceive(messages);
201}
202
203TEST_F(WebSocketTest, SendReceiveLarge) {
204  std::vector<std::string> messages;
205  messages.push_back(std::string(10 << 20, 'a'));
206  SendReceive(messages);
207}
208
209TEST_F(WebSocketTest, SendReceiveMultiple) {
210  std::vector<std::string> messages;
211  messages.push_back("1");
212  messages.push_back("2");
213  messages.push_back("3");
214  SendReceive(messages);
215}
216