embedded_test_server.cc revision b2df76ea8fec9e32f6f3718986dba0d95315b29c
1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/test/embedded_test_server/embedded_test_server.h"
6
7#include "base/bind.h"
8#include "base/run_loop.h"
9#include "base/stl_util.h"
10#include "base/string_util.h"
11#include "base/stringprintf.h"
12#include "net/test/embedded_test_server/http_connection.h"
13#include "net/test/embedded_test_server/http_request.h"
14#include "net/test/embedded_test_server/http_response.h"
15#include "net/tools/fetch/http_listen_socket.h"
16
17namespace net {
18namespace test_server {
19
20namespace {
21
22const int kPort = 8040;
23const char kIp[] = "127.0.0.1";
24const int kRetries = 10;
25
26// Callback to handle requests with default predefined response for requests
27// matching the address |url|.
28scoped_ptr<HttpResponse> HandleDefaultRequest(const GURL& url,
29                                              const HttpResponse& response,
30                                              const HttpRequest& request) {
31  const GURL request_url = url.Resolve(request.relative_url);
32  if (url.path() != request_url.path())
33    return scoped_ptr<HttpResponse>(NULL);
34  return scoped_ptr<HttpResponse>(new HttpResponse(response));
35}
36
37}  // namespace
38
39HttpListenSocket::HttpListenSocket(const SocketDescriptor socket_descriptor,
40                                   StreamListenSocket::Delegate* delegate)
41    : TCPListenSocket(socket_descriptor, delegate) {
42  DCHECK(thread_checker_.CalledOnValidThread());
43}
44
45void HttpListenSocket::Listen() {
46  DCHECK(thread_checker_.CalledOnValidThread());
47  TCPListenSocket::Listen();
48}
49
50HttpListenSocket::~HttpListenSocket() {
51  DCHECK(thread_checker_.CalledOnValidThread());
52}
53
54EmbeddedTestServer::EmbeddedTestServer(
55    const scoped_refptr<base::SingleThreadTaskRunner>& io_thread)
56    : io_thread_(io_thread),
57      port_(-1),
58      weak_factory_(this) {
59  DCHECK(io_thread_);
60  DCHECK(thread_checker_.CalledOnValidThread());
61}
62
63EmbeddedTestServer::~EmbeddedTestServer() {
64  DCHECK(thread_checker_.CalledOnValidThread());
65}
66
67bool EmbeddedTestServer::InitializeAndWaitUntilReady() {
68  DCHECK(thread_checker_.CalledOnValidThread());
69
70  base::RunLoop run_loop;
71  if (!io_thread_->PostTaskAndReply(
72          FROM_HERE,
73          base::Bind(&EmbeddedTestServer::InitializeOnIOThread,
74                     base::Unretained(this)),
75          run_loop.QuitClosure())) {
76    return false;
77  }
78  run_loop.Run();
79
80  return Started();
81}
82
83bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() {
84  DCHECK(thread_checker_.CalledOnValidThread());
85
86  base::RunLoop run_loop;
87  if (!io_thread_->PostTaskAndReply(
88          FROM_HERE,
89          base::Bind(&EmbeddedTestServer::ShutdownOnIOThread,
90                     base::Unretained(this)),
91          run_loop.QuitClosure())) {
92    return false;
93  }
94  run_loop.Run();
95
96  return true;
97}
98
99void EmbeddedTestServer::InitializeOnIOThread() {
100  DCHECK(io_thread_->BelongsToCurrentThread());
101  DCHECK(!Started());
102
103  int retries_left = kRetries + 1;
104  int try_port = kPort;
105
106  while (retries_left > 0) {
107    SocketDescriptor socket_descriptor = TCPListenSocket::CreateAndBind(
108        kIp,
109        try_port);
110    if (socket_descriptor != TCPListenSocket::kInvalidSocket) {
111      listen_socket_ = new HttpListenSocket(socket_descriptor, this);
112      listen_socket_->Listen();
113      base_url_ = GURL(base::StringPrintf("http://%s:%d", kIp, try_port));
114      port_ = try_port;
115      break;
116    }
117    retries_left--;
118    try_port++;
119  }
120}
121
122void EmbeddedTestServer::ShutdownOnIOThread() {
123  DCHECK(io_thread_->BelongsToCurrentThread());
124
125  listen_socket_ = NULL;  // Release the listen socket.
126  STLDeleteContainerPairSecondPointers(connections_.begin(),
127                                       connections_.end());
128  connections_.clear();
129}
130
131void EmbeddedTestServer::HandleRequest(HttpConnection* connection,
132                               scoped_ptr<HttpRequest> request) {
133  DCHECK(io_thread_->BelongsToCurrentThread());
134
135  for (size_t i = 0; i < request_handlers_.size(); ++i) {
136    scoped_ptr<HttpResponse> response =
137        request_handlers_[i].Run(*request.get());
138    if (response.get()) {
139      connection->SendResponse(response.Pass());
140      return;
141    }
142  }
143
144  LOG(WARNING) << "Request not handled. Returning 404: "
145               << request->relative_url;
146  scoped_ptr<HttpResponse> not_found_response(new HttpResponse());
147  not_found_response->set_code(NOT_FOUND);
148  connection->SendResponse(not_found_response.Pass());
149
150  // Drop the connection, since we do not support multiple requests per
151  // connection.
152  connections_.erase(connection->socket_.get());
153  delete connection;
154}
155
156GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const {
157  DCHECK(StartsWithASCII(relative_url, "/", true /* case_sensitive */))
158      << relative_url;
159  return base_url_.Resolve(relative_url);
160}
161
162void EmbeddedTestServer::RegisterRequestHandler(
163    const HandleRequestCallback& callback) {
164  request_handlers_.push_back(callback);
165}
166
167void EmbeddedTestServer::DidAccept(StreamListenSocket* server,
168                           StreamListenSocket* connection) {
169  DCHECK(io_thread_->BelongsToCurrentThread());
170
171  HttpConnection* http_connection = new HttpConnection(
172      connection,
173      base::Bind(&EmbeddedTestServer::HandleRequest,
174                 weak_factory_.GetWeakPtr()));
175  connections_[connection] = http_connection;
176}
177
178void EmbeddedTestServer::DidRead(StreamListenSocket* connection,
179                         const char* data,
180                         int length) {
181  DCHECK(io_thread_->BelongsToCurrentThread());
182
183  HttpConnection* http_connection = FindConnection(connection);
184  if (http_connection == NULL) {
185    LOG(WARNING) << "Unknown connection.";
186    return;
187  }
188  http_connection->ReceiveData(std::string(data, length));
189}
190
191void EmbeddedTestServer::DidClose(StreamListenSocket* connection) {
192  DCHECK(io_thread_->BelongsToCurrentThread());
193
194  HttpConnection* http_connection = FindConnection(connection);
195  if (http_connection == NULL) {
196    LOG(WARNING) << "Unknown connection.";
197    return;
198  }
199  delete http_connection;
200  connections_.erase(connection);
201}
202
203HttpConnection* EmbeddedTestServer::FindConnection(
204    StreamListenSocket* socket) {
205  DCHECK(io_thread_->BelongsToCurrentThread());
206
207  std::map<StreamListenSocket*, HttpConnection*>::iterator it =
208      connections_.find(socket);
209  if (it == connections_.end()) {
210    return NULL;
211  }
212  return it->second;
213}
214
215}  // namespace test_server
216}  // namespace net
217