1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <algorithm>
6#include <utility>
7#include <vector>
8
9#include "base/bind.h"
10#include "base/bind_helpers.h"
11#include "base/callback_helpers.h"
12#include "base/compiler_specific.h"
13#include "base/format_macros.h"
14#include "base/memory/ref_counted.h"
15#include "base/memory/scoped_ptr.h"
16#include "base/memory/weak_ptr.h"
17#include "base/message_loop/message_loop.h"
18#include "base/message_loop/message_loop_proxy.h"
19#include "base/run_loop.h"
20#include "base/strings/string_split.h"
21#include "base/strings/string_util.h"
22#include "base/strings/stringprintf.h"
23#include "base/time/time.h"
24#include "net/base/address_list.h"
25#include "net/base/io_buffer.h"
26#include "net/base/ip_endpoint.h"
27#include "net/base/net_errors.h"
28#include "net/base/net_log.h"
29#include "net/base/net_util.h"
30#include "net/base/test_completion_callback.h"
31#include "net/http/http_response_headers.h"
32#include "net/http/http_util.h"
33#include "net/server/http_server.h"
34#include "net/server/http_server_request_info.h"
35#include "net/socket/tcp_client_socket.h"
36#include "net/socket/tcp_server_socket.h"
37#include "net/url_request/url_fetcher.h"
38#include "net/url_request/url_fetcher_delegate.h"
39#include "net/url_request/url_request_context.h"
40#include "net/url_request/url_request_context_getter.h"
41#include "net/url_request/url_request_test_util.h"
42#include "testing/gtest/include/gtest/gtest.h"
43
44namespace net {
45
46namespace {
47
48const int kMaxExpectedResponseLength = 2048;
49
50void SetTimedOutAndQuitLoop(const base::WeakPtr<bool> timed_out,
51                            const base::Closure& quit_loop_func) {
52  if (timed_out) {
53    *timed_out = true;
54    quit_loop_func.Run();
55  }
56}
57
58bool RunLoopWithTimeout(base::RunLoop* run_loop) {
59  bool timed_out = false;
60  base::WeakPtrFactory<bool> timed_out_weak_factory(&timed_out);
61  base::MessageLoop::current()->PostDelayedTask(
62      FROM_HERE,
63      base::Bind(&SetTimedOutAndQuitLoop,
64                 timed_out_weak_factory.GetWeakPtr(),
65                 run_loop->QuitClosure()),
66      base::TimeDelta::FromSeconds(1));
67  run_loop->Run();
68  return !timed_out;
69}
70
71class TestHttpClient {
72 public:
73  TestHttpClient() : connect_result_(OK) {}
74
75  int ConnectAndWait(const IPEndPoint& address) {
76    AddressList addresses(address);
77    NetLog::Source source;
78    socket_.reset(new TCPClientSocket(addresses, NULL, source));
79
80    base::RunLoop run_loop;
81    connect_result_ = socket_->Connect(base::Bind(&TestHttpClient::OnConnect,
82                                                  base::Unretained(this),
83                                                  run_loop.QuitClosure()));
84    if (connect_result_ != OK && connect_result_ != ERR_IO_PENDING)
85      return connect_result_;
86
87    if (!RunLoopWithTimeout(&run_loop))
88      return ERR_TIMED_OUT;
89    return connect_result_;
90  }
91
92  void Send(const std::string& data) {
93    write_buffer_ =
94        new DrainableIOBuffer(new StringIOBuffer(data), data.length());
95    Write();
96  }
97
98  bool Read(std::string* message, int expected_bytes) {
99    int total_bytes_received = 0;
100    message->clear();
101    while (total_bytes_received < expected_bytes) {
102      net::TestCompletionCallback callback;
103      ReadInternal(callback.callback());
104      int bytes_received = callback.WaitForResult();
105      if (bytes_received <= 0)
106        return false;
107
108      total_bytes_received += bytes_received;
109      message->append(read_buffer_->data(), bytes_received);
110    }
111    return true;
112  }
113
114  bool ReadResponse(std::string* message) {
115    if (!Read(message, 1))
116      return false;
117    while (!IsCompleteResponse(*message)) {
118      std::string chunk;
119      if (!Read(&chunk, 1))
120        return false;
121      message->append(chunk);
122    }
123    return true;
124  }
125
126 private:
127  void OnConnect(const base::Closure& quit_loop, int result) {
128    connect_result_ = result;
129    quit_loop.Run();
130  }
131
132  void Write() {
133    int result = socket_->Write(
134        write_buffer_.get(),
135        write_buffer_->BytesRemaining(),
136        base::Bind(&TestHttpClient::OnWrite, base::Unretained(this)));
137    if (result != ERR_IO_PENDING)
138      OnWrite(result);
139  }
140
141  void OnWrite(int result) {
142    ASSERT_GT(result, 0);
143    write_buffer_->DidConsume(result);
144    if (write_buffer_->BytesRemaining())
145      Write();
146  }
147
148  void ReadInternal(const net::CompletionCallback& callback) {
149    read_buffer_ = new IOBufferWithSize(kMaxExpectedResponseLength);
150    int result =
151        socket_->Read(read_buffer_.get(), kMaxExpectedResponseLength, callback);
152    if (result != ERR_IO_PENDING)
153      callback.Run(result);
154  }
155
156  bool IsCompleteResponse(const std::string& response) {
157    // Check end of headers first.
158    int end_of_headers = HttpUtil::LocateEndOfHeaders(response.data(),
159                                                      response.size());
160    if (end_of_headers < 0)
161      return false;
162
163    // Return true if response has data equal to or more than content length.
164    int64 body_size = static_cast<int64>(response.size()) - end_of_headers;
165    DCHECK_LE(0, body_size);
166    scoped_refptr<HttpResponseHeaders> headers(new HttpResponseHeaders(
167        HttpUtil::AssembleRawHeaders(response.data(), end_of_headers)));
168    return body_size >= headers->GetContentLength();
169  }
170
171  scoped_refptr<IOBufferWithSize> read_buffer_;
172  scoped_refptr<DrainableIOBuffer> write_buffer_;
173  scoped_ptr<TCPClientSocket> socket_;
174  int connect_result_;
175};
176
177}  // namespace
178
179class HttpServerTest : public testing::Test,
180                       public HttpServer::Delegate {
181 public:
182  HttpServerTest() : quit_after_request_count_(0) {}
183
184  virtual void SetUp() OVERRIDE {
185    scoped_ptr<ServerSocket> server_socket(
186        new TCPServerSocket(NULL, net::NetLog::Source()));
187    server_socket->ListenWithAddressAndPort("127.0.0.1", 0, 1);
188    server_.reset(new HttpServer(server_socket.Pass(), this));
189    ASSERT_EQ(OK, server_->GetLocalAddress(&server_address_));
190  }
191
192  virtual void OnConnect(int connection_id) OVERRIDE {}
193
194  virtual void OnHttpRequest(int connection_id,
195                             const HttpServerRequestInfo& info) OVERRIDE {
196    requests_.push_back(std::make_pair(info, connection_id));
197    if (requests_.size() == quit_after_request_count_)
198      run_loop_quit_func_.Run();
199  }
200
201  virtual void OnWebSocketRequest(int connection_id,
202                                  const HttpServerRequestInfo& info) OVERRIDE {
203    NOTREACHED();
204  }
205
206  virtual void OnWebSocketMessage(int connection_id,
207                                  const std::string& data) OVERRIDE {
208    NOTREACHED();
209  }
210
211  virtual void OnClose(int connection_id) OVERRIDE {}
212
213  bool RunUntilRequestsReceived(size_t count) {
214    quit_after_request_count_ = count;
215    if (requests_.size() == count)
216      return true;
217
218    base::RunLoop run_loop;
219    run_loop_quit_func_ = run_loop.QuitClosure();
220    bool success = RunLoopWithTimeout(&run_loop);
221    run_loop_quit_func_.Reset();
222    return success;
223  }
224
225  HttpServerRequestInfo GetRequest(size_t request_index) {
226    return requests_[request_index].first;
227  }
228
229  int GetConnectionId(size_t request_index) {
230    return requests_[request_index].second;
231  }
232
233  void HandleAcceptResult(scoped_ptr<StreamSocket> socket) {
234    server_->accepted_socket_.reset(socket.release());
235    server_->HandleAcceptResult(OK);
236  }
237
238 protected:
239  scoped_ptr<HttpServer> server_;
240  IPEndPoint server_address_;
241  base::Closure run_loop_quit_func_;
242  std::vector<std::pair<HttpServerRequestInfo, int> > requests_;
243
244 private:
245  size_t quit_after_request_count_;
246};
247
248namespace {
249
250class WebSocketTest : public HttpServerTest {
251  virtual void OnHttpRequest(int connection_id,
252                             const HttpServerRequestInfo& info) OVERRIDE {
253    NOTREACHED();
254  }
255
256  virtual void OnWebSocketRequest(int connection_id,
257                                  const HttpServerRequestInfo& info) OVERRIDE {
258    HttpServerTest::OnHttpRequest(connection_id, info);
259  }
260
261  virtual void OnWebSocketMessage(int connection_id,
262                                  const std::string& data) OVERRIDE {
263  }
264};
265
266TEST_F(HttpServerTest, Request) {
267  TestHttpClient client;
268  ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
269  client.Send("GET /test HTTP/1.1\r\n\r\n");
270  ASSERT_TRUE(RunUntilRequestsReceived(1));
271  ASSERT_EQ("GET", GetRequest(0).method);
272  ASSERT_EQ("/test", GetRequest(0).path);
273  ASSERT_EQ("", GetRequest(0).data);
274  ASSERT_EQ(0u, GetRequest(0).headers.size());
275  ASSERT_TRUE(StartsWithASCII(GetRequest(0).peer.ToString(),
276                              "127.0.0.1",
277                              true));
278}
279
280TEST_F(HttpServerTest, RequestWithHeaders) {
281  TestHttpClient client;
282  ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
283  const char* kHeaders[][3] = {
284      {"Header", ": ", "1"},
285      {"HeaderWithNoWhitespace", ":", "1"},
286      {"HeaderWithWhitespace", "   :  \t   ", "1 1 1 \t  "},
287      {"HeaderWithColon", ": ", "1:1"},
288      {"EmptyHeader", ":", ""},
289      {"EmptyHeaderWithWhitespace", ":  \t  ", ""},
290      {"HeaderWithNonASCII", ":  ", "\xf7"},
291  };
292  std::string headers;
293  for (size_t i = 0; i < arraysize(kHeaders); ++i) {
294    headers +=
295        std::string(kHeaders[i][0]) + kHeaders[i][1] + kHeaders[i][2] + "\r\n";
296  }
297
298  client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
299  ASSERT_TRUE(RunUntilRequestsReceived(1));
300  ASSERT_EQ("", GetRequest(0).data);
301
302  for (size_t i = 0; i < arraysize(kHeaders); ++i) {
303    std::string field = base::StringToLowerASCII(std::string(kHeaders[i][0]));
304    std::string value = kHeaders[i][2];
305    ASSERT_EQ(1u, GetRequest(0).headers.count(field)) << field;
306    ASSERT_EQ(value, GetRequest(0).headers[field]) << kHeaders[i][0];
307  }
308}
309
310TEST_F(HttpServerTest, RequestWithDuplicateHeaders) {
311  TestHttpClient client;
312  ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
313  const char* kHeaders[][3] = {
314      {"FirstHeader", ": ", "1"},
315      {"DuplicateHeader", ": ", "2"},
316      {"MiddleHeader", ": ", "3"},
317      {"DuplicateHeader", ": ", "4"},
318      {"LastHeader", ": ", "5"},
319  };
320  std::string headers;
321  for (size_t i = 0; i < arraysize(kHeaders); ++i) {
322    headers +=
323        std::string(kHeaders[i][0]) + kHeaders[i][1] + kHeaders[i][2] + "\r\n";
324  }
325
326  client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
327  ASSERT_TRUE(RunUntilRequestsReceived(1));
328  ASSERT_EQ("", GetRequest(0).data);
329
330  for (size_t i = 0; i < arraysize(kHeaders); ++i) {
331    std::string field = base::StringToLowerASCII(std::string(kHeaders[i][0]));
332    std::string value = (field == "duplicateheader") ? "2,4" : kHeaders[i][2];
333    ASSERT_EQ(1u, GetRequest(0).headers.count(field)) << field;
334    ASSERT_EQ(value, GetRequest(0).headers[field]) << kHeaders[i][0];
335  }
336}
337
338TEST_F(HttpServerTest, HasHeaderValueTest) {
339  TestHttpClient client;
340  ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
341  const char* kHeaders[] = {
342      "Header: Abcd",
343      "HeaderWithNoWhitespace:E",
344      "HeaderWithWhitespace   :  \t   f \t  ",
345      "DuplicateHeader: g",
346      "HeaderWithComma: h, i ,j",
347      "DuplicateHeader: k",
348      "EmptyHeader:",
349      "EmptyHeaderWithWhitespace:  \t  ",
350      "HeaderWithNonASCII:  \xf7",
351  };
352  std::string headers;
353  for (size_t i = 0; i < arraysize(kHeaders); ++i) {
354    headers += std::string(kHeaders[i]) + "\r\n";
355  }
356
357  client.Send("GET /test HTTP/1.1\r\n" + headers + "\r\n");
358  ASSERT_TRUE(RunUntilRequestsReceived(1));
359  ASSERT_EQ("", GetRequest(0).data);
360
361  ASSERT_TRUE(GetRequest(0).HasHeaderValue("header", "abcd"));
362  ASSERT_FALSE(GetRequest(0).HasHeaderValue("header", "bc"));
363  ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithnowhitespace", "e"));
364  ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithwhitespace", "f"));
365  ASSERT_TRUE(GetRequest(0).HasHeaderValue("duplicateheader", "g"));
366  ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "h"));
367  ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "i"));
368  ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithcomma", "j"));
369  ASSERT_TRUE(GetRequest(0).HasHeaderValue("duplicateheader", "k"));
370  ASSERT_FALSE(GetRequest(0).HasHeaderValue("emptyheader", "x"));
371  ASSERT_FALSE(GetRequest(0).HasHeaderValue("emptyheaderwithwhitespace", "x"));
372  ASSERT_TRUE(GetRequest(0).HasHeaderValue("headerwithnonascii", "\xf7"));
373}
374
375TEST_F(HttpServerTest, RequestWithBody) {
376  TestHttpClient client;
377  ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
378  std::string body = "a" + std::string(1 << 10, 'b') + "c";
379  client.Send(base::StringPrintf(
380      "GET /test HTTP/1.1\r\n"
381      "SomeHeader: 1\r\n"
382      "Content-Length: %" PRIuS "\r\n\r\n%s",
383      body.length(),
384      body.c_str()));
385  ASSERT_TRUE(RunUntilRequestsReceived(1));
386  ASSERT_EQ(2u, GetRequest(0).headers.size());
387  ASSERT_EQ(body.length(), GetRequest(0).data.length());
388  ASSERT_EQ('a', body[0]);
389  ASSERT_EQ('c', *body.rbegin());
390}
391
392TEST_F(WebSocketTest, RequestWebSocket) {
393  TestHttpClient client;
394  ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
395  client.Send(
396      "GET /test HTTP/1.1\r\n"
397      "Upgrade: WebSocket\r\n"
398      "Connection: SomethingElse, Upgrade\r\n"
399      "Sec-WebSocket-Version: 8\r\n"
400      "Sec-WebSocket-Key: key\r\n"
401      "\r\n");
402  ASSERT_TRUE(RunUntilRequestsReceived(1));
403}
404
405TEST_F(HttpServerTest, RequestWithTooLargeBody) {
406  class TestURLFetcherDelegate : public URLFetcherDelegate {
407   public:
408    TestURLFetcherDelegate(const base::Closure& quit_loop_func)
409        : quit_loop_func_(quit_loop_func) {}
410    virtual ~TestURLFetcherDelegate() {}
411
412    virtual void OnURLFetchComplete(const URLFetcher* source) OVERRIDE {
413      EXPECT_EQ(HTTP_INTERNAL_SERVER_ERROR, source->GetResponseCode());
414      quit_loop_func_.Run();
415    }
416
417   private:
418    base::Closure quit_loop_func_;
419  };
420
421  base::RunLoop run_loop;
422  TestURLFetcherDelegate delegate(run_loop.QuitClosure());
423
424  scoped_refptr<URLRequestContextGetter> request_context_getter(
425      new TestURLRequestContextGetter(base::MessageLoopProxy::current()));
426  scoped_ptr<URLFetcher> fetcher(
427      URLFetcher::Create(GURL(base::StringPrintf("http://127.0.0.1:%d/test",
428                                                 server_address_.port())),
429                         URLFetcher::GET,
430                         &delegate));
431  fetcher->SetRequestContext(request_context_getter.get());
432  fetcher->AddExtraRequestHeader(
433      base::StringPrintf("content-length:%d", 1 << 30));
434  fetcher->Start();
435
436  ASSERT_TRUE(RunLoopWithTimeout(&run_loop));
437  ASSERT_EQ(0u, requests_.size());
438}
439
440TEST_F(HttpServerTest, Send200) {
441  TestHttpClient client;
442  ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
443  client.Send("GET /test HTTP/1.1\r\n\r\n");
444  ASSERT_TRUE(RunUntilRequestsReceived(1));
445  server_->Send200(GetConnectionId(0), "Response!", "text/plain");
446
447  std::string response;
448  ASSERT_TRUE(client.ReadResponse(&response));
449  ASSERT_TRUE(StartsWithASCII(response, "HTTP/1.1 200 OK", true));
450  ASSERT_TRUE(EndsWith(response, "Response!", true));
451}
452
453TEST_F(HttpServerTest, SendRaw) {
454  TestHttpClient client;
455  ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
456  client.Send("GET /test HTTP/1.1\r\n\r\n");
457  ASSERT_TRUE(RunUntilRequestsReceived(1));
458  server_->SendRaw(GetConnectionId(0), "Raw Data ");
459  server_->SendRaw(GetConnectionId(0), "More Data");
460  server_->SendRaw(GetConnectionId(0), "Third Piece of Data");
461
462  const std::string expected_response("Raw Data More DataThird Piece of Data");
463  std::string response;
464  ASSERT_TRUE(client.Read(&response, expected_response.length()));
465  ASSERT_EQ(expected_response, response);
466}
467
468class MockStreamSocket : public StreamSocket {
469 public:
470  MockStreamSocket()
471      : connected_(true),
472        read_buf_(NULL),
473        read_buf_len_(0) {}
474
475  // StreamSocket
476  virtual int Connect(const CompletionCallback& callback) OVERRIDE {
477    return ERR_NOT_IMPLEMENTED;
478  }
479  virtual void Disconnect() OVERRIDE {
480    connected_ = false;
481    if (!read_callback_.is_null()) {
482      read_buf_ = NULL;
483      read_buf_len_ = 0;
484      base::ResetAndReturn(&read_callback_).Run(ERR_CONNECTION_CLOSED);
485    }
486  }
487  virtual bool IsConnected() const OVERRIDE { return connected_; }
488  virtual bool IsConnectedAndIdle() const OVERRIDE { return IsConnected(); }
489  virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE {
490    return ERR_NOT_IMPLEMENTED;
491  }
492  virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE {
493    return ERR_NOT_IMPLEMENTED;
494  }
495  virtual const BoundNetLog& NetLog() const OVERRIDE { return net_log_; }
496  virtual void SetSubresourceSpeculation() OVERRIDE {}
497  virtual void SetOmniboxSpeculation() OVERRIDE {}
498  virtual bool WasEverUsed() const OVERRIDE { return true; }
499  virtual bool UsingTCPFastOpen() const OVERRIDE { return false; }
500  virtual bool WasNpnNegotiated() const OVERRIDE { return false; }
501  virtual NextProto GetNegotiatedProtocol() const OVERRIDE {
502    return kProtoUnknown;
503  }
504  virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { return false; }
505
506  // Socket
507  virtual int Read(IOBuffer* buf, int buf_len,
508                   const CompletionCallback& callback) OVERRIDE {
509    if (!connected_) {
510      return ERR_SOCKET_NOT_CONNECTED;
511    }
512    if (pending_read_data_.empty()) {
513      read_buf_ = buf;
514      read_buf_len_ = buf_len;
515      read_callback_ = callback;
516      return ERR_IO_PENDING;
517    }
518    DCHECK_GT(buf_len, 0);
519    int read_len = std::min(static_cast<int>(pending_read_data_.size()),
520                            buf_len);
521    memcpy(buf->data(), pending_read_data_.data(), read_len);
522    pending_read_data_.erase(0, read_len);
523    return read_len;
524  }
525  virtual int Write(IOBuffer* buf, int buf_len,
526                    const CompletionCallback& callback) OVERRIDE {
527    return ERR_NOT_IMPLEMENTED;
528  }
529  virtual int SetReceiveBufferSize(int32 size) OVERRIDE {
530    return ERR_NOT_IMPLEMENTED;
531  }
532  virtual int SetSendBufferSize(int32 size) OVERRIDE {
533    return ERR_NOT_IMPLEMENTED;
534  }
535
536  void DidRead(const char* data, int data_len) {
537    if (!read_buf_.get()) {
538      pending_read_data_.append(data, data_len);
539      return;
540    }
541    int read_len = std::min(data_len, read_buf_len_);
542    memcpy(read_buf_->data(), data, read_len);
543    pending_read_data_.assign(data + read_len, data_len - read_len);
544    read_buf_ = NULL;
545    read_buf_len_ = 0;
546    base::ResetAndReturn(&read_callback_).Run(read_len);
547  }
548
549 private:
550  virtual ~MockStreamSocket() {}
551
552  bool connected_;
553  scoped_refptr<IOBuffer> read_buf_;
554  int read_buf_len_;
555  CompletionCallback read_callback_;
556  std::string pending_read_data_;
557  BoundNetLog net_log_;
558
559  DISALLOW_COPY_AND_ASSIGN(MockStreamSocket);
560};
561
562TEST_F(HttpServerTest, RequestWithBodySplitAcrossPackets) {
563  MockStreamSocket* socket = new MockStreamSocket();
564  HandleAcceptResult(make_scoped_ptr<StreamSocket>(socket));
565  std::string body("body");
566  std::string request_text = base::StringPrintf(
567      "GET /test HTTP/1.1\r\n"
568      "SomeHeader: 1\r\n"
569      "Content-Length: %" PRIuS "\r\n\r\n%s",
570      body.length(),
571      body.c_str());
572  socket->DidRead(request_text.c_str(), request_text.length() - 2);
573  ASSERT_EQ(0u, requests_.size());
574  socket->DidRead(request_text.c_str() + request_text.length() - 2, 2);
575  ASSERT_EQ(1u, requests_.size());
576  ASSERT_EQ(body, GetRequest(0).data);
577}
578
579TEST_F(HttpServerTest, MultipleRequestsOnSameConnection) {
580  // The idea behind this test is that requests with or without bodies should
581  // not break parsing of the next request.
582  TestHttpClient client;
583  ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
584  std::string body = "body";
585  client.Send(base::StringPrintf(
586      "GET /test HTTP/1.1\r\n"
587      "Content-Length: %" PRIuS "\r\n\r\n%s",
588      body.length(),
589      body.c_str()));
590  ASSERT_TRUE(RunUntilRequestsReceived(1));
591  ASSERT_EQ(body, GetRequest(0).data);
592
593  int client_connection_id = GetConnectionId(0);
594  server_->Send200(client_connection_id, "Content for /test", "text/plain");
595  std::string response1;
596  ASSERT_TRUE(client.ReadResponse(&response1));
597  ASSERT_TRUE(StartsWithASCII(response1, "HTTP/1.1 200 OK", true));
598  ASSERT_TRUE(EndsWith(response1, "Content for /test", true));
599
600  client.Send("GET /test2 HTTP/1.1\r\n\r\n");
601  ASSERT_TRUE(RunUntilRequestsReceived(2));
602  ASSERT_EQ("/test2", GetRequest(1).path);
603
604  ASSERT_EQ(client_connection_id, GetConnectionId(1));
605  server_->Send404(client_connection_id);
606  std::string response2;
607  ASSERT_TRUE(client.ReadResponse(&response2));
608  ASSERT_TRUE(StartsWithASCII(response2, "HTTP/1.1 404 Not Found", true));
609
610  client.Send("GET /test3 HTTP/1.1\r\n\r\n");
611  ASSERT_TRUE(RunUntilRequestsReceived(3));
612  ASSERT_EQ("/test3", GetRequest(2).path);
613
614  ASSERT_EQ(client_connection_id, GetConnectionId(2));
615  server_->Send200(client_connection_id, "Content for /test3", "text/plain");
616  std::string response3;
617  ASSERT_TRUE(client.ReadResponse(&response3));
618  ASSERT_TRUE(StartsWithASCII(response3, "HTTP/1.1 200 OK", true));
619  ASSERT_TRUE(EndsWith(response3, "Content for /test3", true));
620}
621
622class CloseOnConnectHttpServerTest : public HttpServerTest {
623 public:
624  virtual void OnConnect(int connection_id) OVERRIDE {
625    connection_ids_.push_back(connection_id);
626    server_->Close(connection_id);
627  }
628
629 protected:
630  std::vector<int> connection_ids_;
631};
632
633TEST_F(CloseOnConnectHttpServerTest, ServerImmediatelyClosesConnection) {
634  TestHttpClient client;
635  ASSERT_EQ(OK, client.ConnectAndWait(server_address_));
636  client.Send("GET / HTTP/1.1\r\n\r\n");
637  ASSERT_FALSE(RunUntilRequestsReceived(1));
638  ASSERT_EQ(1ul, connection_ids_.size());
639  ASSERT_EQ(0ul, requests_.size());
640}
641
642}  // namespace
643
644}  // namespace net
645