1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "jingle/glue/pseudotcp_adapter.h"
6
7#include <vector>
8
9#include "base/bind.h"
10#include "base/bind_helpers.h"
11#include "base/compiler_specific.h"
12#include "jingle/glue/thread_wrapper.h"
13#include "net/base/io_buffer.h"
14#include "net/base/net_errors.h"
15#include "net/base/test_completion_callback.h"
16#include "net/udp/udp_socket.h"
17#include "testing/gmock/include/gmock/gmock.h"
18#include "testing/gtest/include/gtest/gtest.h"
19
20
21namespace jingle_glue {
22namespace {
23class FakeSocket;
24}  // namespace
25}  // namespace jingle_glue
26
27namespace jingle_glue {
28
29namespace {
30
31const int kMessageSize = 1024;
32const int kMessages = 100;
33const int kTestDataSize = kMessages * kMessageSize;
34
35class RateLimiter {
36 public:
37  virtual ~RateLimiter() { };
38  // Returns true if the new packet needs to be dropped, false otherwise.
39  virtual bool DropNextPacket() = 0;
40};
41
42class LeakyBucket : public RateLimiter {
43 public:
44  // |rate| is in drops per second.
45  LeakyBucket(double volume, double rate)
46      : volume_(volume),
47        rate_(rate),
48        level_(0.0),
49        last_update_(base::TimeTicks::HighResNow()) {
50  }
51
52  virtual ~LeakyBucket() { }
53
54  virtual bool DropNextPacket() OVERRIDE {
55    base::TimeTicks now = base::TimeTicks::HighResNow();
56    double interval = (now - last_update_).InSecondsF();
57    last_update_ = now;
58    level_ = level_ + 1.0 - interval * rate_;
59    if (level_ > volume_) {
60      level_ = volume_;
61      return true;
62    } else if (level_ < 0.0) {
63      level_ = 0.0;
64    }
65    return false;
66  }
67
68 private:
69  double volume_;
70  double rate_;
71  double level_;
72  base::TimeTicks last_update_;
73};
74
75class FakeSocket : public net::Socket {
76 public:
77  FakeSocket()
78      : rate_limiter_(NULL),
79        latency_ms_(0) {
80  }
81  virtual ~FakeSocket() { }
82
83  void AppendInputPacket(const std::vector<char>& data) {
84    if (rate_limiter_ && rate_limiter_->DropNextPacket())
85      return;  // Lose the packet.
86
87    if (!read_callback_.is_null()) {
88      int size = std::min(read_buffer_size_, static_cast<int>(data.size()));
89      memcpy(read_buffer_->data(), &data[0], data.size());
90      net::CompletionCallback cb = read_callback_;
91      read_callback_.Reset();
92      read_buffer_ = NULL;
93      cb.Run(size);
94    } else {
95      incoming_packets_.push_back(data);
96    }
97  }
98
99  void Connect(FakeSocket* peer_socket) {
100    peer_socket_ = peer_socket;
101  }
102
103  void set_rate_limiter(RateLimiter* rate_limiter) {
104    rate_limiter_ = rate_limiter;
105  };
106
107  void set_latency(int latency_ms) { latency_ms_ = latency_ms; };
108
109  // net::Socket interface.
110  virtual int Read(net::IOBuffer* buf, int buf_len,
111                   const net::CompletionCallback& callback) OVERRIDE {
112    CHECK(read_callback_.is_null());
113    CHECK(buf);
114
115    if (incoming_packets_.size() > 0) {
116      scoped_refptr<net::IOBuffer> buffer(buf);
117      int size = std::min(
118          static_cast<int>(incoming_packets_.front().size()), buf_len);
119      memcpy(buffer->data(), &*incoming_packets_.front().begin(), size);
120      incoming_packets_.pop_front();
121      return size;
122    } else {
123      read_callback_ = callback;
124      read_buffer_ = buf;
125      read_buffer_size_ = buf_len;
126      return net::ERR_IO_PENDING;
127    }
128  }
129
130  virtual int Write(net::IOBuffer* buf, int buf_len,
131                    const net::CompletionCallback& callback) OVERRIDE {
132    DCHECK(buf);
133    if (peer_socket_) {
134      base::MessageLoop::current()->PostDelayedTask(
135          FROM_HERE,
136          base::Bind(&FakeSocket::AppendInputPacket,
137                     base::Unretained(peer_socket_),
138                     std::vector<char>(buf->data(), buf->data() + buf_len)),
139          base::TimeDelta::FromMilliseconds(latency_ms_));
140    }
141
142    return buf_len;
143  }
144
145  virtual int SetReceiveBufferSize(int32 size) OVERRIDE {
146    NOTIMPLEMENTED();
147    return net::ERR_NOT_IMPLEMENTED;
148  }
149  virtual int SetSendBufferSize(int32 size) OVERRIDE {
150    NOTIMPLEMENTED();
151    return net::ERR_NOT_IMPLEMENTED;
152  }
153
154 private:
155  scoped_refptr<net::IOBuffer> read_buffer_;
156  int read_buffer_size_;
157  net::CompletionCallback read_callback_;
158
159  std::deque<std::vector<char> > incoming_packets_;
160
161  FakeSocket* peer_socket_;
162  RateLimiter* rate_limiter_;
163  int latency_ms_;
164};
165
166class TCPChannelTester : public base::RefCountedThreadSafe<TCPChannelTester> {
167 public:
168  TCPChannelTester(base::MessageLoop* message_loop,
169                   net::Socket* client_socket,
170                   net::Socket* host_socket)
171      : message_loop_(message_loop),
172        host_socket_(host_socket),
173        client_socket_(client_socket),
174        done_(false),
175        write_errors_(0),
176        read_errors_(0) {}
177
178  void Start() {
179    message_loop_->PostTask(
180        FROM_HERE, base::Bind(&TCPChannelTester::DoStart, this));
181  }
182
183  void CheckResults() {
184    EXPECT_EQ(0, write_errors_);
185    EXPECT_EQ(0, read_errors_);
186
187    ASSERT_EQ(kTestDataSize + kMessageSize, input_buffer_->capacity());
188
189    output_buffer_->SetOffset(0);
190    ASSERT_EQ(kTestDataSize, output_buffer_->size());
191
192    EXPECT_EQ(0, memcmp(output_buffer_->data(),
193                        input_buffer_->StartOfBuffer(), kTestDataSize));
194  }
195
196 protected:
197  virtual ~TCPChannelTester() {}
198
199  void Done() {
200    done_ = true;
201    message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure());
202  }
203
204  void DoStart() {
205    InitBuffers();
206    DoRead();
207    DoWrite();
208  }
209
210  void InitBuffers() {
211    output_buffer_ = new net::DrainableIOBuffer(
212        new net::IOBuffer(kTestDataSize), kTestDataSize);
213    memset(output_buffer_->data(), 123, kTestDataSize);
214
215    input_buffer_ = new net::GrowableIOBuffer();
216    // Always keep kMessageSize bytes available at the end of the input buffer.
217    input_buffer_->SetCapacity(kMessageSize);
218  }
219
220  void DoWrite() {
221    int result = 1;
222    while (result > 0) {
223      if (output_buffer_->BytesRemaining() == 0)
224        break;
225
226      int bytes_to_write = std::min(output_buffer_->BytesRemaining(),
227                                    kMessageSize);
228      result = client_socket_->Write(
229          output_buffer_.get(),
230          bytes_to_write,
231          base::Bind(&TCPChannelTester::OnWritten, base::Unretained(this)));
232      HandleWriteResult(result);
233    }
234  }
235
236  void OnWritten(int result) {
237    HandleWriteResult(result);
238    DoWrite();
239  }
240
241  void HandleWriteResult(int result) {
242    if (result <= 0 && result != net::ERR_IO_PENDING) {
243      LOG(ERROR) << "Received error " << result << " when trying to write";
244      write_errors_++;
245      Done();
246    } else if (result > 0) {
247      output_buffer_->DidConsume(result);
248    }
249  }
250
251  void DoRead() {
252    int result = 1;
253    while (result > 0) {
254      input_buffer_->set_offset(input_buffer_->capacity() - kMessageSize);
255
256      result = host_socket_->Read(
257          input_buffer_.get(),
258          kMessageSize,
259          base::Bind(&TCPChannelTester::OnRead, base::Unretained(this)));
260      HandleReadResult(result);
261    };
262  }
263
264  void OnRead(int result) {
265    HandleReadResult(result);
266    DoRead();
267  }
268
269  void HandleReadResult(int result) {
270    if (result <= 0 && result != net::ERR_IO_PENDING) {
271      if (!done_) {
272        LOG(ERROR) << "Received error " << result << " when trying to read";
273        read_errors_++;
274        Done();
275      }
276    } else if (result > 0) {
277      // Allocate memory for the next read.
278      input_buffer_->SetCapacity(input_buffer_->capacity() + result);
279      if (input_buffer_->capacity() == kTestDataSize + kMessageSize)
280        Done();
281    }
282  }
283
284 private:
285  friend class base::RefCountedThreadSafe<TCPChannelTester>;
286
287  base::MessageLoop* message_loop_;
288  net::Socket* host_socket_;
289  net::Socket* client_socket_;
290  bool done_;
291
292  scoped_refptr<net::DrainableIOBuffer> output_buffer_;
293  scoped_refptr<net::GrowableIOBuffer> input_buffer_;
294
295  int write_errors_;
296  int read_errors_;
297};
298
299class PseudoTcpAdapterTest : public testing::Test {
300 protected:
301  virtual void SetUp() OVERRIDE {
302    JingleThreadWrapper::EnsureForCurrentMessageLoop();
303
304    host_socket_ = new FakeSocket();
305    client_socket_ = new FakeSocket();
306
307    host_socket_->Connect(client_socket_);
308    client_socket_->Connect(host_socket_);
309
310    host_pseudotcp_.reset(new PseudoTcpAdapter(host_socket_));
311    client_pseudotcp_.reset(new PseudoTcpAdapter(client_socket_));
312  }
313
314  FakeSocket* host_socket_;
315  FakeSocket* client_socket_;
316
317  scoped_ptr<PseudoTcpAdapter> host_pseudotcp_;
318  scoped_ptr<PseudoTcpAdapter> client_pseudotcp_;
319  base::MessageLoop message_loop_;
320};
321
322TEST_F(PseudoTcpAdapterTest, DataTransfer) {
323  net::TestCompletionCallback host_connect_cb;
324  net::TestCompletionCallback client_connect_cb;
325
326  int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback());
327  int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback());
328
329  if (rv1 == net::ERR_IO_PENDING)
330    rv1 = host_connect_cb.WaitForResult();
331  if (rv2 == net::ERR_IO_PENDING)
332    rv2 = client_connect_cb.WaitForResult();
333  ASSERT_EQ(net::OK, rv1);
334  ASSERT_EQ(net::OK, rv2);
335
336  scoped_refptr<TCPChannelTester> tester =
337      new TCPChannelTester(&message_loop_, host_pseudotcp_.get(),
338                           client_pseudotcp_.get());
339
340  tester->Start();
341  message_loop_.Run();
342  tester->CheckResults();
343}
344
345TEST_F(PseudoTcpAdapterTest, LimitedChannel) {
346  const int kLatencyMs = 20;
347  const int kPacketsPerSecond = 400;
348  const int kBurstPackets = 10;
349
350  LeakyBucket host_limiter(kBurstPackets, kPacketsPerSecond);
351  host_socket_->set_latency(kLatencyMs);
352  host_socket_->set_rate_limiter(&host_limiter);
353
354  LeakyBucket client_limiter(kBurstPackets, kPacketsPerSecond);
355  host_socket_->set_latency(kLatencyMs);
356  client_socket_->set_rate_limiter(&client_limiter);
357
358  net::TestCompletionCallback host_connect_cb;
359  net::TestCompletionCallback client_connect_cb;
360
361  int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback());
362  int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback());
363
364  if (rv1 == net::ERR_IO_PENDING)
365    rv1 = host_connect_cb.WaitForResult();
366  if (rv2 == net::ERR_IO_PENDING)
367    rv2 = client_connect_cb.WaitForResult();
368  ASSERT_EQ(net::OK, rv1);
369  ASSERT_EQ(net::OK, rv2);
370
371  scoped_refptr<TCPChannelTester> tester =
372      new TCPChannelTester(&message_loop_, host_pseudotcp_.get(),
373                           client_pseudotcp_.get());
374
375  tester->Start();
376  message_loop_.Run();
377  tester->CheckResults();
378}
379
380class DeleteOnConnected {
381 public:
382  DeleteOnConnected(base::MessageLoop* message_loop,
383                    scoped_ptr<PseudoTcpAdapter>* adapter)
384      : message_loop_(message_loop), adapter_(adapter) {}
385  void OnConnected(int error) {
386    adapter_->reset();
387    message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure());
388  }
389  base::MessageLoop* message_loop_;
390  scoped_ptr<PseudoTcpAdapter>* adapter_;
391};
392
393TEST_F(PseudoTcpAdapterTest, DeleteOnConnected) {
394  // This test verifies that deleting the adapter mid-callback doesn't lead
395  // to deleted structures being touched as the stack unrolls, so the failure
396  // mode is a crash rather than a normal test failure.
397  net::TestCompletionCallback client_connect_cb;
398  DeleteOnConnected host_delete(&message_loop_, &host_pseudotcp_);
399
400  host_pseudotcp_->Connect(base::Bind(&DeleteOnConnected::OnConnected,
401                                      base::Unretained(&host_delete)));
402  client_pseudotcp_->Connect(client_connect_cb.callback());
403  message_loop_.Run();
404
405  ASSERT_EQ(NULL, host_pseudotcp_.get());
406}
407
408// Verify that we can send/receive data with the write-waits-for-send
409// flag set.
410TEST_F(PseudoTcpAdapterTest, WriteWaitsForSendLetsDataThrough) {
411  net::TestCompletionCallback host_connect_cb;
412  net::TestCompletionCallback client_connect_cb;
413
414  host_pseudotcp_->SetWriteWaitsForSend(true);
415  client_pseudotcp_->SetWriteWaitsForSend(true);
416
417  // Disable Nagle's algorithm because the test is slow when it is
418  // enabled.
419  host_pseudotcp_->SetNoDelay(true);
420
421  int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback());
422  int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback());
423
424  if (rv1 == net::ERR_IO_PENDING)
425    rv1 = host_connect_cb.WaitForResult();
426  if (rv2 == net::ERR_IO_PENDING)
427    rv2 = client_connect_cb.WaitForResult();
428  ASSERT_EQ(net::OK, rv1);
429  ASSERT_EQ(net::OK, rv2);
430
431  scoped_refptr<TCPChannelTester> tester =
432      new TCPChannelTester(&message_loop_, host_pseudotcp_.get(),
433                           client_pseudotcp_.get());
434
435  tester->Start();
436  message_loop_.Run();
437  tester->CheckResults();
438}
439
440}  // namespace
441
442}  // namespace jingle_glue
443