1// Copyright (c) 2010 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/http/http_response_body_drainer.h"
6
7#include <cstring>
8
9#include "base/compiler_specific.h"
10#include "base/message_loop.h"
11#include "base/task.h"
12#include "net/base/io_buffer.h"
13#include "net/base/net_errors.h"
14#include "net/base/ssl_config_service_defaults.h"
15#include "net/base/test_completion_callback.h"
16#include "net/http/http_network_session.h"
17#include "net/http/http_stream.h"
18#include "net/proxy/proxy_service.h"
19#include "testing/gtest/include/gtest/gtest.h"
20
21namespace net {
22
23namespace {
24
25const int kMagicChunkSize = 1024;
26COMPILE_ASSERT(
27    (HttpResponseBodyDrainer::kDrainBodyBufferSize % kMagicChunkSize) == 0,
28    chunk_size_needs_to_divide_evenly_into_buffer_size);
29
30class CloseResultWaiter {
31 public:
32  CloseResultWaiter()
33      : result_(false),
34        have_result_(false),
35        waiting_for_result_(false) {}
36
37  int WaitForResult() {
38    DCHECK(!waiting_for_result_);
39    while (!have_result_) {
40      waiting_for_result_ = true;
41      MessageLoop::current()->Run();
42      waiting_for_result_ = false;
43    }
44    return result_;
45  }
46
47  void set_result(bool result) {
48    result_ = result;
49    have_result_ = true;
50    if (waiting_for_result_)
51      MessageLoop::current()->Quit();
52  }
53
54 private:
55  int result_;
56  bool have_result_;
57  bool waiting_for_result_;
58
59  DISALLOW_COPY_AND_ASSIGN(CloseResultWaiter);
60};
61
62class MockHttpStream : public HttpStream {
63 public:
64  MockHttpStream(CloseResultWaiter* result_waiter)
65      : result_waiter_(result_waiter),
66        user_callback_(NULL),
67        closed_(false),
68        stall_reads_forever_(false),
69        num_chunks_(0),
70        is_complete_(false),
71        ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)) {}
72  virtual ~MockHttpStream() {}
73
74  // HttpStream implementation:
75  virtual int InitializeStream(const HttpRequestInfo* request_info,
76                               const BoundNetLog& net_log,
77                               CompletionCallback* callback) OVERRIDE {
78    return ERR_UNEXPECTED;
79  }
80  virtual int SendRequest(const HttpRequestHeaders& request_headers,
81                          UploadDataStream* request_body,
82                          HttpResponseInfo* response,
83                          CompletionCallback* callback) OVERRIDE {
84    return ERR_UNEXPECTED;
85  }
86  virtual uint64 GetUploadProgress() const OVERRIDE { return 0; }
87  virtual int ReadResponseHeaders(CompletionCallback* callback) OVERRIDE {
88    return ERR_UNEXPECTED;
89  }
90  virtual const HttpResponseInfo* GetResponseInfo() const OVERRIDE {
91    return NULL;
92  }
93
94  virtual bool CanFindEndOfResponse() const OVERRIDE { return true; }
95  virtual bool IsMoreDataBuffered() const OVERRIDE { return false; }
96  virtual bool IsConnectionReused() const OVERRIDE { return false; }
97  virtual void SetConnectionReused() OVERRIDE {}
98  virtual bool IsConnectionReusable() const OVERRIDE { return false; }
99  virtual void GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {}
100  virtual void GetSSLCertRequestInfo(
101      SSLCertRequestInfo* cert_request_info) OVERRIDE {}
102
103  // Mocked API
104  virtual int ReadResponseBody(IOBuffer* buf, int buf_len,
105                               CompletionCallback* callback) OVERRIDE;
106  virtual void Close(bool not_reusable) OVERRIDE {
107    DCHECK(!closed_);
108    closed_ = true;
109    result_waiter_->set_result(not_reusable);
110  }
111
112  virtual HttpStream* RenewStreamForAuth() OVERRIDE {
113    return NULL;
114  }
115
116  virtual bool IsResponseBodyComplete() const OVERRIDE { return is_complete_; }
117
118  virtual bool IsSpdyHttpStream() const OVERRIDE { return false; }
119
120  // Methods to tweak/observer mock behavior:
121  void StallReadsForever() { stall_reads_forever_ = true; }
122
123  void set_num_chunks(int num_chunks) { num_chunks_ = num_chunks; }
124
125 private:
126  void CompleteRead();
127
128  bool closed() const { return closed_; }
129
130  CloseResultWaiter* const result_waiter_;
131  scoped_refptr<IOBuffer> user_buf_;
132  CompletionCallback* user_callback_;
133  bool closed_;
134  bool stall_reads_forever_;
135  int num_chunks_;
136  bool is_complete_;
137  ScopedRunnableMethodFactory<MockHttpStream> method_factory_;
138};
139
140int MockHttpStream::ReadResponseBody(
141    IOBuffer* buf, int buf_len, CompletionCallback* callback) {
142  DCHECK(callback);
143  DCHECK(!user_callback_);
144  DCHECK(buf);
145
146  if (stall_reads_forever_)
147    return ERR_IO_PENDING;
148
149  if (num_chunks_ == 0)
150    return ERR_UNEXPECTED;
151
152  if (buf_len > kMagicChunkSize && num_chunks_ > 1) {
153    user_buf_ = buf;
154    user_callback_ = callback;
155    MessageLoop::current()->PostTask(
156        FROM_HERE,
157        method_factory_.NewRunnableMethod(&MockHttpStream::CompleteRead));
158    return ERR_IO_PENDING;
159  }
160
161  num_chunks_--;
162  if (!num_chunks_)
163    is_complete_ = true;
164
165  return buf_len;
166}
167
168void MockHttpStream::CompleteRead() {
169  CompletionCallback* callback = user_callback_;
170  std::memset(user_buf_->data(), 1, kMagicChunkSize);
171  user_buf_ = NULL;
172  user_callback_ = NULL;
173  num_chunks_--;
174  if (!num_chunks_)
175    is_complete_ = true;
176  callback->Run(kMagicChunkSize);
177}
178
179class HttpResponseBodyDrainerTest : public testing::Test {
180 protected:
181  HttpResponseBodyDrainerTest()
182      : proxy_service_(ProxyService::CreateDirect()),
183        ssl_config_service_(new SSLConfigServiceDefaults),
184        session_(CreateNetworkSession()),
185        mock_stream_(new MockHttpStream(&result_waiter_)),
186        drainer_(new HttpResponseBodyDrainer(mock_stream_)) {}
187
188  ~HttpResponseBodyDrainerTest() {}
189
190  HttpNetworkSession* CreateNetworkSession() const {
191    HttpNetworkSession::Params params;
192    params.proxy_service = proxy_service_;
193    params.ssl_config_service = ssl_config_service_;
194    return new HttpNetworkSession(params);
195  }
196
197  scoped_refptr<ProxyService> proxy_service_;
198  scoped_refptr<SSLConfigService> ssl_config_service_;
199  const scoped_refptr<HttpNetworkSession> session_;
200  CloseResultWaiter result_waiter_;
201  MockHttpStream* const mock_stream_;  // Owned by |drainer_|.
202  HttpResponseBodyDrainer* const drainer_;  // Deletes itself.
203};
204
205TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncOK) {
206  mock_stream_->set_num_chunks(1);
207  drainer_->Start(session_);
208  EXPECT_FALSE(result_waiter_.WaitForResult());
209}
210
211TEST_F(HttpResponseBodyDrainerTest, DrainBodyAsyncOK) {
212  mock_stream_->set_num_chunks(3);
213  drainer_->Start(session_);
214  EXPECT_FALSE(result_waiter_.WaitForResult());
215}
216
217TEST_F(HttpResponseBodyDrainerTest, DrainBodySizeEqualsDrainBuffer) {
218  mock_stream_->set_num_chunks(
219      HttpResponseBodyDrainer::kDrainBodyBufferSize / kMagicChunkSize);
220  drainer_->Start(session_);
221  EXPECT_FALSE(result_waiter_.WaitForResult());
222}
223
224TEST_F(HttpResponseBodyDrainerTest, DrainBodyTimeOut) {
225  mock_stream_->set_num_chunks(2);
226  mock_stream_->StallReadsForever();
227  drainer_->Start(session_);
228  EXPECT_TRUE(result_waiter_.WaitForResult());
229}
230
231TEST_F(HttpResponseBodyDrainerTest, CancelledBySession) {
232  mock_stream_->set_num_chunks(2);
233  mock_stream_->StallReadsForever();
234  drainer_->Start(session_);
235  // HttpNetworkSession should delete |drainer_|.
236}
237
238TEST_F(HttpResponseBodyDrainerTest, DrainBodyTooLarge) {
239  TestCompletionCallback callback;
240  int too_many_chunks =
241      HttpResponseBodyDrainer::kDrainBodyBufferSize / kMagicChunkSize;
242  too_many_chunks += 1;  // Now it's too large.
243
244  mock_stream_->set_num_chunks(too_many_chunks);
245  drainer_->Start(session_);
246  EXPECT_TRUE(result_waiter_.WaitForResult());
247}
248
249}  // namespace
250
251}  // namespace net
252