1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/tools/quic/quic_dispatcher.h" 6 7#include <string> 8 9#include "base/strings/string_piece.h" 10#include "net/quic/crypto/crypto_handshake.h" 11#include "net/quic/crypto/quic_crypto_server_config.h" 12#include "net/quic/crypto/quic_random.h" 13#include "net/quic/quic_crypto_stream.h" 14#include "net/quic/quic_flags.h" 15#include "net/quic/quic_utils.h" 16#include "net/quic/test_tools/quic_test_utils.h" 17#include "net/tools/epoll_server/epoll_server.h" 18#include "net/tools/quic/quic_packet_writer_wrapper.h" 19#include "net/tools/quic/quic_time_wait_list_manager.h" 20#include "net/tools/quic/test_tools/quic_dispatcher_peer.h" 21#include "net/tools/quic/test_tools/quic_test_utils.h" 22#include "testing/gmock/include/gmock/gmock.h" 23#include "testing/gtest/include/gtest/gtest.h" 24 25using base::StringPiece; 26using net::EpollServer; 27using net::test::ConstructEncryptedPacket; 28using net::test::MockSession; 29using net::test::ValueRestore; 30using net::tools::test::MockConnection; 31using std::make_pair; 32using testing::DoAll; 33using testing::InSequence; 34using testing::Invoke; 35using testing::WithoutArgs; 36using testing::_; 37 38namespace net { 39namespace tools { 40namespace test { 41namespace { 42 43class TestDispatcher : public QuicDispatcher { 44 public: 45 explicit TestDispatcher(const QuicConfig& config, 46 const QuicCryptoServerConfig& crypto_config, 47 EpollServer* eps) 48 : QuicDispatcher(config, 49 crypto_config, 50 QuicSupportedVersions(), 51 new QuicDispatcher::DefaultPacketWriterFactory(), 52 eps) { 53 } 54 55 MOCK_METHOD3(CreateQuicSession, QuicSession*( 56 QuicConnectionId connection_id, 57 const IPEndPoint& server_address, 58 const IPEndPoint& client_address)); 59 60 using QuicDispatcher::current_server_address; 61 using QuicDispatcher::current_client_address; 62}; 63 64// A Connection class which unregisters the session from the dispatcher 65// when sending connection close. 66// It'd be slightly more realistic to do this from the Session but it would 67// involve a lot more mocking. 68class MockServerConnection : public MockConnection { 69 public: 70 MockServerConnection(QuicConnectionId connection_id, 71 QuicDispatcher* dispatcher) 72 : MockConnection(connection_id, true), 73 dispatcher_(dispatcher) {} 74 75 void UnregisterOnConnectionClosed() { 76 LOG(ERROR) << "Unregistering " << connection_id(); 77 dispatcher_->OnConnectionClosed(connection_id(), QUIC_NO_ERROR); 78 } 79 private: 80 QuicDispatcher* dispatcher_; 81}; 82 83QuicSession* CreateSession(QuicDispatcher* dispatcher, 84 QuicConnectionId connection_id, 85 const IPEndPoint& client_address, 86 MockSession** session) { 87 MockServerConnection* connection = 88 new MockServerConnection(connection_id, dispatcher); 89 *session = new MockSession(connection); 90 ON_CALL(*connection, SendConnectionClose(_)).WillByDefault( 91 WithoutArgs(Invoke( 92 connection, &MockServerConnection::UnregisterOnConnectionClosed))); 93 EXPECT_CALL(*reinterpret_cast<MockConnection*>((*session)->connection()), 94 ProcessUdpPacket(_, client_address, _)); 95 96 return *session; 97} 98 99class QuicDispatcherTest : public ::testing::Test { 100 public: 101 QuicDispatcherTest() 102 : crypto_config_(QuicCryptoServerConfig::TESTING, 103 QuicRandom::GetInstance()), 104 dispatcher_(config_, crypto_config_, &eps_), 105 session1_(NULL), 106 session2_(NULL) { 107 dispatcher_.Initialize(1); 108 } 109 110 virtual ~QuicDispatcherTest() {} 111 112 MockConnection* connection1() { 113 return reinterpret_cast<MockConnection*>(session1_->connection()); 114 } 115 116 MockConnection* connection2() { 117 return reinterpret_cast<MockConnection*>(session2_->connection()); 118 } 119 120 void ProcessPacket(IPEndPoint client_address, 121 QuicConnectionId connection_id, 122 bool has_version_flag, 123 const string& data) { 124 scoped_ptr<QuicEncryptedPacket> packet(ConstructEncryptedPacket( 125 connection_id, has_version_flag, false, 1, data)); 126 data_ = string(packet->data(), packet->length()); 127 dispatcher_.ProcessPacket(server_address_, client_address, *packet); 128 } 129 130 void ValidatePacket(const QuicEncryptedPacket& packet) { 131 EXPECT_EQ(data_.length(), packet.AsStringPiece().length()); 132 EXPECT_EQ(data_, packet.AsStringPiece()); 133 } 134 135 EpollServer eps_; 136 QuicConfig config_; 137 QuicCryptoServerConfig crypto_config_; 138 IPEndPoint server_address_; 139 TestDispatcher dispatcher_; 140 MockSession* session1_; 141 MockSession* session2_; 142 string data_; 143}; 144 145TEST_F(QuicDispatcherTest, ProcessPackets) { 146 IPEndPoint client_address(net::test::Loopback4(), 1); 147 IPAddressNumber any4; 148 CHECK(net::ParseIPLiteralToNumber("0.0.0.0", &any4)); 149 server_address_ = IPEndPoint(any4, 5); 150 151 EXPECT_CALL(dispatcher_, CreateQuicSession(1, _, client_address)) 152 .WillOnce(testing::Return(CreateSession( 153 &dispatcher_, 1, client_address, &session1_))); 154 ProcessPacket(client_address, 1, true, "foo"); 155 EXPECT_EQ(client_address, dispatcher_.current_client_address()); 156 EXPECT_EQ(server_address_, dispatcher_.current_server_address()); 157 158 159 EXPECT_CALL(dispatcher_, CreateQuicSession(2, _, client_address)) 160 .WillOnce(testing::Return(CreateSession( 161 &dispatcher_, 2, client_address, &session2_))); 162 ProcessPacket(client_address, 2, true, "bar"); 163 164 EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()), 165 ProcessUdpPacket(_, _, _)).Times(1). 166 WillOnce(testing::WithArgs<2>(Invoke( 167 this, &QuicDispatcherTest::ValidatePacket))); 168 ProcessPacket(client_address, 1, false, "eep"); 169} 170 171TEST_F(QuicDispatcherTest, Shutdown) { 172 IPEndPoint client_address(net::test::Loopback4(), 1); 173 174 EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, client_address)) 175 .WillOnce(testing::Return(CreateSession( 176 &dispatcher_, 1, client_address, &session1_))); 177 178 ProcessPacket(client_address, 1, true, "foo"); 179 180 EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()), 181 SendConnectionClose(QUIC_PEER_GOING_AWAY)); 182 183 dispatcher_.Shutdown(); 184} 185 186class MockTimeWaitListManager : public QuicTimeWaitListManager { 187 public: 188 MockTimeWaitListManager(QuicPacketWriter* writer, 189 QuicServerSessionVisitor* visitor, 190 EpollServer* eps) 191 : QuicTimeWaitListManager(writer, visitor, eps, QuicSupportedVersions()) { 192 } 193 194 MOCK_METHOD5(ProcessPacket, void(const IPEndPoint& server_address, 195 const IPEndPoint& client_address, 196 QuicConnectionId connection_id, 197 QuicPacketSequenceNumber sequence_number, 198 const QuicEncryptedPacket& packet)); 199}; 200 201TEST_F(QuicDispatcherTest, TimeWaitListManager) { 202 MockTimeWaitListManager* time_wait_list_manager = 203 new MockTimeWaitListManager( 204 QuicDispatcherPeer::GetWriter(&dispatcher_), &dispatcher_, &eps_); 205 // dispatcher takes the ownership of time_wait_list_manager. 206 QuicDispatcherPeer::SetTimeWaitListManager(&dispatcher_, 207 time_wait_list_manager); 208 // Create a new session. 209 IPEndPoint client_address(net::test::Loopback4(), 1); 210 QuicConnectionId connection_id = 1; 211 EXPECT_CALL(dispatcher_, CreateQuicSession(connection_id, _, client_address)) 212 .WillOnce(testing::Return(CreateSession( 213 &dispatcher_, connection_id, client_address, &session1_))); 214 ProcessPacket(client_address, connection_id, true, "foo"); 215 216 // Close the connection by sending public reset packet. 217 QuicPublicResetPacket packet; 218 packet.public_header.connection_id = connection_id; 219 packet.public_header.reset_flag = true; 220 packet.public_header.version_flag = false; 221 packet.rejected_sequence_number = 19191; 222 packet.nonce_proof = 132232; 223 scoped_ptr<QuicEncryptedPacket> encrypted( 224 QuicFramer::BuildPublicResetPacket(packet)); 225 EXPECT_CALL(*session1_, OnConnectionClosed(QUIC_PUBLIC_RESET, true)).Times(1) 226 .WillOnce(WithoutArgs(Invoke( 227 reinterpret_cast<MockServerConnection*>(session1_->connection()), 228 &MockServerConnection::UnregisterOnConnectionClosed))); 229 EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()), 230 ProcessUdpPacket(_, _, _)) 231 .WillOnce(Invoke( 232 reinterpret_cast<MockConnection*>(session1_->connection()), 233 &MockConnection::ReallyProcessUdpPacket)); 234 dispatcher_.ProcessPacket(IPEndPoint(), client_address, *encrypted); 235 EXPECT_TRUE(time_wait_list_manager->IsConnectionIdInTimeWait(connection_id)); 236 237 // Dispatcher forwards subsequent packets for this connection_id to the time 238 // wait list manager. 239 EXPECT_CALL(*time_wait_list_manager, 240 ProcessPacket(_, _, connection_id, _, _)).Times(1); 241 ProcessPacket(client_address, connection_id, true, "foo"); 242} 243 244TEST_F(QuicDispatcherTest, StrayPacketToTimeWaitListManager) { 245 MockTimeWaitListManager* time_wait_list_manager = 246 new MockTimeWaitListManager( 247 QuicDispatcherPeer::GetWriter(&dispatcher_), &dispatcher_, &eps_); 248 // dispatcher takes the ownership of time_wait_list_manager. 249 QuicDispatcherPeer::SetTimeWaitListManager(&dispatcher_, 250 time_wait_list_manager); 251 252 IPEndPoint client_address(net::test::Loopback4(), 1); 253 QuicConnectionId connection_id = 1; 254 // Dispatcher forwards all packets for this connection_id to the time wait 255 // list manager. 256 EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, _)).Times(0); 257 EXPECT_CALL(*time_wait_list_manager, 258 ProcessPacket(_, _, connection_id, _, _)).Times(1); 259 string data = "foo"; 260 ProcessPacket(client_address, connection_id, false, "foo"); 261} 262 263class BlockingWriter : public QuicPacketWriterWrapper { 264 public: 265 BlockingWriter() : write_blocked_(false) {} 266 267 virtual bool IsWriteBlocked() const OVERRIDE { return write_blocked_; } 268 virtual void SetWritable() OVERRIDE { write_blocked_ = false; } 269 270 virtual WriteResult WritePacket( 271 const char* buffer, 272 size_t buf_len, 273 const IPAddressNumber& self_client_address, 274 const IPEndPoint& peer_client_address) OVERRIDE { 275 // It would be quite possible to actually implement this method here with 276 // the fake blocked status, but it would be significantly more work in 277 // Chromium, and since it's not called anyway, don't bother. 278 LOG(DFATAL) << "Not supported"; 279 return WriteResult(); 280 } 281 282 bool write_blocked_; 283}; 284 285class QuicDispatcherWriteBlockedListTest : public QuicDispatcherTest { 286 public: 287 virtual void SetUp() { 288 writer_ = new BlockingWriter; 289 QuicDispatcherPeer::SetPacketWriterFactory(&dispatcher_, 290 new TestWriterFactory()); 291 QuicDispatcherPeer::UseWriter(&dispatcher_, writer_); 292 293 IPEndPoint client_address(net::test::Loopback4(), 1); 294 295 EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, client_address)) 296 .WillOnce(testing::Return(CreateSession( 297 &dispatcher_, 1, client_address, &session1_))); 298 ProcessPacket(client_address, 1, true, "foo"); 299 300 EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, client_address)) 301 .WillOnce(testing::Return(CreateSession( 302 &dispatcher_, 2, client_address, &session2_))); 303 ProcessPacket(client_address, 2, true, "bar"); 304 305 blocked_list_ = QuicDispatcherPeer::GetWriteBlockedList(&dispatcher_); 306 } 307 308 virtual void TearDown() { 309 EXPECT_CALL(*connection1(), SendConnectionClose(QUIC_PEER_GOING_AWAY)); 310 EXPECT_CALL(*connection2(), SendConnectionClose(QUIC_PEER_GOING_AWAY)); 311 dispatcher_.Shutdown(); 312 } 313 314 void SetBlocked() { 315 writer_->write_blocked_ = true; 316 } 317 318 void BlockConnection2() { 319 writer_->write_blocked_ = true; 320 dispatcher_.OnWriteBlocked(connection2()); 321 } 322 323 protected: 324 BlockingWriter* writer_; 325 QuicDispatcher::WriteBlockedList* blocked_list_; 326}; 327 328TEST_F(QuicDispatcherWriteBlockedListTest, BasicOnCanWrite) { 329 // No OnCanWrite calls because no connections are blocked. 330 dispatcher_.OnCanWrite(); 331 332 // Register connection 1 for events, and make sure it's notified. 333 SetBlocked(); 334 dispatcher_.OnWriteBlocked(connection1()); 335 EXPECT_CALL(*connection1(), OnCanWrite()); 336 dispatcher_.OnCanWrite(); 337 338 // It should get only one notification. 339 EXPECT_CALL(*connection1(), OnCanWrite()).Times(0); 340 dispatcher_.OnCanWrite(); 341 EXPECT_FALSE(dispatcher_.HasPendingWrites()); 342} 343 344TEST_F(QuicDispatcherWriteBlockedListTest, OnCanWriteOrder) { 345 // Make sure we handle events in order. 346 InSequence s; 347 SetBlocked(); 348 dispatcher_.OnWriteBlocked(connection1()); 349 dispatcher_.OnWriteBlocked(connection2()); 350 EXPECT_CALL(*connection1(), OnCanWrite()); 351 EXPECT_CALL(*connection2(), OnCanWrite()); 352 dispatcher_.OnCanWrite(); 353 354 // Check the other ordering. 355 SetBlocked(); 356 dispatcher_.OnWriteBlocked(connection2()); 357 dispatcher_.OnWriteBlocked(connection1()); 358 EXPECT_CALL(*connection2(), OnCanWrite()); 359 EXPECT_CALL(*connection1(), OnCanWrite()); 360 dispatcher_.OnCanWrite(); 361} 362 363TEST_F(QuicDispatcherWriteBlockedListTest, OnCanWriteRemove) { 364 // Add and remove one connction. 365 SetBlocked(); 366 dispatcher_.OnWriteBlocked(connection1()); 367 blocked_list_->erase(connection1()); 368 EXPECT_CALL(*connection1(), OnCanWrite()).Times(0); 369 dispatcher_.OnCanWrite(); 370 371 // Add and remove one connction and make sure it doesn't affect others. 372 SetBlocked(); 373 dispatcher_.OnWriteBlocked(connection1()); 374 dispatcher_.OnWriteBlocked(connection2()); 375 blocked_list_->erase(connection1()); 376 EXPECT_CALL(*connection2(), OnCanWrite()); 377 dispatcher_.OnCanWrite(); 378 379 // Add it, remove it, and add it back and make sure things are OK. 380 SetBlocked(); 381 dispatcher_.OnWriteBlocked(connection1()); 382 blocked_list_->erase(connection1()); 383 dispatcher_.OnWriteBlocked(connection1()); 384 EXPECT_CALL(*connection1(), OnCanWrite()).Times(1); 385 dispatcher_.OnCanWrite(); 386} 387 388TEST_F(QuicDispatcherWriteBlockedListTest, DoubleAdd) { 389 // Make sure a double add does not necessitate a double remove. 390 SetBlocked(); 391 dispatcher_.OnWriteBlocked(connection1()); 392 dispatcher_.OnWriteBlocked(connection1()); 393 blocked_list_->erase(connection1()); 394 EXPECT_CALL(*connection1(), OnCanWrite()).Times(0); 395 dispatcher_.OnCanWrite(); 396 397 // Make sure a double add does not result in two OnCanWrite calls. 398 SetBlocked(); 399 dispatcher_.OnWriteBlocked(connection1()); 400 dispatcher_.OnWriteBlocked(connection1()); 401 EXPECT_CALL(*connection1(), OnCanWrite()).Times(1); 402 dispatcher_.OnCanWrite(); 403} 404 405TEST_F(QuicDispatcherWriteBlockedListTest, OnCanWriteHandleBlock) { 406 // Finally make sure if we write block on a write call, we stop calling. 407 InSequence s; 408 SetBlocked(); 409 dispatcher_.OnWriteBlocked(connection1()); 410 dispatcher_.OnWriteBlocked(connection2()); 411 EXPECT_CALL(*connection1(), OnCanWrite()).WillOnce( 412 Invoke(this, &QuicDispatcherWriteBlockedListTest::SetBlocked)); 413 EXPECT_CALL(*connection2(), OnCanWrite()).Times(0); 414 dispatcher_.OnCanWrite(); 415 416 // And we'll resume where we left off when we get another call. 417 EXPECT_CALL(*connection2(), OnCanWrite()); 418 dispatcher_.OnCanWrite(); 419} 420 421TEST_F(QuicDispatcherWriteBlockedListTest, LimitedWrites) { 422 // Make sure we call both writers. The first will register for more writing 423 // but should not be immediately called due to limits. 424 InSequence s; 425 SetBlocked(); 426 dispatcher_.OnWriteBlocked(connection1()); 427 dispatcher_.OnWriteBlocked(connection2()); 428 EXPECT_CALL(*connection1(), OnCanWrite()); 429 EXPECT_CALL(*connection2(), OnCanWrite()).WillOnce( 430 Invoke(this, &QuicDispatcherWriteBlockedListTest::BlockConnection2)); 431 dispatcher_.OnCanWrite(); 432 EXPECT_TRUE(dispatcher_.HasPendingWrites()); 433 434 // Now call OnCanWrite again, and connection1 should get its second chance 435 EXPECT_CALL(*connection2(), OnCanWrite()); 436 dispatcher_.OnCanWrite(); 437 EXPECT_FALSE(dispatcher_.HasPendingWrites()); 438} 439 440TEST_F(QuicDispatcherWriteBlockedListTest, TestWriteLimits) { 441 // Finally make sure if we write block on a write call, we stop calling. 442 InSequence s; 443 SetBlocked(); 444 dispatcher_.OnWriteBlocked(connection1()); 445 dispatcher_.OnWriteBlocked(connection2()); 446 EXPECT_CALL(*connection1(), OnCanWrite()).WillOnce( 447 Invoke(this, &QuicDispatcherWriteBlockedListTest::SetBlocked)); 448 EXPECT_CALL(*connection2(), OnCanWrite()).Times(0); 449 dispatcher_.OnCanWrite(); 450 EXPECT_TRUE(dispatcher_.HasPendingWrites()); 451 452 // And we'll resume where we left off when we get another call. 453 EXPECT_CALL(*connection2(), OnCanWrite()); 454 dispatcher_.OnCanWrite(); 455 EXPECT_FALSE(dispatcher_.HasPendingWrites()); 456} 457 458} // namespace 459} // namespace test 460} // namespace tools 461} // namespace net 462