1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/protocol/channel_multiplexer.h"
6
7#include "base/bind.h"
8#include "base/message_loop/message_loop.h"
9#include "net/base/net_errors.h"
10#include "net/socket/socket.h"
11#include "net/socket/stream_socket.h"
12#include "remoting/base/constants.h"
13#include "remoting/protocol/connection_tester.h"
14#include "remoting/protocol/fake_session.h"
15#include "testing/gmock/include/gmock/gmock.h"
16#include "testing/gtest/include/gtest/gtest.h"
17
18using testing::_;
19using testing::AtMost;
20using testing::InvokeWithoutArgs;
21
22namespace remoting {
23namespace protocol {
24
25namespace {
26
27const int kMessageSize = 1024;
28const int kMessages = 100;
29const char kMuxChannelName[] = "mux";
30
31const char kTestChannelName[] = "test";
32const char kTestChannelName2[] = "test2";
33
34
35void QuitCurrentThread() {
36  base::MessageLoop::current()->PostTask(FROM_HERE,
37                                         base::MessageLoop::QuitClosure());
38}
39
40class MockSocketCallback {
41 public:
42  MOCK_METHOD1(OnDone, void(int result));
43};
44
45class MockConnectCallback {
46 public:
47  MOCK_METHOD1(OnConnectedPtr, void(net::StreamSocket* socket));
48  void OnConnected(scoped_ptr<net::StreamSocket> socket) {
49    OnConnectedPtr(socket.release());
50  }
51};
52
53}  // namespace
54
55class ChannelMultiplexerTest : public testing::Test {
56 public:
57  void DeleteAll() {
58    host_socket1_.reset();
59    host_socket2_.reset();
60    client_socket1_.reset();
61    client_socket2_.reset();
62    host_mux_.reset();
63    client_mux_.reset();
64  }
65
66  void DeleteAfterSessionFail() {
67    host_mux_->CancelChannelCreation(kTestChannelName2);
68    DeleteAll();
69  }
70
71 protected:
72  virtual void SetUp() OVERRIDE {
73    // Create pair of multiplexers and connect them to each other.
74    host_mux_.reset(new ChannelMultiplexer(&host_session_, kMuxChannelName));
75    client_mux_.reset(new ChannelMultiplexer(&client_session_,
76                                             kMuxChannelName));
77  }
78
79  // Connect sockets to each other. Must be called after we've created at least
80  // one channel with each multiplexer.
81  void ConnectSockets() {
82    FakeSocket* host_socket =
83        host_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName);
84    FakeSocket* client_socket =
85        client_session_.GetStreamChannel(ChannelMultiplexer::kMuxChannelName);
86    host_socket->PairWith(client_socket);
87
88    // Make writes asynchronous in one direction.
89    host_socket->set_async_write(true);
90  }
91
92  void CreateChannel(const std::string& name,
93                     scoped_ptr<net::StreamSocket>* host_socket,
94                     scoped_ptr<net::StreamSocket>* client_socket) {
95    int counter = 2;
96    host_mux_->CreateStreamChannel(name, base::Bind(
97        &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this),
98        host_socket, &counter));
99    client_mux_->CreateStreamChannel(name, base::Bind(
100        &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this),
101        client_socket, &counter));
102
103    message_loop_.Run();
104
105    EXPECT_TRUE(host_socket->get());
106    EXPECT_TRUE(client_socket->get());
107  }
108
109  void OnChannelConnected(
110      scoped_ptr<net::StreamSocket>* storage,
111      int* counter,
112      scoped_ptr<net::StreamSocket> socket) {
113    *storage = socket.Pass();
114    --(*counter);
115    EXPECT_GE(*counter, 0);
116    if (*counter == 0)
117      QuitCurrentThread();
118  }
119
120  scoped_refptr<net::IOBufferWithSize> CreateTestBuffer(int size) {
121    scoped_refptr<net::IOBufferWithSize> result =
122        new net::IOBufferWithSize(size);
123    for (int i = 0; i< size; ++i) {
124      result->data()[i] = rand() % 256;
125    }
126    return result;
127  }
128
129  base::MessageLoop message_loop_;
130
131  FakeSession host_session_;
132  FakeSession client_session_;
133
134  scoped_ptr<ChannelMultiplexer> host_mux_;
135  scoped_ptr<ChannelMultiplexer> client_mux_;
136
137  scoped_ptr<net::StreamSocket> host_socket1_;
138  scoped_ptr<net::StreamSocket> client_socket1_;
139  scoped_ptr<net::StreamSocket> host_socket2_;
140  scoped_ptr<net::StreamSocket> client_socket2_;
141};
142
143
144TEST_F(ChannelMultiplexerTest, OneChannel) {
145  scoped_ptr<net::StreamSocket> host_socket;
146  scoped_ptr<net::StreamSocket> client_socket;
147  ASSERT_NO_FATAL_FAILURE(
148      CreateChannel(kTestChannelName, &host_socket, &client_socket));
149
150  ConnectSockets();
151
152  StreamConnectionTester tester(host_socket.get(), client_socket.get(),
153                                kMessageSize, kMessages);
154  tester.Start();
155  message_loop_.Run();
156  tester.CheckResults();
157}
158
159TEST_F(ChannelMultiplexerTest, TwoChannels) {
160  scoped_ptr<net::StreamSocket> host_socket1_;
161  scoped_ptr<net::StreamSocket> client_socket1_;
162  ASSERT_NO_FATAL_FAILURE(
163      CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
164
165  scoped_ptr<net::StreamSocket> host_socket2_;
166  scoped_ptr<net::StreamSocket> client_socket2_;
167  ASSERT_NO_FATAL_FAILURE(
168      CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
169
170  ConnectSockets();
171
172  StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
173                                kMessageSize, kMessages);
174  StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
175                                 kMessageSize, kMessages);
176  tester1.Start();
177  tester2.Start();
178  while (!tester1.done() || !tester2.done()) {
179    message_loop_.Run();
180  }
181  tester1.CheckResults();
182  tester2.CheckResults();
183}
184
185// Four channels, two in each direction
186TEST_F(ChannelMultiplexerTest, FourChannels) {
187  scoped_ptr<net::StreamSocket> host_socket1_;
188  scoped_ptr<net::StreamSocket> client_socket1_;
189  ASSERT_NO_FATAL_FAILURE(
190      CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
191
192  scoped_ptr<net::StreamSocket> host_socket2_;
193  scoped_ptr<net::StreamSocket> client_socket2_;
194  ASSERT_NO_FATAL_FAILURE(
195      CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
196
197  scoped_ptr<net::StreamSocket> host_socket3;
198  scoped_ptr<net::StreamSocket> client_socket3;
199  ASSERT_NO_FATAL_FAILURE(
200      CreateChannel("test3", &host_socket3, &client_socket3));
201
202  scoped_ptr<net::StreamSocket> host_socket4;
203  scoped_ptr<net::StreamSocket> client_socket4;
204  ASSERT_NO_FATAL_FAILURE(
205      CreateChannel("ch4", &host_socket4, &client_socket4));
206
207  ConnectSockets();
208
209  StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(),
210                                kMessageSize, kMessages);
211  StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(),
212                                 kMessageSize, kMessages);
213  StreamConnectionTester tester3(client_socket3.get(), host_socket3.get(),
214                                 kMessageSize, kMessages);
215  StreamConnectionTester tester4(client_socket4.get(), host_socket4.get(),
216                                 kMessageSize, kMessages);
217  tester1.Start();
218  tester2.Start();
219  tester3.Start();
220  tester4.Start();
221  while (!tester1.done() || !tester2.done() ||
222         !tester3.done() || !tester4.done()) {
223    message_loop_.Run();
224  }
225  tester1.CheckResults();
226  tester2.CheckResults();
227  tester3.CheckResults();
228  tester4.CheckResults();
229}
230
231TEST_F(ChannelMultiplexerTest, WriteFailSync) {
232  scoped_ptr<net::StreamSocket> host_socket1_;
233  scoped_ptr<net::StreamSocket> client_socket1_;
234  ASSERT_NO_FATAL_FAILURE(
235      CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
236
237  scoped_ptr<net::StreamSocket> host_socket2_;
238  scoped_ptr<net::StreamSocket> client_socket2_;
239  ASSERT_NO_FATAL_FAILURE(
240      CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
241
242  ConnectSockets();
243
244  host_session_.GetStreamChannel(kMuxChannelName)->
245      set_next_write_error(net::ERR_FAILED);
246  host_session_.GetStreamChannel(kMuxChannelName)->
247      set_async_write(false);
248
249  scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
250
251  MockSocketCallback cb1;
252  MockSocketCallback cb2;
253
254  EXPECT_CALL(cb1, OnDone(_))
255      .Times(0);
256  EXPECT_CALL(cb2, OnDone(_))
257      .Times(0);
258
259  EXPECT_EQ(net::ERR_FAILED,
260            host_socket1_->Write(buf.get(),
261                                 buf->size(),
262                                 base::Bind(&MockSocketCallback::OnDone,
263                                            base::Unretained(&cb1))));
264  EXPECT_EQ(net::ERR_FAILED,
265            host_socket2_->Write(buf.get(),
266                                 buf->size(),
267                                 base::Bind(&MockSocketCallback::OnDone,
268                                            base::Unretained(&cb2))));
269
270  message_loop_.RunUntilIdle();
271}
272
273TEST_F(ChannelMultiplexerTest, WriteFailAsync) {
274  ASSERT_NO_FATAL_FAILURE(
275      CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
276
277  ASSERT_NO_FATAL_FAILURE(
278      CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
279
280  ConnectSockets();
281
282  host_session_.GetStreamChannel(kMuxChannelName)->
283      set_next_write_error(net::ERR_FAILED);
284  host_session_.GetStreamChannel(kMuxChannelName)->
285      set_async_write(true);
286
287  scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
288
289  MockSocketCallback cb1;
290  MockSocketCallback cb2;
291  EXPECT_CALL(cb1, OnDone(net::ERR_FAILED));
292  EXPECT_CALL(cb2, OnDone(net::ERR_FAILED));
293
294  EXPECT_EQ(net::ERR_IO_PENDING,
295            host_socket1_->Write(buf.get(),
296                                 buf->size(),
297                                 base::Bind(&MockSocketCallback::OnDone,
298                                            base::Unretained(&cb1))));
299  EXPECT_EQ(net::ERR_IO_PENDING,
300            host_socket2_->Write(buf.get(),
301                                 buf->size(),
302                                 base::Bind(&MockSocketCallback::OnDone,
303                                            base::Unretained(&cb2))));
304
305  message_loop_.RunUntilIdle();
306}
307
308TEST_F(ChannelMultiplexerTest, DeleteWhenFailed) {
309  ASSERT_NO_FATAL_FAILURE(
310      CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_));
311  ASSERT_NO_FATAL_FAILURE(
312      CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_));
313
314  ConnectSockets();
315
316  host_session_.GetStreamChannel(kMuxChannelName)->
317      set_next_write_error(net::ERR_FAILED);
318  host_session_.GetStreamChannel(kMuxChannelName)->
319      set_async_write(true);
320
321  scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100);
322
323  MockSocketCallback cb1;
324  MockSocketCallback cb2;
325
326  EXPECT_CALL(cb1, OnDone(net::ERR_FAILED))
327      .Times(AtMost(1))
328      .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll));
329  EXPECT_CALL(cb2, OnDone(net::ERR_FAILED))
330      .Times(AtMost(1))
331      .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll));
332
333  EXPECT_EQ(net::ERR_IO_PENDING,
334            host_socket1_->Write(buf.get(),
335                                 buf->size(),
336                                 base::Bind(&MockSocketCallback::OnDone,
337                                            base::Unretained(&cb1))));
338  EXPECT_EQ(net::ERR_IO_PENDING,
339            host_socket2_->Write(buf.get(),
340                                 buf->size(),
341                                 base::Bind(&MockSocketCallback::OnDone,
342                                            base::Unretained(&cb2))));
343
344  message_loop_.RunUntilIdle();
345
346  // Check that the sockets were destroyed.
347  EXPECT_FALSE(host_mux_.get());
348}
349
350TEST_F(ChannelMultiplexerTest, SessionFail) {
351  host_session_.set_async_creation(true);
352  host_session_.set_error(AUTHENTICATION_FAILED);
353
354  MockConnectCallback cb1;
355  MockConnectCallback cb2;
356
357  host_mux_->CreateStreamChannel(kTestChannelName, base::Bind(
358      &MockConnectCallback::OnConnected, base::Unretained(&cb1)));
359  host_mux_->CreateStreamChannel(kTestChannelName2, base::Bind(
360      &MockConnectCallback::OnConnected, base::Unretained(&cb2)));
361
362  EXPECT_CALL(cb1, OnConnectedPtr(NULL))
363      .Times(AtMost(1))
364      .WillOnce(InvokeWithoutArgs(
365          this, &ChannelMultiplexerTest::DeleteAfterSessionFail));
366  EXPECT_CALL(cb2, OnConnectedPtr(_))
367      .Times(0);
368
369  message_loop_.RunUntilIdle();
370}
371
372}  // namespace protocol
373}  // namespace remoting
374