embedded_test_server.cc revision 7dbb3d5cf0c15f500944d211057644d6a2f37371
1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/test/embedded_test_server/embedded_test_server.h"
6
7#include "base/bind.h"
8#include "base/files/file_path.h"
9#include "base/file_util.h"
10#include "base/path_service.h"
11#include "base/run_loop.h"
12#include "base/stl_util.h"
13#include "base/strings/string_util.h"
14#include "base/strings/stringprintf.h"
15#include "base/threading/thread_restrictions.h"
16#include "net/base/ip_endpoint.h"
17#include "net/base/net_errors.h"
18#include "net/test/embedded_test_server/http_connection.h"
19#include "net/test/embedded_test_server/http_request.h"
20#include "net/test/embedded_test_server/http_response.h"
21#include "net/tools/fetch/http_listen_socket.h"
22
23namespace net {
24namespace test_server {
25
26namespace {
27
28class CustomHttpResponse : public HttpResponse {
29 public:
30  CustomHttpResponse(const std::string& headers, const std::string& contents)
31      : headers_(headers), contents_(contents) {
32  }
33
34  virtual std::string ToResponseString() const OVERRIDE {
35    return headers_ + "\r\n" + contents_;
36  }
37
38 private:
39  std::string headers_;
40  std::string contents_;
41
42  DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse);
43};
44
45// Handles |request| by serving a file from under |server_root|.
46scoped_ptr<HttpResponse> HandleFileRequest(
47    const base::FilePath& server_root,
48    const HttpRequest& request) {
49  // This is a test-only server. Ignore I/O thread restrictions.
50  base::ThreadRestrictions::ScopedAllowIO allow_io;
51
52  // Trim the first byte ('/').
53  std::string request_path(request.relative_url.substr(1));
54
55  // Remove the query string if present.
56  size_t query_pos = request_path.find('?');
57  if (query_pos != std::string::npos)
58    request_path = request_path.substr(0, query_pos);
59
60  base::FilePath file_path(server_root.AppendASCII(request_path));
61  std::string file_contents;
62  if (!file_util::ReadFileToString(file_path, &file_contents))
63    return scoped_ptr<HttpResponse>();
64
65  base::FilePath headers_path(
66      file_path.AddExtension(FILE_PATH_LITERAL("mock-http-headers")));
67
68  if (base::PathExists(headers_path)) {
69    std::string headers_contents;
70    if (!file_util::ReadFileToString(headers_path, &headers_contents))
71      return scoped_ptr<HttpResponse>();
72
73    scoped_ptr<CustomHttpResponse> http_response(
74        new CustomHttpResponse(headers_contents, file_contents));
75    return http_response.PassAs<HttpResponse>();
76  }
77
78  scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse);
79  http_response->set_code(HTTP_OK);
80  http_response->set_content(file_contents);
81  return http_response.PassAs<HttpResponse>();
82}
83
84}  // namespace
85
86HttpListenSocket::HttpListenSocket(const SocketDescriptor socket_descriptor,
87                                   StreamListenSocket::Delegate* delegate)
88    : TCPListenSocket(socket_descriptor, delegate) {
89  DCHECK(thread_checker_.CalledOnValidThread());
90}
91
92void HttpListenSocket::Listen() {
93  DCHECK(thread_checker_.CalledOnValidThread());
94  TCPListenSocket::Listen();
95}
96
97HttpListenSocket::~HttpListenSocket() {
98  DCHECK(thread_checker_.CalledOnValidThread());
99}
100
101EmbeddedTestServer::EmbeddedTestServer(
102    const scoped_refptr<base::SingleThreadTaskRunner>& io_thread)
103    : io_thread_(io_thread),
104      port_(-1),
105      weak_factory_(this) {
106  DCHECK(io_thread_.get());
107  DCHECK(thread_checker_.CalledOnValidThread());
108}
109
110EmbeddedTestServer::~EmbeddedTestServer() {
111  DCHECK(thread_checker_.CalledOnValidThread());
112
113  if (Started() && !ShutdownAndWaitUntilComplete()) {
114    LOG(ERROR) << "EmbeddedTestServer failed to shut down.";
115  }
116}
117
118bool EmbeddedTestServer::InitializeAndWaitUntilReady() {
119  DCHECK(thread_checker_.CalledOnValidThread());
120
121  base::RunLoop run_loop;
122  if (!io_thread_->PostTaskAndReply(
123          FROM_HERE,
124          base::Bind(&EmbeddedTestServer::InitializeOnIOThread,
125                     base::Unretained(this)),
126          run_loop.QuitClosure())) {
127    return false;
128  }
129  run_loop.Run();
130
131  return Started() && base_url_.is_valid();
132}
133
134bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() {
135  DCHECK(thread_checker_.CalledOnValidThread());
136
137  base::RunLoop run_loop;
138  if (!io_thread_->PostTaskAndReply(
139          FROM_HERE,
140          base::Bind(&EmbeddedTestServer::ShutdownOnIOThread,
141                     base::Unretained(this)),
142          run_loop.QuitClosure())) {
143    return false;
144  }
145  run_loop.Run();
146
147  return true;
148}
149
150void EmbeddedTestServer::InitializeOnIOThread() {
151  DCHECK(io_thread_->BelongsToCurrentThread());
152  DCHECK(!Started());
153
154  SocketDescriptor socket_descriptor =
155      TCPListenSocket::CreateAndBindAnyPort("127.0.0.1", &port_);
156  if (socket_descriptor == TCPListenSocket::kInvalidSocket)
157    return;
158
159  listen_socket_ = new HttpListenSocket(socket_descriptor, this);
160  listen_socket_->Listen();
161
162  IPEndPoint address;
163  int result = listen_socket_->GetLocalAddress(&address);
164  if (result == OK) {
165    base_url_ = GURL(std::string("http://") + address.ToString());
166  } else {
167    LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result);
168  }
169}
170
171void EmbeddedTestServer::ShutdownOnIOThread() {
172  DCHECK(io_thread_->BelongsToCurrentThread());
173
174  listen_socket_ = NULL;  // Release the listen socket.
175  STLDeleteContainerPairSecondPointers(connections_.begin(),
176                                       connections_.end());
177  connections_.clear();
178}
179
180void EmbeddedTestServer::HandleRequest(HttpConnection* connection,
181                               scoped_ptr<HttpRequest> request) {
182  DCHECK(io_thread_->BelongsToCurrentThread());
183
184  bool request_handled = false;
185
186  for (size_t i = 0; i < request_handlers_.size(); ++i) {
187    scoped_ptr<HttpResponse> response =
188        request_handlers_[i].Run(*request.get());
189    if (response.get()) {
190      connection->SendResponse(response.Pass());
191      request_handled = true;
192      break;
193    }
194  }
195
196  if (!request_handled) {
197    LOG(WARNING) << "Request not handled. Returning 404: "
198                 << request->relative_url;
199    scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse);
200    not_found_response->set_code(HTTP_NOT_FOUND);
201    connection->SendResponse(
202        not_found_response.PassAs<HttpResponse>());
203  }
204
205  // Drop the connection, since we do not support multiple requests per
206  // connection.
207  connections_.erase(connection->socket_.get());
208  delete connection;
209}
210
211GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const {
212  DCHECK(StartsWithASCII(relative_url, "/", true /* case_sensitive */))
213      << relative_url;
214  return base_url_.Resolve(relative_url);
215}
216
217void EmbeddedTestServer::ServeFilesFromDirectory(
218    const base::FilePath& directory) {
219  RegisterRequestHandler(base::Bind(&HandleFileRequest, directory));
220}
221
222void EmbeddedTestServer::RegisterRequestHandler(
223    const HandleRequestCallback& callback) {
224  request_handlers_.push_back(callback);
225}
226
227void EmbeddedTestServer::DidAccept(StreamListenSocket* server,
228                           StreamListenSocket* connection) {
229  DCHECK(io_thread_->BelongsToCurrentThread());
230
231  HttpConnection* http_connection = new HttpConnection(
232      connection,
233      base::Bind(&EmbeddedTestServer::HandleRequest,
234                 weak_factory_.GetWeakPtr()));
235  connections_[connection] = http_connection;
236}
237
238void EmbeddedTestServer::DidRead(StreamListenSocket* connection,
239                         const char* data,
240                         int length) {
241  DCHECK(io_thread_->BelongsToCurrentThread());
242
243  HttpConnection* http_connection = FindConnection(connection);
244  if (http_connection == NULL) {
245    LOG(WARNING) << "Unknown connection.";
246    return;
247  }
248  http_connection->ReceiveData(std::string(data, length));
249}
250
251void EmbeddedTestServer::DidClose(StreamListenSocket* connection) {
252  DCHECK(io_thread_->BelongsToCurrentThread());
253
254  HttpConnection* http_connection = FindConnection(connection);
255  if (http_connection == NULL) {
256    LOG(WARNING) << "Unknown connection.";
257    return;
258  }
259  delete http_connection;
260  connections_.erase(connection);
261}
262
263HttpConnection* EmbeddedTestServer::FindConnection(
264    StreamListenSocket* socket) {
265  DCHECK(io_thread_->BelongsToCurrentThread());
266
267  std::map<StreamListenSocket*, HttpConnection*>::iterator it =
268      connections_.find(socket);
269  if (it == connections_.end()) {
270    return NULL;
271  }
272  return it->second;
273}
274
275}  // namespace test_server
276}  // namespace net
277