embedded_test_server.cc revision 90dce4d38c5ff5333bea97d859d4e484e27edf0c
1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/test/embedded_test_server/embedded_test_server.h" 6 7#include "base/bind.h" 8#include "base/files/file_path.h" 9#include "base/file_util.h" 10#include "base/path_service.h" 11#include "base/run_loop.h" 12#include "base/stl_util.h" 13#include "base/string_util.h" 14#include "base/stringprintf.h" 15#include "base/threading/thread_restrictions.h" 16#include "net/base/ip_endpoint.h" 17#include "net/base/net_errors.h" 18#include "net/test/embedded_test_server/http_connection.h" 19#include "net/test/embedded_test_server/http_request.h" 20#include "net/test/embedded_test_server/http_response.h" 21#include "net/tools/fetch/http_listen_socket.h" 22 23namespace net { 24namespace test_server { 25 26namespace { 27 28// Handles |request| by serving a file from under |server_root|. 29scoped_ptr<HttpResponse> HandleFileRequest( 30 const base::FilePath& server_root, 31 const HttpRequest& request) { 32 // This is a test-only server. Ignore I/O thread restrictions. 33 base::ThreadRestrictions::ScopedAllowIO allow_io; 34 35 // Trim the first byte ('/'). 36 std::string request_path(request.relative_url.substr(1)); 37 38 // Remove the query string if present. 39 size_t query_pos = request_path.find('?'); 40 if (query_pos != std::string::npos) 41 request_path = request_path.substr(0, query_pos); 42 43 std::string file_contents; 44 if (!file_util::ReadFileToString( 45 server_root.AppendASCII(request_path), &file_contents)) { 46 return scoped_ptr<HttpResponse>(NULL); 47 } 48 49 scoped_ptr<BasicHttpResponse> http_response(new BasicHttpResponse); 50 http_response->set_code(net::test_server::SUCCESS); 51 http_response->set_content(file_contents); 52 return http_response.PassAs<HttpResponse>(); 53} 54 55} // namespace 56 57HttpListenSocket::HttpListenSocket(const SocketDescriptor socket_descriptor, 58 StreamListenSocket::Delegate* delegate) 59 : TCPListenSocket(socket_descriptor, delegate) { 60 DCHECK(thread_checker_.CalledOnValidThread()); 61} 62 63void HttpListenSocket::Listen() { 64 DCHECK(thread_checker_.CalledOnValidThread()); 65 TCPListenSocket::Listen(); 66} 67 68HttpListenSocket::~HttpListenSocket() { 69 DCHECK(thread_checker_.CalledOnValidThread()); 70} 71 72EmbeddedTestServer::EmbeddedTestServer( 73 const scoped_refptr<base::SingleThreadTaskRunner>& io_thread) 74 : io_thread_(io_thread), 75 port_(-1), 76 weak_factory_(this) { 77 DCHECK(io_thread_); 78 DCHECK(thread_checker_.CalledOnValidThread()); 79} 80 81EmbeddedTestServer::~EmbeddedTestServer() { 82 DCHECK(thread_checker_.CalledOnValidThread()); 83 84 if (Started() && !ShutdownAndWaitUntilComplete()) { 85 LOG(ERROR) << "EmbeddedTestServer failed to shut down."; 86 } 87} 88 89bool EmbeddedTestServer::InitializeAndWaitUntilReady() { 90 DCHECK(thread_checker_.CalledOnValidThread()); 91 92 base::RunLoop run_loop; 93 if (!io_thread_->PostTaskAndReply( 94 FROM_HERE, 95 base::Bind(&EmbeddedTestServer::InitializeOnIOThread, 96 base::Unretained(this)), 97 run_loop.QuitClosure())) { 98 return false; 99 } 100 run_loop.Run(); 101 102 return Started() && base_url_.is_valid(); 103} 104 105bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() { 106 DCHECK(thread_checker_.CalledOnValidThread()); 107 108 base::RunLoop run_loop; 109 if (!io_thread_->PostTaskAndReply( 110 FROM_HERE, 111 base::Bind(&EmbeddedTestServer::ShutdownOnIOThread, 112 base::Unretained(this)), 113 run_loop.QuitClosure())) { 114 return false; 115 } 116 run_loop.Run(); 117 118 return true; 119} 120 121void EmbeddedTestServer::InitializeOnIOThread() { 122 DCHECK(io_thread_->BelongsToCurrentThread()); 123 DCHECK(!Started()); 124 125 SocketDescriptor socket_descriptor = 126 TCPListenSocket::CreateAndBindAnyPort("127.0.0.1", &port_); 127 if (socket_descriptor == TCPListenSocket::kInvalidSocket) 128 return; 129 130 listen_socket_ = new HttpListenSocket(socket_descriptor, this); 131 listen_socket_->Listen(); 132 133 IPEndPoint address; 134 int result = listen_socket_->GetLocalAddress(&address); 135 if (result == OK) { 136 base_url_ = GURL(std::string("http://") + address.ToString()); 137 } else { 138 LOG(ERROR) << "GetLocalAddress failed: " << ErrorToString(result); 139 } 140} 141 142void EmbeddedTestServer::ShutdownOnIOThread() { 143 DCHECK(io_thread_->BelongsToCurrentThread()); 144 145 listen_socket_ = NULL; // Release the listen socket. 146 STLDeleteContainerPairSecondPointers(connections_.begin(), 147 connections_.end()); 148 connections_.clear(); 149} 150 151void EmbeddedTestServer::HandleRequest(HttpConnection* connection, 152 scoped_ptr<HttpRequest> request) { 153 DCHECK(io_thread_->BelongsToCurrentThread()); 154 155 bool request_handled = false; 156 157 for (size_t i = 0; i < request_handlers_.size(); ++i) { 158 scoped_ptr<HttpResponse> response = 159 request_handlers_[i].Run(*request.get()); 160 if (response.get()) { 161 connection->SendResponse(response.Pass()); 162 request_handled = true; 163 break; 164 } 165 } 166 167 if (!request_handled) { 168 LOG(WARNING) << "Request not handled. Returning 404: " 169 << request->relative_url; 170 scoped_ptr<BasicHttpResponse> not_found_response(new BasicHttpResponse); 171 not_found_response->set_code(NOT_FOUND); 172 connection->SendResponse( 173 not_found_response.PassAs<HttpResponse>()); 174 } 175 176 // Drop the connection, since we do not support multiple requests per 177 // connection. 178 connections_.erase(connection->socket_.get()); 179 delete connection; 180} 181 182GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const { 183 DCHECK(StartsWithASCII(relative_url, "/", true /* case_sensitive */)) 184 << relative_url; 185 return base_url_.Resolve(relative_url); 186} 187 188void EmbeddedTestServer::ServeFilesFromDirectory( 189 const base::FilePath& directory) { 190 RegisterRequestHandler(base::Bind(&HandleFileRequest, directory)); 191} 192 193void EmbeddedTestServer::RegisterRequestHandler( 194 const HandleRequestCallback& callback) { 195 request_handlers_.push_back(callback); 196} 197 198void EmbeddedTestServer::DidAccept(StreamListenSocket* server, 199 StreamListenSocket* connection) { 200 DCHECK(io_thread_->BelongsToCurrentThread()); 201 202 HttpConnection* http_connection = new HttpConnection( 203 connection, 204 base::Bind(&EmbeddedTestServer::HandleRequest, 205 weak_factory_.GetWeakPtr())); 206 connections_[connection] = http_connection; 207} 208 209void EmbeddedTestServer::DidRead(StreamListenSocket* connection, 210 const char* data, 211 int length) { 212 DCHECK(io_thread_->BelongsToCurrentThread()); 213 214 HttpConnection* http_connection = FindConnection(connection); 215 if (http_connection == NULL) { 216 LOG(WARNING) << "Unknown connection."; 217 return; 218 } 219 http_connection->ReceiveData(std::string(data, length)); 220} 221 222void EmbeddedTestServer::DidClose(StreamListenSocket* connection) { 223 DCHECK(io_thread_->BelongsToCurrentThread()); 224 225 HttpConnection* http_connection = FindConnection(connection); 226 if (http_connection == NULL) { 227 LOG(WARNING) << "Unknown connection."; 228 return; 229 } 230 delete http_connection; 231 connections_.erase(connection); 232} 233 234HttpConnection* EmbeddedTestServer::FindConnection( 235 StreamListenSocket* socket) { 236 DCHECK(io_thread_->BelongsToCurrentThread()); 237 238 std::map<StreamListenSocket*, HttpConnection*>::iterator it = 239 connections_.find(socket); 240 if (it == connections_.end()) { 241 return NULL; 242 } 243 return it->second; 244} 245 246} // namespace test_server 247} // namespace net 248