embedded_test_server.cc revision 424c4d7b64af9d0d8fd9624f381f469654d5e3d2
1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/test/embedded_test_server/embedded_test_server.h" 6 7#include "base/bind.h" 8#include "base/files/file_path.h" 9#include "base/file_util.h" 10#include "base/path_service.h" 11#include "base/run_loop.h" 12#include "base/stl_util.h" 13#include "base/strings/string_util.h" 14#include "base/strings/stringprintf.h" 15#include "base/threading/thread_restrictions.h" 16#include "net/base/ip_endpoint.h" 17#include "net/base/net_errors.h" 18#include "net/test/embedded_test_server/http_connection.h" 19#include "net/test/embedded_test_server/http_request.h" 20#include "net/test/embedded_test_server/http_response.h" 21#include "net/tools/fetch/http_listen_socket.h" 22 23namespace net { 24namespace test_server { 25 26namespace { 27 28class CustomHttpResponse : public HttpResponse { 29 public: 30 CustomHttpResponse(const std::string& headers, const std::string& contents) 31 : headers_(headers), contents_(contents) { 32 } 33 34 virtual std::string ToResponseString() const OVERRIDE { 35 return headers_ + "\r\n" + contents_; 36 } 37 38 private: 39 std::string headers_; 40 std::string contents_; 41 42 DISALLOW_COPY_AND_ASSIGN(CustomHttpResponse); 43}; 44 45// Handles |request| by serving a file from under |server_root|. 46scoped_ptr<HttpResponse> HandleFileRequest( 47 const base::FilePath& server_root, 48 const HttpRequest& request) { 49 // This is a test-only server. Ignore I/O thread restrictions. 50 base::ThreadRestrictions::ScopedAllowIO allow_io; 51 52 // Trim the first byte ('/'). 53 std::string request_path(request.relative_url.substr(1)); 54 55 // Remove the query string if present. 56 size_t query_pos = request_path.find('?'); 57 if (query_pos != std::string::npos) 58 request_path = request_path.substr(0, query_pos); 59 60 base::FilePath file_path(server_root.AppendASCII(request_path)); 61 std::string file_contents; 62 if (!file_util::ReadFileToString(file_path, &file_contents)) 63 return scoped_ptr<HttpResponse>(); 64 65 base::FilePath headers_path( 66 file_path.AddExtension(FILE_PATH_LITERAL("mock-http-headers"))); 67 68 if (base::PathExists(headers_path)) { 69 std::string headers_contents; 70 if (!file_util::ReadFileToString(headers_path, &headers_contents)) 71 return scoped_ptr<HttpResponse>(); 72 73 scoped_ptr<CustomHttpResponse> http_response( 74 new CustomHttpResponse(headers_contents, file_contents)); 75 return http_response.PassAs<HttpResponse>(); 76 } 77 78 scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse); 79 http_response->set_code(HTTP_OK); 80 http_response->set_content(file_contents); 81 return http_response.PassAs<HttpResponse>(); 82} 83 84} // namespace 85 86HttpListenSocket::HttpListenSocket(const SocketDescriptor socket_descriptor, 87 StreamListenSocket::Delegate* delegate) 88 : TCPListenSocket(socket_descriptor, delegate) { 89 DCHECK(thread_checker_.CalledOnValidThread()); 90} 91 92void HttpListenSocket::Listen() { 93 DCHECK(thread_checker_.CalledOnValidThread()); 94 TCPListenSocket::Listen(); 95} 96 97HttpListenSocket::~HttpListenSocket() { 98 DCHECK(thread_checker_.CalledOnValidThread()); 99} 100 101EmbeddedTestServer::EmbeddedTestServer( 102 const scoped_refptr<base::SingleThreadTaskRunner>& io_thread) 103 : io_thread_(io_thread), 104 port_(-1), 105 weak_factory_(this) { 106 DCHECK(io_thread_.get()); 107 DCHECK(thread_checker_.CalledOnValidThread()); 108} 109 110EmbeddedTestServer::~EmbeddedTestServer() { 111 DCHECK(thread_checker_.CalledOnValidThread()); 112 113 if (Started() && !ShutdownAndWaitUntilComplete()) { 114 LOG(ERROR) << "EmbeddedTestServer failed to shut down."; 115 } 116} 117 118bool EmbeddedTestServer::InitializeAndWaitUntilReady() { 119 DCHECK(thread_checker_.CalledOnValidThread()); 120 121 base::RunLoop run_loop; 122 if (!io_thread_->PostTaskAndReply( 123 FROM_HERE, 124 base::Bind(&EmbeddedTestServer::InitializeOnIOThread, 125 base::Unretained(this)), 126 run_loop.QuitClosure())) { 127 return false; 128 } 129 run_loop.Run(); 130 131 return Started() && base_url_.is_valid(); 132} 133 134bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() { 135 DCHECK(thread_checker_.CalledOnValidThread()); 136 137 base::RunLoop run_loop; 138 if (!io_thread_->PostTaskAndReply( 139 FROM_HERE, 140 base::Bind(&EmbeddedTestServer::ShutdownOnIOThread, 141 base::Unretained(this)), 142 run_loop.QuitClosure())) { 143 return false; 144 } 145 run_loop.Run(); 146 147 return true; 148} 149 150void EmbeddedTestServer::InitializeOnIOThread() { 151 DCHECK(io_thread_->BelongsToCurrentThread()); 152 DCHECK(!Started()); 153 154 SocketDescriptor socket_descriptor = 155 TCPListenSocket::CreateAndBindAnyPort("127.0.0.1", &port_); 156 if (socket_descriptor == kInvalidSocket) 157 return; 158 159 listen_socket_ = new HttpListenSocket(socket_descriptor, this); 160 listen_socket_->Listen(); 161 162 IPEndPoint address; 163 int result = listen_socket_->GetLocalAddress(&address); 164 if (result == OK) { 165 base_url_ = GURL(std::string("http://") + address.ToString()); 166 } else { 167 LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result); 168 } 169} 170 171void EmbeddedTestServer::ShutdownOnIOThread() { 172 DCHECK(io_thread_->BelongsToCurrentThread()); 173 174 listen_socket_ = NULL; // Release the listen socket. 175 STLDeleteContainerPairSecondPointers(connections_.begin(), 176 connections_.end()); 177 connections_.clear(); 178} 179 180void EmbeddedTestServer::HandleRequest(HttpConnection* connection, 181 scoped_ptr<HttpRequest> request) { 182 DCHECK(io_thread_->BelongsToCurrentThread()); 183 184 bool request_handled = false; 185 186 for (size_t i = 0; i < request_handlers_.size(); ++i) { 187 scoped_ptr<HttpResponse> response = 188 request_handlers_[i].Run(*request.get()); 189 if (response.get()) { 190 connection->SendResponse(response.Pass()); 191 request_handled = true; 192 break; 193 } 194 } 195 196 if (!request_handled) { 197 LOG(WARNING) << "Request not handled. Returning 404: " 198 << request->relative_url; 199 scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse); 200 not_found_response->set_code(HTTP_NOT_FOUND); 201 connection->SendResponse( 202 not_found_response.PassAs<HttpResponse>()); 203 } 204 205 // Drop the connection, since we do not support multiple requests per 206 // connection. 207 connections_.erase(connection->socket_.get()); 208 delete connection; 209} 210 211GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const { 212 DCHECK(StartsWithASCII(relative_url, "/", true /* case_sensitive */)) 213 << relative_url; 214 return base_url_.Resolve(relative_url); 215} 216 217void EmbeddedTestServer::ServeFilesFromDirectory( 218 const base::FilePath& directory) { 219 RegisterRequestHandler(base::Bind(&HandleFileRequest, directory)); 220} 221 222void EmbeddedTestServer::RegisterRequestHandler( 223 const HandleRequestCallback& callback) { 224 request_handlers_.push_back(callback); 225} 226 227void EmbeddedTestServer::DidAccept(StreamListenSocket* server, 228 StreamListenSocket* connection) { 229 DCHECK(io_thread_->BelongsToCurrentThread()); 230 231 HttpConnection* http_connection = new HttpConnection( 232 connection, 233 base::Bind(&EmbeddedTestServer::HandleRequest, 234 weak_factory_.GetWeakPtr())); 235 connections_[connection] = http_connection; 236} 237 238void EmbeddedTestServer::DidRead(StreamListenSocket* connection, 239 const char* data, 240 int length) { 241 DCHECK(io_thread_->BelongsToCurrentThread()); 242 243 HttpConnection* http_connection = FindConnection(connection); 244 if (http_connection == NULL) { 245 LOG(WARNING) << "Unknown connection."; 246 return; 247 } 248 http_connection->ReceiveData(std::string(data, length)); 249} 250 251void EmbeddedTestServer::DidClose(StreamListenSocket* connection) { 252 DCHECK(io_thread_->BelongsToCurrentThread()); 253 254 HttpConnection* http_connection = FindConnection(connection); 255 if (http_connection == NULL) { 256 LOG(WARNING) << "Unknown connection."; 257 return; 258 } 259 delete http_connection; 260 connections_.erase(connection); 261} 262 263HttpConnection* EmbeddedTestServer::FindConnection( 264 StreamListenSocket* socket) { 265 DCHECK(io_thread_->BelongsToCurrentThread()); 266 267 std::map<StreamListenSocket*, HttpConnection*>::iterator it = 268 connections_.find(socket); 269 if (it == connections_.end()) { 270 return NULL; 271 } 272 return it->second; 273} 274 275} // namespace test_server 276} // namespace net 277