1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/server/http_server.h"
6
7#include "base/compiler_specific.h"
8#include "base/logging.h"
9#include "base/stl_util.h"
10#include "base/strings/string_number_conversions.h"
11#include "base/strings/string_util.h"
12#include "base/strings/stringprintf.h"
13#include "base/sys_byteorder.h"
14#include "build/build_config.h"
15#include "net/base/net_errors.h"
16#include "net/server/http_connection.h"
17#include "net/server/http_server_request_info.h"
18#include "net/server/http_server_response_info.h"
19#include "net/server/web_socket.h"
20#include "net/socket/server_socket.h"
21#include "net/socket/stream_socket.h"
22#include "net/socket/tcp_server_socket.h"
23
24namespace net {
25
26HttpServer::HttpServer(scoped_ptr<ServerSocket> server_socket,
27                       HttpServer::Delegate* delegate)
28    : server_socket_(server_socket.Pass()),
29      delegate_(delegate),
30      last_id_(0),
31      weak_ptr_factory_(this) {
32  DCHECK(server_socket_);
33  DoAcceptLoop();
34}
35
36HttpServer::~HttpServer() {
37  STLDeleteContainerPairSecondPointers(
38      id_to_connection_.begin(), id_to_connection_.end());
39}
40
41void HttpServer::AcceptWebSocket(
42    int connection_id,
43    const HttpServerRequestInfo& request) {
44  HttpConnection* connection = FindConnection(connection_id);
45  if (connection == NULL)
46    return;
47  DCHECK(connection->web_socket());
48  connection->web_socket()->Accept(request);
49}
50
51void HttpServer::SendOverWebSocket(int connection_id,
52                                   const std::string& data) {
53  HttpConnection* connection = FindConnection(connection_id);
54  if (connection == NULL)
55    return;
56  DCHECK(connection->web_socket());
57  connection->web_socket()->Send(data);
58}
59
60void HttpServer::SendRaw(int connection_id, const std::string& data) {
61  HttpConnection* connection = FindConnection(connection_id);
62  if (connection == NULL)
63    return;
64
65  bool writing_in_progress = !connection->write_buf()->IsEmpty();
66  if (connection->write_buf()->Append(data) && !writing_in_progress)
67    DoWriteLoop(connection);
68}
69
70void HttpServer::SendResponse(int connection_id,
71                              const HttpServerResponseInfo& response) {
72  SendRaw(connection_id, response.Serialize());
73}
74
75void HttpServer::Send(int connection_id,
76                      HttpStatusCode status_code,
77                      const std::string& data,
78                      const std::string& content_type) {
79  HttpServerResponseInfo response(status_code);
80  response.SetContentHeaders(data.size(), content_type);
81  SendResponse(connection_id, response);
82  SendRaw(connection_id, data);
83}
84
85void HttpServer::Send200(int connection_id,
86                         const std::string& data,
87                         const std::string& content_type) {
88  Send(connection_id, HTTP_OK, data, content_type);
89}
90
91void HttpServer::Send404(int connection_id) {
92  SendResponse(connection_id, HttpServerResponseInfo::CreateFor404());
93}
94
95void HttpServer::Send500(int connection_id, const std::string& message) {
96  SendResponse(connection_id, HttpServerResponseInfo::CreateFor500(message));
97}
98
99void HttpServer::Close(int connection_id) {
100  HttpConnection* connection = FindConnection(connection_id);
101  if (connection == NULL)
102    return;
103
104  id_to_connection_.erase(connection_id);
105  delegate_->OnClose(connection_id);
106
107  // The call stack might have callbacks which still have the pointer of
108  // connection. Instead of referencing connection with ID all the time,
109  // destroys the connection in next run loop to make sure any pending
110  // callbacks in the call stack return.
111  base::MessageLoopProxy::current()->DeleteSoon(FROM_HERE, connection);
112}
113
114int HttpServer::GetLocalAddress(IPEndPoint* address) {
115  return server_socket_->GetLocalAddress(address);
116}
117
118void HttpServer::SetReceiveBufferSize(int connection_id, int32 size) {
119  HttpConnection* connection = FindConnection(connection_id);
120  DCHECK(connection);
121  connection->read_buf()->set_max_buffer_size(size);
122}
123
124void HttpServer::SetSendBufferSize(int connection_id, int32 size) {
125  HttpConnection* connection = FindConnection(connection_id);
126  DCHECK(connection);
127  connection->write_buf()->set_max_buffer_size(size);
128}
129
130void HttpServer::DoAcceptLoop() {
131  int rv;
132  do {
133    rv = server_socket_->Accept(&accepted_socket_,
134                                base::Bind(&HttpServer::OnAcceptCompleted,
135                                           weak_ptr_factory_.GetWeakPtr()));
136    if (rv == ERR_IO_PENDING)
137      return;
138    rv = HandleAcceptResult(rv);
139  } while (rv == OK);
140}
141
142void HttpServer::OnAcceptCompleted(int rv) {
143  if (HandleAcceptResult(rv) == OK)
144    DoAcceptLoop();
145}
146
147int HttpServer::HandleAcceptResult(int rv) {
148  if (rv < 0) {
149    LOG(ERROR) << "Accept error: rv=" << rv;
150    return rv;
151  }
152
153  HttpConnection* connection =
154      new HttpConnection(++last_id_, accepted_socket_.Pass());
155  id_to_connection_[connection->id()] = connection;
156  delegate_->OnConnect(connection->id());
157  if (!HasClosedConnection(connection))
158    DoReadLoop(connection);
159  return OK;
160}
161
162void HttpServer::DoReadLoop(HttpConnection* connection) {
163  int rv;
164  do {
165    HttpConnection::ReadIOBuffer* read_buf = connection->read_buf();
166    // Increases read buffer size if necessary.
167    if (read_buf->RemainingCapacity() == 0 && !read_buf->IncreaseCapacity()) {
168      Close(connection->id());
169      return;
170    }
171
172    rv = connection->socket()->Read(
173        read_buf,
174        read_buf->RemainingCapacity(),
175        base::Bind(&HttpServer::OnReadCompleted,
176                   weak_ptr_factory_.GetWeakPtr(), connection->id()));
177    if (rv == ERR_IO_PENDING)
178      return;
179    rv = HandleReadResult(connection, rv);
180  } while (rv == OK);
181}
182
183void HttpServer::OnReadCompleted(int connection_id, int rv) {
184  HttpConnection* connection = FindConnection(connection_id);
185  if (!connection)  // It might be closed right before by write error.
186    return;
187
188  if (HandleReadResult(connection, rv) == OK)
189    DoReadLoop(connection);
190}
191
192int HttpServer::HandleReadResult(HttpConnection* connection, int rv) {
193  if (rv <= 0) {
194    Close(connection->id());
195    return rv == 0 ? ERR_CONNECTION_CLOSED : rv;
196  }
197
198  HttpConnection::ReadIOBuffer* read_buf = connection->read_buf();
199  read_buf->DidRead(rv);
200
201  // Handles http requests or websocket messages.
202  while (read_buf->GetSize() > 0) {
203    if (connection->web_socket()) {
204      std::string message;
205      WebSocket::ParseResult result = connection->web_socket()->Read(&message);
206      if (result == WebSocket::FRAME_INCOMPLETE)
207        break;
208
209      if (result == WebSocket::FRAME_CLOSE ||
210          result == WebSocket::FRAME_ERROR) {
211        Close(connection->id());
212        return ERR_CONNECTION_CLOSED;
213      }
214      delegate_->OnWebSocketMessage(connection->id(), message);
215      if (HasClosedConnection(connection))
216        return ERR_CONNECTION_CLOSED;
217      continue;
218    }
219
220    HttpServerRequestInfo request;
221    size_t pos = 0;
222    if (!ParseHeaders(read_buf->StartOfBuffer(), read_buf->GetSize(),
223                      &request, &pos)) {
224      break;
225    }
226
227    // Sets peer address if exists.
228    connection->socket()->GetPeerAddress(&request.peer);
229
230    if (request.HasHeaderValue("connection", "upgrade")) {
231      scoped_ptr<WebSocket> websocket(
232          WebSocket::CreateWebSocket(this, connection, request, &pos));
233      if (!websocket)  // Not enough data was received.
234        break;
235      connection->SetWebSocket(websocket.Pass());
236      read_buf->DidConsume(pos);
237      delegate_->OnWebSocketRequest(connection->id(), request);
238      if (HasClosedConnection(connection))
239        return ERR_CONNECTION_CLOSED;
240      continue;
241    }
242
243    const char kContentLength[] = "content-length";
244    if (request.headers.count(kContentLength) > 0) {
245      size_t content_length = 0;
246      const size_t kMaxBodySize = 100 << 20;
247      if (!base::StringToSizeT(request.GetHeaderValue(kContentLength),
248                               &content_length) ||
249          content_length > kMaxBodySize) {
250        SendResponse(connection->id(),
251                     HttpServerResponseInfo::CreateFor500(
252                         "request content-length too big or unknown: " +
253                         request.GetHeaderValue(kContentLength)));
254        Close(connection->id());
255        return ERR_CONNECTION_CLOSED;
256      }
257
258      if (read_buf->GetSize() - pos < content_length)
259        break;  // Not enough data was received yet.
260      request.data.assign(read_buf->StartOfBuffer() + pos, content_length);
261      pos += content_length;
262    }
263
264    read_buf->DidConsume(pos);
265    delegate_->OnHttpRequest(connection->id(), request);
266    if (HasClosedConnection(connection))
267      return ERR_CONNECTION_CLOSED;
268  }
269
270  return OK;
271}
272
273void HttpServer::DoWriteLoop(HttpConnection* connection) {
274  int rv = OK;
275  HttpConnection::QueuedWriteIOBuffer* write_buf = connection->write_buf();
276  while (rv == OK && write_buf->GetSizeToWrite() > 0) {
277    rv = connection->socket()->Write(
278        write_buf,
279        write_buf->GetSizeToWrite(),
280        base::Bind(&HttpServer::OnWriteCompleted,
281                   weak_ptr_factory_.GetWeakPtr(), connection->id()));
282    if (rv == ERR_IO_PENDING || rv == OK)
283      return;
284    rv = HandleWriteResult(connection, rv);
285  }
286}
287
288void HttpServer::OnWriteCompleted(int connection_id, int rv) {
289  HttpConnection* connection = FindConnection(connection_id);
290  if (!connection)  // It might be closed right before by read error.
291    return;
292
293  if (HandleWriteResult(connection, rv) == OK)
294    DoWriteLoop(connection);
295}
296
297int HttpServer::HandleWriteResult(HttpConnection* connection, int rv) {
298  if (rv < 0) {
299    Close(connection->id());
300    return rv;
301  }
302
303  connection->write_buf()->DidConsume(rv);
304  return OK;
305}
306
307namespace {
308
309//
310// HTTP Request Parser
311// This HTTP request parser uses a simple state machine to quickly parse
312// through the headers.  The parser is not 100% complete, as it is designed
313// for use in this simple test driver.
314//
315// Known issues:
316//   - does not handle whitespace on first HTTP line correctly.  Expects
317//     a single space between the method/url and url/protocol.
318
319// Input character types.
320enum header_parse_inputs {
321  INPUT_LWS,
322  INPUT_CR,
323  INPUT_LF,
324  INPUT_COLON,
325  INPUT_DEFAULT,
326  MAX_INPUTS,
327};
328
329// Parser states.
330enum header_parse_states {
331  ST_METHOD,     // Receiving the method
332  ST_URL,        // Receiving the URL
333  ST_PROTO,      // Receiving the protocol
334  ST_HEADER,     // Starting a Request Header
335  ST_NAME,       // Receiving a request header name
336  ST_SEPARATOR,  // Receiving the separator between header name and value
337  ST_VALUE,      // Receiving a request header value
338  ST_DONE,       // Parsing is complete and successful
339  ST_ERR,        // Parsing encountered invalid syntax.
340  MAX_STATES
341};
342
343// State transition table
344int parser_state[MAX_STATES][MAX_INPUTS] = {
345/* METHOD    */ { ST_URL,       ST_ERR,     ST_ERR,   ST_ERR,       ST_METHOD },
346/* URL       */ { ST_PROTO,     ST_ERR,     ST_ERR,   ST_URL,       ST_URL },
347/* PROTOCOL  */ { ST_ERR,       ST_HEADER,  ST_NAME,  ST_ERR,       ST_PROTO },
348/* HEADER    */ { ST_ERR,       ST_ERR,     ST_NAME,  ST_ERR,       ST_ERR },
349/* NAME      */ { ST_SEPARATOR, ST_DONE,    ST_ERR,   ST_VALUE,     ST_NAME },
350/* SEPARATOR */ { ST_SEPARATOR, ST_ERR,     ST_ERR,   ST_VALUE,     ST_ERR },
351/* VALUE     */ { ST_VALUE,     ST_HEADER,  ST_NAME,  ST_VALUE,     ST_VALUE },
352/* DONE      */ { ST_DONE,      ST_DONE,    ST_DONE,  ST_DONE,      ST_DONE },
353/* ERR       */ { ST_ERR,       ST_ERR,     ST_ERR,   ST_ERR,       ST_ERR }
354};
355
356// Convert an input character to the parser's input token.
357int charToInput(char ch) {
358  switch(ch) {
359    case ' ':
360    case '\t':
361      return INPUT_LWS;
362    case '\r':
363      return INPUT_CR;
364    case '\n':
365      return INPUT_LF;
366    case ':':
367      return INPUT_COLON;
368  }
369  return INPUT_DEFAULT;
370}
371
372}  // namespace
373
374bool HttpServer::ParseHeaders(const char* data,
375                              size_t data_len,
376                              HttpServerRequestInfo* info,
377                              size_t* ppos) {
378  size_t& pos = *ppos;
379  int state = ST_METHOD;
380  std::string buffer;
381  std::string header_name;
382  std::string header_value;
383  while (pos < data_len) {
384    char ch = data[pos++];
385    int input = charToInput(ch);
386    int next_state = parser_state[state][input];
387
388    bool transition = (next_state != state);
389    HttpServerRequestInfo::HeadersMap::iterator it;
390    if (transition) {
391      // Do any actions based on state transitions.
392      switch (state) {
393        case ST_METHOD:
394          info->method = buffer;
395          buffer.clear();
396          break;
397        case ST_URL:
398          info->path = buffer;
399          buffer.clear();
400          break;
401        case ST_PROTO:
402          // TODO(mbelshe): Deal better with parsing protocol.
403          DCHECK(buffer == "HTTP/1.1");
404          buffer.clear();
405          break;
406        case ST_NAME:
407          header_name = base::StringToLowerASCII(buffer);
408          buffer.clear();
409          break;
410        case ST_VALUE:
411          base::TrimWhitespaceASCII(buffer, base::TRIM_LEADING, &header_value);
412          it = info->headers.find(header_name);
413          // See last paragraph ("Multiple message-header fields...")
414          // of www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
415          if (it == info->headers.end()) {
416            info->headers[header_name] = header_value;
417          } else {
418            it->second.append(",");
419            it->second.append(header_value);
420          }
421          buffer.clear();
422          break;
423        case ST_SEPARATOR:
424          break;
425      }
426      state = next_state;
427    } else {
428      // Do any actions based on current state
429      switch (state) {
430        case ST_METHOD:
431        case ST_URL:
432        case ST_PROTO:
433        case ST_VALUE:
434        case ST_NAME:
435          buffer.append(&ch, 1);
436          break;
437        case ST_DONE:
438          DCHECK(input == INPUT_LF);
439          return true;
440        case ST_ERR:
441          return false;
442      }
443    }
444  }
445  // No more characters, but we haven't finished parsing yet.
446  return false;
447}
448
449HttpConnection* HttpServer::FindConnection(int connection_id) {
450  IdToConnectionMap::iterator it = id_to_connection_.find(connection_id);
451  if (it == id_to_connection_.end())
452    return NULL;
453  return it->second;
454}
455
456// This is called after any delegate callbacks are called to check if Close()
457// has been called during callback processing. Using the pointer of connection,
458// |connection| is safe here because Close() deletes the connection in next run
459// loop.
460bool HttpServer::HasClosedConnection(HttpConnection* connection) {
461  return FindConnection(connection->id()) != connection;
462}
463
464}  // namespace net
465