1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/base/mock_file_stream.h"
6
7#include "base/bind.h"
8#include "base/message_loop/message_loop.h"
9
10namespace net {
11
12namespace testing {
13
14MockFileStream::MockFileStream(
15    const scoped_refptr<base::TaskRunner>& task_runner)
16    : net::FileStream(task_runner),
17      forced_error_(net::OK),
18      async_error_(false),
19      throttled_(false),
20      weak_factory_(this) {
21}
22
23MockFileStream::MockFileStream(
24    base::File file,
25    const scoped_refptr<base::TaskRunner>& task_runner)
26    : net::FileStream(file.Pass(), task_runner),
27      forced_error_(net::OK),
28      async_error_(false),
29      throttled_(false),
30      weak_factory_(this) {
31}
32
33MockFileStream::~MockFileStream() {
34}
35
36int MockFileStream::Seek(base::File::Whence whence, int64 offset,
37                         const Int64CompletionCallback& callback) {
38  Int64CompletionCallback wrapped_callback =
39      base::Bind(&MockFileStream::DoCallback64,
40                 weak_factory_.GetWeakPtr(), callback);
41  if (forced_error_ == net::OK)
42    return FileStream::Seek(whence, offset, wrapped_callback);
43  return ErrorCallback64(wrapped_callback);
44}
45
46int MockFileStream::Read(IOBuffer* buf,
47                         int buf_len,
48                         const CompletionCallback& callback) {
49  CompletionCallback wrapped_callback = base::Bind(&MockFileStream::DoCallback,
50                                                   weak_factory_.GetWeakPtr(),
51                                                   callback);
52  if (forced_error_ == net::OK)
53    return FileStream::Read(buf, buf_len, wrapped_callback);
54  return ErrorCallback(wrapped_callback);
55}
56
57int MockFileStream::Write(IOBuffer* buf,
58                          int buf_len,
59                          const CompletionCallback& callback) {
60  CompletionCallback wrapped_callback = base::Bind(&MockFileStream::DoCallback,
61                                                   weak_factory_.GetWeakPtr(),
62                                                   callback);
63  if (forced_error_ == net::OK)
64    return FileStream::Write(buf, buf_len, wrapped_callback);
65  return ErrorCallback(wrapped_callback);
66}
67
68int MockFileStream::Flush(const CompletionCallback& callback) {
69  CompletionCallback wrapped_callback = base::Bind(&MockFileStream::DoCallback,
70                                                   weak_factory_.GetWeakPtr(),
71                                                   callback);
72  if (forced_error_ == net::OK)
73    return FileStream::Flush(wrapped_callback);
74  return ErrorCallback(wrapped_callback);
75}
76
77void MockFileStream::ThrottleCallbacks() {
78  CHECK(!throttled_);
79  throttled_ = true;
80}
81
82void MockFileStream::ReleaseCallbacks() {
83  CHECK(throttled_);
84  throttled_ = false;
85
86  if (!throttled_task_.is_null()) {
87    base::Closure throttled_task = throttled_task_;
88    throttled_task_.Reset();
89    base::MessageLoop::current()->PostTask(FROM_HERE, throttled_task);
90  }
91}
92
93void MockFileStream::DoCallback(const CompletionCallback& callback,
94                                int result) {
95  if (!throttled_) {
96    callback.Run(result);
97    return;
98  }
99  CHECK(throttled_task_.is_null());
100  throttled_task_ = base::Bind(callback, result);
101}
102
103void MockFileStream::DoCallback64(const Int64CompletionCallback& callback,
104                                  int64 result) {
105  if (!throttled_) {
106    callback.Run(result);
107    return;
108  }
109  CHECK(throttled_task_.is_null());
110  throttled_task_ = base::Bind(callback, result);
111}
112
113int MockFileStream::ErrorCallback(const CompletionCallback& callback) {
114  CHECK_NE(net::OK, forced_error_);
115  if (async_error_) {
116    base::MessageLoop::current()->PostTask(
117        FROM_HERE, base::Bind(callback, forced_error_));
118    clear_forced_error();
119    return net::ERR_IO_PENDING;
120  }
121  int ret = forced_error_;
122  clear_forced_error();
123  return ret;
124}
125
126int64 MockFileStream::ErrorCallback64(const Int64CompletionCallback& callback) {
127  CHECK_NE(net::OK, forced_error_);
128  if (async_error_) {
129    base::MessageLoop::current()->PostTask(
130        FROM_HERE, base::Bind(callback, forced_error_));
131    clear_forced_error();
132    return net::ERR_IO_PENDING;
133  }
134  int64 ret = forced_error_;
135  clear_forced_error();
136  return ret;
137}
138
139}  // namespace testing
140
141}  // namespace net
142