1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/quic/test_tools/quic_test_utils.h"
6
7#include "base/stl_util.h"
8#include "base/strings/string_number_conversions.h"
9#include "net/quic/crypto/crypto_framer.h"
10#include "net/quic/crypto/crypto_handshake.h"
11#include "net/quic/crypto/crypto_utils.h"
12#include "net/quic/crypto/null_encrypter.h"
13#include "net/quic/crypto/quic_decrypter.h"
14#include "net/quic/crypto/quic_encrypter.h"
15#include "net/quic/quic_framer.h"
16#include "net/quic/quic_packet_creator.h"
17#include "net/quic/test_tools/quic_connection_peer.h"
18#include "net/spdy/spdy_frame_builder.h"
19
20using base::StringPiece;
21using std::max;
22using std::min;
23using std::string;
24using testing::_;
25
26namespace net {
27namespace test {
28namespace {
29
30// No-op alarm implementation used by MockHelper.
31class TestAlarm : public QuicAlarm {
32 public:
33  explicit TestAlarm(QuicAlarm::Delegate* delegate)
34      : QuicAlarm(delegate) {
35  }
36
37  virtual void SetImpl() OVERRIDE {}
38  virtual void CancelImpl() OVERRIDE {}
39};
40
41}  // namespace
42
43MockFramerVisitor::MockFramerVisitor() {
44  // By default, we want to accept packets.
45  ON_CALL(*this, OnProtocolVersionMismatch(_))
46      .WillByDefault(testing::Return(false));
47
48  // By default, we want to accept packets.
49  ON_CALL(*this, OnUnauthenticatedHeader(_))
50      .WillByDefault(testing::Return(true));
51
52  ON_CALL(*this, OnPacketHeader(_))
53      .WillByDefault(testing::Return(true));
54
55  ON_CALL(*this, OnStreamFrame(_))
56      .WillByDefault(testing::Return(true));
57
58  ON_CALL(*this, OnAckFrame(_))
59      .WillByDefault(testing::Return(true));
60
61  ON_CALL(*this, OnCongestionFeedbackFrame(_))
62      .WillByDefault(testing::Return(true));
63
64  ON_CALL(*this, OnRstStreamFrame(_))
65      .WillByDefault(testing::Return(true));
66
67  ON_CALL(*this, OnConnectionCloseFrame(_))
68      .WillByDefault(testing::Return(true));
69
70  ON_CALL(*this, OnGoAwayFrame(_))
71      .WillByDefault(testing::Return(true));
72}
73
74MockFramerVisitor::~MockFramerVisitor() {
75}
76
77bool NoOpFramerVisitor::OnProtocolVersionMismatch(QuicVersion version) {
78  return false;
79}
80
81bool NoOpFramerVisitor::OnUnauthenticatedHeader(
82    const QuicPacketHeader& header) {
83  return true;
84}
85
86bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader& header) {
87  return true;
88}
89
90bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame& frame) {
91  return true;
92}
93
94bool NoOpFramerVisitor::OnAckFrame(const QuicAckFrame& frame) {
95  return true;
96}
97
98bool NoOpFramerVisitor::OnCongestionFeedbackFrame(
99    const QuicCongestionFeedbackFrame& frame) {
100  return true;
101}
102
103bool NoOpFramerVisitor::OnRstStreamFrame(
104    const QuicRstStreamFrame& frame) {
105  return true;
106}
107
108bool NoOpFramerVisitor::OnConnectionCloseFrame(
109    const QuicConnectionCloseFrame& frame) {
110  return true;
111}
112
113bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame& frame) {
114  return true;
115}
116
117FramerVisitorCapturingFrames::FramerVisitorCapturingFrames() : frame_count_(0) {
118}
119
120FramerVisitorCapturingFrames::~FramerVisitorCapturingFrames() {
121  Reset();
122}
123
124void FramerVisitorCapturingFrames::Reset() {
125  STLDeleteElements(&stream_data_);
126  stream_frames_.clear();
127  frame_count_ = 0;
128  ack_.reset();
129  feedback_.reset();
130  rst_.reset();
131  close_.reset();
132  goaway_.reset();
133  version_negotiation_packet_.reset();
134}
135
136bool FramerVisitorCapturingFrames::OnPacketHeader(
137    const QuicPacketHeader& header) {
138  header_ = header;
139  frame_count_ = 0;
140  return true;
141}
142
143bool FramerVisitorCapturingFrames::OnStreamFrame(const QuicStreamFrame& frame) {
144  // Make a copy of the frame and store a copy of underlying string, since
145  // frame.data may not exist outside this callback.
146  stream_data_.push_back(frame.GetDataAsString());
147  QuicStreamFrame frame_copy = frame;
148  frame_copy.data.Clear();
149  frame_copy.data.Append(const_cast<char*>(stream_data_.back()->data()),
150                         stream_data_.back()->size());
151  stream_frames_.push_back(frame_copy);
152  ++frame_count_;
153  return true;
154}
155
156bool FramerVisitorCapturingFrames::OnAckFrame(const QuicAckFrame& frame) {
157  ack_.reset(new QuicAckFrame(frame));
158  ++frame_count_;
159  return true;
160}
161
162bool FramerVisitorCapturingFrames::OnCongestionFeedbackFrame(
163    const QuicCongestionFeedbackFrame& frame) {
164  feedback_.reset(new QuicCongestionFeedbackFrame(frame));
165  ++frame_count_;
166  return true;
167}
168
169bool FramerVisitorCapturingFrames::OnRstStreamFrame(
170    const QuicRstStreamFrame& frame) {
171  rst_.reset(new QuicRstStreamFrame(frame));
172  ++frame_count_;
173  return true;
174}
175
176bool FramerVisitorCapturingFrames::OnConnectionCloseFrame(
177    const QuicConnectionCloseFrame& frame) {
178  close_.reset(new QuicConnectionCloseFrame(frame));
179  ++frame_count_;
180  return true;
181}
182
183bool FramerVisitorCapturingFrames::OnGoAwayFrame(const QuicGoAwayFrame& frame) {
184  goaway_.reset(new QuicGoAwayFrame(frame));
185  ++frame_count_;
186  return true;
187}
188
189void FramerVisitorCapturingFrames::OnVersionNegotiationPacket(
190    const QuicVersionNegotiationPacket& packet) {
191  version_negotiation_packet_.reset(new QuicVersionNegotiationPacket(packet));
192  frame_count_ = 0;
193}
194
195FramerVisitorCapturingPublicReset::FramerVisitorCapturingPublicReset() {
196}
197
198FramerVisitorCapturingPublicReset::~FramerVisitorCapturingPublicReset() {
199}
200
201void FramerVisitorCapturingPublicReset::OnPublicResetPacket(
202    const QuicPublicResetPacket& public_reset) {
203  public_reset_packet_ = public_reset;
204}
205
206MockConnectionVisitor::MockConnectionVisitor() {
207}
208
209MockConnectionVisitor::~MockConnectionVisitor() {
210}
211
212MockHelper::MockHelper() {
213}
214
215MockHelper::~MockHelper() {
216}
217
218const QuicClock* MockHelper::GetClock() const {
219  return &clock_;
220}
221
222QuicRandom* MockHelper::GetRandomGenerator() {
223  return &random_generator_;
224}
225
226QuicAlarm* MockHelper::CreateAlarm(QuicAlarm::Delegate* delegate) {
227  return new TestAlarm(delegate);
228}
229
230void MockHelper::AdvanceTime(QuicTime::Delta delta) {
231  clock_.AdvanceTime(delta);
232}
233
234MockConnection::MockConnection(bool is_server)
235    : QuicConnection(kTestGuid,
236                     IPEndPoint(Loopback4(), kTestPort),
237                     new testing::NiceMock<MockHelper>(),
238                     new testing::NiceMock<MockPacketWriter>(),
239                     is_server, QuicSupportedVersions()),
240      writer_(QuicConnectionPeer::GetWriter(this)),
241      helper_(helper()) {
242}
243
244MockConnection::MockConnection(IPEndPoint address,
245                               bool is_server)
246    : QuicConnection(kTestGuid, address,
247                     new testing::NiceMock<MockHelper>(),
248                     new testing::NiceMock<MockPacketWriter>(),
249                     is_server, QuicSupportedVersions()),
250      writer_(QuicConnectionPeer::GetWriter(this)),
251      helper_(helper()) {
252}
253
254MockConnection::MockConnection(QuicGuid guid,
255                               bool is_server)
256    : QuicConnection(guid,
257                     IPEndPoint(Loopback4(), kTestPort),
258                     new testing::NiceMock<MockHelper>(),
259                     new testing::NiceMock<MockPacketWriter>(),
260                     is_server, QuicSupportedVersions()),
261      writer_(QuicConnectionPeer::GetWriter(this)),
262      helper_(helper()) {
263}
264
265MockConnection::~MockConnection() {
266}
267
268void MockConnection::AdvanceTime(QuicTime::Delta delta) {
269  static_cast<MockHelper*>(helper())->AdvanceTime(delta);
270}
271
272PacketSavingConnection::PacketSavingConnection(bool is_server)
273    : MockConnection(is_server) {
274}
275
276PacketSavingConnection::~PacketSavingConnection() {
277  STLDeleteElements(&packets_);
278  STLDeleteElements(&encrypted_packets_);
279}
280
281bool PacketSavingConnection::SendOrQueuePacket(
282    EncryptionLevel level,
283    const SerializedPacket& packet,
284    TransmissionType transmission_type) {
285  packets_.push_back(packet.packet);
286  QuicEncryptedPacket* encrypted =
287      framer_.EncryptPacket(level, packet.sequence_number, *packet.packet);
288  encrypted_packets_.push_back(encrypted);
289  return true;
290}
291
292MockSession::MockSession(QuicConnection* connection)
293    : QuicSession(connection, DefaultQuicConfig()) {
294  ON_CALL(*this, WritevData(_, _, _, _, _, _))
295      .WillByDefault(testing::Return(QuicConsumedData(0, false)));
296}
297
298MockSession::~MockSession() {
299}
300
301TestSession::TestSession(QuicConnection* connection,
302                         const QuicConfig& config)
303    : QuicSession(connection, config),
304      crypto_stream_(NULL) {
305}
306
307TestSession::~TestSession() {}
308
309void TestSession::SetCryptoStream(QuicCryptoStream* stream) {
310  crypto_stream_ = stream;
311}
312
313QuicCryptoStream* TestSession::GetCryptoStream() {
314  return crypto_stream_;
315}
316
317MockPacketWriter::MockPacketWriter() {
318}
319
320MockPacketWriter::~MockPacketWriter() {
321}
322
323MockSendAlgorithm::MockSendAlgorithm() {
324}
325
326MockSendAlgorithm::~MockSendAlgorithm() {
327}
328
329MockAckNotifierDelegate::MockAckNotifierDelegate() {
330}
331
332MockAckNotifierDelegate::~MockAckNotifierDelegate() {
333}
334
335namespace {
336
337string HexDumpWithMarks(const char* data, int length,
338                        const bool* marks, int mark_length) {
339  static const char kHexChars[] = "0123456789abcdef";
340  static const int kColumns = 4;
341
342  const int kSizeLimit = 1024;
343  if (length > kSizeLimit || mark_length > kSizeLimit) {
344    LOG(ERROR) << "Only dumping first " << kSizeLimit << " bytes.";
345    length = min(length, kSizeLimit);
346    mark_length = min(mark_length, kSizeLimit);
347  }
348
349  string hex;
350  for (const char* row = data; length > 0;
351       row += kColumns, length -= kColumns) {
352    for (const char *p = row; p < row + 4; ++p) {
353      if (p < row + length) {
354        const bool mark =
355            (marks && (p - data) < mark_length && marks[p - data]);
356        hex += mark ? '*' : ' ';
357        hex += kHexChars[(*p & 0xf0) >> 4];
358        hex += kHexChars[*p & 0x0f];
359        hex += mark ? '*' : ' ';
360      } else {
361        hex += "    ";
362      }
363    }
364    hex = hex + "  ";
365
366    for (const char *p = row; p < row + 4 && p < row + length; ++p)
367      hex += (*p >= 0x20 && *p <= 0x7f) ? (*p) : '.';
368
369    hex = hex + '\n';
370  }
371  return hex;
372}
373
374}  // namespace
375
376QuicVersion QuicVersionMax() { return QuicSupportedVersions().front(); }
377
378QuicVersion QuicVersionMin() { return QuicSupportedVersions().back(); }
379
380IPAddressNumber Loopback4() {
381  net::IPAddressNumber addr;
382  CHECK(net::ParseIPLiteralToNumber("127.0.0.1", &addr));
383  return addr;
384}
385
386void CompareCharArraysWithHexError(
387    const string& description,
388    const char* actual,
389    const int actual_len,
390    const char* expected,
391    const int expected_len) {
392  EXPECT_EQ(actual_len, expected_len);
393  const int min_len = min(actual_len, expected_len);
394  const int max_len = max(actual_len, expected_len);
395  scoped_ptr<bool[]> marks(new bool[max_len]);
396  bool identical = (actual_len == expected_len);
397  for (int i = 0; i < min_len; ++i) {
398    if (actual[i] != expected[i]) {
399      marks[i] = true;
400      identical = false;
401    } else {
402      marks[i] = false;
403    }
404  }
405  for (int i = min_len; i < max_len; ++i) {
406    marks[i] = true;
407  }
408  if (identical) return;
409  ADD_FAILURE()
410      << "Description:\n"
411      << description
412      << "\n\nExpected:\n"
413      << HexDumpWithMarks(expected, expected_len, marks.get(), max_len)
414      << "\nActual:\n"
415      << HexDumpWithMarks(actual, actual_len, marks.get(), max_len);
416}
417
418bool DecodeHexString(const base::StringPiece& hex, std::string* bytes) {
419  bytes->clear();
420  if (hex.empty())
421    return true;
422  std::vector<uint8> v;
423  if (!base::HexStringToBytes(hex.as_string(), &v))
424    return false;
425  if (!v.empty())
426    bytes->assign(reinterpret_cast<const char*>(&v[0]), v.size());
427  return true;
428}
429
430static QuicPacket* ConstructPacketFromHandshakeMessage(
431    QuicGuid guid,
432    const CryptoHandshakeMessage& message,
433    bool should_include_version) {
434  CryptoFramer crypto_framer;
435  scoped_ptr<QuicData> data(crypto_framer.ConstructHandshakeMessage(message));
436  QuicFramer quic_framer(QuicSupportedVersions(), QuicTime::Zero(), false);
437
438  QuicPacketHeader header;
439  header.public_header.guid = guid;
440  header.public_header.reset_flag = false;
441  header.public_header.version_flag = should_include_version;
442  header.packet_sequence_number = 1;
443  header.entropy_flag = false;
444  header.entropy_hash = 0;
445  header.fec_flag = false;
446  header.fec_group = 0;
447
448  QuicStreamFrame stream_frame(kCryptoStreamId, false, 0,
449                               MakeIOVector(data->AsStringPiece()));
450
451  QuicFrame frame(&stream_frame);
452  QuicFrames frames;
453  frames.push_back(frame);
454  return quic_framer.BuildUnsizedDataPacket(header, frames).packet;
455}
456
457QuicPacket* ConstructHandshakePacket(QuicGuid guid, QuicTag tag) {
458  CryptoHandshakeMessage message;
459  message.set_tag(tag);
460  return ConstructPacketFromHandshakeMessage(guid, message, false);
461}
462
463size_t GetPacketLengthForOneStream(
464    QuicVersion version,
465    bool include_version,
466    QuicSequenceNumberLength sequence_number_length,
467    InFecGroup is_in_fec_group,
468    size_t* payload_length) {
469  *payload_length = 1;
470  const size_t stream_length =
471      NullEncrypter().GetCiphertextSize(*payload_length) +
472      QuicPacketCreator::StreamFramePacketOverhead(
473          version, PACKET_8BYTE_GUID, include_version,
474          sequence_number_length, is_in_fec_group);
475  const size_t ack_length = NullEncrypter().GetCiphertextSize(
476      QuicFramer::GetMinAckFrameSize(
477          version, sequence_number_length, PACKET_1BYTE_SEQUENCE_NUMBER)) +
478      GetPacketHeaderSize(PACKET_8BYTE_GUID, include_version,
479                          sequence_number_length, is_in_fec_group);
480  if (stream_length < ack_length) {
481    *payload_length = 1 + ack_length - stream_length;
482  }
483
484  return NullEncrypter().GetCiphertextSize(*payload_length) +
485      QuicPacketCreator::StreamFramePacketOverhead(
486          version, PACKET_8BYTE_GUID, include_version,
487          sequence_number_length, is_in_fec_group);
488}
489
490// Size in bytes of the stream frame fields for an arbitrary StreamID and
491// offset and the last frame in a packet.
492size_t GetMinStreamFrameSize(QuicVersion version) {
493  return kQuicFrameTypeSize + kQuicMaxStreamIdSize + kQuicMaxStreamOffsetSize;
494}
495
496TestEntropyCalculator::TestEntropyCalculator() { }
497
498TestEntropyCalculator::~TestEntropyCalculator() { }
499
500QuicPacketEntropyHash TestEntropyCalculator::EntropyHash(
501    QuicPacketSequenceNumber sequence_number) const {
502  return 1u;
503}
504
505MockEntropyCalculator::MockEntropyCalculator() { }
506
507MockEntropyCalculator::~MockEntropyCalculator() { }
508
509QuicConfig DefaultQuicConfig() {
510  QuicConfig config;
511  config.SetDefaults();
512  return config;
513}
514
515bool TestDecompressorVisitor::OnDecompressedData(StringPiece data) {
516  data.AppendToString(&data_);
517  return true;
518}
519
520void TestDecompressorVisitor::OnDecompressionError() {
521  error_ = true;
522}
523
524}  // namespace test
525}  // namespace net
526