1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/http/http_response_body_drainer.h"
6
7#include <cstring>
8
9#include "base/bind.h"
10#include "base/compiler_specific.h"
11#include "base/memory/weak_ptr.h"
12#include "base/message_loop/message_loop.h"
13#include "net/base/io_buffer.h"
14#include "net/base/net_errors.h"
15#include "net/base/test_completion_callback.h"
16#include "net/http/http_network_session.h"
17#include "net/http/http_server_properties_impl.h"
18#include "net/http/http_stream.h"
19#include "net/http/transport_security_state.h"
20#include "net/proxy/proxy_service.h"
21#include "net/ssl/ssl_config_service_defaults.h"
22#include "testing/gtest/include/gtest/gtest.h"
23
24namespace net {
25
26namespace {
27
28const int kMagicChunkSize = 1024;
29COMPILE_ASSERT(
30    (HttpResponseBodyDrainer::kDrainBodyBufferSize % kMagicChunkSize) == 0,
31    chunk_size_needs_to_divide_evenly_into_buffer_size);
32
33class CloseResultWaiter {
34 public:
35  CloseResultWaiter()
36      : result_(false),
37        have_result_(false),
38        waiting_for_result_(false) {}
39
40  int WaitForResult() {
41    CHECK(!waiting_for_result_);
42    while (!have_result_) {
43      waiting_for_result_ = true;
44      base::MessageLoop::current()->Run();
45      waiting_for_result_ = false;
46    }
47    return result_;
48  }
49
50  void set_result(bool result) {
51    result_ = result;
52    have_result_ = true;
53    if (waiting_for_result_)
54      base::MessageLoop::current()->Quit();
55  }
56
57 private:
58  int result_;
59  bool have_result_;
60  bool waiting_for_result_;
61
62  DISALLOW_COPY_AND_ASSIGN(CloseResultWaiter);
63};
64
65class MockHttpStream : public HttpStream {
66 public:
67  MockHttpStream(CloseResultWaiter* result_waiter)
68      : result_waiter_(result_waiter),
69        buf_len_(0),
70        closed_(false),
71        stall_reads_forever_(false),
72        num_chunks_(0),
73        is_sync_(false),
74        is_last_chunk_zero_size_(false),
75        is_complete_(false),
76        weak_factory_(this) {}
77  virtual ~MockHttpStream() {}
78
79  // HttpStream implementation.
80  virtual int InitializeStream(const HttpRequestInfo* request_info,
81                               RequestPriority priority,
82                               const BoundNetLog& net_log,
83                               const CompletionCallback& callback) OVERRIDE {
84    return ERR_UNEXPECTED;
85  }
86  virtual int SendRequest(const HttpRequestHeaders& request_headers,
87                          HttpResponseInfo* response,
88                          const CompletionCallback& callback) OVERRIDE {
89    return ERR_UNEXPECTED;
90  }
91  virtual UploadProgress GetUploadProgress() const OVERRIDE {
92    return UploadProgress();
93  }
94  virtual int ReadResponseHeaders(const CompletionCallback& callback) OVERRIDE {
95    return ERR_UNEXPECTED;
96  }
97
98  virtual bool CanFindEndOfResponse() const OVERRIDE { return true; }
99  virtual bool IsConnectionReused() const OVERRIDE { return false; }
100  virtual void SetConnectionReused() OVERRIDE {}
101  virtual bool IsConnectionReusable() const OVERRIDE { return false; }
102  virtual int64 GetTotalReceivedBytes() const OVERRIDE { return 0; }
103  virtual void GetSSLInfo(SSLInfo* ssl_info) OVERRIDE {}
104  virtual void GetSSLCertRequestInfo(
105      SSLCertRequestInfo* cert_request_info) OVERRIDE {}
106
107  // Mocked API
108  virtual int ReadResponseBody(IOBuffer* buf, int buf_len,
109                               const CompletionCallback& callback) OVERRIDE;
110  virtual void Close(bool not_reusable) OVERRIDE {
111    CHECK(!closed_);
112    closed_ = true;
113    result_waiter_->set_result(not_reusable);
114  }
115
116  virtual HttpStream* RenewStreamForAuth() OVERRIDE {
117    return NULL;
118  }
119
120  virtual bool IsResponseBodyComplete() const OVERRIDE { return is_complete_; }
121
122  virtual bool IsSpdyHttpStream() const OVERRIDE { return false; }
123
124  virtual bool GetLoadTimingInfo(
125      LoadTimingInfo* load_timing_info) const OVERRIDE { return false; }
126
127  virtual void Drain(HttpNetworkSession*) OVERRIDE {}
128
129  virtual void SetPriority(RequestPriority priority) OVERRIDE {}
130
131  // Methods to tweak/observer mock behavior:
132  void set_stall_reads_forever() { stall_reads_forever_ = true; }
133
134  void set_num_chunks(int num_chunks) { num_chunks_ = num_chunks; }
135
136  void set_sync() { is_sync_ = true; }
137
138  void set_is_last_chunk_zero_size() { is_last_chunk_zero_size_ = true; }
139
140 private:
141  int ReadResponseBodyImpl(IOBuffer* buf, int buf_len);
142  void CompleteRead();
143
144  bool closed() const { return closed_; }
145
146  CloseResultWaiter* const result_waiter_;
147  scoped_refptr<IOBuffer> user_buf_;
148  CompletionCallback callback_;
149  int buf_len_;
150  bool closed_;
151  bool stall_reads_forever_;
152  int num_chunks_;
153  bool is_sync_;
154  bool is_last_chunk_zero_size_;
155  bool is_complete_;
156  base::WeakPtrFactory<MockHttpStream> weak_factory_;
157};
158
159int MockHttpStream::ReadResponseBody(IOBuffer* buf,
160                                     int buf_len,
161                                     const CompletionCallback& callback) {
162  CHECK(!callback.is_null());
163  CHECK(callback_.is_null());
164  CHECK(buf);
165
166  if (stall_reads_forever_)
167    return ERR_IO_PENDING;
168
169  if (is_complete_)
170    return ERR_UNEXPECTED;
171
172  if (!is_sync_) {
173    user_buf_ = buf;
174    buf_len_ = buf_len;
175    callback_ = callback;
176    base::MessageLoop::current()->PostTask(
177        FROM_HERE,
178        base::Bind(&MockHttpStream::CompleteRead, weak_factory_.GetWeakPtr()));
179    return ERR_IO_PENDING;
180  } else {
181    return ReadResponseBodyImpl(buf, buf_len);
182  }
183}
184
185int MockHttpStream::ReadResponseBodyImpl(IOBuffer* buf, int buf_len) {
186  if (is_last_chunk_zero_size_ && num_chunks_ == 1) {
187    buf_len = 0;
188  } else {
189    if (buf_len > kMagicChunkSize)
190      buf_len = kMagicChunkSize;
191    std::memset(buf->data(), 1, buf_len);
192  }
193  num_chunks_--;
194  if (!num_chunks_)
195    is_complete_ = true;
196
197  return buf_len;
198}
199
200void MockHttpStream::CompleteRead() {
201  int result = ReadResponseBodyImpl(user_buf_.get(), buf_len_);
202  user_buf_ = NULL;
203  CompletionCallback callback = callback_;
204  callback_.Reset();
205  callback.Run(result);
206}
207
208class HttpResponseBodyDrainerTest : public testing::Test {
209 protected:
210  HttpResponseBodyDrainerTest()
211      : proxy_service_(ProxyService::CreateDirect()),
212        ssl_config_service_(new SSLConfigServiceDefaults),
213        http_server_properties_(new HttpServerPropertiesImpl()),
214        transport_security_state_(new TransportSecurityState()),
215        session_(CreateNetworkSession()),
216        mock_stream_(new MockHttpStream(&result_waiter_)),
217        drainer_(new HttpResponseBodyDrainer(mock_stream_)) {}
218
219  virtual ~HttpResponseBodyDrainerTest() {}
220
221  HttpNetworkSession* CreateNetworkSession() const {
222    HttpNetworkSession::Params params;
223    params.proxy_service = proxy_service_.get();
224    params.ssl_config_service = ssl_config_service_.get();
225    params.http_server_properties = http_server_properties_->GetWeakPtr();
226    params.transport_security_state = transport_security_state_.get();
227    return new HttpNetworkSession(params);
228  }
229
230  scoped_ptr<ProxyService> proxy_service_;
231  scoped_refptr<SSLConfigService> ssl_config_service_;
232  scoped_ptr<HttpServerPropertiesImpl> http_server_properties_;
233  scoped_ptr<TransportSecurityState> transport_security_state_;
234  const scoped_refptr<HttpNetworkSession> session_;
235  CloseResultWaiter result_waiter_;
236  MockHttpStream* const mock_stream_;  // Owned by |drainer_|.
237  HttpResponseBodyDrainer* const drainer_;  // Deletes itself.
238};
239
240TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncSingleOK) {
241  mock_stream_->set_num_chunks(1);
242  mock_stream_->set_sync();
243  drainer_->Start(session_.get());
244  EXPECT_FALSE(result_waiter_.WaitForResult());
245}
246
247TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncOK) {
248  mock_stream_->set_num_chunks(3);
249  mock_stream_->set_sync();
250  drainer_->Start(session_.get());
251  EXPECT_FALSE(result_waiter_.WaitForResult());
252}
253
254TEST_F(HttpResponseBodyDrainerTest, DrainBodyAsyncOK) {
255  mock_stream_->set_num_chunks(3);
256  drainer_->Start(session_.get());
257  EXPECT_FALSE(result_waiter_.WaitForResult());
258}
259
260// Test the case when the final chunk is 0 bytes. This can happen when
261// the final 0-byte chunk of a chunk-encoded http response is read in a last
262// call to ReadResponseBody, after all data were returned from HttpStream.
263TEST_F(HttpResponseBodyDrainerTest, DrainBodyAsyncEmptyChunk) {
264  mock_stream_->set_num_chunks(4);
265  mock_stream_->set_is_last_chunk_zero_size();
266  drainer_->Start(session_.get());
267  EXPECT_FALSE(result_waiter_.WaitForResult());
268}
269
270TEST_F(HttpResponseBodyDrainerTest, DrainBodySyncEmptyChunk) {
271  mock_stream_->set_num_chunks(4);
272  mock_stream_->set_sync();
273  mock_stream_->set_is_last_chunk_zero_size();
274  drainer_->Start(session_.get());
275  EXPECT_FALSE(result_waiter_.WaitForResult());
276}
277
278TEST_F(HttpResponseBodyDrainerTest, DrainBodySizeEqualsDrainBuffer) {
279  mock_stream_->set_num_chunks(
280      HttpResponseBodyDrainer::kDrainBodyBufferSize / kMagicChunkSize);
281  drainer_->Start(session_.get());
282  EXPECT_FALSE(result_waiter_.WaitForResult());
283}
284
285TEST_F(HttpResponseBodyDrainerTest, DrainBodyTimeOut) {
286  mock_stream_->set_num_chunks(2);
287  mock_stream_->set_stall_reads_forever();
288  drainer_->Start(session_.get());
289  EXPECT_TRUE(result_waiter_.WaitForResult());
290}
291
292TEST_F(HttpResponseBodyDrainerTest, CancelledBySession) {
293  mock_stream_->set_num_chunks(2);
294  mock_stream_->set_stall_reads_forever();
295  drainer_->Start(session_.get());
296  // HttpNetworkSession should delete |drainer_|.
297}
298
299TEST_F(HttpResponseBodyDrainerTest, DrainBodyTooLarge) {
300  int too_many_chunks =
301      HttpResponseBodyDrainer::kDrainBodyBufferSize / kMagicChunkSize;
302  too_many_chunks += 1;  // Now it's too large.
303
304  mock_stream_->set_num_chunks(too_many_chunks);
305  drainer_->Start(session_.get());
306  EXPECT_TRUE(result_waiter_.WaitForResult());
307}
308
309}  // namespace
310
311}  // namespace net
312