1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/protocol/fake_stream_socket.h"
6
7#include "base/bind.h"
8#include "base/single_thread_task_runner.h"
9#include "base/thread_task_runner_handle.h"
10#include "net/base/address_list.h"
11#include "net/base/io_buffer.h"
12#include "net/base/net_errors.h"
13#include "net/base/net_util.h"
14#include "testing/gtest/include/gtest/gtest.h"
15
16namespace remoting {
17namespace protocol {
18
19FakeStreamSocket::FakeStreamSocket()
20    : async_write_(false),
21      write_pending_(false),
22      write_limit_(0),
23      next_write_error_(net::OK),
24      next_read_error_(net::OK),
25      read_buffer_size_(0),
26      input_pos_(0),
27      task_runner_(base::ThreadTaskRunnerHandle::Get()),
28      weak_factory_(this) {
29}
30
31FakeStreamSocket::~FakeStreamSocket() {
32  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
33}
34
35void FakeStreamSocket::AppendInputData(const std::string& data) {
36  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
37  input_data_.insert(input_data_.end(), data.begin(), data.end());
38  // Complete pending read if any.
39  if (!read_callback_.is_null()) {
40    int result = std::min(read_buffer_size_,
41                          static_cast<int>(input_data_.size() - input_pos_));
42    EXPECT_GT(result, 0);
43    memcpy(read_buffer_->data(),
44           &(*input_data_.begin()) + input_pos_, result);
45    input_pos_ += result;
46    read_buffer_ = NULL;
47
48    net::CompletionCallback callback = read_callback_;
49    read_callback_.Reset();
50    callback.Run(result);
51  }
52}
53
54void FakeStreamSocket::PairWith(FakeStreamSocket* peer_socket) {
55  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
56  peer_socket_ = peer_socket->GetWeakPtr();
57  peer_socket->peer_socket_ = GetWeakPtr();
58}
59
60base::WeakPtr<FakeStreamSocket> FakeStreamSocket::GetWeakPtr() {
61  return weak_factory_.GetWeakPtr();
62}
63
64int FakeStreamSocket::Read(net::IOBuffer* buf, int buf_len,
65                           const net::CompletionCallback& callback) {
66  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
67
68  if (next_read_error_ != net::OK) {
69    int r = next_read_error_;
70    next_read_error_ = net::OK;
71    return r;
72  }
73
74  if (input_pos_ < static_cast<int>(input_data_.size())) {
75    int result = std::min(buf_len,
76                          static_cast<int>(input_data_.size()) - input_pos_);
77    memcpy(buf->data(), &(*input_data_.begin()) + input_pos_, result);
78    input_pos_ += result;
79    return result;
80  } else {
81    read_buffer_ = buf;
82    read_buffer_size_ = buf_len;
83    read_callback_ = callback;
84    return net::ERR_IO_PENDING;
85  }
86}
87
88int FakeStreamSocket::Write(net::IOBuffer* buf, int buf_len,
89                      const net::CompletionCallback& callback) {
90  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
91  EXPECT_FALSE(write_pending_);
92
93  if (write_limit_ > 0)
94    buf_len = std::min(write_limit_, buf_len);
95
96  if (async_write_) {
97    task_runner_->PostTask(FROM_HERE, base::Bind(
98        &FakeStreamSocket::DoAsyncWrite, weak_factory_.GetWeakPtr(),
99        scoped_refptr<net::IOBuffer>(buf), buf_len, callback));
100    write_pending_ = true;
101    return net::ERR_IO_PENDING;
102  } else {
103    if (next_write_error_ != net::OK) {
104      int r = next_write_error_;
105      next_write_error_ = net::OK;
106      return r;
107    }
108
109    DoWrite(buf, buf_len);
110    return buf_len;
111  }
112}
113
114void FakeStreamSocket::DoAsyncWrite(scoped_refptr<net::IOBuffer> buf,
115                                    int buf_len,
116                                    const net::CompletionCallback& callback) {
117  write_pending_ = false;
118
119  if (next_write_error_ != net::OK) {
120    int r = next_write_error_;
121    next_write_error_ = net::OK;
122    callback.Run(r);
123    return;
124  }
125
126  DoWrite(buf.get(), buf_len);
127  callback.Run(buf_len);
128}
129
130void FakeStreamSocket::DoWrite(net::IOBuffer* buf, int buf_len) {
131  written_data_.insert(written_data_.end(),
132                       buf->data(), buf->data() + buf_len);
133
134  if (peer_socket_.get()) {
135    task_runner_->PostTask(
136        FROM_HERE,
137        base::Bind(&FakeStreamSocket::AppendInputData,
138                   peer_socket_,
139                   std::string(buf->data(), buf->data() + buf_len)));
140  }
141}
142
143int FakeStreamSocket::SetReceiveBufferSize(int32 size) {
144  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
145  NOTIMPLEMENTED();
146  return net::ERR_NOT_IMPLEMENTED;
147}
148
149int FakeStreamSocket::SetSendBufferSize(int32 size) {
150  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
151  NOTIMPLEMENTED();
152  return net::ERR_NOT_IMPLEMENTED;
153}
154
155int FakeStreamSocket::Connect(const net::CompletionCallback& callback) {
156  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
157  return net::OK;
158}
159
160void FakeStreamSocket::Disconnect() {
161  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
162  peer_socket_.reset();
163}
164
165bool FakeStreamSocket::IsConnected() const {
166  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
167  return true;
168}
169
170bool FakeStreamSocket::IsConnectedAndIdle() const {
171  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
172  NOTIMPLEMENTED();
173  return false;
174}
175
176int FakeStreamSocket::GetPeerAddress(net::IPEndPoint* address) const {
177  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
178  net::IPAddressNumber ip(net::kIPv4AddressSize);
179  *address = net::IPEndPoint(ip, 0);
180  return net::OK;
181}
182
183int FakeStreamSocket::GetLocalAddress(net::IPEndPoint* address) const {
184  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
185  NOTIMPLEMENTED();
186  return net::ERR_NOT_IMPLEMENTED;
187}
188
189const net::BoundNetLog& FakeStreamSocket::NetLog() const {
190  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
191  return net_log_;
192}
193
194void FakeStreamSocket::SetSubresourceSpeculation() {
195  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
196  NOTIMPLEMENTED();
197}
198
199void FakeStreamSocket::SetOmniboxSpeculation() {
200  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
201  NOTIMPLEMENTED();
202}
203
204bool FakeStreamSocket::WasEverUsed() const {
205  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
206  NOTIMPLEMENTED();
207  return true;
208}
209
210bool FakeStreamSocket::UsingTCPFastOpen() const {
211  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
212  NOTIMPLEMENTED();
213  return true;
214}
215
216bool FakeStreamSocket::WasNpnNegotiated() const {
217  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
218  return false;
219}
220
221net::NextProto FakeStreamSocket::GetNegotiatedProtocol() const {
222  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
223  NOTIMPLEMENTED();
224  return net::kProtoUnknown;
225}
226
227bool FakeStreamSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
228  EXPECT_TRUE(task_runner_->BelongsToCurrentThread());
229  return false;
230}
231
232FakeStreamChannelFactory::FakeStreamChannelFactory()
233    : task_runner_(base::ThreadTaskRunnerHandle::Get()),
234      asynchronous_create_(false),
235      fail_create_(false),
236      weak_factory_(this) {
237}
238
239FakeStreamChannelFactory::~FakeStreamChannelFactory() {}
240
241FakeStreamSocket* FakeStreamChannelFactory::GetFakeChannel(
242    const std::string& name) {
243  return channels_[name].get();
244}
245
246void FakeStreamChannelFactory::CreateChannel(
247    const std::string& name,
248    const ChannelCreatedCallback& callback) {
249  scoped_ptr<FakeStreamSocket> channel(new FakeStreamSocket());
250  channels_[name] = channel->GetWeakPtr();
251
252  if (fail_create_)
253    channel.reset();
254
255  if (asynchronous_create_) {
256    task_runner_->PostTask(FROM_HERE, base::Bind(
257        &FakeStreamChannelFactory::NotifyChannelCreated,
258        weak_factory_.GetWeakPtr(), base::Passed(&channel), name, callback));
259  } else {
260    NotifyChannelCreated(channel.Pass(), name, callback);
261  }
262}
263
264void FakeStreamChannelFactory::NotifyChannelCreated(
265    scoped_ptr<FakeStreamSocket> owned_channel,
266    const std::string& name,
267    const ChannelCreatedCallback& callback) {
268  if (channels_.find(name) != channels_.end())
269    callback.Run(owned_channel.PassAs<net::StreamSocket>());
270}
271
272void FakeStreamChannelFactory::CancelChannelCreation(const std::string& name) {
273  channels_.erase(name);
274}
275
276}  // namespace protocol
277}  // namespace remoting
278