1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/quic/test_tools/quic_test_utils.h"
6
7#include "base/sha1.h"
8#include "base/stl_util.h"
9#include "base/strings/string_number_conversions.h"
10#include "net/quic/crypto/crypto_framer.h"
11#include "net/quic/crypto/crypto_handshake.h"
12#include "net/quic/crypto/crypto_utils.h"
13#include "net/quic/crypto/null_encrypter.h"
14#include "net/quic/crypto/quic_decrypter.h"
15#include "net/quic/crypto/quic_encrypter.h"
16#include "net/quic/quic_framer.h"
17#include "net/quic/quic_packet_creator.h"
18#include "net/quic/quic_utils.h"
19#include "net/quic/test_tools/quic_connection_peer.h"
20#include "net/spdy/spdy_frame_builder.h"
21
22using base::StringPiece;
23using std::max;
24using std::min;
25using std::string;
26using testing::AnyNumber;
27using testing::_;
28
29namespace net {
30namespace test {
31namespace {
32
33// No-op alarm implementation used by MockHelper.
34class TestAlarm : public QuicAlarm {
35 public:
36  explicit TestAlarm(QuicAlarm::Delegate* delegate)
37      : QuicAlarm(delegate) {
38  }
39
40  virtual void SetImpl() OVERRIDE {}
41  virtual void CancelImpl() OVERRIDE {}
42};
43
44}  // namespace
45
46QuicAckFrame MakeAckFrame(QuicPacketSequenceNumber largest_observed,
47                          QuicPacketSequenceNumber least_unacked) {
48  QuicAckFrame ack;
49  ack.received_info.largest_observed = largest_observed;
50  ack.received_info.entropy_hash = 0;
51  ack.sent_info.least_unacked = least_unacked;
52  ack.sent_info.entropy_hash = 0;
53  return ack;
54}
55
56QuicAckFrame MakeAckFrameWithNackRanges(
57    size_t num_nack_ranges, QuicPacketSequenceNumber least_unacked) {
58  QuicAckFrame ack = MakeAckFrame(2 * num_nack_ranges + least_unacked,
59                                  least_unacked);
60  // Add enough missing packets to get num_nack_ranges nack ranges.
61  for (QuicPacketSequenceNumber i = 1; i < 2 * num_nack_ranges; i += 2) {
62    ack.received_info.missing_packets.insert(least_unacked + i);
63  }
64  return ack;
65}
66
67SerializedPacket BuildUnsizedDataPacket(QuicFramer* framer,
68                                        const QuicPacketHeader& header,
69                                        const QuicFrames& frames) {
70  const size_t max_plaintext_size = framer->GetMaxPlaintextSize(kMaxPacketSize);
71  size_t packet_size = GetPacketHeaderSize(header);
72  for (size_t i = 0; i < frames.size(); ++i) {
73    DCHECK_LE(packet_size, max_plaintext_size);
74    bool first_frame = i == 0;
75    bool last_frame = i == frames.size() - 1;
76    const size_t frame_size = framer->GetSerializedFrameLength(
77        frames[i], max_plaintext_size - packet_size, first_frame, last_frame,
78        header.is_in_fec_group,
79        header.public_header.sequence_number_length);
80    DCHECK(frame_size);
81    packet_size += frame_size;
82  }
83  return framer->BuildDataPacket(header, frames, packet_size);
84}
85
86uint64 SimpleRandom::RandUint64() {
87  unsigned char hash[base::kSHA1Length];
88  base::SHA1HashBytes(reinterpret_cast<unsigned char*>(&seed_), sizeof(seed_),
89                      hash);
90  memcpy(&seed_, hash, sizeof(seed_));
91  return seed_;
92}
93
94MockFramerVisitor::MockFramerVisitor() {
95  // By default, we want to accept packets.
96  ON_CALL(*this, OnProtocolVersionMismatch(_))
97      .WillByDefault(testing::Return(false));
98
99  // By default, we want to accept packets.
100  ON_CALL(*this, OnUnauthenticatedHeader(_))
101      .WillByDefault(testing::Return(true));
102
103  ON_CALL(*this, OnUnauthenticatedPublicHeader(_))
104      .WillByDefault(testing::Return(true));
105
106  ON_CALL(*this, OnPacketHeader(_))
107      .WillByDefault(testing::Return(true));
108
109  ON_CALL(*this, OnStreamFrame(_))
110      .WillByDefault(testing::Return(true));
111
112  ON_CALL(*this, OnAckFrame(_))
113      .WillByDefault(testing::Return(true));
114
115  ON_CALL(*this, OnCongestionFeedbackFrame(_))
116      .WillByDefault(testing::Return(true));
117
118  ON_CALL(*this, OnStopWaitingFrame(_))
119      .WillByDefault(testing::Return(true));
120
121  ON_CALL(*this, OnPingFrame(_))
122      .WillByDefault(testing::Return(true));
123
124  ON_CALL(*this, OnRstStreamFrame(_))
125      .WillByDefault(testing::Return(true));
126
127  ON_CALL(*this, OnConnectionCloseFrame(_))
128      .WillByDefault(testing::Return(true));
129
130  ON_CALL(*this, OnGoAwayFrame(_))
131      .WillByDefault(testing::Return(true));
132}
133
134MockFramerVisitor::~MockFramerVisitor() {
135}
136
137bool NoOpFramerVisitor::OnProtocolVersionMismatch(QuicVersion version) {
138  return false;
139}
140
141bool NoOpFramerVisitor::OnUnauthenticatedPublicHeader(
142    const QuicPacketPublicHeader& header) {
143  return true;
144}
145
146bool NoOpFramerVisitor::OnUnauthenticatedHeader(
147    const QuicPacketHeader& header) {
148  return true;
149}
150
151bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader& header) {
152  return true;
153}
154
155bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame& frame) {
156  return true;
157}
158
159bool NoOpFramerVisitor::OnAckFrame(const QuicAckFrame& frame) {
160  return true;
161}
162
163bool NoOpFramerVisitor::OnCongestionFeedbackFrame(
164    const QuicCongestionFeedbackFrame& frame) {
165  return true;
166}
167
168bool NoOpFramerVisitor::OnStopWaitingFrame(
169    const QuicStopWaitingFrame& frame) {
170  return true;
171}
172
173bool NoOpFramerVisitor::OnPingFrame(const QuicPingFrame& frame) {
174  return true;
175}
176
177bool NoOpFramerVisitor::OnRstStreamFrame(
178    const QuicRstStreamFrame& frame) {
179  return true;
180}
181
182bool NoOpFramerVisitor::OnConnectionCloseFrame(
183    const QuicConnectionCloseFrame& frame) {
184  return true;
185}
186
187bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame& frame) {
188  return true;
189}
190
191bool NoOpFramerVisitor::OnWindowUpdateFrame(
192    const QuicWindowUpdateFrame& frame) {
193  return true;
194}
195
196bool NoOpFramerVisitor::OnBlockedFrame(const QuicBlockedFrame& frame) {
197  return true;
198}
199
200MockConnectionVisitor::MockConnectionVisitor() {
201}
202
203MockConnectionVisitor::~MockConnectionVisitor() {
204}
205
206MockHelper::MockHelper() {
207}
208
209MockHelper::~MockHelper() {
210}
211
212const QuicClock* MockHelper::GetClock() const {
213  return &clock_;
214}
215
216QuicRandom* MockHelper::GetRandomGenerator() {
217  return &random_generator_;
218}
219
220QuicAlarm* MockHelper::CreateAlarm(QuicAlarm::Delegate* delegate) {
221  return new TestAlarm(delegate);
222}
223
224void MockHelper::AdvanceTime(QuicTime::Delta delta) {
225  clock_.AdvanceTime(delta);
226}
227
228MockConnection::MockConnection(bool is_server)
229    : QuicConnection(kTestConnectionId,
230                     IPEndPoint(TestPeerIPAddress(), kTestPort),
231                     new testing::NiceMock<MockHelper>(),
232                     new testing::NiceMock<MockPacketWriter>(),
233                     is_server, QuicSupportedVersions()),
234      writer_(QuicConnectionPeer::GetWriter(this)),
235      helper_(helper()) {
236}
237
238MockConnection::MockConnection(IPEndPoint address,
239                               bool is_server)
240    : QuicConnection(kTestConnectionId, address,
241                     new testing::NiceMock<MockHelper>(),
242                     new testing::NiceMock<MockPacketWriter>(),
243                     is_server, QuicSupportedVersions()),
244      writer_(QuicConnectionPeer::GetWriter(this)),
245      helper_(helper()) {
246}
247
248MockConnection::MockConnection(QuicConnectionId connection_id,
249                               bool is_server)
250    : QuicConnection(connection_id,
251                     IPEndPoint(TestPeerIPAddress(), kTestPort),
252                     new testing::NiceMock<MockHelper>(),
253                     new testing::NiceMock<MockPacketWriter>(),
254                     is_server, QuicSupportedVersions()),
255      writer_(QuicConnectionPeer::GetWriter(this)),
256      helper_(helper()) {
257}
258
259MockConnection::MockConnection(bool is_server,
260                               const QuicVersionVector& supported_versions)
261    : QuicConnection(kTestConnectionId,
262                     IPEndPoint(TestPeerIPAddress(), kTestPort),
263                     new testing::NiceMock<MockHelper>(),
264                     new testing::NiceMock<MockPacketWriter>(),
265                     is_server, supported_versions),
266      writer_(QuicConnectionPeer::GetWriter(this)),
267      helper_(helper()) {
268}
269
270MockConnection::~MockConnection() {
271}
272
273void MockConnection::AdvanceTime(QuicTime::Delta delta) {
274  static_cast<MockHelper*>(helper())->AdvanceTime(delta);
275}
276
277PacketSavingConnection::PacketSavingConnection(bool is_server)
278    : MockConnection(is_server) {
279}
280
281PacketSavingConnection::PacketSavingConnection(
282    bool is_server,
283    const QuicVersionVector& supported_versions)
284    : MockConnection(is_server, supported_versions) {
285}
286
287PacketSavingConnection::~PacketSavingConnection() {
288  STLDeleteElements(&packets_);
289  STLDeleteElements(&encrypted_packets_);
290}
291
292bool PacketSavingConnection::SendOrQueuePacket(
293    EncryptionLevel level,
294    const SerializedPacket& packet,
295    TransmissionType transmission_type) {
296  packets_.push_back(packet.packet);
297  QuicEncryptedPacket* encrypted = QuicConnectionPeer::GetFramer(this)->
298      EncryptPacket(level, packet.sequence_number, *packet.packet);
299  encrypted_packets_.push_back(encrypted);
300  return true;
301}
302
303MockSession::MockSession(QuicConnection* connection)
304    : QuicSession(connection, DefaultQuicConfig()) {
305  ON_CALL(*this, WritevData(_, _, _, _, _, _))
306      .WillByDefault(testing::Return(QuicConsumedData(0, false)));
307}
308
309MockSession::~MockSession() {
310}
311
312TestSession::TestSession(QuicConnection* connection, const QuicConfig& config)
313    : QuicSession(connection, config),
314      crypto_stream_(NULL) {}
315
316TestSession::~TestSession() {}
317
318void TestSession::SetCryptoStream(QuicCryptoStream* stream) {
319  crypto_stream_ = stream;
320}
321
322QuicCryptoStream* TestSession::GetCryptoStream() {
323  return crypto_stream_;
324}
325
326TestClientSession::TestClientSession(QuicConnection* connection,
327                                     const QuicConfig& config)
328    : QuicClientSessionBase(connection,
329                            config),
330      crypto_stream_(NULL) {
331    EXPECT_CALL(*this, OnProofValid(_)).Times(AnyNumber());
332}
333
334TestClientSession::~TestClientSession() {}
335
336void TestClientSession::SetCryptoStream(QuicCryptoStream* stream) {
337  crypto_stream_ = stream;
338}
339
340QuicCryptoStream* TestClientSession::GetCryptoStream() {
341  return crypto_stream_;
342}
343
344MockPacketWriter::MockPacketWriter() {
345}
346
347MockPacketWriter::~MockPacketWriter() {
348}
349
350MockSendAlgorithm::MockSendAlgorithm() {
351}
352
353MockSendAlgorithm::~MockSendAlgorithm() {
354}
355
356MockLossAlgorithm::MockLossAlgorithm() {
357}
358
359MockLossAlgorithm::~MockLossAlgorithm() {
360}
361
362MockAckNotifierDelegate::MockAckNotifierDelegate() {
363}
364
365MockAckNotifierDelegate::~MockAckNotifierDelegate() {
366}
367
368namespace {
369
370string HexDumpWithMarks(const char* data, int length,
371                        const bool* marks, int mark_length) {
372  static const char kHexChars[] = "0123456789abcdef";
373  static const int kColumns = 4;
374
375  const int kSizeLimit = 1024;
376  if (length > kSizeLimit || mark_length > kSizeLimit) {
377    LOG(ERROR) << "Only dumping first " << kSizeLimit << " bytes.";
378    length = min(length, kSizeLimit);
379    mark_length = min(mark_length, kSizeLimit);
380  }
381
382  string hex;
383  for (const char* row = data; length > 0;
384       row += kColumns, length -= kColumns) {
385    for (const char *p = row; p < row + 4; ++p) {
386      if (p < row + length) {
387        const bool mark =
388            (marks && (p - data) < mark_length && marks[p - data]);
389        hex += mark ? '*' : ' ';
390        hex += kHexChars[(*p & 0xf0) >> 4];
391        hex += kHexChars[*p & 0x0f];
392        hex += mark ? '*' : ' ';
393      } else {
394        hex += "    ";
395      }
396    }
397    hex = hex + "  ";
398
399    for (const char *p = row; p < row + 4 && p < row + length; ++p)
400      hex += (*p >= 0x20 && *p <= 0x7f) ? (*p) : '.';
401
402    hex = hex + '\n';
403  }
404  return hex;
405}
406
407}  // namespace
408
409IPAddressNumber TestPeerIPAddress() { return Loopback4(); }
410
411QuicVersion QuicVersionMax() { return QuicSupportedVersions().front(); }
412
413QuicVersion QuicVersionMin() { return QuicSupportedVersions().back(); }
414
415IPAddressNumber Loopback4() {
416  IPAddressNumber addr;
417  CHECK(ParseIPLiteralToNumber("127.0.0.1", &addr));
418  return addr;
419}
420
421IPAddressNumber Loopback6() {
422  IPAddressNumber addr;
423  CHECK(ParseIPLiteralToNumber("::1", &addr));
424  return addr;
425}
426
427void GenerateBody(string* body, int length) {
428  body->clear();
429  body->reserve(length);
430  for (int i = 0; i < length; ++i) {
431    body->append(1, static_cast<char>(32 + i % (126 - 32)));
432  }
433}
434
435QuicEncryptedPacket* ConstructEncryptedPacket(
436    QuicConnectionId connection_id,
437    bool version_flag,
438    bool reset_flag,
439    QuicPacketSequenceNumber sequence_number,
440    const string& data) {
441  QuicPacketHeader header;
442  header.public_header.connection_id = connection_id;
443  header.public_header.connection_id_length = PACKET_8BYTE_CONNECTION_ID;
444  header.public_header.version_flag = version_flag;
445  header.public_header.reset_flag = reset_flag;
446  header.public_header.sequence_number_length = PACKET_6BYTE_SEQUENCE_NUMBER;
447  header.packet_sequence_number = sequence_number;
448  header.entropy_flag = false;
449  header.entropy_hash = 0;
450  header.fec_flag = false;
451  header.is_in_fec_group = NOT_IN_FEC_GROUP;
452  header.fec_group = 0;
453  QuicStreamFrame stream_frame(1, false, 0, MakeIOVector(data));
454  QuicFrame frame(&stream_frame);
455  QuicFrames frames;
456  frames.push_back(frame);
457  QuicFramer framer(QuicSupportedVersions(), QuicTime::Zero(), false);
458  scoped_ptr<QuicPacket> packet(
459      BuildUnsizedDataPacket(&framer, header, frames).packet);
460  EXPECT_TRUE(packet != NULL);
461  QuicEncryptedPacket* encrypted = framer.EncryptPacket(ENCRYPTION_NONE,
462                                                        sequence_number,
463                                                        *packet);
464  EXPECT_TRUE(encrypted != NULL);
465  return encrypted;
466}
467
468void CompareCharArraysWithHexError(
469    const string& description,
470    const char* actual,
471    const int actual_len,
472    const char* expected,
473    const int expected_len) {
474  EXPECT_EQ(actual_len, expected_len);
475  const int min_len = min(actual_len, expected_len);
476  const int max_len = max(actual_len, expected_len);
477  scoped_ptr<bool[]> marks(new bool[max_len]);
478  bool identical = (actual_len == expected_len);
479  for (int i = 0; i < min_len; ++i) {
480    if (actual[i] != expected[i]) {
481      marks[i] = true;
482      identical = false;
483    } else {
484      marks[i] = false;
485    }
486  }
487  for (int i = min_len; i < max_len; ++i) {
488    marks[i] = true;
489  }
490  if (identical) return;
491  ADD_FAILURE()
492      << "Description:\n"
493      << description
494      << "\n\nExpected:\n"
495      << HexDumpWithMarks(expected, expected_len, marks.get(), max_len)
496      << "\nActual:\n"
497      << HexDumpWithMarks(actual, actual_len, marks.get(), max_len);
498}
499
500bool DecodeHexString(const base::StringPiece& hex, std::string* bytes) {
501  bytes->clear();
502  if (hex.empty())
503    return true;
504  std::vector<uint8> v;
505  if (!base::HexStringToBytes(hex.as_string(), &v))
506    return false;
507  if (!v.empty())
508    bytes->assign(reinterpret_cast<const char*>(&v[0]), v.size());
509  return true;
510}
511
512static QuicPacket* ConstructPacketFromHandshakeMessage(
513    QuicConnectionId connection_id,
514    const CryptoHandshakeMessage& message,
515    bool should_include_version) {
516  CryptoFramer crypto_framer;
517  scoped_ptr<QuicData> data(crypto_framer.ConstructHandshakeMessage(message));
518  QuicFramer quic_framer(QuicSupportedVersions(), QuicTime::Zero(), false);
519
520  QuicPacketHeader header;
521  header.public_header.connection_id = connection_id;
522  header.public_header.reset_flag = false;
523  header.public_header.version_flag = should_include_version;
524  header.packet_sequence_number = 1;
525  header.entropy_flag = false;
526  header.entropy_hash = 0;
527  header.fec_flag = false;
528  header.fec_group = 0;
529
530  QuicStreamFrame stream_frame(kCryptoStreamId, false, 0,
531                               MakeIOVector(data->AsStringPiece()));
532
533  QuicFrame frame(&stream_frame);
534  QuicFrames frames;
535  frames.push_back(frame);
536  return BuildUnsizedDataPacket(&quic_framer, header, frames).packet;
537}
538
539QuicPacket* ConstructHandshakePacket(QuicConnectionId connection_id,
540                                     QuicTag tag) {
541  CryptoHandshakeMessage message;
542  message.set_tag(tag);
543  return ConstructPacketFromHandshakeMessage(connection_id, message, false);
544}
545
546size_t GetPacketLengthForOneStream(
547    QuicVersion version,
548    bool include_version,
549    QuicSequenceNumberLength sequence_number_length,
550    InFecGroup is_in_fec_group,
551    size_t* payload_length) {
552  *payload_length = 1;
553  const size_t stream_length =
554      NullEncrypter().GetCiphertextSize(*payload_length) +
555      QuicPacketCreator::StreamFramePacketOverhead(
556          version, PACKET_8BYTE_CONNECTION_ID, include_version,
557          sequence_number_length, 0u, is_in_fec_group);
558  const size_t ack_length = NullEncrypter().GetCiphertextSize(
559      QuicFramer::GetMinAckFrameSize(
560          version, sequence_number_length, PACKET_1BYTE_SEQUENCE_NUMBER)) +
561      GetPacketHeaderSize(PACKET_8BYTE_CONNECTION_ID, include_version,
562                          sequence_number_length, is_in_fec_group);
563  if (stream_length < ack_length) {
564    *payload_length = 1 + ack_length - stream_length;
565  }
566
567  return NullEncrypter().GetCiphertextSize(*payload_length) +
568      QuicPacketCreator::StreamFramePacketOverhead(
569          version, PACKET_8BYTE_CONNECTION_ID, include_version,
570          sequence_number_length, 0u, is_in_fec_group);
571}
572
573TestEntropyCalculator::TestEntropyCalculator() {}
574
575TestEntropyCalculator::~TestEntropyCalculator() {}
576
577QuicPacketEntropyHash TestEntropyCalculator::EntropyHash(
578    QuicPacketSequenceNumber sequence_number) const {
579  return 1u;
580}
581
582MockEntropyCalculator::MockEntropyCalculator() {}
583
584MockEntropyCalculator::~MockEntropyCalculator() {}
585
586QuicConfig DefaultQuicConfig() {
587  QuicConfig config;
588  config.SetDefaults();
589  config.SetInitialFlowControlWindowToSend(
590      kInitialSessionFlowControlWindowForTest);
591  config.SetInitialStreamFlowControlWindowToSend(
592      kInitialStreamFlowControlWindowForTest);
593  config.SetInitialSessionFlowControlWindowToSend(
594      kInitialSessionFlowControlWindowForTest);
595  return config;
596}
597
598QuicVersionVector SupportedVersions(QuicVersion version) {
599  QuicVersionVector versions;
600  versions.push_back(version);
601  return versions;
602}
603
604}  // namespace test
605}  // namespace net
606