1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome_frame/test/test_server.h"
6
7#include <windows.h>
8#include <objbase.h>
9#include <urlmon.h>
10
11#include "base/bind.h"
12#include "base/logging.h"
13#include "base/strings/string_number_conversions.h"
14#include "base/strings/string_piece.h"
15#include "base/strings/string_util.h"
16#include "base/strings/stringprintf.h"
17#include "base/strings/utf_string_conversions.h"
18#include "chrome_frame/test/chrome_frame_test_utils.h"
19#include "net/base/winsock_init.h"
20#include "net/http/http_util.h"
21#include "net/socket/tcp_listen_socket.h"
22
23namespace test_server {
24const char kDefaultHeaderTemplate[] =
25    "HTTP/1.1 %hs\r\n"
26    "Connection: close\r\n"
27    "Content-Type: %hs\r\n"
28    "Content-Length: %i\r\n\r\n";
29const char kStatusOk[] = "200 OK";
30const char kStatusNotFound[] = "404 Not Found";
31const char kDefaultContentType[] = "text/html; charset=UTF-8";
32
33void Request::ParseHeaders(const std::string& headers) {
34  DCHECK(method_.length() == 0);
35
36  size_t pos = headers.find("\r\n");
37  DCHECK(pos != std::string::npos);
38  if (pos != std::string::npos) {
39    headers_ = headers.substr(pos + 2);
40
41    base::StringTokenizer tokenizer(
42        headers.begin(), headers.begin() + pos, " ");
43    std::string* parse[] = { &method_, &path_, &version_ };
44    int field = 0;
45    while (tokenizer.GetNext() && field < arraysize(parse)) {
46      parse[field++]->assign(tokenizer.token_begin(),
47                             tokenizer.token_end());
48    }
49  }
50
51  // Check for content-length in case we're being sent some data.
52  net::HttpUtil::HeadersIterator it(headers_.begin(), headers_.end(),
53                                    "\r\n");
54  while (it.GetNext()) {
55    if (LowerCaseEqualsASCII(it.name(), "content-length")) {
56      int int_content_length;
57      base::StringToInt(base::StringPiece(it.values_begin(),
58                                          it.values_end()),
59                        &int_content_length);
60      content_length_ = int_content_length;
61      break;
62    }
63  }
64}
65
66void Request::OnDataReceived(const std::string& data) {
67  content_ += data;
68
69  if (method_.length() == 0) {
70    size_t index = content_.find("\r\n\r\n");
71    if (index != std::string::npos) {
72      // Parse the headers before returning and chop them of the
73      // data buffer we've already received.
74      std::string headers(content_.substr(0, index + 2));
75      ParseHeaders(headers);
76      content_.erase(0, index + 4);
77    }
78  }
79}
80
81ResponseForPath::~ResponseForPath() {
82}
83
84SimpleResponse::~SimpleResponse() {
85}
86
87bool FileResponse::GetContentType(std::string* content_type) const {
88  size_t length = ContentLength();
89  char buffer[4096];
90  void* data = NULL;
91
92  if (length) {
93    // Create a copy of the first few bytes of the file.
94    // If we try and use the mapped file directly, FindMimeFromData will crash
95    // 'cause it cheats and temporarily tries to write to the buffer!
96    length = std::min(arraysize(buffer), length);
97    memcpy(buffer, file_->data(), length);
98    data = buffer;
99  }
100
101  LPOLESTR mime_type = NULL;
102  FindMimeFromData(NULL, file_path_.value().c_str(), data, length, NULL,
103                   FMFD_DEFAULT, &mime_type, 0);
104  if (mime_type) {
105    *content_type = WideToASCII(mime_type);
106    ::CoTaskMemFree(mime_type);
107  }
108
109  return content_type->length() > 0;
110}
111
112void FileResponse::WriteContents(net::StreamListenSocket* socket) const {
113  DCHECK(file_.get());
114  if (file_.get()) {
115    socket->Send(reinterpret_cast<const char*>(file_->data()),
116                 file_->length(), false);
117  }
118}
119
120size_t FileResponse::ContentLength() const {
121  if (file_.get() == NULL) {
122    file_.reset(new base::MemoryMappedFile());
123    if (!file_->Initialize(file_path_)) {
124      NOTREACHED();
125      file_.reset();
126    }
127  }
128  return file_.get() ? file_->length() : 0;
129}
130
131bool RedirectResponse::GetCustomHeaders(std::string* headers) const {
132  *headers = base::StringPrintf("HTTP/1.1 302 Found\r\n"
133                                "Connection: close\r\n"
134                                "Content-Length: 0\r\n"
135                                "Content-Type: text/html\r\n"
136                                "Location: %hs\r\n\r\n",
137                                redirect_url_.c_str());
138  return true;
139}
140
141SimpleWebServer::SimpleWebServer(int port) {
142  Construct(chrome_frame_test::GetLocalIPv4Address(), port);
143}
144
145SimpleWebServer::SimpleWebServer(const std::string& address, int port) {
146  Construct(address, port);
147}
148
149SimpleWebServer::~SimpleWebServer() {
150  ConnectionList::const_iterator it;
151  for (it = connections_.begin(); it != connections_.end(); ++it)
152    delete (*it);
153  connections_.clear();
154}
155
156void SimpleWebServer::Construct(const std::string& address, int port) {
157  CHECK(base::MessageLoop::current())
158      << "SimpleWebServer requires a message loop";
159  net::EnsureWinsockInit();
160  AddResponse(&quit_);
161  host_ = address;
162  server_ = net::TCPListenSocket::CreateAndListen(address, port, this);
163  LOG_IF(DFATAL, !server_.get())
164      << "Failed to create listener socket at " << address << ":" << port;
165}
166
167void SimpleWebServer::AddResponse(Response* response) {
168  responses_.push_back(response);
169}
170
171void SimpleWebServer::DeleteAllResponses() {
172  std::list<Response*>::const_iterator it;
173  for (it = responses_.begin(); it != responses_.end(); ++it) {
174    if ((*it) != &quit_)
175      delete (*it);
176  }
177}
178
179Response* SimpleWebServer::FindResponse(const Request& request) const {
180  std::list<Response*>::const_iterator it;
181  for (it = responses_.begin(); it != responses_.end(); it++) {
182    Response* response = (*it);
183    if (response->Matches(request)) {
184      return response;
185    }
186  }
187  return NULL;
188}
189
190Connection* SimpleWebServer::FindConnection(
191    const net::StreamListenSocket* socket) const {
192  ConnectionList::const_iterator it;
193  for (it = connections_.begin(); it != connections_.end(); it++) {
194    if ((*it)->IsSame(socket)) {
195      return (*it);
196    }
197  }
198  return NULL;
199}
200
201void SimpleWebServer::DidAccept(
202    net::StreamListenSocket* server,
203    scoped_ptr<net::StreamListenSocket> connection) {
204  connections_.push_back(new Connection(connection.Pass()));
205}
206
207void SimpleWebServer::DidRead(net::StreamListenSocket* connection,
208                              const char* data,
209                              int len) {
210  Connection* c = FindConnection(connection);
211  DCHECK(c);
212  Request& r = c->request();
213  std::string str(data, len);
214  r.OnDataReceived(str);
215  if (r.AllContentReceived()) {
216    const Request& request = c->request();
217    Response* response = FindResponse(request);
218    if (response) {
219      std::string headers;
220      if (!response->GetCustomHeaders(&headers)) {
221        std::string content_type;
222        if (!response->GetContentType(&content_type))
223          content_type = kDefaultContentType;
224        headers = base::StringPrintf(kDefaultHeaderTemplate, kStatusOk,
225                                     content_type.c_str(),
226                                     response->ContentLength());
227      }
228
229      connection->Send(headers, false);
230      response->WriteContents(connection);
231      response->IncrementAccessCounter();
232    } else {
233      std::string payload = "sorry, I can't find " + request.path();
234      std::string headers(base::StringPrintf(kDefaultHeaderTemplate,
235                                             kStatusNotFound,
236                                             kDefaultContentType,
237                                             payload.length()));
238      connection->Send(headers, false);
239      connection->Send(payload, false);
240    }
241  }
242}
243
244void SimpleWebServer::DidClose(net::StreamListenSocket* sock) {
245  // To keep the historical list of connections reasonably tidy, we delete
246  // 404's when the connection ends.
247  Connection* c = FindConnection(sock);
248  DCHECK(c);
249  c->OnSocketClosed();
250  if (!FindResponse(c->request())) {
251    // extremely inefficient, but in one line and not that common... :)
252    connections_.erase(std::find(connections_.begin(), connections_.end(), c));
253    delete c;
254  }
255}
256
257HTTPTestServer::HTTPTestServer(int port, const std::wstring& address,
258                               base::FilePath root_dir)
259    : port_(port), address_(address), root_dir_(root_dir) {
260  net::EnsureWinsockInit();
261  server_ =
262      net::TCPListenSocket::CreateAndListen(WideToUTF8(address), port, this);
263}
264
265HTTPTestServer::~HTTPTestServer() {
266}
267
268std::list<scoped_refptr<ConfigurableConnection>>::iterator
269HTTPTestServer::FindConnection(const net::StreamListenSocket* socket) {
270  ConnectionList::iterator it;
271  // Scan through the list searching for the desired socket. Along the way,
272  // erase any connections for which the corresponding socket has already been
273  // forgotten about as a result of all data having been sent.
274  for (it = connection_list_.begin(); it != connection_list_.end(); ) {
275    ConfigurableConnection* connection = it->get();
276    if (connection->socket_ == NULL) {
277      connection_list_.erase(it++);
278      continue;
279    }
280    if (connection->socket_ == socket)
281      break;
282    ++it;
283  }
284
285  return it;
286}
287
288scoped_refptr<ConfigurableConnection> HTTPTestServer::ConnectionFromSocket(
289    const net::StreamListenSocket* socket) {
290  ConnectionList::iterator it = FindConnection(socket);
291  if (it != connection_list_.end())
292    return *it;
293  return NULL;
294}
295
296void HTTPTestServer::DidAccept(net::StreamListenSocket* server,
297                               scoped_ptr<net::StreamListenSocket> socket) {
298  connection_list_.push_back(new ConfigurableConnection(socket.Pass()));
299}
300
301void HTTPTestServer::DidRead(net::StreamListenSocket* socket,
302                             const char* data,
303                             int len) {
304  scoped_refptr<ConfigurableConnection> connection =
305      ConnectionFromSocket(socket);
306  if (connection) {
307    std::string str(data, len);
308    connection->r_.OnDataReceived(str);
309    if (connection->r_.AllContentReceived()) {
310      VLOG(1) << __FUNCTION__ << ": " << connection->r_.method() << " "
311              << connection->r_.path();
312      std::wstring path = UTF8ToWide(connection->r_.path());
313      if (LowerCaseEqualsASCII(connection->r_.method(), "post"))
314        this->Post(connection, path, connection->r_);
315      else
316        this->Get(connection, path, connection->r_);
317    }
318  }
319}
320
321void HTTPTestServer::DidClose(net::StreamListenSocket* socket) {
322  ConnectionList::iterator it = FindConnection(socket);
323  if (it != connection_list_.end())
324    connection_list_.erase(it);
325}
326
327std::wstring HTTPTestServer::Resolve(const std::wstring& path) {
328  // Remove the first '/' if needed.
329  std::wstring stripped_path = path;
330  if (path.size() && path[0] == L'/')
331    stripped_path = path.substr(1);
332
333  if (port_ == 80) {
334    if (stripped_path.empty()) {
335      return base::StringPrintf(L"http://%ls", address_.c_str());
336    } else {
337      return base::StringPrintf(L"http://%ls/%ls", address_.c_str(),
338                          stripped_path.c_str());
339    }
340  } else {
341    if (stripped_path.empty()) {
342      return base::StringPrintf(L"http://%ls:%d", address_.c_str(), port_);
343    } else {
344      return base::StringPrintf(L"http://%ls:%d/%ls", address_.c_str(), port_,
345                                stripped_path.c_str());
346    }
347  }
348}
349
350void ConfigurableConnection::SendChunk() {
351  int size = (int)data_.size();
352  const char* chunk_ptr = data_.c_str() + cur_pos_;
353  int bytes_to_send = std::min(options_.chunk_size_, size - cur_pos_);
354
355  socket_->Send(chunk_ptr, bytes_to_send);
356  VLOG(1) << "Sent(" << cur_pos_ << "," << bytes_to_send << "): "
357          << base::StringPiece(chunk_ptr, bytes_to_send);
358
359  cur_pos_ += bytes_to_send;
360  if (cur_pos_ < size) {
361    base::MessageLoop::current()->PostDelayedTask(
362        FROM_HERE, base::Bind(&ConfigurableConnection::SendChunk, this),
363        base::TimeDelta::FromMilliseconds(options_.timeout_));
364  } else {
365    Close();
366  }
367}
368
369void ConfigurableConnection::Close() {
370  socket_.reset();
371}
372
373void ConfigurableConnection::Send(const std::string& headers,
374                                  const std::string& content) {
375  SendOptions options(SendOptions::IMMEDIATE, 0, 0);
376  SendWithOptions(headers, content, options);
377}
378
379void ConfigurableConnection::SendWithOptions(const std::string& headers,
380                                             const std::string& content,
381                                             const SendOptions& options) {
382  std::string content_length_header;
383  if (!content.empty() &&
384      std::string::npos == headers.find("Context-Length:")) {
385    content_length_header = base::StringPrintf("Content-Length: %u\r\n",
386                                               content.size());
387  }
388
389  // Save the options.
390  options_ = options;
391
392  if (options_.speed_ == SendOptions::IMMEDIATE) {
393    socket_->Send(headers);
394    socket_->Send(content_length_header, true);
395    socket_->Send(content);
396    // Post a task to close the socket since StreamListenSocket doesn't like
397    // instances to go away from within its callbacks.
398    base::MessageLoop::current()->PostTask(
399        FROM_HERE, base::Bind(&ConfigurableConnection::Close, this));
400
401    return;
402  }
403
404  if (options_.speed_ == SendOptions::IMMEDIATE_HEADERS_DELAYED_CONTENT) {
405    socket_->Send(headers);
406    socket_->Send(content_length_header, true);
407    VLOG(1) << "Headers sent: " << headers << content_length_header;
408    data_.append(content);
409  }
410
411  if (options_.speed_ == SendOptions::DELAYED) {
412    data_ = headers;
413    data_.append(content_length_header);
414    data_.append("\r\n");
415  }
416
417  base::MessageLoop::current()->PostDelayedTask(
418      FROM_HERE, base::Bind(&ConfigurableConnection::SendChunk, this),
419      base::TimeDelta::FromMilliseconds(options.timeout_));
420}
421
422}  // namespace test_server
423