1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/tools/quic/quic_dispatcher.h"
6
7#include <string>
8
9#include "base/strings/string_piece.h"
10#include "net/quic/crypto/crypto_handshake.h"
11#include "net/quic/crypto/quic_crypto_server_config.h"
12#include "net/quic/crypto/quic_random.h"
13#include "net/quic/quic_crypto_stream.h"
14#include "net/quic/quic_flags.h"
15#include "net/quic/quic_utils.h"
16#include "net/quic/test_tools/quic_test_utils.h"
17#include "net/tools/epoll_server/epoll_server.h"
18#include "net/tools/quic/quic_packet_writer_wrapper.h"
19#include "net/tools/quic/quic_time_wait_list_manager.h"
20#include "net/tools/quic/test_tools/quic_dispatcher_peer.h"
21#include "net/tools/quic/test_tools/quic_test_utils.h"
22#include "testing/gmock/include/gmock/gmock.h"
23#include "testing/gtest/include/gtest/gtest.h"
24
25using base::StringPiece;
26using net::EpollServer;
27using net::test::ConstructEncryptedPacket;
28using net::test::MockSession;
29using net::test::ValueRestore;
30using net::tools::test::MockConnection;
31using std::make_pair;
32using testing::DoAll;
33using testing::InSequence;
34using testing::Invoke;
35using testing::WithoutArgs;
36using testing::_;
37
38namespace net {
39namespace tools {
40namespace test {
41namespace {
42
43class TestDispatcher : public QuicDispatcher {
44 public:
45  explicit TestDispatcher(const QuicConfig& config,
46                          const QuicCryptoServerConfig& crypto_config,
47                          EpollServer* eps)
48      : QuicDispatcher(config,
49                       crypto_config,
50                       QuicSupportedVersions(),
51                       new QuicDispatcher::DefaultPacketWriterFactory(),
52                       eps) {
53  }
54
55  MOCK_METHOD3(CreateQuicSession, QuicSession*(
56      QuicConnectionId connection_id,
57      const IPEndPoint& server_address,
58      const IPEndPoint& client_address));
59
60  using QuicDispatcher::current_server_address;
61  using QuicDispatcher::current_client_address;
62};
63
64// A Connection class which unregisters the session from the dispatcher
65// when sending connection close.
66// It'd be slightly more realistic to do this from the Session but it would
67// involve a lot more mocking.
68class MockServerConnection : public MockConnection {
69 public:
70  MockServerConnection(QuicConnectionId connection_id,
71                       QuicDispatcher* dispatcher)
72      : MockConnection(connection_id, true),
73        dispatcher_(dispatcher) {}
74
75  void UnregisterOnConnectionClosed() {
76    LOG(ERROR) << "Unregistering " << connection_id();
77    dispatcher_->OnConnectionClosed(connection_id(), QUIC_NO_ERROR);
78  }
79 private:
80  QuicDispatcher* dispatcher_;
81};
82
83QuicSession* CreateSession(QuicDispatcher* dispatcher,
84                           QuicConnectionId connection_id,
85                           const IPEndPoint& client_address,
86                           MockSession** session) {
87  MockServerConnection* connection =
88      new MockServerConnection(connection_id, dispatcher);
89  *session = new MockSession(connection);
90  ON_CALL(*connection, SendConnectionClose(_)).WillByDefault(
91      WithoutArgs(Invoke(
92          connection, &MockServerConnection::UnregisterOnConnectionClosed)));
93  EXPECT_CALL(*reinterpret_cast<MockConnection*>((*session)->connection()),
94              ProcessUdpPacket(_, client_address, _));
95
96  return *session;
97}
98
99class QuicDispatcherTest : public ::testing::Test {
100 public:
101  QuicDispatcherTest()
102      : crypto_config_(QuicCryptoServerConfig::TESTING,
103                       QuicRandom::GetInstance()),
104        dispatcher_(config_, crypto_config_, &eps_),
105        session1_(NULL),
106        session2_(NULL) {
107    dispatcher_.Initialize(1);
108  }
109
110  virtual ~QuicDispatcherTest() {}
111
112  MockConnection* connection1() {
113    return reinterpret_cast<MockConnection*>(session1_->connection());
114  }
115
116  MockConnection* connection2() {
117    return reinterpret_cast<MockConnection*>(session2_->connection());
118  }
119
120  void ProcessPacket(IPEndPoint client_address,
121                     QuicConnectionId connection_id,
122                     bool has_version_flag,
123                     const string& data) {
124    scoped_ptr<QuicEncryptedPacket> packet(ConstructEncryptedPacket(
125        connection_id, has_version_flag, false, 1, data));
126    data_ = string(packet->data(), packet->length());
127    dispatcher_.ProcessPacket(server_address_, client_address, *packet);
128  }
129
130  void ValidatePacket(const QuicEncryptedPacket& packet) {
131    EXPECT_EQ(data_.length(), packet.AsStringPiece().length());
132    EXPECT_EQ(data_, packet.AsStringPiece());
133  }
134
135  EpollServer eps_;
136  QuicConfig config_;
137  QuicCryptoServerConfig crypto_config_;
138  IPEndPoint server_address_;
139  TestDispatcher dispatcher_;
140  MockSession* session1_;
141  MockSession* session2_;
142  string data_;
143};
144
145TEST_F(QuicDispatcherTest, ProcessPackets) {
146  IPEndPoint client_address(net::test::Loopback4(), 1);
147  IPAddressNumber any4;
148  CHECK(net::ParseIPLiteralToNumber("0.0.0.0", &any4));
149  server_address_ = IPEndPoint(any4, 5);
150
151  EXPECT_CALL(dispatcher_, CreateQuicSession(1, _, client_address))
152      .WillOnce(testing::Return(CreateSession(
153          &dispatcher_, 1, client_address, &session1_)));
154  ProcessPacket(client_address, 1, true, "foo");
155  EXPECT_EQ(client_address, dispatcher_.current_client_address());
156  EXPECT_EQ(server_address_, dispatcher_.current_server_address());
157
158
159  EXPECT_CALL(dispatcher_, CreateQuicSession(2, _, client_address))
160      .WillOnce(testing::Return(CreateSession(
161                    &dispatcher_, 2, client_address, &session2_)));
162  ProcessPacket(client_address, 2, true, "bar");
163
164  EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()),
165              ProcessUdpPacket(_, _, _)).Times(1).
166      WillOnce(testing::WithArgs<2>(Invoke(
167          this, &QuicDispatcherTest::ValidatePacket)));
168  ProcessPacket(client_address, 1, false, "eep");
169}
170
171TEST_F(QuicDispatcherTest, Shutdown) {
172  IPEndPoint client_address(net::test::Loopback4(), 1);
173
174  EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, client_address))
175      .WillOnce(testing::Return(CreateSession(
176                    &dispatcher_, 1, client_address, &session1_)));
177
178  ProcessPacket(client_address, 1, true, "foo");
179
180  EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()),
181              SendConnectionClose(QUIC_PEER_GOING_AWAY));
182
183  dispatcher_.Shutdown();
184}
185
186class MockTimeWaitListManager : public QuicTimeWaitListManager {
187 public:
188  MockTimeWaitListManager(QuicPacketWriter* writer,
189                          QuicServerSessionVisitor* visitor,
190                          EpollServer* eps)
191      : QuicTimeWaitListManager(writer, visitor, eps, QuicSupportedVersions()) {
192  }
193
194  MOCK_METHOD5(ProcessPacket, void(const IPEndPoint& server_address,
195                                   const IPEndPoint& client_address,
196                                   QuicConnectionId connection_id,
197                                   QuicPacketSequenceNumber sequence_number,
198                                   const QuicEncryptedPacket& packet));
199};
200
201TEST_F(QuicDispatcherTest, TimeWaitListManager) {
202  MockTimeWaitListManager* time_wait_list_manager =
203      new MockTimeWaitListManager(
204          QuicDispatcherPeer::GetWriter(&dispatcher_), &dispatcher_, &eps_);
205  // dispatcher takes the ownership of time_wait_list_manager.
206  QuicDispatcherPeer::SetTimeWaitListManager(&dispatcher_,
207                                             time_wait_list_manager);
208  // Create a new session.
209  IPEndPoint client_address(net::test::Loopback4(), 1);
210  QuicConnectionId connection_id = 1;
211  EXPECT_CALL(dispatcher_, CreateQuicSession(connection_id, _, client_address))
212      .WillOnce(testing::Return(CreateSession(
213                    &dispatcher_, connection_id, client_address, &session1_)));
214  ProcessPacket(client_address, connection_id, true, "foo");
215
216  // Close the connection by sending public reset packet.
217  QuicPublicResetPacket packet;
218  packet.public_header.connection_id = connection_id;
219  packet.public_header.reset_flag = true;
220  packet.public_header.version_flag = false;
221  packet.rejected_sequence_number = 19191;
222  packet.nonce_proof = 132232;
223  scoped_ptr<QuicEncryptedPacket> encrypted(
224      QuicFramer::BuildPublicResetPacket(packet));
225  EXPECT_CALL(*session1_, OnConnectionClosed(QUIC_PUBLIC_RESET, true)).Times(1)
226      .WillOnce(WithoutArgs(Invoke(
227          reinterpret_cast<MockServerConnection*>(session1_->connection()),
228          &MockServerConnection::UnregisterOnConnectionClosed)));
229  EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()),
230              ProcessUdpPacket(_, _, _))
231      .WillOnce(Invoke(
232          reinterpret_cast<MockConnection*>(session1_->connection()),
233          &MockConnection::ReallyProcessUdpPacket));
234  dispatcher_.ProcessPacket(IPEndPoint(), client_address, *encrypted);
235  EXPECT_TRUE(time_wait_list_manager->IsConnectionIdInTimeWait(connection_id));
236
237  // Dispatcher forwards subsequent packets for this connection_id to the time
238  // wait list manager.
239  EXPECT_CALL(*time_wait_list_manager,
240              ProcessPacket(_, _, connection_id, _, _)).Times(1);
241  ProcessPacket(client_address, connection_id, true, "foo");
242}
243
244TEST_F(QuicDispatcherTest, StrayPacketToTimeWaitListManager) {
245  MockTimeWaitListManager* time_wait_list_manager =
246      new MockTimeWaitListManager(
247          QuicDispatcherPeer::GetWriter(&dispatcher_), &dispatcher_, &eps_);
248  // dispatcher takes the ownership of time_wait_list_manager.
249  QuicDispatcherPeer::SetTimeWaitListManager(&dispatcher_,
250                                             time_wait_list_manager);
251
252  IPEndPoint client_address(net::test::Loopback4(), 1);
253  QuicConnectionId connection_id = 1;
254  // Dispatcher forwards all packets for this connection_id to the time wait
255  // list manager.
256  EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, _)).Times(0);
257  EXPECT_CALL(*time_wait_list_manager,
258              ProcessPacket(_, _, connection_id, _, _)).Times(1);
259  string data = "foo";
260  ProcessPacket(client_address, connection_id, false, "foo");
261}
262
263class BlockingWriter : public QuicPacketWriterWrapper {
264 public:
265  BlockingWriter() : write_blocked_(false) {}
266
267  virtual bool IsWriteBlocked() const OVERRIDE { return write_blocked_; }
268  virtual void SetWritable() OVERRIDE { write_blocked_ = false; }
269
270  virtual WriteResult WritePacket(
271      const char* buffer,
272      size_t buf_len,
273      const IPAddressNumber& self_client_address,
274      const IPEndPoint& peer_client_address) OVERRIDE {
275    // It would be quite possible to actually implement this method here with
276    // the fake blocked status, but it would be significantly more work in
277    // Chromium, and since it's not called anyway, don't bother.
278    LOG(DFATAL) << "Not supported";
279    return WriteResult();
280  }
281
282  bool write_blocked_;
283};
284
285class QuicDispatcherWriteBlockedListTest : public QuicDispatcherTest {
286 public:
287  virtual void SetUp() {
288    writer_ = new BlockingWriter;
289    QuicDispatcherPeer::SetPacketWriterFactory(&dispatcher_,
290                                               new TestWriterFactory());
291    QuicDispatcherPeer::UseWriter(&dispatcher_, writer_);
292
293    IPEndPoint client_address(net::test::Loopback4(), 1);
294
295    EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, client_address))
296        .WillOnce(testing::Return(CreateSession(
297                      &dispatcher_, 1, client_address, &session1_)));
298    ProcessPacket(client_address, 1, true, "foo");
299
300    EXPECT_CALL(dispatcher_, CreateQuicSession(_, _, client_address))
301        .WillOnce(testing::Return(CreateSession(
302                      &dispatcher_, 2, client_address, &session2_)));
303    ProcessPacket(client_address, 2, true, "bar");
304
305    blocked_list_ = QuicDispatcherPeer::GetWriteBlockedList(&dispatcher_);
306  }
307
308  virtual void TearDown() {
309    EXPECT_CALL(*connection1(), SendConnectionClose(QUIC_PEER_GOING_AWAY));
310    EXPECT_CALL(*connection2(), SendConnectionClose(QUIC_PEER_GOING_AWAY));
311    dispatcher_.Shutdown();
312  }
313
314  void SetBlocked() {
315    writer_->write_blocked_ = true;
316  }
317
318  void BlockConnection2() {
319    writer_->write_blocked_ = true;
320    dispatcher_.OnWriteBlocked(connection2());
321  }
322
323 protected:
324  BlockingWriter* writer_;
325  QuicDispatcher::WriteBlockedList* blocked_list_;
326};
327
328TEST_F(QuicDispatcherWriteBlockedListTest, BasicOnCanWrite) {
329  // No OnCanWrite calls because no connections are blocked.
330  dispatcher_.OnCanWrite();
331
332  // Register connection 1 for events, and make sure it's notified.
333  SetBlocked();
334  dispatcher_.OnWriteBlocked(connection1());
335  EXPECT_CALL(*connection1(), OnCanWrite());
336  dispatcher_.OnCanWrite();
337
338  // It should get only one notification.
339  EXPECT_CALL(*connection1(), OnCanWrite()).Times(0);
340  dispatcher_.OnCanWrite();
341  EXPECT_FALSE(dispatcher_.HasPendingWrites());
342}
343
344TEST_F(QuicDispatcherWriteBlockedListTest, OnCanWriteOrder) {
345  // Make sure we handle events in order.
346  InSequence s;
347  SetBlocked();
348  dispatcher_.OnWriteBlocked(connection1());
349  dispatcher_.OnWriteBlocked(connection2());
350  EXPECT_CALL(*connection1(), OnCanWrite());
351  EXPECT_CALL(*connection2(), OnCanWrite());
352  dispatcher_.OnCanWrite();
353
354  // Check the other ordering.
355  SetBlocked();
356  dispatcher_.OnWriteBlocked(connection2());
357  dispatcher_.OnWriteBlocked(connection1());
358  EXPECT_CALL(*connection2(), OnCanWrite());
359  EXPECT_CALL(*connection1(), OnCanWrite());
360  dispatcher_.OnCanWrite();
361}
362
363TEST_F(QuicDispatcherWriteBlockedListTest, OnCanWriteRemove) {
364  // Add and remove one connction.
365  SetBlocked();
366  dispatcher_.OnWriteBlocked(connection1());
367  blocked_list_->erase(connection1());
368  EXPECT_CALL(*connection1(), OnCanWrite()).Times(0);
369  dispatcher_.OnCanWrite();
370
371  // Add and remove one connction and make sure it doesn't affect others.
372  SetBlocked();
373  dispatcher_.OnWriteBlocked(connection1());
374  dispatcher_.OnWriteBlocked(connection2());
375  blocked_list_->erase(connection1());
376  EXPECT_CALL(*connection2(), OnCanWrite());
377  dispatcher_.OnCanWrite();
378
379  // Add it, remove it, and add it back and make sure things are OK.
380  SetBlocked();
381  dispatcher_.OnWriteBlocked(connection1());
382  blocked_list_->erase(connection1());
383  dispatcher_.OnWriteBlocked(connection1());
384  EXPECT_CALL(*connection1(), OnCanWrite()).Times(1);
385  dispatcher_.OnCanWrite();
386}
387
388TEST_F(QuicDispatcherWriteBlockedListTest, DoubleAdd) {
389  // Make sure a double add does not necessitate a double remove.
390  SetBlocked();
391  dispatcher_.OnWriteBlocked(connection1());
392  dispatcher_.OnWriteBlocked(connection1());
393  blocked_list_->erase(connection1());
394  EXPECT_CALL(*connection1(), OnCanWrite()).Times(0);
395  dispatcher_.OnCanWrite();
396
397  // Make sure a double add does not result in two OnCanWrite calls.
398  SetBlocked();
399  dispatcher_.OnWriteBlocked(connection1());
400  dispatcher_.OnWriteBlocked(connection1());
401  EXPECT_CALL(*connection1(), OnCanWrite()).Times(1);
402  dispatcher_.OnCanWrite();
403}
404
405TEST_F(QuicDispatcherWriteBlockedListTest, OnCanWriteHandleBlock) {
406  // Finally make sure if we write block on a write call, we stop calling.
407  InSequence s;
408  SetBlocked();
409  dispatcher_.OnWriteBlocked(connection1());
410  dispatcher_.OnWriteBlocked(connection2());
411  EXPECT_CALL(*connection1(), OnCanWrite()).WillOnce(
412      Invoke(this, &QuicDispatcherWriteBlockedListTest::SetBlocked));
413  EXPECT_CALL(*connection2(), OnCanWrite()).Times(0);
414  dispatcher_.OnCanWrite();
415
416  // And we'll resume where we left off when we get another call.
417  EXPECT_CALL(*connection2(), OnCanWrite());
418  dispatcher_.OnCanWrite();
419}
420
421TEST_F(QuicDispatcherWriteBlockedListTest, LimitedWrites) {
422  // Make sure we call both writers.  The first will register for more writing
423  // but should not be immediately called due to limits.
424  InSequence s;
425  SetBlocked();
426  dispatcher_.OnWriteBlocked(connection1());
427  dispatcher_.OnWriteBlocked(connection2());
428  EXPECT_CALL(*connection1(), OnCanWrite());
429  EXPECT_CALL(*connection2(), OnCanWrite()).WillOnce(
430      Invoke(this, &QuicDispatcherWriteBlockedListTest::BlockConnection2));
431  dispatcher_.OnCanWrite();
432  EXPECT_TRUE(dispatcher_.HasPendingWrites());
433
434  // Now call OnCanWrite again, and connection1 should get its second chance
435  EXPECT_CALL(*connection2(), OnCanWrite());
436  dispatcher_.OnCanWrite();
437  EXPECT_FALSE(dispatcher_.HasPendingWrites());
438}
439
440TEST_F(QuicDispatcherWriteBlockedListTest, TestWriteLimits) {
441  // Finally make sure if we write block on a write call, we stop calling.
442  InSequence s;
443  SetBlocked();
444  dispatcher_.OnWriteBlocked(connection1());
445  dispatcher_.OnWriteBlocked(connection2());
446  EXPECT_CALL(*connection1(), OnCanWrite()).WillOnce(
447      Invoke(this, &QuicDispatcherWriteBlockedListTest::SetBlocked));
448  EXPECT_CALL(*connection2(), OnCanWrite()).Times(0);
449  dispatcher_.OnCanWrite();
450  EXPECT_TRUE(dispatcher_.HasPendingWrites());
451
452  // And we'll resume where we left off when we get another call.
453  EXPECT_CALL(*connection2(), OnCanWrite());
454  dispatcher_.OnCanWrite();
455  EXPECT_FALSE(dispatcher_.HasPendingWrites());
456}
457
458}  // namespace
459}  // namespace test
460}  // namespace tools
461}  // namespace net
462