1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/protocol/connection_tester.h"
6
7#include "base/bind.h"
8#include "base/message_loop/message_loop.h"
9#include "net/base/io_buffer.h"
10#include "net/base/net_errors.h"
11#include "net/socket/stream_socket.h"
12#include "testing/gtest/include/gtest/gtest.h"
13
14namespace remoting {
15namespace protocol {
16
17StreamConnectionTester::StreamConnectionTester(net::StreamSocket* client_socket,
18                                               net::StreamSocket* host_socket,
19                                               int message_size,
20                                               int message_count)
21    : message_loop_(base::MessageLoop::current()),
22      host_socket_(host_socket),
23      client_socket_(client_socket),
24      message_size_(message_size),
25      test_data_size_(message_size * message_count),
26      done_(false),
27      write_errors_(0),
28      read_errors_(0) {
29}
30
31StreamConnectionTester::~StreamConnectionTester() {
32}
33
34void StreamConnectionTester::Start() {
35  InitBuffers();
36  DoRead();
37  DoWrite();
38}
39
40void StreamConnectionTester::CheckResults() {
41  EXPECT_EQ(0, write_errors_);
42  EXPECT_EQ(0, read_errors_);
43
44  ASSERT_EQ(test_data_size_, input_buffer_->offset());
45
46  output_buffer_->SetOffset(0);
47  ASSERT_EQ(test_data_size_, output_buffer_->size());
48
49  EXPECT_EQ(0, memcmp(output_buffer_->data(),
50                      input_buffer_->StartOfBuffer(), test_data_size_));
51}
52
53void StreamConnectionTester::Done() {
54  done_ = true;
55  message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure());
56}
57
58void StreamConnectionTester::InitBuffers() {
59  output_buffer_ = new net::DrainableIOBuffer(
60      new net::IOBuffer(test_data_size_), test_data_size_);
61  for (int i = 0; i < test_data_size_; ++i) {
62    output_buffer_->data()[i] = static_cast<char>(i);
63  }
64
65  input_buffer_ = new net::GrowableIOBuffer();
66}
67
68void StreamConnectionTester::DoWrite() {
69  int result = 1;
70  while (result > 0) {
71    if (output_buffer_->BytesRemaining() == 0)
72      break;
73
74    int bytes_to_write = std::min(output_buffer_->BytesRemaining(),
75                                  message_size_);
76    result = client_socket_->Write(
77        output_buffer_.get(),
78        bytes_to_write,
79        base::Bind(&StreamConnectionTester::OnWritten, base::Unretained(this)));
80    HandleWriteResult(result);
81  }
82}
83
84void StreamConnectionTester::OnWritten(int result) {
85  HandleWriteResult(result);
86  DoWrite();
87}
88
89void StreamConnectionTester::HandleWriteResult(int result) {
90  if (result <= 0 && result != net::ERR_IO_PENDING) {
91    LOG(ERROR) << "Received error " << result << " when trying to write";
92    write_errors_++;
93    Done();
94  } else if (result > 0) {
95    output_buffer_->DidConsume(result);
96  }
97}
98
99void StreamConnectionTester::DoRead() {
100  int result = 1;
101  while (result > 0) {
102    input_buffer_->SetCapacity(input_buffer_->offset() + message_size_);
103    result = host_socket_->Read(
104        input_buffer_.get(),
105        message_size_,
106        base::Bind(&StreamConnectionTester::OnRead, base::Unretained(this)));
107    HandleReadResult(result);
108  };
109}
110
111void StreamConnectionTester::OnRead(int result) {
112  HandleReadResult(result);
113  if (!done_)
114    DoRead();  // Don't try to read again when we are done reading.
115}
116
117void StreamConnectionTester::HandleReadResult(int result) {
118  if (result <= 0 && result != net::ERR_IO_PENDING) {
119    LOG(ERROR) << "Received error " << result << " when trying to read";
120    read_errors_++;
121    Done();
122  } else if (result > 0) {
123    // Allocate memory for the next read.
124    input_buffer_->set_offset(input_buffer_->offset() + result);
125    if (input_buffer_->offset() == test_data_size_)
126      Done();
127  }
128}
129
130DatagramConnectionTester::DatagramConnectionTester(net::Socket* client_socket,
131                                                   net::Socket* host_socket,
132                                                   int message_size,
133                                                   int message_count,
134                                                   int delay_ms)
135    : message_loop_(base::MessageLoop::current()),
136      host_socket_(host_socket),
137      client_socket_(client_socket),
138      message_size_(message_size),
139      message_count_(message_count),
140      delay_ms_(delay_ms),
141      done_(false),
142      write_errors_(0),
143      read_errors_(0),
144      packets_sent_(0),
145      packets_received_(0),
146      bad_packets_received_(0) {
147  sent_packets_.resize(message_count_);
148}
149
150DatagramConnectionTester::~DatagramConnectionTester() {
151}
152
153void DatagramConnectionTester::Start() {
154  DoRead();
155  DoWrite();
156}
157
158void DatagramConnectionTester::CheckResults() {
159  EXPECT_EQ(0, write_errors_);
160  EXPECT_EQ(0, read_errors_);
161
162  EXPECT_EQ(0, bad_packets_received_);
163
164  // Verify that we've received at least one packet.
165  EXPECT_GT(packets_received_, 0);
166  VLOG(0) << "Received " << packets_received_ << " packets out of "
167          << message_count_;
168}
169
170void DatagramConnectionTester::Done() {
171  done_ = true;
172  message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure());
173}
174
175void DatagramConnectionTester::DoWrite() {
176  if (packets_sent_ >= message_count_) {
177    Done();
178    return;
179  }
180
181  scoped_refptr<net::IOBuffer> packet(new net::IOBuffer(message_size_));
182  for (int i = 0; i < message_size_; ++i) {
183    packet->data()[i] = static_cast<char>(i);
184  }
185  sent_packets_[packets_sent_] = packet;
186  // Put index of this packet in the beginning of the packet body.
187  memcpy(packet->data(), &packets_sent_, sizeof(packets_sent_));
188
189  int result = client_socket_->Write(
190      packet.get(),
191      message_size_,
192      base::Bind(&DatagramConnectionTester::OnWritten, base::Unretained(this)));
193  HandleWriteResult(result);
194}
195
196void DatagramConnectionTester::OnWritten(int result) {
197  HandleWriteResult(result);
198}
199
200void DatagramConnectionTester::HandleWriteResult(int result) {
201  if (result <= 0 && result != net::ERR_IO_PENDING) {
202    LOG(ERROR) << "Received error " << result << " when trying to write";
203    write_errors_++;
204    Done();
205  } else if (result > 0) {
206    EXPECT_EQ(message_size_, result);
207    packets_sent_++;
208    message_loop_->PostDelayedTask(
209        FROM_HERE,
210        base::Bind(&DatagramConnectionTester::DoWrite, base::Unretained(this)),
211        base::TimeDelta::FromMilliseconds(delay_ms_));
212  }
213}
214
215void DatagramConnectionTester::DoRead() {
216  int result = 1;
217  while (result > 0) {
218    int kReadSize = message_size_ * 2;
219    read_buffer_ = new net::IOBuffer(kReadSize);
220
221    result = host_socket_->Read(
222        read_buffer_.get(),
223        kReadSize,
224        base::Bind(&DatagramConnectionTester::OnRead, base::Unretained(this)));
225    HandleReadResult(result);
226  };
227}
228
229void DatagramConnectionTester::OnRead(int result) {
230  HandleReadResult(result);
231  DoRead();
232}
233
234void DatagramConnectionTester::HandleReadResult(int result) {
235  if (result <= 0 && result != net::ERR_IO_PENDING) {
236    // Error will be received after the socket is closed.
237    LOG(ERROR) << "Received error " << result << " when trying to read";
238    read_errors_++;
239    Done();
240  } else if (result > 0) {
241    packets_received_++;
242    if (message_size_ != result) {
243      // Invalid packet size;
244      bad_packets_received_++;
245    } else {
246      // Validate packet body.
247      int packet_id;
248      memcpy(&packet_id, read_buffer_->data(), sizeof(packet_id));
249      if (packet_id < 0 || packet_id >= message_count_) {
250        bad_packets_received_++;
251      } else {
252        if (memcmp(read_buffer_->data(), sent_packets_[packet_id]->data(),
253                   message_size_) != 0)
254          bad_packets_received_++;
255      }
256    }
257  }
258}
259
260}  // namespace protocol
261}  // namespace remoting
262