1// Copyright (c) 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/test/chromedriver/net/test_http_server.h"
6
7#include "base/bind.h"
8#include "base/location.h"
9#include "base/message_loop/message_loop.h"
10#include "base/message_loop/message_loop_proxy.h"
11#include "base/strings/stringprintf.h"
12#include "base/time/time.h"
13#include "net/base/ip_endpoint.h"
14#include "net/base/net_errors.h"
15#include "net/server/http_server_request_info.h"
16#include "net/socket/tcp_server_socket.h"
17#include "testing/gtest/include/gtest/gtest.h"
18
19const int kBufferSize = 100 * 1024 * 1024;  // 100 MB
20
21TestHttpServer::TestHttpServer()
22    : thread_("ServerThread"),
23      all_closed_event_(false, true),
24      request_action_(kAccept),
25      message_action_(kEchoMessage) {
26}
27
28TestHttpServer::~TestHttpServer() {
29}
30
31bool TestHttpServer::Start() {
32  base::Thread::Options options(base::MessageLoop::TYPE_IO, 0);
33  bool thread_started = thread_.StartWithOptions(options);
34  EXPECT_TRUE(thread_started);
35  if (!thread_started)
36    return false;
37  bool success;
38  base::WaitableEvent event(false, false);
39  thread_.message_loop_proxy()->PostTask(
40      FROM_HERE,
41      base::Bind(&TestHttpServer::StartOnServerThread,
42                 base::Unretained(this), &success, &event));
43  event.Wait();
44  return success;
45}
46
47void TestHttpServer::Stop() {
48  if (!thread_.IsRunning())
49    return;
50  base::WaitableEvent event(false, false);
51  thread_.message_loop_proxy()->PostTask(
52      FROM_HERE,
53      base::Bind(&TestHttpServer::StopOnServerThread,
54                 base::Unretained(this), &event));
55  event.Wait();
56  thread_.Stop();
57}
58
59bool TestHttpServer::WaitForConnectionsToClose() {
60  return all_closed_event_.TimedWait(base::TimeDelta::FromSeconds(10));
61}
62
63void TestHttpServer::SetRequestAction(WebSocketRequestAction action) {
64  base::AutoLock lock(action_lock_);
65  request_action_ = action;
66}
67
68void TestHttpServer::SetMessageAction(WebSocketMessageAction action) {
69  base::AutoLock lock(action_lock_);
70  message_action_ = action;
71}
72
73GURL TestHttpServer::web_socket_url() const {
74  base::AutoLock lock(url_lock_);
75  return web_socket_url_;
76}
77
78void TestHttpServer::OnConnect(int connection_id) {
79  server_->SetSendBufferSize(connection_id, kBufferSize);
80  server_->SetReceiveBufferSize(connection_id, kBufferSize);
81}
82
83void TestHttpServer::OnWebSocketRequest(
84    int connection_id,
85    const net::HttpServerRequestInfo& info) {
86  WebSocketRequestAction action;
87  {
88    base::AutoLock lock(action_lock_);
89    action = request_action_;
90  }
91  connections_.insert(connection_id);
92  all_closed_event_.Reset();
93
94  switch (action) {
95    case kAccept:
96      server_->AcceptWebSocket(connection_id, info);
97      break;
98    case kNotFound:
99      server_->Send404(connection_id);
100      break;
101    case kClose:
102      server_->Close(connection_id);
103      break;
104  }
105}
106
107void TestHttpServer::OnWebSocketMessage(int connection_id,
108                                        const std::string& data) {
109  WebSocketMessageAction action;
110  {
111    base::AutoLock lock(action_lock_);
112    action = message_action_;
113  }
114  switch (action) {
115    case kEchoMessage:
116      server_->SendOverWebSocket(connection_id, data);
117      break;
118    case kCloseOnMessage:
119      server_->Close(connection_id);
120      break;
121  }
122}
123
124void TestHttpServer::OnClose(int connection_id) {
125  connections_.erase(connection_id);
126  if (connections_.empty())
127    all_closed_event_.Signal();
128}
129
130void TestHttpServer::StartOnServerThread(bool* success,
131                                         base::WaitableEvent* event) {
132  scoped_ptr<net::ServerSocket> server_socket(
133      new net::TCPServerSocket(NULL, net::NetLog::Source()));
134  server_socket->ListenWithAddressAndPort("127.0.0.1", 0, 1);
135  server_.reset(new net::HttpServer(server_socket.Pass(), this));
136
137  net::IPEndPoint address;
138  int error = server_->GetLocalAddress(&address);
139  EXPECT_EQ(net::OK, error);
140  if (error == net::OK) {
141    base::AutoLock lock(url_lock_);
142    web_socket_url_ = GURL(base::StringPrintf("ws://127.0.0.1:%d",
143                                              address.port()));
144  } else {
145    server_.reset(NULL);
146  }
147  *success = server_.get();
148  event->Signal();
149}
150
151void TestHttpServer::StopOnServerThread(base::WaitableEvent* event) {
152  server_.reset(NULL);
153  event->Signal();
154}
155