embedded_test_server.cc revision b2df76ea8fec9e32f6f3718986dba0d95315b29c
1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/test/embedded_test_server/embedded_test_server.h" 6 7#include "base/bind.h" 8#include "base/run_loop.h" 9#include "base/stl_util.h" 10#include "base/string_util.h" 11#include "base/stringprintf.h" 12#include "net/test/embedded_test_server/http_connection.h" 13#include "net/test/embedded_test_server/http_request.h" 14#include "net/test/embedded_test_server/http_response.h" 15#include "net/tools/fetch/http_listen_socket.h" 16 17namespace net { 18namespace test_server { 19 20namespace { 21 22const int kPort = 8040; 23const char kIp[] = "127.0.0.1"; 24const int kRetries = 10; 25 26// Callback to handle requests with default predefined response for requests 27// matching the address |url|. 28scoped_ptr<HttpResponse> HandleDefaultRequest(const GURL& url, 29 const HttpResponse& response, 30 const HttpRequest& request) { 31 const GURL request_url = url.Resolve(request.relative_url); 32 if (url.path() != request_url.path()) 33 return scoped_ptr<HttpResponse>(NULL); 34 return scoped_ptr<HttpResponse>(new HttpResponse(response)); 35} 36 37} // namespace 38 39HttpListenSocket::HttpListenSocket(const SocketDescriptor socket_descriptor, 40 StreamListenSocket::Delegate* delegate) 41 : TCPListenSocket(socket_descriptor, delegate) { 42 DCHECK(thread_checker_.CalledOnValidThread()); 43} 44 45void HttpListenSocket::Listen() { 46 DCHECK(thread_checker_.CalledOnValidThread()); 47 TCPListenSocket::Listen(); 48} 49 50HttpListenSocket::~HttpListenSocket() { 51 DCHECK(thread_checker_.CalledOnValidThread()); 52} 53 54EmbeddedTestServer::EmbeddedTestServer( 55 const scoped_refptr<base::SingleThreadTaskRunner>& io_thread) 56 : io_thread_(io_thread), 57 port_(-1), 58 weak_factory_(this) { 59 DCHECK(io_thread_); 60 DCHECK(thread_checker_.CalledOnValidThread()); 61} 62 63EmbeddedTestServer::~EmbeddedTestServer() { 64 DCHECK(thread_checker_.CalledOnValidThread()); 65} 66 67bool EmbeddedTestServer::InitializeAndWaitUntilReady() { 68 DCHECK(thread_checker_.CalledOnValidThread()); 69 70 base::RunLoop run_loop; 71 if (!io_thread_->PostTaskAndReply( 72 FROM_HERE, 73 base::Bind(&EmbeddedTestServer::InitializeOnIOThread, 74 base::Unretained(this)), 75 run_loop.QuitClosure())) { 76 return false; 77 } 78 run_loop.Run(); 79 80 return Started(); 81} 82 83bool EmbeddedTestServer::ShutdownAndWaitUntilComplete() { 84 DCHECK(thread_checker_.CalledOnValidThread()); 85 86 base::RunLoop run_loop; 87 if (!io_thread_->PostTaskAndReply( 88 FROM_HERE, 89 base::Bind(&EmbeddedTestServer::ShutdownOnIOThread, 90 base::Unretained(this)), 91 run_loop.QuitClosure())) { 92 return false; 93 } 94 run_loop.Run(); 95 96 return true; 97} 98 99void EmbeddedTestServer::InitializeOnIOThread() { 100 DCHECK(io_thread_->BelongsToCurrentThread()); 101 DCHECK(!Started()); 102 103 int retries_left = kRetries + 1; 104 int try_port = kPort; 105 106 while (retries_left > 0) { 107 SocketDescriptor socket_descriptor = TCPListenSocket::CreateAndBind( 108 kIp, 109 try_port); 110 if (socket_descriptor != TCPListenSocket::kInvalidSocket) { 111 listen_socket_ = new HttpListenSocket(socket_descriptor, this); 112 listen_socket_->Listen(); 113 base_url_ = GURL(base::StringPrintf("http://%s:%d", kIp, try_port)); 114 port_ = try_port; 115 break; 116 } 117 retries_left--; 118 try_port++; 119 } 120} 121 122void EmbeddedTestServer::ShutdownOnIOThread() { 123 DCHECK(io_thread_->BelongsToCurrentThread()); 124 125 listen_socket_ = NULL; // Release the listen socket. 126 STLDeleteContainerPairSecondPointers(connections_.begin(), 127 connections_.end()); 128 connections_.clear(); 129} 130 131void EmbeddedTestServer::HandleRequest(HttpConnection* connection, 132 scoped_ptr<HttpRequest> request) { 133 DCHECK(io_thread_->BelongsToCurrentThread()); 134 135 for (size_t i = 0; i < request_handlers_.size(); ++i) { 136 scoped_ptr<HttpResponse> response = 137 request_handlers_[i].Run(*request.get()); 138 if (response.get()) { 139 connection->SendResponse(response.Pass()); 140 return; 141 } 142 } 143 144 LOG(WARNING) << "Request not handled. Returning 404: " 145 << request->relative_url; 146 scoped_ptr<HttpResponse> not_found_response(new HttpResponse()); 147 not_found_response->set_code(NOT_FOUND); 148 connection->SendResponse(not_found_response.Pass()); 149 150 // Drop the connection, since we do not support multiple requests per 151 // connection. 152 connections_.erase(connection->socket_.get()); 153 delete connection; 154} 155 156GURL EmbeddedTestServer::GetURL(const std::string& relative_url) const { 157 DCHECK(StartsWithASCII(relative_url, "/", true /* case_sensitive */)) 158 << relative_url; 159 return base_url_.Resolve(relative_url); 160} 161 162void EmbeddedTestServer::RegisterRequestHandler( 163 const HandleRequestCallback& callback) { 164 request_handlers_.push_back(callback); 165} 166 167void EmbeddedTestServer::DidAccept(StreamListenSocket* server, 168 StreamListenSocket* connection) { 169 DCHECK(io_thread_->BelongsToCurrentThread()); 170 171 HttpConnection* http_connection = new HttpConnection( 172 connection, 173 base::Bind(&EmbeddedTestServer::HandleRequest, 174 weak_factory_.GetWeakPtr())); 175 connections_[connection] = http_connection; 176} 177 178void EmbeddedTestServer::DidRead(StreamListenSocket* connection, 179 const char* data, 180 int length) { 181 DCHECK(io_thread_->BelongsToCurrentThread()); 182 183 HttpConnection* http_connection = FindConnection(connection); 184 if (http_connection == NULL) { 185 LOG(WARNING) << "Unknown connection."; 186 return; 187 } 188 http_connection->ReceiveData(std::string(data, length)); 189} 190 191void EmbeddedTestServer::DidClose(StreamListenSocket* connection) { 192 DCHECK(io_thread_->BelongsToCurrentThread()); 193 194 HttpConnection* http_connection = FindConnection(connection); 195 if (http_connection == NULL) { 196 LOG(WARNING) << "Unknown connection."; 197 return; 198 } 199 delete http_connection; 200 connections_.erase(connection); 201} 202 203HttpConnection* EmbeddedTestServer::FindConnection( 204 StreamListenSocket* socket) { 205 DCHECK(io_thread_->BelongsToCurrentThread()); 206 207 std::map<StreamListenSocket*, HttpConnection*>::iterator it = 208 connections_.find(socket); 209 if (it == connections_.end()) { 210 return NULL; 211 } 212 return it->second; 213} 214 215} // namespace test_server 216} // namespace net 217