embedded_test_server.cc revision a3f6a49ab37290eeeb8db0f41ec0f1cb74a68be7
1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/test/embedded_test_server/embedded_test_server.h"
6
7#include "base/bind.h"
8#include "base/file_util.h"
9#include "base/files/file_path.h"
10#include "base/message_loop/message_loop.h"
11#include "base/path_service.h"
12#include "base/process/process_metrics.h"
13#include "base/run_loop.h"
14#include "base/stl_util.h"
15#include "base/strings/string_util.h"
16#include "base/strings/stringprintf.h"
17#include "base/threading/thread_restrictions.h"
18#include "net/base/ip_endpoint.h"
19#include "net/base/net_errors.h"
20#include "net/test/embedded_test_server/http_connection.h"
21#include "net/test/embedded_test_server/http_request.h"
22#include "net/test/embedded_test_server/http_response.h"
23#include "net/tools/fetch/http_listen_socket.h"
24
25namespace net {
26namespace test_server {
27
28namespace {
29
30class CustomHttpResponse : public HttpResponse {
31 public:
32  CustomHttpResponse(const std::string& headers, const std::string& contents)
33      : headers_(headers), contents_(contents) {
34  }
35
36  virtual std::string ToResponseString() const OVERRIDE {
37    return headers_ + "\r\n" + contents_;
38  }
39
40 private:
41  std::string headers_;
42  std::string contents_;
43
44  DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse);
45};
46
47// Handles |request| by serving a file from under |server_root|.
48scoped_ptr<HttpResponse> HandleFileRequest(
49    const base::FilePath& server_root,
50    const HttpRequest& request) {
51  // This is a test-only server. Ignore I/O thread restrictions.
52  base::ThreadRestrictions::ScopedAllowIO allow_io;
53
54  // Trim the first byte ('/').
55  std::string request_path(request.relative_url.substr(1));
56
57  // Remove the query string if present.
58  size_t query_pos = request_path.find('?');
59  if (query_pos != std::string::npos)
60    request_path = request_path.substr(0, query_pos);
61
62  base::FilePath file_path(server_root.AppendASCII(request_path));
63  std::string file_contents;
64  if (!base::ReadFileToString(file_path, &file_contents))
65    return scoped_ptr<HttpResponse>();
66
67  base::FilePath headers_path(
68      file_path.AddExtension(FILE_PATH_LITERAL("mock-http-headers")));
69
70  if (base::PathExists(headers_path)) {
71    std::string headers_contents;
72    if (!base::ReadFileToString(headers_path, &headers_contents))
73      return scoped_ptr<HttpResponse>();
74
75    scoped_ptr<CustomHttpResponse> http_response(
76        new CustomHttpResponse(headers_contents, file_contents));
77    return http_response.PassAs<HttpResponse>();
78  }
79
80  scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse);
81  http_response->set_code(HTTP_OK);
82  http_response->set_content(file_contents);
83  return http_response.PassAs<HttpResponse>();
84}
85
86}  // namespace
87
88HttpListenSocket::HttpListenSocket(const SocketDescriptor socket_descriptor,
89                                   StreamListenSocket::Delegate* delegate)
90    : TCPListenSocket(socket_descriptor, delegate) {
91  DCHECK(thread_checker_.CalledOnValidThread());
92}
93
94void HttpListenSocket::Listen() {
95  DCHECK(thread_checker_.CalledOnValidThread());
96  TCPListenSocket::Listen();
97}
98
99HttpListenSocket::~HttpListenSocket() {
100  DCHECK(thread_checker_.CalledOnValidThread());
101}
102
103void HttpListenSocket::DetachFromThread() {
104  thread_checker_.DetachFromThread();
105}
106
107EmbeddedTestServer::EmbeddedTestServer()
108    : port_(-1),
109      weak_factory_(this) {
110  DCHECK(thread_checker_.CalledOnValidThread());
111}
112
113EmbeddedTestServer::~EmbeddedTestServer() {
114  DCHECK(thread_checker_.CalledOnValidThread());
115
116  if (Started() && !ShutdownAndWaitUntilComplete()) {
117    LOG(ERROR) << "EmbeddedTestServer failed to shut down.";
118  }
119}
120
121bool EmbeddedTestServer::InitializeAndWaitUntilReady() {
122  StartThread();
123  DCHECK(thread_checker_.CalledOnValidThread());
124  if (!PostTaskToIOThreadAndWait(base::Bind(
125          &EmbeddedTestServer::InitializeOnIOThread, base::Unretained(this)))) {
126    return false;
127  }
128  return Started() && base_url_.is_valid();
129}
130
131void EmbeddedTestServer::StopThread() {
132  DCHECK(io_thread_ && io_thread_->IsRunning());
133
134#if defined(OS_LINUX)
135  const int thread_count =
136      base::GetNumberOfThreads(base::GetCurrentProcessHandle());
137#endif
138
139  io_thread_->Stop();
140  io_thread_.reset();
141  thread_checker_.DetachFromThread();
142  listen_socket_->DetachFromThread();
143
144#if defined(OS_LINUX)
145  // Busy loop to wait for thread count to decrease. This is needed because
146  // pthread_join does not guarantee that kernel stat is updated when it
147  // returns. Thus, GetNumberOfThreads does not immediately reflect the stopped
148  // thread and hits the thread number DCHECK in render_sandbox_host_linux.cc
149  // in browser_tests.
150  while (thread_count ==
151         base::GetNumberOfThreads(base::GetCurrentProcessHandle())) {
152    base::PlatformThread::YieldCurrentThread();
153  }
154#endif
155}
156
157void EmbeddedTestServer::RestartThreadAndListen() {
158  StartThread();
159  CHECK(PostTaskToIOThreadAndWait(base::Bind(
160      &EmbeddedTestServer::ListenOnIOThread, base::Unretained(this))));
161}
162
163bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() {
164  DCHECK(thread_checker_.CalledOnValidThread());
165
166  return PostTaskToIOThreadAndWait(base::Bind(
167      &EmbeddedTestServer::ShutdownOnIOThread, base::Unretained(this)));
168}
169
170void EmbeddedTestServer::StartThread() {
171  DCHECK(!io_thread_.get());
172  base::Thread::Options thread_options;
173  thread_options.message_loop_type = base::MessageLoop::TYPE_IO;
174  io_thread_.reset(new base::Thread("EmbeddedTestServer io thread"));
175  CHECK(io_thread_->StartWithOptions(thread_options));
176}
177
178void EmbeddedTestServer::InitializeOnIOThread() {
179  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
180  DCHECK(!Started());
181
182  SocketDescriptor socket_descriptor =
183      TCPListenSocket::CreateAndBindAnyPort("127.0.0.1", &port_);
184  if (socket_descriptor == kInvalidSocket)
185    return;
186
187  listen_socket_.reset(new HttpListenSocket(socket_descriptor, this));
188  listen_socket_->Listen();
189
190  IPEndPoint address;
191  int result = listen_socket_->GetLocalAddress(&address);
192  if (result == OK) {
193    base_url_ = GURL(std::string("http://") + address.ToString());
194  } else {
195    LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result);
196  }
197}
198
199void EmbeddedTestServer::ListenOnIOThread() {
200  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
201  DCHECK(Started());
202  listen_socket_->Listen();
203}
204
205void EmbeddedTestServer::ShutdownOnIOThread() {
206  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
207
208  listen_socket_.reset();
209  STLDeleteContainerPairSecondPointers(connections_.begin(),
210                                       connections_.end());
211  connections_.clear();
212}
213
214void EmbeddedTestServer::HandleRequest(HttpConnection* connection,
215                               scoped_ptr<HttpRequest> request) {
216  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
217
218  bool request_handled = false;
219
220  for (size_t i = 0; i < request_handlers_.size(); ++i) {
221    scoped_ptr<HttpResponse> response =
222        request_handlers_[i].Run(*request.get());
223    if (response.get()) {
224      connection->SendResponse(response.Pass());
225      request_handled = true;
226      break;
227    }
228  }
229
230  if (!request_handled) {
231    LOG(WARNING) << "Request not handled. Returning 404: "
232                 << request->relative_url;
233    scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse);
234    not_found_response->set_code(HTTP_NOT_FOUND);
235    connection->SendResponse(
236        not_found_response.PassAs<HttpResponse>());
237  }
238
239  // Drop the connection, since we do not support multiple requests per
240  // connection.
241  connections_.erase(connection->socket_.get());
242  delete connection;
243}
244
245GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const {
246  DCHECK(Started()) << "You must start the server first.";
247  DCHECK(StartsWithASCII(relative_url, "/", true /* case_sensitive */))
248      << relative_url;
249  return base_url_.Resolve(relative_url);
250}
251
252void EmbeddedTestServer::ServeFilesFromDirectory(
253    const base::FilePath& directory) {
254  RegisterRequestHandler(base::Bind(&HandleFileRequest, directory));
255}
256
257void EmbeddedTestServer::RegisterRequestHandler(
258    const HandleRequestCallback& callback) {
259  request_handlers_.push_back(callback);
260}
261
262void EmbeddedTestServer::DidAccept(
263    StreamListenSocket* server,
264    scoped_ptr<StreamListenSocket> connection) {
265  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
266
267  HttpConnection* http_connection = new HttpConnection(
268      connection.Pass(),
269      base::Bind(&EmbeddedTestServer::HandleRequest,
270                 weak_factory_.GetWeakPtr()));
271  // TODO(szym): Make HttpConnection the StreamListenSocket delegate.
272  connections_[http_connection->socket_.get()] = http_connection;
273}
274
275void EmbeddedTestServer::DidRead(StreamListenSocket* connection,
276                         const char* data,
277                         int length) {
278  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
279
280  HttpConnection* http_connection = FindConnection(connection);
281  if (http_connection == NULL) {
282    LOG(WARNING) << "Unknown connection.";
283    return;
284  }
285  http_connection->ReceiveData(std::string(data, length));
286}
287
288void EmbeddedTestServer::DidClose(StreamListenSocket* connection) {
289  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
290
291  HttpConnection* http_connection = FindConnection(connection);
292  if (http_connection == NULL) {
293    LOG(WARNING) << "Unknown connection.";
294    return;
295  }
296  delete http_connection;
297  connections_.erase(connection);
298}
299
300HttpConnection* EmbeddedTestServer::FindConnection(
301    StreamListenSocket* socket) {
302  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
303
304  std::map<StreamListenSocket*, HttpConnection*>::iterator it =
305      connections_.find(socket);
306  if (it == connections_.end()) {
307    return NULL;
308  }
309  return it->second;
310}
311
312bool EmbeddedTestServer::PostTaskToIOThreadAndWait(
313    const base::Closure& closure) {
314  // Note that PostTaskAndReply below requires base::MessageLoopProxy::current()
315  // to return a loop for posting the reply task. However, in order to make
316  // EmbeddedTestServer universally usable, it needs to cope with the situation
317  // where it's running on a thread on which a message loop is not (yet)
318  // available or as has been destroyed already.
319  //
320  // To handle this situation, create temporary message loop to support the
321  // PostTaskAndReply operation if the current thread as no message loop.
322  scoped_ptr<base::MessageLoop> temporary_loop;
323  if (!base::MessageLoop::current())
324    temporary_loop.reset(new base::MessageLoop());
325
326  base::RunLoop run_loop;
327  if (!io_thread_->message_loop_proxy()->PostTaskAndReply(
328          FROM_HERE, closure, run_loop.QuitClosure())) {
329    return false;
330  }
331  run_loop.Run();
332
333  return true;
334}
335
336}  // namespace test_server
337}  // namespace net
338