embedded_test_server.cc revision 868fa2fe829687343ffae624259930155e16dbd8
1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/test/embedded_test_server/embedded_test_server.h"
6
7#include "base/bind.h"
8#include "base/files/file_path.h"
9#include "base/file_util.h"
10#include "base/path_service.h"
11#include "base/run_loop.h"
12#include "base/stl_util.h"
13#include "base/strings/string_util.h"
14#include "base/strings/stringprintf.h"
15#include "base/threading/thread_restrictions.h"
16#include "net/base/ip_endpoint.h"
17#include "net/base/net_errors.h"
18#include "net/test/embedded_test_server/http_connection.h"
19#include "net/test/embedded_test_server/http_request.h"
20#include "net/test/embedded_test_server/http_response.h"
21#include "net/tools/fetch/http_listen_socket.h"
22
23namespace net {
24namespace test_server {
25
26namespace {
27
28// Handles |request| by serving a file from under |server_root|.
29scoped_ptr<HttpResponse> HandleFileRequest(
30    const base::FilePath& server_root,
31    const HttpRequest& request) {
32  // This is a test-only server. Ignore I/O thread restrictions.
33  base::ThreadRestrictions::ScopedAllowIO allow_io;
34
35  // Trim the first byte ('/').
36  std::string request_path(request.relative_url.substr(1));
37
38  // Remove the query string if present.
39  size_t query_pos = request_path.find('?');
40  if (query_pos != std::string::npos)
41    request_path = request_path.substr(0, query_pos);
42
43  std::string file_contents;
44  if (!file_util::ReadFileToString(
45          server_root.AppendASCII(request_path), &file_contents)) {
46    return scoped_ptr<HttpResponse>(NULL);
47  }
48
49  scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse);
50  http_response->set_code(net::test_server::SUCCESS);
51  http_response->set_content(file_contents);
52  return http_response.PassAs<HttpResponse>();
53}
54
55}  // namespace
56
57HttpListenSocket::HttpListenSocket(const SocketDescriptor socket_descriptor,
58                                   StreamListenSocket::Delegate* delegate)
59    : TCPListenSocket(socket_descriptor, delegate) {
60  DCHECK(thread_checker_.CalledOnValidThread());
61}
62
63void HttpListenSocket::Listen() {
64  DCHECK(thread_checker_.CalledOnValidThread());
65  TCPListenSocket::Listen();
66}
67
68HttpListenSocket::~HttpListenSocket() {
69  DCHECK(thread_checker_.CalledOnValidThread());
70}
71
72EmbeddedTestServer::EmbeddedTestServer(
73    const scoped_refptr<base::SingleThreadTaskRunner>& io_thread)
74    : io_thread_(io_thread),
75      port_(-1),
76      weak_factory_(this) {
77  DCHECK(io_thread_.get());
78  DCHECK(thread_checker_.CalledOnValidThread());
79}
80
81EmbeddedTestServer::~EmbeddedTestServer() {
82  DCHECK(thread_checker_.CalledOnValidThread());
83
84  if (Started() && !ShutdownAndWaitUntilComplete()) {
85    LOG(ERROR) << "EmbeddedTestServer failed to shut down.";
86  }
87}
88
89bool EmbeddedTestServer::InitializeAndWaitUntilReady() {
90  DCHECK(thread_checker_.CalledOnValidThread());
91
92  base::RunLoop run_loop;
93  if (!io_thread_->PostTaskAndReply(
94          FROM_HERE,
95          base::Bind(&EmbeddedTestServer::InitializeOnIOThread,
96                     base::Unretained(this)),
97          run_loop.QuitClosure())) {
98    return false;
99  }
100  run_loop.Run();
101
102  return Started() && base_url_.is_valid();
103}
104
105bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() {
106  DCHECK(thread_checker_.CalledOnValidThread());
107
108  base::RunLoop run_loop;
109  if (!io_thread_->PostTaskAndReply(
110          FROM_HERE,
111          base::Bind(&EmbeddedTestServer::ShutdownOnIOThread,
112                     base::Unretained(this)),
113          run_loop.QuitClosure())) {
114    return false;
115  }
116  run_loop.Run();
117
118  return true;
119}
120
121void EmbeddedTestServer::InitializeOnIOThread() {
122  DCHECK(io_thread_->BelongsToCurrentThread());
123  DCHECK(!Started());
124
125  SocketDescriptor socket_descriptor =
126      TCPListenSocket::CreateAndBindAnyPort("127.0.0.1", &port_);
127  if (socket_descriptor == TCPListenSocket::kInvalidSocket)
128    return;
129
130  listen_socket_ = new HttpListenSocket(socket_descriptor, this);
131  listen_socket_->Listen();
132
133  IPEndPoint address;
134  int result = listen_socket_->GetLocalAddress(&address);
135  if (result == OK) {
136    base_url_ = GURL(std::string("http://") + address.ToString());
137  } else {
138    LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result);
139  }
140}
141
142void EmbeddedTestServer::ShutdownOnIOThread() {
143  DCHECK(io_thread_->BelongsToCurrentThread());
144
145  listen_socket_ = NULL;  // Release the listen socket.
146  STLDeleteContainerPairSecondPointers(connections_.begin(),
147                                       connections_.end());
148  connections_.clear();
149}
150
151void EmbeddedTestServer::HandleRequest(HttpConnection* connection,
152                               scoped_ptr<HttpRequest> request) {
153  DCHECK(io_thread_->BelongsToCurrentThread());
154
155  bool request_handled = false;
156
157  for (size_t i = 0; i < request_handlers_.size(); ++i) {
158    scoped_ptr<HttpResponse> response =
159        request_handlers_[i].Run(*request.get());
160    if (response.get()) {
161      connection->SendResponse(response.Pass());
162      request_handled = true;
163      break;
164    }
165  }
166
167  if (!request_handled) {
168    LOG(WARNING) << "Request not handled. Returning 404: "
169                 << request->relative_url;
170    scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse);
171    not_found_response->set_code(NOT_FOUND);
172    connection->SendResponse(
173        not_found_response.PassAs<HttpResponse>());
174  }
175
176  // Drop the connection, since we do not support multiple requests per
177  // connection.
178  connections_.erase(connection->socket_.get());
179  delete connection;
180}
181
182GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const {
183  DCHECK(StartsWithASCII(relative_url, "/", true /* case_sensitive */))
184      << relative_url;
185  return base_url_.Resolve(relative_url);
186}
187
188void EmbeddedTestServer::ServeFilesFromDirectory(
189    const base::FilePath& directory) {
190  RegisterRequestHandler(base::Bind(&HandleFileRequest, directory));
191}
192
193void EmbeddedTestServer::RegisterRequestHandler(
194    const HandleRequestCallback& callback) {
195  request_handlers_.push_back(callback);
196}
197
198void EmbeddedTestServer::DidAccept(StreamListenSocket* server,
199                           StreamListenSocket* connection) {
200  DCHECK(io_thread_->BelongsToCurrentThread());
201
202  HttpConnection* http_connection = new HttpConnection(
203      connection,
204      base::Bind(&EmbeddedTestServer::HandleRequest,
205                 weak_factory_.GetWeakPtr()));
206  connections_[connection] = http_connection;
207}
208
209void EmbeddedTestServer::DidRead(StreamListenSocket* connection,
210                         const char* data,
211                         int length) {
212  DCHECK(io_thread_->BelongsToCurrentThread());
213
214  HttpConnection* http_connection = FindConnection(connection);
215  if (http_connection == NULL) {
216    LOG(WARNING) << "Unknown connection.";
217    return;
218  }
219  http_connection->ReceiveData(std::string(data, length));
220}
221
222void EmbeddedTestServer::DidClose(StreamListenSocket* connection) {
223  DCHECK(io_thread_->BelongsToCurrentThread());
224
225  HttpConnection* http_connection = FindConnection(connection);
226  if (http_connection == NULL) {
227    LOG(WARNING) << "Unknown connection.";
228    return;
229  }
230  delete http_connection;
231  connections_.erase(connection);
232}
233
234HttpConnection* EmbeddedTestServer::FindConnection(
235    StreamListenSocket* socket) {
236  DCHECK(io_thread_->BelongsToCurrentThread());
237
238  std::map<StreamListenSocket*, HttpConnection*>::iterator it =
239      connections_.find(socket);
240  if (it == connections_.end()) {
241    return NULL;
242  }
243  return it->second;
244}
245
246}  // namespace test_server
247}  // namespace net
248