1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/protocol/channel_multiplexer.h"
6
7#include "base/bind.h"
8#include "base/message_loop/message_loop.h"
9#include "base/run_loop.h"
10#include "net/base/net_errors.h"
11#include "net/socket/socket.h"
12#include "net/socket/stream_socket.h"
13#include "remoting/base/constants.h"
14#include "remoting/protocol/connection_tester.h"
15#include "remoting/protocol/fake_session.h"
16#include "testing/gmock/include/gmock/gmock.h"
17#include "testing/gtest/include/gtest/gtest.h"
18
19using testing::_;
20using testing::AtMost;
21using testing::InvokeWithoutArgs;
22
23namespace remoting {
24namespace protocol {
25
26namespace {
27
28const int kMessageSize = 1024;
29const int kMessages = 100;
30const char kMuxChannelName[] = "mux";
31
32const char kTestChannelName[] = "test";
33const char kTestChannelName2[] = "test2";
34
35
36void QuitCurrentThread() {
37  base::MessageLoop::current()->PostTask(FROM_HERE,
38                                         base::MessageLoop::QuitClosure());
39}
40
41class MockSocketCallback {
42 public:
43  MOCK_METHOD1(OnDone, void(int result));
44};
45
46class MockConnectCallback {
47 public:
48  MOCK_METHOD1(OnConnectedPtr, void(net::StreamSocket* socket));
49  void OnConnected(scoped_ptr<net::StreamSocket> socket) {
50    OnConnectedPtr(socket.release());
51  }
52};
53
54}  // namespace
55
56class ChannelMultiplexerTest : public testing::Test {
57 public:
58  void DeleteAll() {
59    host_socket1_.reset();
60    host_socket2_.reset();
61    client_socket1_.reset();
62    client_socket2_.reset();
63    host_mux_.reset();
64    client_mux_.reset();
65  }
66
67  void DeleteAfterSessionFail() {
68    host_mux_->CancelChannelCreation(kTestChannelName2);
69    DeleteAll();
70  }
71
72 protected:
73  virtual void SetUp() OVERRIDE {
74    // Create pair of multiplexers and connect them to each other.
75    host_mux_.reset(new ChannelMultiplexer(
76        host_session_.GetTransportChannelFactory(), kMuxChannelName));
77    client_mux_.reset(new ChannelMultiplexer(
78        client_session_.GetTransportChannelFactory(), kMuxChannelName));
79  }
80
81  // Connect sockets to each other. Must be called after we've created at least
82  // one channel with each multiplexer.
83  void ConnectSockets() {
84    FakeStreamSocket* host_socket =
85        host_session_.fake_channel_factory().GetFakeChannel(
86            ChannelMultiplexer::kMuxChannelName);
87    FakeStreamSocket* client_socket =
88        client_session_.fake_channel_factory().GetFakeChannel(
89            ChannelMultiplexer::kMuxChannelName);
90    host_socket->PairWith(client_socket);
91
92    // Make writes asynchronous in one direction.
93    host_socket->set_async_write(true);
94  }
95
96  void CreateChannel(const std::string& name,
97                     scoped_ptr<net::StreamSocket>* host_socket,
98                     scoped_ptr<net::StreamSocket>* client_socket) {
99    int counter = 2;
100    host_mux_->CreateChannel(name, base::Bind(
101        &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this),
102        host_socket, &counter));
103    client_mux_->CreateChannel(name, base::Bind(
104        &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this),
105        client_socket, &counter));
106
107    message_loop_.Run();
108
109    EXPECT_TRUE(host_socket->get());
110    EXPECT_TRUE(client_socket->get());
111  }
112
113  void OnChannelConnected(
114      scoped_ptr<net::StreamSocket>* storage,
115      int* counter,
116      scoped_ptr<net::StreamSocket> socket) {
117    *storage = socket.Pass();
118    --(*counter);
119    EXPECT_GE(*counter, 0);
120    if (*counter == 0)
121      QuitCurrentThread();
122  }
123
124  scoped_refptr<net::IOBufferWithSize> CreateTestBuffer(int size) {
125    scoped_refptr<net::IOBufferWithSize> result =
126        new net::IOBufferWithSize(size);
127    for (int i = 0; i< size; ++i) {
128      result->data()[i] = rand() % 256;
129    }
130    return result;
131  }
132
133  base::MessageLoop message_loop_;
134
135  FakeSession host_session_;
136  FakeSession client_session_;
137
138  scoped_ptr<ChannelMultiplexer> host_mux_;
139  scoped_ptr<ChannelMultiplexer> client_mux_;
140
141  scoped_ptr<net::StreamSocket> host_socket1_;
142  scoped_ptr<net::StreamSocket> client_socket1_;
143  scoped_ptr<net::StreamSocket> host_socket2_;
144  scoped_ptr<net::StreamSocket> client_socket2_;
145};
146
147
148TEST_F(ChannelMultiplexerTest, OneChannel) {
149  scoped_ptr<net::StreamSocket> host_socket;
150  scoped_ptr<net::StreamSocket> client_socket;
151  ASSERT_NO_FATAL_FAILURE(
152      CreateChannel(kTestChannelName, &host_socket, &client_socket));
153
154  ConnectSockets();
155
156  StreamConnectionTester tester(host_socket.get(), client_socket.get(),
157                                kMessageSize, kMessages);
158  tester.Start();
159  message_loop_.Run();
160  tester.CheckResults();
161}
162
163TEST_F(ChannelMultiplexerTest, TwoChannels) {
164  scoped_ptr<net::StreamSocket> host_socket1_;
165  scoped_ptr<net::StreamSocket> client_socket1_;
166  ASSERT_NO_FATAL_FAILURE(
167      CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
168
169  scoped_ptr<net::StreamSocket> host_socket2_;
170  scoped_ptr<net::StreamSocket> client_socket2_;
171  ASSERT_NO_FATAL_FAILURE(
172      CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
173
174  ConnectSockets();
175
176  StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
177                                kMessageSize, kMessages);
178  StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
179                                 kMessageSize, kMessages);
180  tester1.Start();
181  tester2.Start();
182  while (!tester1.done() || !tester2.done()) {
183    message_loop_.Run();
184  }
185  tester1.CheckResults();
186  tester2.CheckResults();
187}
188
189// Four channels, two in each direction
190TEST_F(ChannelMultiplexerTest, FourChannels) {
191  scoped_ptr<net::StreamSocket> host_socket1_;
192  scoped_ptr<net::StreamSocket> client_socket1_;
193  ASSERT_NO_FATAL_FAILURE(
194      CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
195
196  scoped_ptr<net::StreamSocket> host_socket2_;
197  scoped_ptr<net::StreamSocket> client_socket2_;
198  ASSERT_NO_FATAL_FAILURE(
199      CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
200
201  scoped_ptr<net::StreamSocket> host_socket3;
202  scoped_ptr<net::StreamSocket> client_socket3;
203  ASSERT_NO_FATAL_FAILURE(
204      CreateChannel("test3", &host_socket3, &client_socket3));
205
206  scoped_ptr<net::StreamSocket> host_socket4;
207  scoped_ptr<net::StreamSocket> client_socket4;
208  ASSERT_NO_FATAL_FAILURE(
209      CreateChannel("ch4", &host_socket4, &client_socket4));
210
211  ConnectSockets();
212
213  StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
214                                kMessageSize, kMessages);
215  StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
216                                 kMessageSize, kMessages);
217  StreamConnectionTester tester3(client_socket3.get(), host_socket3.get(),
218                                 kMessageSize, kMessages);
219  StreamConnectionTester tester4(client_socket4.get(), host_socket4.get(),
220                                 kMessageSize, kMessages);
221  tester1.Start();
222  tester2.Start();
223  tester3.Start();
224  tester4.Start();
225  while (!tester1.done() || !tester2.done() ||
226         !tester3.done() || !tester4.done()) {
227    message_loop_.Run();
228  }
229  tester1.CheckResults();
230  tester2.CheckResults();
231  tester3.CheckResults();
232  tester4.CheckResults();
233}
234
235TEST_F(ChannelMultiplexerTest, WriteFailSync) {
236  scoped_ptr<net::StreamSocket> host_socket1_;
237  scoped_ptr<net::StreamSocket> client_socket1_;
238  ASSERT_NO_FATAL_FAILURE(
239      CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
240
241  scoped_ptr<net::StreamSocket> host_socket2_;
242  scoped_ptr<net::StreamSocket> client_socket2_;
243  ASSERT_NO_FATAL_FAILURE(
244      CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
245
246  ConnectSockets();
247
248  host_session_.fake_channel_factory().GetFakeChannel(kMuxChannelName)->
249      set_next_write_error(net::ERR_FAILED);
250  host_session_.fake_channel_factory().GetFakeChannel(kMuxChannelName)->
251      set_async_write(false);
252
253  scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
254
255  MockSocketCallback cb1;
256  MockSocketCallback cb2;
257
258  EXPECT_CALL(cb1, OnDone(_))
259      .Times(0);
260  EXPECT_CALL(cb2, OnDone(_))
261      .Times(0);
262
263  EXPECT_EQ(net::ERR_FAILED,
264            host_socket1_->Write(buf.get(),
265                                 buf->size(),
266                                 base::Bind(&MockSocketCallback::OnDone,
267                                            base::Unretained(&cb1))));
268  EXPECT_EQ(net::ERR_FAILED,
269            host_socket2_->Write(buf.get(),
270                                 buf->size(),
271                                 base::Bind(&MockSocketCallback::OnDone,
272                                            base::Unretained(&cb2))));
273
274  base::RunLoop().RunUntilIdle();
275}
276
277TEST_F(ChannelMultiplexerTest, WriteFailAsync) {
278  ASSERT_NO_FATAL_FAILURE(
279      CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
280
281  ASSERT_NO_FATAL_FAILURE(
282      CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
283
284  ConnectSockets();
285
286  host_session_.fake_channel_factory().GetFakeChannel(kMuxChannelName)->
287      set_next_write_error(net::ERR_FAILED);
288  host_session_.fake_channel_factory().GetFakeChannel(kMuxChannelName)->
289      set_async_write(true);
290
291  scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
292
293  MockSocketCallback cb1;
294  MockSocketCallback cb2;
295  EXPECT_CALL(cb1, OnDone(net::ERR_FAILED));
296  EXPECT_CALL(cb2, OnDone(net::ERR_FAILED));
297
298  EXPECT_EQ(net::ERR_IO_PENDING,
299            host_socket1_->Write(buf.get(),
300                                 buf->size(),
301                                 base::Bind(&MockSocketCallback::OnDone,
302                                            base::Unretained(&cb1))));
303  EXPECT_EQ(net::ERR_IO_PENDING,
304            host_socket2_->Write(buf.get(),
305                                 buf->size(),
306                                 base::Bind(&MockSocketCallback::OnDone,
307                                            base::Unretained(&cb2))));
308
309  base::RunLoop().RunUntilIdle();
310}
311
312TEST_F(ChannelMultiplexerTest, DeleteWhenFailed) {
313  ASSERT_NO_FATAL_FAILURE(
314      CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
315  ASSERT_NO_FATAL_FAILURE(
316      CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
317
318  ConnectSockets();
319
320  host_session_.fake_channel_factory().GetFakeChannel(kMuxChannelName)->
321      set_next_write_error(net::ERR_FAILED);
322  host_session_.fake_channel_factory().GetFakeChannel(kMuxChannelName)->
323      set_async_write(true);
324
325  scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
326
327  MockSocketCallback cb1;
328  MockSocketCallback cb2;
329
330  EXPECT_CALL(cb1, OnDone(net::ERR_FAILED))
331      .Times(AtMost(1))
332      .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll));
333  EXPECT_CALL(cb2, OnDone(net::ERR_FAILED))
334      .Times(AtMost(1))
335      .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll));
336
337  EXPECT_EQ(net::ERR_IO_PENDING,
338            host_socket1_->Write(buf.get(),
339                                 buf->size(),
340                                 base::Bind(&MockSocketCallback::OnDone,
341                                            base::Unretained(&cb1))));
342  EXPECT_EQ(net::ERR_IO_PENDING,
343            host_socket2_->Write(buf.get(),
344                                 buf->size(),
345                                 base::Bind(&MockSocketCallback::OnDone,
346                                            base::Unretained(&cb2))));
347
348  base::RunLoop().RunUntilIdle();
349
350  // Check that the sockets were destroyed.
351  EXPECT_FALSE(host_mux_.get());
352}
353
354TEST_F(ChannelMultiplexerTest, SessionFail) {
355  host_session_.fake_channel_factory().set_asynchronous_create(true);
356  host_session_.fake_channel_factory().set_fail_create(true);
357
358  MockConnectCallback cb1;
359  MockConnectCallback cb2;
360
361  host_mux_->CreateChannel(kTestChannelName, base::Bind(
362      &MockConnectCallback::OnConnected, base::Unretained(&cb1)));
363  host_mux_->CreateChannel(kTestChannelName2, base::Bind(
364      &MockConnectCallback::OnConnected, base::Unretained(&cb2)));
365
366  EXPECT_CALL(cb1, OnConnectedPtr(NULL))
367      .Times(AtMost(1))
368      .WillOnce(InvokeWithoutArgs(
369          this, &ChannelMultiplexerTest::DeleteAfterSessionFail));
370  EXPECT_CALL(cb2, OnConnectedPtr(_))
371      .Times(0);
372
373  base::RunLoop().RunUntilIdle();
374}
375
376}  // namespace protocol
377}  // namespace remoting
378