1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "remoting/protocol/channel_multiplexer.h" 6 7#include "base/bind.h" 8#include "base/message_loop/message_loop.h" 9#include "base/run_loop.h" 10#include "net/base/net_errors.h" 11#include "net/socket/socket.h" 12#include "net/socket/stream_socket.h" 13#include "remoting/base/constants.h" 14#include "remoting/protocol/connection_tester.h" 15#include "remoting/protocol/fake_session.h" 16#include "testing/gmock/include/gmock/gmock.h" 17#include "testing/gtest/include/gtest/gtest.h" 18 19using testing::_; 20using testing::AtMost; 21using testing::InvokeWithoutArgs; 22 23namespace remoting { 24namespace protocol { 25 26namespace { 27 28const int kMessageSize = 1024; 29const int kMessages = 100; 30const char kMuxChannelName[] = "mux"; 31 32const char kTestChannelName[] = "test"; 33const char kTestChannelName2[] = "test2"; 34 35 36void QuitCurrentThread() { 37 base::MessageLoop::current()->PostTask(FROM_HERE, 38 base::MessageLoop::QuitClosure()); 39} 40 41class MockSocketCallback { 42 public: 43 MOCK_METHOD1(OnDone, void(int result)); 44}; 45 46class MockConnectCallback { 47 public: 48 MOCK_METHOD1(OnConnectedPtr, void(net::StreamSocket* socket)); 49 void OnConnected(scoped_ptr<net::StreamSocket> socket) { 50 OnConnectedPtr(socket.release()); 51 } 52}; 53 54} // namespace 55 56class ChannelMultiplexerTest : public testing::Test { 57 public: 58 void DeleteAll() { 59 host_socket1_.reset(); 60 host_socket2_.reset(); 61 client_socket1_.reset(); 62 client_socket2_.reset(); 63 host_mux_.reset(); 64 client_mux_.reset(); 65 } 66 67 void DeleteAfterSessionFail() { 68 host_mux_->CancelChannelCreation(kTestChannelName2); 69 DeleteAll(); 70 } 71 72 protected: 73 virtual void SetUp() OVERRIDE { 74 // Create pair of multiplexers and connect them to each other. 75 host_mux_.reset(new ChannelMultiplexer( 76 host_session_.GetTransportChannelFactory(), kMuxChannelName)); 77 client_mux_.reset(new ChannelMultiplexer( 78 client_session_.GetTransportChannelFactory(), kMuxChannelName)); 79 } 80 81 // Connect sockets to each other. Must be called after we've created at least 82 // one channel with each multiplexer. 83 void ConnectSockets() { 84 FakeStreamSocket* host_socket = 85 host_session_.fake_channel_factory().GetFakeChannel( 86 ChannelMultiplexer::kMuxChannelName); 87 FakeStreamSocket* client_socket = 88 client_session_.fake_channel_factory().GetFakeChannel( 89 ChannelMultiplexer::kMuxChannelName); 90 host_socket->PairWith(client_socket); 91 92 // Make writes asynchronous in one direction. 93 host_socket->set_async_write(true); 94 } 95 96 void CreateChannel(const std::string& name, 97 scoped_ptr<net::StreamSocket>* host_socket, 98 scoped_ptr<net::StreamSocket>* client_socket) { 99 int counter = 2; 100 host_mux_->CreateChannel(name, base::Bind( 101 &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this), 102 host_socket, &counter)); 103 client_mux_->CreateChannel(name, base::Bind( 104 &ChannelMultiplexerTest::OnChannelConnected, base::Unretained(this), 105 client_socket, &counter)); 106 107 message_loop_.Run(); 108 109 EXPECT_TRUE(host_socket->get()); 110 EXPECT_TRUE(client_socket->get()); 111 } 112 113 void OnChannelConnected( 114 scoped_ptr<net::StreamSocket>* storage, 115 int* counter, 116 scoped_ptr<net::StreamSocket> socket) { 117 *storage = socket.Pass(); 118 --(*counter); 119 EXPECT_GE(*counter, 0); 120 if (*counter == 0) 121 QuitCurrentThread(); 122 } 123 124 scoped_refptr<net::IOBufferWithSize> CreateTestBuffer(int size) { 125 scoped_refptr<net::IOBufferWithSize> result = 126 new net::IOBufferWithSize(size); 127 for (int i = 0; i< size; ++i) { 128 result->data()[i] = rand() % 256; 129 } 130 return result; 131 } 132 133 base::MessageLoop message_loop_; 134 135 FakeSession host_session_; 136 FakeSession client_session_; 137 138 scoped_ptr<ChannelMultiplexer> host_mux_; 139 scoped_ptr<ChannelMultiplexer> client_mux_; 140 141 scoped_ptr<net::StreamSocket> host_socket1_; 142 scoped_ptr<net::StreamSocket> client_socket1_; 143 scoped_ptr<net::StreamSocket> host_socket2_; 144 scoped_ptr<net::StreamSocket> client_socket2_; 145}; 146 147 148TEST_F(ChannelMultiplexerTest, OneChannel) { 149 scoped_ptr<net::StreamSocket> host_socket; 150 scoped_ptr<net::StreamSocket> client_socket; 151 ASSERT_NO_FATAL_FAILURE( 152 CreateChannel(kTestChannelName, &host_socket, &client_socket)); 153 154 ConnectSockets(); 155 156 StreamConnectionTester tester(host_socket.get(), client_socket.get(), 157 kMessageSize, kMessages); 158 tester.Start(); 159 message_loop_.Run(); 160 tester.CheckResults(); 161} 162 163TEST_F(ChannelMultiplexerTest, TwoChannels) { 164 scoped_ptr<net::StreamSocket> host_socket1_; 165 scoped_ptr<net::StreamSocket> client_socket1_; 166 ASSERT_NO_FATAL_FAILURE( 167 CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_)); 168 169 scoped_ptr<net::StreamSocket> host_socket2_; 170 scoped_ptr<net::StreamSocket> client_socket2_; 171 ASSERT_NO_FATAL_FAILURE( 172 CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_)); 173 174 ConnectSockets(); 175 176 StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(), 177 kMessageSize, kMessages); 178 StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(), 179 kMessageSize, kMessages); 180 tester1.Start(); 181 tester2.Start(); 182 while (!tester1.done() || !tester2.done()) { 183 message_loop_.Run(); 184 } 185 tester1.CheckResults(); 186 tester2.CheckResults(); 187} 188 189// Four channels, two in each direction 190TEST_F(ChannelMultiplexerTest, FourChannels) { 191 scoped_ptr<net::StreamSocket> host_socket1_; 192 scoped_ptr<net::StreamSocket> client_socket1_; 193 ASSERT_NO_FATAL_FAILURE( 194 CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_)); 195 196 scoped_ptr<net::StreamSocket> host_socket2_; 197 scoped_ptr<net::StreamSocket> client_socket2_; 198 ASSERT_NO_FATAL_FAILURE( 199 CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_)); 200 201 scoped_ptr<net::StreamSocket> host_socket3; 202 scoped_ptr<net::StreamSocket> client_socket3; 203 ASSERT_NO_FATAL_FAILURE( 204 CreateChannel("test3", &host_socket3, &client_socket3)); 205 206 scoped_ptr<net::StreamSocket> host_socket4; 207 scoped_ptr<net::StreamSocket> client_socket4; 208 ASSERT_NO_FATAL_FAILURE( 209 CreateChannel("ch4", &host_socket4, &client_socket4)); 210 211 ConnectSockets(); 212 213 StreamConnectionTester tester1(host_socket1_.get(), client_socket1_.get(), 214 kMessageSize, kMessages); 215 StreamConnectionTester tester2(host_socket2_.get(), client_socket2_.get(), 216 kMessageSize, kMessages); 217 StreamConnectionTester tester3(client_socket3.get(), host_socket3.get(), 218 kMessageSize, kMessages); 219 StreamConnectionTester tester4(client_socket4.get(), host_socket4.get(), 220 kMessageSize, kMessages); 221 tester1.Start(); 222 tester2.Start(); 223 tester3.Start(); 224 tester4.Start(); 225 while (!tester1.done() || !tester2.done() || 226 !tester3.done() || !tester4.done()) { 227 message_loop_.Run(); 228 } 229 tester1.CheckResults(); 230 tester2.CheckResults(); 231 tester3.CheckResults(); 232 tester4.CheckResults(); 233} 234 235TEST_F(ChannelMultiplexerTest, WriteFailSync) { 236 scoped_ptr<net::StreamSocket> host_socket1_; 237 scoped_ptr<net::StreamSocket> client_socket1_; 238 ASSERT_NO_FATAL_FAILURE( 239 CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_)); 240 241 scoped_ptr<net::StreamSocket> host_socket2_; 242 scoped_ptr<net::StreamSocket> client_socket2_; 243 ASSERT_NO_FATAL_FAILURE( 244 CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_)); 245 246 ConnectSockets(); 247 248 host_session_.fake_channel_factory().GetFakeChannel(kMuxChannelName)-> 249 set_next_write_error(net::ERR_FAILED); 250 host_session_.fake_channel_factory().GetFakeChannel(kMuxChannelName)-> 251 set_async_write(false); 252 253 scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100); 254 255 MockSocketCallback cb1; 256 MockSocketCallback cb2; 257 258 EXPECT_CALL(cb1, OnDone(_)) 259 .Times(0); 260 EXPECT_CALL(cb2, OnDone(_)) 261 .Times(0); 262 263 EXPECT_EQ(net::ERR_FAILED, 264 host_socket1_->Write(buf.get(), 265 buf->size(), 266 base::Bind(&MockSocketCallback::OnDone, 267 base::Unretained(&cb1)))); 268 EXPECT_EQ(net::ERR_FAILED, 269 host_socket2_->Write(buf.get(), 270 buf->size(), 271 base::Bind(&MockSocketCallback::OnDone, 272 base::Unretained(&cb2)))); 273 274 base::RunLoop().RunUntilIdle(); 275} 276 277TEST_F(ChannelMultiplexerTest, WriteFailAsync) { 278 ASSERT_NO_FATAL_FAILURE( 279 CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_)); 280 281 ASSERT_NO_FATAL_FAILURE( 282 CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_)); 283 284 ConnectSockets(); 285 286 host_session_.fake_channel_factory().GetFakeChannel(kMuxChannelName)-> 287 set_next_write_error(net::ERR_FAILED); 288 host_session_.fake_channel_factory().GetFakeChannel(kMuxChannelName)-> 289 set_async_write(true); 290 291 scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100); 292 293 MockSocketCallback cb1; 294 MockSocketCallback cb2; 295 EXPECT_CALL(cb1, OnDone(net::ERR_FAILED)); 296 EXPECT_CALL(cb2, OnDone(net::ERR_FAILED)); 297 298 EXPECT_EQ(net::ERR_IO_PENDING, 299 host_socket1_->Write(buf.get(), 300 buf->size(), 301 base::Bind(&MockSocketCallback::OnDone, 302 base::Unretained(&cb1)))); 303 EXPECT_EQ(net::ERR_IO_PENDING, 304 host_socket2_->Write(buf.get(), 305 buf->size(), 306 base::Bind(&MockSocketCallback::OnDone, 307 base::Unretained(&cb2)))); 308 309 base::RunLoop().RunUntilIdle(); 310} 311 312TEST_F(ChannelMultiplexerTest, DeleteWhenFailed) { 313 ASSERT_NO_FATAL_FAILURE( 314 CreateChannel(kTestChannelName, &host_socket1_, &client_socket1_)); 315 ASSERT_NO_FATAL_FAILURE( 316 CreateChannel(kTestChannelName2, &host_socket2_, &client_socket2_)); 317 318 ConnectSockets(); 319 320 host_session_.fake_channel_factory().GetFakeChannel(kMuxChannelName)-> 321 set_next_write_error(net::ERR_FAILED); 322 host_session_.fake_channel_factory().GetFakeChannel(kMuxChannelName)-> 323 set_async_write(true); 324 325 scoped_refptr<net::IOBufferWithSize> buf = CreateTestBuffer(100); 326 327 MockSocketCallback cb1; 328 MockSocketCallback cb2; 329 330 EXPECT_CALL(cb1, OnDone(net::ERR_FAILED)) 331 .Times(AtMost(1)) 332 .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll)); 333 EXPECT_CALL(cb2, OnDone(net::ERR_FAILED)) 334 .Times(AtMost(1)) 335 .WillOnce(InvokeWithoutArgs(this, &ChannelMultiplexerTest::DeleteAll)); 336 337 EXPECT_EQ(net::ERR_IO_PENDING, 338 host_socket1_->Write(buf.get(), 339 buf->size(), 340 base::Bind(&MockSocketCallback::OnDone, 341 base::Unretained(&cb1)))); 342 EXPECT_EQ(net::ERR_IO_PENDING, 343 host_socket2_->Write(buf.get(), 344 buf->size(), 345 base::Bind(&MockSocketCallback::OnDone, 346 base::Unretained(&cb2)))); 347 348 base::RunLoop().RunUntilIdle(); 349 350 // Check that the sockets were destroyed. 351 EXPECT_FALSE(host_mux_.get()); 352} 353 354TEST_F(ChannelMultiplexerTest, SessionFail) { 355 host_session_.fake_channel_factory().set_asynchronous_create(true); 356 host_session_.fake_channel_factory().set_fail_create(true); 357 358 MockConnectCallback cb1; 359 MockConnectCallback cb2; 360 361 host_mux_->CreateChannel(kTestChannelName, base::Bind( 362 &MockConnectCallback::OnConnected, base::Unretained(&cb1))); 363 host_mux_->CreateChannel(kTestChannelName2, base::Bind( 364 &MockConnectCallback::OnConnected, base::Unretained(&cb2))); 365 366 EXPECT_CALL(cb1, OnConnectedPtr(NULL)) 367 .Times(AtMost(1)) 368 .WillOnce(InvokeWithoutArgs( 369 this, &ChannelMultiplexerTest::DeleteAfterSessionFail)); 370 EXPECT_CALL(cb2, OnConnectedPtr(_)) 371 .Times(0); 372 373 base::RunLoop().RunUntilIdle(); 374} 375 376} // namespace protocol 377} // namespace remoting 378