1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/tools/quic/quic_dispatcher.h"
6
7#include <string>
8
9#include "base/strings/string_piece.h"
10#include "net/quic/crypto/crypto_handshake.h"
11#include "net/quic/crypto/crypto_server_config.h"
12#include "net/quic/crypto/quic_random.h"
13#include "net/quic/quic_crypto_stream.h"
14#include "net/quic/test_tools/quic_test_utils.h"
15#include "net/tools/flip_server/epoll_server.h"
16#include "net/tools/quic/quic_time_wait_list_manager.h"
17#include "net/tools/quic/test_tools/quic_test_utils.h"
18#include "testing/gmock/include/gmock/gmock.h"
19#include "testing/gtest/include/gtest/gtest.h"
20
21using base::StringPiece;
22using net::EpollServer;
23using net::test::MockSession;
24using net::tools::test::MockConnection;
25using testing::_;
26using testing::DoAll;
27using testing::Invoke;
28using testing::InSequence;
29using testing::Return;
30using testing::WithoutArgs;
31
32namespace net {
33namespace tools {
34namespace test {
35class QuicDispatcherPeer {
36 public:
37  static void SetTimeWaitListManager(
38      QuicDispatcher* dispatcher,
39      QuicTimeWaitListManager* time_wait_list_manager) {
40    dispatcher->time_wait_list_manager_.reset(time_wait_list_manager);
41  }
42
43  static void SetWriteBlocked(QuicDispatcher* dispatcher) {
44    dispatcher->write_blocked_ = true;
45  }
46};
47
48namespace {
49
50class TestDispatcher : public QuicDispatcher {
51 public:
52  explicit TestDispatcher(const QuicConfig& config,
53                          const QuicCryptoServerConfig& crypto_config,
54                          EpollServer* eps)
55      : QuicDispatcher(config, crypto_config, 1, eps) {}
56
57  MOCK_METHOD4(CreateQuicSession, QuicSession*(
58      QuicGuid guid,
59      const IPEndPoint& client_address,
60      int fd,
61      EpollServer* eps));
62  using QuicDispatcher::write_blocked_list;
63};
64
65// A Connection class which unregisters the session from the dispatcher
66// when sending connection close.
67// It'd be slightly more realistic to do this from the Session but it would
68// involve a lot more mocking.
69class MockServerConnection : public MockConnection {
70 public:
71  MockServerConnection(QuicGuid guid,
72                       IPEndPoint address,
73                       int fd,
74                       EpollServer* eps,
75                       QuicDispatcher* dispatcher)
76      : MockConnection(guid, address, fd, eps, true),
77        dispatcher_(dispatcher) {
78  }
79  void UnregisterOnConnectionClose() {
80    LOG(ERROR) << "Unregistering " << guid();
81    dispatcher_->OnConnectionClose(guid(), QUIC_NO_ERROR);
82  }
83 private:
84  QuicDispatcher* dispatcher_;
85};
86
87QuicSession* CreateSession(QuicDispatcher* dispatcher,
88                           QuicGuid guid,
89                           const IPEndPoint& addr,
90                           MockSession** session,
91                           EpollServer* eps) {
92  MockServerConnection* connection =
93      new MockServerConnection(guid, addr, 0, eps, dispatcher);
94  *session = new MockSession(connection, true);
95  ON_CALL(*connection, SendConnectionClose(_)).WillByDefault(
96      WithoutArgs(Invoke(
97          connection, &MockServerConnection::UnregisterOnConnectionClose)));
98  EXPECT_CALL(*reinterpret_cast<MockConnection*>((*session)->connection()),
99              ProcessUdpPacket(_, addr, _));
100
101  return *session;
102}
103
104class QuicDispatcherTest : public ::testing::Test {
105 public:
106  QuicDispatcherTest()
107      : crypto_config_(QuicCryptoServerConfig::TESTING,
108                       QuicRandom::GetInstance()),
109        dispatcher_(config_, crypto_config_, &eps_),
110        session1_(NULL),
111        session2_(NULL) {
112  }
113
114  virtual ~QuicDispatcherTest() {}
115
116  MockConnection* connection1() {
117    return reinterpret_cast<MockConnection*>(session1_->connection());
118  }
119
120  MockConnection* connection2() {
121    return reinterpret_cast<MockConnection*>(session2_->connection());
122  }
123
124  void ProcessPacket(IPEndPoint addr,
125                     QuicGuid guid,
126                     const string& data) {
127    QuicEncryptedPacket packet(data.data(), data.length());
128    dispatcher_.ProcessPacket(IPEndPoint(), addr, guid, packet);
129  }
130
131  void ValidatePacket(const QuicEncryptedPacket& packet) {
132    EXPECT_TRUE(packet.AsStringPiece().find(data_) != StringPiece::npos);
133  }
134
135  IPAddressNumber Loopback4() {
136    net::IPAddressNumber addr;
137    CHECK(net::ParseIPLiteralToNumber("127.0.0.1", &addr));
138    return addr;
139  }
140
141  EpollServer eps_;
142  QuicConfig config_;
143  QuicCryptoServerConfig crypto_config_;
144  TestDispatcher dispatcher_;
145  MockSession* session1_;
146  MockSession* session2_;
147  string data_;
148};
149
150TEST_F(QuicDispatcherTest, ProcessPackets) {
151  IPEndPoint addr(Loopback4(), 1);
152
153  EXPECT_CALL(dispatcher_, CreateQuicSession(1, addr, _, &eps_))
154      .WillOnce(testing::Return(CreateSession(
155          &dispatcher_, 1, addr, &session1_, &eps_)));
156  ProcessPacket(addr, 1, "foo");
157
158  EXPECT_CALL(dispatcher_, CreateQuicSession(2, addr, _, &eps_))
159      .WillOnce(testing::Return(CreateSession(
160                    &dispatcher_, 2, addr, &session2_, &eps_)));
161  ProcessPacket(addr, 2, "bar");
162
163  data_ = "eep";
164  EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()),
165              ProcessUdpPacket(_, _, _)).Times(1).
166      WillOnce(testing::WithArgs<2>(Invoke(
167          this, &QuicDispatcherTest::ValidatePacket)));
168  ProcessPacket(addr, 1, "eep");
169}
170
171TEST_F(QuicDispatcherTest, Shutdown) {
172  IPEndPoint addr(Loopback4(), 1);
173
174  EXPECT_CALL(dispatcher_, CreateQuicSession(_, addr, _, &eps_))
175      .WillOnce(testing::Return(CreateSession(
176                    &dispatcher_, 1, addr, &session1_, &eps_)));
177
178  ProcessPacket(addr, 1, "foo");
179
180  EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()),
181              SendConnectionClose(QUIC_PEER_GOING_AWAY));
182
183  dispatcher_.Shutdown();
184}
185
186class MockTimeWaitListManager : public QuicTimeWaitListManager {
187 public:
188  MockTimeWaitListManager(QuicPacketWriter* writer,
189                          EpollServer* eps)
190      : QuicTimeWaitListManager(writer, eps) {
191  }
192
193  MOCK_METHOD4(ProcessPacket, void(const IPEndPoint& server_address,
194                                   const IPEndPoint& client_address,
195                                   QuicGuid guid,
196                                   const QuicEncryptedPacket& packet));
197};
198
199TEST_F(QuicDispatcherTest, TimeWaitListManager) {
200  MockTimeWaitListManager* time_wait_list_manager =
201      new MockTimeWaitListManager(&dispatcher_, &eps_);
202  // dispatcher takes the ownership of time_wait_list_manager.
203  QuicDispatcherPeer::SetTimeWaitListManager(&dispatcher_,
204                                             time_wait_list_manager);
205  // Create a new session.
206  IPEndPoint addr(Loopback4(), 1);
207  QuicGuid guid = 1;
208  EXPECT_CALL(dispatcher_, CreateQuicSession(guid, addr, _, &eps_))
209      .WillOnce(testing::Return(CreateSession(
210                    &dispatcher_, guid, addr, &session1_, &eps_)));
211  ProcessPacket(addr, guid, "foo");
212
213  // Close the connection by sending public reset packet.
214  QuicPublicResetPacket packet;
215  packet.public_header.guid = guid;
216  packet.public_header.reset_flag = true;
217  packet.public_header.version_flag = false;
218  packet.rejected_sequence_number = 19191;
219  packet.nonce_proof = 132232;
220  scoped_ptr<QuicEncryptedPacket> encrypted(
221      QuicFramer::BuildPublicResetPacket(packet));
222  EXPECT_CALL(*session1_, ConnectionClose(QUIC_PUBLIC_RESET, true)).Times(1)
223      .WillOnce(WithoutArgs(Invoke(
224          reinterpret_cast<MockServerConnection*>(session1_->connection()),
225          &MockServerConnection::UnregisterOnConnectionClose)));
226  EXPECT_CALL(*reinterpret_cast<MockConnection*>(session1_->connection()),
227              ProcessUdpPacket(_, _, _))
228      .WillOnce(Invoke(
229          reinterpret_cast<MockConnection*>(session1_->connection()),
230          &MockConnection::ReallyProcessUdpPacket));
231  dispatcher_.ProcessPacket(IPEndPoint(), addr, guid, *encrypted);
232  EXPECT_TRUE(time_wait_list_manager->IsGuidInTimeWait(guid));
233
234  // Dispatcher forwards subsequent packets for this guid to the time wait list
235  // manager.
236  EXPECT_CALL(*time_wait_list_manager, ProcessPacket(_, _, guid, _)).Times(1);
237  ProcessPacket(addr, guid, "foo");
238}
239
240class WriteBlockedListTest : public QuicDispatcherTest {
241 public:
242  virtual void SetUp() {
243    IPEndPoint addr(Loopback4(), 1);
244
245    EXPECT_CALL(dispatcher_, CreateQuicSession(_, addr, _, &eps_))
246        .WillOnce(testing::Return(CreateSession(
247                      &dispatcher_, 1, addr, &session1_, &eps_)));
248    ProcessPacket(addr, 1, "foo");
249
250    EXPECT_CALL(dispatcher_, CreateQuicSession(_, addr, _, &eps_))
251        .WillOnce(testing::Return(CreateSession(
252                      &dispatcher_, 2, addr, &session2_, &eps_)));
253    ProcessPacket(addr, 2, "bar");
254
255    blocked_list_ = dispatcher_.write_blocked_list();
256  }
257
258  virtual void TearDown() {
259    EXPECT_CALL(*connection1(), SendConnectionClose(QUIC_PEER_GOING_AWAY));
260    EXPECT_CALL(*connection2(), SendConnectionClose(QUIC_PEER_GOING_AWAY));
261    dispatcher_.Shutdown();
262  }
263
264  bool SetBlocked() {
265    QuicDispatcherPeer::SetWriteBlocked(&dispatcher_);
266    return true;
267  }
268
269 protected:
270  QuicDispatcher::WriteBlockedList* blocked_list_;
271};
272
273TEST_F(WriteBlockedListTest, BasicOnCanWrite) {
274  // No OnCanWrite calls because no connections are blocked.
275  dispatcher_.OnCanWrite();
276
277  // Register connection 1 for events, and make sure it's nofitied.
278  blocked_list_->AddBlockedObject(connection1());
279  EXPECT_CALL(*connection1(), OnCanWrite());
280  dispatcher_.OnCanWrite();
281
282  // It should get only one notification.
283  EXPECT_CALL(*connection1(), OnCanWrite()).Times(0);
284  EXPECT_FALSE(dispatcher_.OnCanWrite());
285}
286
287TEST_F(WriteBlockedListTest, OnCanWriteOrder) {
288  // Make sure we handle events in order.
289  InSequence s;
290  blocked_list_->AddBlockedObject(connection1());
291  blocked_list_->AddBlockedObject(connection2());
292  EXPECT_CALL(*connection1(), OnCanWrite());
293  EXPECT_CALL(*connection2(), OnCanWrite());
294  dispatcher_.OnCanWrite();
295
296  // Check the other ordering.
297  blocked_list_->AddBlockedObject(connection2());
298  blocked_list_->AddBlockedObject(connection1());
299  EXPECT_CALL(*connection2(), OnCanWrite());
300  EXPECT_CALL(*connection1(), OnCanWrite());
301  dispatcher_.OnCanWrite();
302}
303
304TEST_F(WriteBlockedListTest, OnCanWriteRemove) {
305  // Add and remove one connction.
306  blocked_list_->AddBlockedObject(connection1());
307  blocked_list_->RemoveBlockedObject(connection1());
308  EXPECT_CALL(*connection1(), OnCanWrite()).Times(0);
309  dispatcher_.OnCanWrite();
310
311  // Add and remove one connction and make sure it doesn't affect others.
312  blocked_list_->AddBlockedObject(connection1());
313  blocked_list_->AddBlockedObject(connection2());
314  blocked_list_->RemoveBlockedObject(connection1());
315  EXPECT_CALL(*connection2(), OnCanWrite());
316  dispatcher_.OnCanWrite();
317
318  // Add it, remove it, and add it back and make sure things are OK.
319  blocked_list_->AddBlockedObject(connection1());
320  blocked_list_->RemoveBlockedObject(connection1());
321  blocked_list_->AddBlockedObject(connection1());
322  EXPECT_CALL(*connection1(), OnCanWrite()).Times(1);
323  dispatcher_.OnCanWrite();
324}
325
326TEST_F(WriteBlockedListTest, DoubleAdd) {
327  // Make sure a double add does not necessitate a double remove.
328  blocked_list_->AddBlockedObject(connection1());
329  blocked_list_->AddBlockedObject(connection1());
330  blocked_list_->RemoveBlockedObject(connection1());
331  EXPECT_CALL(*connection1(), OnCanWrite()).Times(0);
332  dispatcher_.OnCanWrite();
333
334  // Make sure a double add does not result in two OnCanWrite calls.
335  blocked_list_->AddBlockedObject(connection1());
336  blocked_list_->AddBlockedObject(connection1());
337  EXPECT_CALL(*connection1(), OnCanWrite()).Times(1);
338  dispatcher_.OnCanWrite();
339}
340
341TEST_F(WriteBlockedListTest, OnCanWriteHandleBlock) {
342  // Finally make sure if we write block on a write call, we stop calling.
343  InSequence s;
344  blocked_list_->AddBlockedObject(connection1());
345  blocked_list_->AddBlockedObject(connection2());
346  EXPECT_CALL(*connection1(), OnCanWrite()).WillOnce(
347      Invoke(this, &WriteBlockedListTest::SetBlocked));
348  EXPECT_CALL(*connection2(), OnCanWrite()).Times(0);
349  dispatcher_.OnCanWrite();
350
351  // And we'll resume where we left off when we get another call.
352  EXPECT_CALL(*connection2(), OnCanWrite());
353  dispatcher_.OnCanWrite();
354}
355
356TEST_F(WriteBlockedListTest, LimitedWrites) {
357  // Make sure we call both writers.  The first will register for more writing
358  // but should not be immediately called due to limits.
359  InSequence s;
360  blocked_list_->AddBlockedObject(connection1());
361  blocked_list_->AddBlockedObject(connection2());
362  EXPECT_CALL(*connection1(), OnCanWrite()).WillOnce(Return(true));
363  EXPECT_CALL(*connection2(), OnCanWrite()).WillOnce(Return(false));
364  dispatcher_.OnCanWrite();
365
366  // Now call OnCanWrite again, and connection1 should get its second chance
367  EXPECT_CALL(*connection1(), OnCanWrite());
368  dispatcher_.OnCanWrite();
369}
370
371TEST_F(WriteBlockedListTest, TestWriteLimits) {
372  // Finally make sure if we write block on a write call, we stop calling.
373  InSequence s;
374  blocked_list_->AddBlockedObject(connection1());
375  blocked_list_->AddBlockedObject(connection2());
376  EXPECT_CALL(*connection1(), OnCanWrite()).WillOnce(
377      Invoke(this, &WriteBlockedListTest::SetBlocked));
378  EXPECT_CALL(*connection2(), OnCanWrite()).Times(0);
379  dispatcher_.OnCanWrite();
380
381  // And we'll resume where we left off when we get another call.
382  EXPECT_CALL(*connection2(), OnCanWrite());
383  dispatcher_.OnCanWrite();
384}
385
386
387}  // namespace
388}  // namespace test
389}  // namespace tools
390}  // namespace net
391