1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/test/embedded_test_server/embedded_test_server.h"
6
7#include "base/bind.h"
8#include "base/files/file_path.h"
9#include "base/files/file_util.h"
10#include "base/message_loop/message_loop.h"
11#include "base/path_service.h"
12#include "base/process/process_metrics.h"
13#include "base/run_loop.h"
14#include "base/stl_util.h"
15#include "base/strings/string_util.h"
16#include "base/strings/stringprintf.h"
17#include "base/threading/thread_restrictions.h"
18#include "net/base/ip_endpoint.h"
19#include "net/base/net_errors.h"
20#include "net/test/embedded_test_server/http_connection.h"
21#include "net/test/embedded_test_server/http_request.h"
22#include "net/test/embedded_test_server/http_response.h"
23
24namespace net {
25namespace test_server {
26
27namespace {
28
29class CustomHttpResponse : public HttpResponse {
30 public:
31  CustomHttpResponse(const std::string& headers, const std::string& contents)
32      : headers_(headers), contents_(contents) {
33  }
34
35  virtual std::string ToResponseString() const OVERRIDE {
36    return headers_ + "\r\n" + contents_;
37  }
38
39 private:
40  std::string headers_;
41  std::string contents_;
42
43  DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse);
44};
45
46// Handles |request| by serving a file from under |server_root|.
47scoped_ptr<HttpResponse> HandleFileRequest(
48    const base::FilePath& server_root,
49    const HttpRequest& request) {
50  // This is a test-only server. Ignore I/O thread restrictions.
51  base::ThreadRestrictions::ScopedAllowIO allow_io;
52
53  // Trim the first byte ('/').
54  std::string request_path(request.relative_url.substr(1));
55
56  // Remove the query string if present.
57  size_t query_pos = request_path.find('?');
58  if (query_pos != std::string::npos)
59    request_path = request_path.substr(0, query_pos);
60
61  base::FilePath file_path(server_root.AppendASCII(request_path));
62  std::string file_contents;
63  if (!base::ReadFileToString(file_path, &file_contents))
64    return scoped_ptr<HttpResponse>();
65
66  base::FilePath headers_path(
67      file_path.AddExtension(FILE_PATH_LITERAL("mock-http-headers")));
68
69  if (base::PathExists(headers_path)) {
70    std::string headers_contents;
71    if (!base::ReadFileToString(headers_path, &headers_contents))
72      return scoped_ptr<HttpResponse>();
73
74    scoped_ptr<CustomHttpResponse> http_response(
75        new CustomHttpResponse(headers_contents, file_contents));
76    return http_response.PassAs<HttpResponse>();
77  }
78
79  scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse);
80  http_response->set_code(HTTP_OK);
81  http_response->set_content(file_contents);
82  return http_response.PassAs<HttpResponse>();
83}
84
85}  // namespace
86
87HttpListenSocket::HttpListenSocket(const SocketDescriptor socket_descriptor,
88                                   StreamListenSocket::Delegate* delegate)
89    : TCPListenSocket(socket_descriptor, delegate) {
90  DCHECK(thread_checker_.CalledOnValidThread());
91}
92
93void HttpListenSocket::Listen() {
94  DCHECK(thread_checker_.CalledOnValidThread());
95  TCPListenSocket::Listen();
96}
97
98void HttpListenSocket::ListenOnIOThread() {
99  DCHECK(thread_checker_.CalledOnValidThread());
100#if !defined(OS_POSIX)
101  // This method may be called after the IO thread is changed, thus we need to
102  // call |WatchSocket| again to make sure it listens on the current IO thread.
103  // Only needed for non POSIX platforms, since on POSIX platforms
104  // StreamListenSocket::Listen already calls WatchSocket inside the function.
105  WatchSocket(WAITING_ACCEPT);
106#endif
107  Listen();
108}
109
110HttpListenSocket::~HttpListenSocket() {
111  DCHECK(thread_checker_.CalledOnValidThread());
112}
113
114void HttpListenSocket::DetachFromThread() {
115  thread_checker_.DetachFromThread();
116}
117
118EmbeddedTestServer::EmbeddedTestServer()
119    : port_(-1),
120      weak_factory_(this) {
121  DCHECK(thread_checker_.CalledOnValidThread());
122}
123
124EmbeddedTestServer::~EmbeddedTestServer() {
125  DCHECK(thread_checker_.CalledOnValidThread());
126
127  if (Started() && !ShutdownAndWaitUntilComplete()) {
128    LOG(ERROR) << "EmbeddedTestServer failed to shut down.";
129  }
130}
131
132bool EmbeddedTestServer::InitializeAndWaitUntilReady() {
133  StartThread();
134  DCHECK(thread_checker_.CalledOnValidThread());
135  if (!PostTaskToIOThreadAndWait(base::Bind(
136          &EmbeddedTestServer::InitializeOnIOThread, base::Unretained(this)))) {
137    return false;
138  }
139  return Started() && base_url_.is_valid();
140}
141
142void EmbeddedTestServer::StopThread() {
143  DCHECK(io_thread_ && io_thread_->IsRunning());
144
145#if defined(OS_LINUX)
146  const int thread_count =
147      base::GetNumberOfThreads(base::GetCurrentProcessHandle());
148#endif
149
150  io_thread_->Stop();
151  io_thread_.reset();
152  thread_checker_.DetachFromThread();
153  listen_socket_->DetachFromThread();
154
155#if defined(OS_LINUX)
156  // Busy loop to wait for thread count to decrease. This is needed because
157  // pthread_join does not guarantee that kernel stat is updated when it
158  // returns. Thus, GetNumberOfThreads does not immediately reflect the stopped
159  // thread and hits the thread number DCHECK in render_sandbox_host_linux.cc
160  // in browser_tests.
161  while (thread_count ==
162         base::GetNumberOfThreads(base::GetCurrentProcessHandle())) {
163    base::PlatformThread::YieldCurrentThread();
164  }
165#endif
166}
167
168void EmbeddedTestServer::RestartThreadAndListen() {
169  StartThread();
170  CHECK(PostTaskToIOThreadAndWait(base::Bind(
171      &EmbeddedTestServer::ListenOnIOThread, base::Unretained(this))));
172}
173
174bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() {
175  DCHECK(thread_checker_.CalledOnValidThread());
176
177  return PostTaskToIOThreadAndWait(base::Bind(
178      &EmbeddedTestServer::ShutdownOnIOThread, base::Unretained(this)));
179}
180
181void EmbeddedTestServer::StartThread() {
182  DCHECK(!io_thread_.get());
183  base::Thread::Options thread_options;
184  thread_options.message_loop_type = base::MessageLoop::TYPE_IO;
185  io_thread_.reset(new base::Thread("EmbeddedTestServer io thread"));
186  CHECK(io_thread_->StartWithOptions(thread_options));
187}
188
189void EmbeddedTestServer::InitializeOnIOThread() {
190  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
191  DCHECK(!Started());
192
193  SocketDescriptor socket_descriptor =
194      TCPListenSocket::CreateAndBindAnyPort("127.0.0.1", &port_);
195  if (socket_descriptor == kInvalidSocket)
196    return;
197
198  listen_socket_.reset(new HttpListenSocket(socket_descriptor, this));
199  listen_socket_->Listen();
200
201  IPEndPoint address;
202  int result = listen_socket_->GetLocalAddress(&address);
203  if (result == OK) {
204    base_url_ = GURL(std::string("http://") + address.ToString());
205  } else {
206    LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result);
207  }
208}
209
210void EmbeddedTestServer::ListenOnIOThread() {
211  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
212  DCHECK(Started());
213  listen_socket_->ListenOnIOThread();
214}
215
216void EmbeddedTestServer::ShutdownOnIOThread() {
217  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
218
219  listen_socket_.reset();
220  STLDeleteContainerPairSecondPointers(connections_.begin(),
221                                       connections_.end());
222  connections_.clear();
223}
224
225void EmbeddedTestServer::HandleRequest(HttpConnection* connection,
226                               scoped_ptr<HttpRequest> request) {
227  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
228
229  bool request_handled = false;
230
231  for (size_t i = 0; i < request_handlers_.size(); ++i) {
232    scoped_ptr<HttpResponse> response =
233        request_handlers_[i].Run(*request.get());
234    if (response.get()) {
235      connection->SendResponse(response.Pass());
236      request_handled = true;
237      break;
238    }
239  }
240
241  if (!request_handled) {
242    LOG(WARNING) << "Request not handled. Returning 404: "
243                 << request->relative_url;
244    scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse);
245    not_found_response->set_code(HTTP_NOT_FOUND);
246    connection->SendResponse(
247        not_found_response.PassAs<HttpResponse>());
248  }
249
250  // Drop the connection, since we do not support multiple requests per
251  // connection.
252  connections_.erase(connection->socket_.get());
253  delete connection;
254}
255
256GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const {
257  DCHECK(Started()) << "You must start the server first.";
258  DCHECK(StartsWithASCII(relative_url, "/", true /* case_sensitive */))
259      << relative_url;
260  return base_url_.Resolve(relative_url);
261}
262
263void EmbeddedTestServer::ServeFilesFromDirectory(
264    const base::FilePath& directory) {
265  RegisterRequestHandler(base::Bind(&HandleFileRequest, directory));
266}
267
268void EmbeddedTestServer::RegisterRequestHandler(
269    const HandleRequestCallback& callback) {
270  request_handlers_.push_back(callback);
271}
272
273void EmbeddedTestServer::DidAccept(
274    StreamListenSocket* server,
275    scoped_ptr<StreamListenSocket> connection) {
276  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
277
278  HttpConnection* http_connection = new HttpConnection(
279      connection.Pass(),
280      base::Bind(&EmbeddedTestServer::HandleRequest,
281                 weak_factory_.GetWeakPtr()));
282  // TODO(szym): Make HttpConnection the StreamListenSocket delegate.
283  connections_[http_connection->socket_.get()] = http_connection;
284}
285
286void EmbeddedTestServer::DidRead(StreamListenSocket* connection,
287                         const char* data,
288                         int length) {
289  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
290
291  HttpConnection* http_connection = FindConnection(connection);
292  if (http_connection == NULL) {
293    LOG(WARNING) << "Unknown connection.";
294    return;
295  }
296  http_connection->ReceiveData(std::string(data, length));
297}
298
299void EmbeddedTestServer::DidClose(StreamListenSocket* connection) {
300  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
301
302  HttpConnection* http_connection = FindConnection(connection);
303  if (http_connection == NULL) {
304    LOG(WARNING) << "Unknown connection.";
305    return;
306  }
307  delete http_connection;
308  connections_.erase(connection);
309}
310
311HttpConnection* EmbeddedTestServer::FindConnection(
312    StreamListenSocket* socket) {
313  DCHECK(io_thread_->message_loop_proxy()->BelongsToCurrentThread());
314
315  std::map<StreamListenSocket*, HttpConnection*>::iterator it =
316      connections_.find(socket);
317  if (it == connections_.end()) {
318    return NULL;
319  }
320  return it->second;
321}
322
323bool EmbeddedTestServer::PostTaskToIOThreadAndWait(
324    const base::Closure& closure) {
325  // Note that PostTaskAndReply below requires base::MessageLoopProxy::current()
326  // to return a loop for posting the reply task. However, in order to make
327  // EmbeddedTestServer universally usable, it needs to cope with the situation
328  // where it's running on a thread on which a message loop is not (yet)
329  // available or as has been destroyed already.
330  //
331  // To handle this situation, create temporary message loop to support the
332  // PostTaskAndReply operation if the current thread as no message loop.
333  scoped_ptr<base::MessageLoop> temporary_loop;
334  if (!base::MessageLoop::current())
335    temporary_loop.reset(new base::MessageLoop());
336
337  base::RunLoop run_loop;
338  if (!io_thread_->message_loop_proxy()->PostTaskAndReply(
339          FROM_HERE, closure, run_loop.QuitClosure())) {
340    return false;
341  }
342  run_loop.Run();
343
344  return true;
345}
346
347}  // namespace test_server
348}  // namespace net
349