1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/quic/test_tools/quic_test_utils.h"
6
7#include "base/stl_util.h"
8#include "net/quic/crypto/crypto_framer.h"
9#include "net/quic/crypto/crypto_handshake.h"
10#include "net/quic/crypto/crypto_utils.h"
11#include "net/quic/crypto/null_encrypter.h"
12#include "net/quic/crypto/quic_decrypter.h"
13#include "net/quic/crypto/quic_encrypter.h"
14#include "net/quic/quic_framer.h"
15#include "net/quic/quic_packet_creator.h"
16#include "net/spdy/spdy_frame_builder.h"
17
18using base::StringPiece;
19using std::max;
20using std::min;
21using std::string;
22using testing::_;
23
24namespace net {
25namespace test {
26namespace {
27
28// No-op alarm implementation used by MockHelper.
29class TestAlarm : public QuicAlarm {
30 public:
31  explicit TestAlarm(QuicAlarm::Delegate* delegate)
32      : QuicAlarm(delegate) {
33  }
34
35  virtual void SetImpl() OVERRIDE {}
36  virtual void CancelImpl() OVERRIDE {}
37};
38
39}  // namespace
40
41MockFramerVisitor::MockFramerVisitor() {
42  // By default, we want to accept packets.
43  ON_CALL(*this, OnProtocolVersionMismatch(_))
44      .WillByDefault(testing::Return(false));
45
46  // By default, we want to accept packets.
47  ON_CALL(*this, OnPacketHeader(_))
48      .WillByDefault(testing::Return(true));
49
50  ON_CALL(*this, OnStreamFrame(_))
51      .WillByDefault(testing::Return(true));
52
53  ON_CALL(*this, OnAckFrame(_))
54      .WillByDefault(testing::Return(true));
55
56  ON_CALL(*this, OnCongestionFeedbackFrame(_))
57      .WillByDefault(testing::Return(true));
58
59  ON_CALL(*this, OnRstStreamFrame(_))
60      .WillByDefault(testing::Return(true));
61
62  ON_CALL(*this, OnConnectionCloseFrame(_))
63      .WillByDefault(testing::Return(true));
64
65  ON_CALL(*this, OnGoAwayFrame(_))
66      .WillByDefault(testing::Return(true));
67}
68
69MockFramerVisitor::~MockFramerVisitor() {
70}
71
72bool NoOpFramerVisitor::OnProtocolVersionMismatch(QuicVersion version) {
73  return false;
74}
75
76bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader& header) {
77  return true;
78}
79
80bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame& frame) {
81  return true;
82}
83
84bool NoOpFramerVisitor::OnAckFrame(const QuicAckFrame& frame) {
85  return true;
86}
87
88bool NoOpFramerVisitor::OnCongestionFeedbackFrame(
89    const QuicCongestionFeedbackFrame& frame) {
90  return true;
91}
92
93bool NoOpFramerVisitor::OnRstStreamFrame(
94    const QuicRstStreamFrame& frame) {
95  return true;
96}
97
98bool NoOpFramerVisitor::OnConnectionCloseFrame(
99    const QuicConnectionCloseFrame& frame) {
100  return true;
101}
102
103bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame& frame) {
104  return true;
105}
106
107FramerVisitorCapturingFrames::FramerVisitorCapturingFrames() : frame_count_(0) {
108}
109
110FramerVisitorCapturingFrames::~FramerVisitorCapturingFrames() {
111}
112
113bool FramerVisitorCapturingFrames::OnPacketHeader(
114    const QuicPacketHeader& header) {
115  header_ = header;
116  frame_count_ = 0;
117  return true;
118}
119
120bool FramerVisitorCapturingFrames::OnStreamFrame(const QuicStreamFrame& frame) {
121  // TODO(ianswett): Own the underlying string, so it will not exist outside
122  // this callback.
123  stream_frames_.push_back(frame);
124  ++frame_count_;
125  return true;
126}
127
128bool FramerVisitorCapturingFrames::OnAckFrame(const QuicAckFrame& frame) {
129  ack_.reset(new QuicAckFrame(frame));
130  ++frame_count_;
131  return true;
132}
133
134bool FramerVisitorCapturingFrames::OnCongestionFeedbackFrame(
135    const QuicCongestionFeedbackFrame& frame) {
136  feedback_.reset(new QuicCongestionFeedbackFrame(frame));
137  ++frame_count_;
138  return true;
139}
140
141bool FramerVisitorCapturingFrames::OnRstStreamFrame(
142    const QuicRstStreamFrame& frame) {
143  rst_.reset(new QuicRstStreamFrame(frame));
144  ++frame_count_;
145  return true;
146}
147
148bool FramerVisitorCapturingFrames::OnConnectionCloseFrame(
149    const QuicConnectionCloseFrame& frame) {
150  close_.reset(new QuicConnectionCloseFrame(frame));
151  ++frame_count_;
152  return true;
153}
154
155bool FramerVisitorCapturingFrames::OnGoAwayFrame(const QuicGoAwayFrame& frame) {
156  goaway_.reset(new QuicGoAwayFrame(frame));
157  ++frame_count_;
158  return true;
159}
160
161void FramerVisitorCapturingFrames::OnVersionNegotiationPacket(
162    const QuicVersionNegotiationPacket& packet) {
163  version_negotiation_packet_.reset(new QuicVersionNegotiationPacket(packet));
164  frame_count_ = 0;
165}
166
167FramerVisitorCapturingPublicReset::FramerVisitorCapturingPublicReset() {
168}
169
170FramerVisitorCapturingPublicReset::~FramerVisitorCapturingPublicReset() {
171}
172
173void FramerVisitorCapturingPublicReset::OnPublicResetPacket(
174    const QuicPublicResetPacket& public_reset) {
175  public_reset_packet_ = public_reset;
176}
177
178MockConnectionVisitor::MockConnectionVisitor() {
179}
180
181MockConnectionVisitor::~MockConnectionVisitor() {
182}
183
184MockHelper::MockHelper() {
185}
186
187MockHelper::~MockHelper() {
188}
189
190const QuicClock* MockHelper::GetClock() const {
191  return &clock_;
192}
193
194QuicRandom* MockHelper::GetRandomGenerator() {
195  return &random_generator_;
196}
197
198QuicAlarm* MockHelper::CreateAlarm(QuicAlarm::Delegate* delegate) {
199  return new TestAlarm(delegate);
200}
201
202void MockHelper::AdvanceTime(QuicTime::Delta delta) {
203  clock_.AdvanceTime(delta);
204}
205
206MockConnection::MockConnection(QuicGuid guid,
207                               IPEndPoint address,
208                               bool is_server)
209    : QuicConnection(guid, address, new testing::NiceMock<MockHelper>(),
210                     is_server, QuicVersionMax()),
211      has_mock_helper_(true) {
212}
213
214MockConnection::MockConnection(QuicGuid guid,
215                               IPEndPoint address,
216                               QuicConnectionHelperInterface* helper,
217                               bool is_server)
218    : QuicConnection(guid, address, helper, is_server, QuicVersionMax()),
219      has_mock_helper_(false) {
220}
221
222MockConnection::~MockConnection() {
223}
224
225void MockConnection::AdvanceTime(QuicTime::Delta delta) {
226  CHECK(has_mock_helper_) << "Cannot advance time unless a MockClock is being"
227                             " used";
228  static_cast<MockHelper*>(helper())->AdvanceTime(delta);
229}
230
231PacketSavingConnection::PacketSavingConnection(QuicGuid guid,
232                                               IPEndPoint address,
233                                               bool is_server)
234    : MockConnection(guid, address, is_server) {
235}
236
237PacketSavingConnection::~PacketSavingConnection() {
238  STLDeleteElements(&packets_);
239  STLDeleteElements(&encrypted_packets_);
240}
241
242bool PacketSavingConnection::SendOrQueuePacket(
243    EncryptionLevel level,
244    QuicPacketSequenceNumber sequence_number,
245    QuicPacket* packet,
246    QuicPacketEntropyHash entropy_hash,
247    HasRetransmittableData retransmittable) {
248  packets_.push_back(packet);
249  QuicEncryptedPacket* encrypted =
250      framer_.EncryptPacket(level, sequence_number, *packet);
251  encrypted_packets_.push_back(encrypted);
252  return true;
253}
254
255MockSession::MockSession(QuicConnection* connection, bool is_server)
256    : QuicSession(connection, DefaultQuicConfig(), is_server) {
257  ON_CALL(*this, WriteData(_, _, _, _))
258      .WillByDefault(testing::Return(QuicConsumedData(0, false)));
259}
260
261MockSession::~MockSession() {
262}
263
264TestSession::TestSession(QuicConnection* connection,
265                         const QuicConfig& config,
266                         bool is_server)
267    : QuicSession(connection, config, is_server),
268      crypto_stream_(NULL) {
269}
270
271TestSession::~TestSession() {}
272
273void TestSession::SetCryptoStream(QuicCryptoStream* stream) {
274  crypto_stream_ = stream;
275}
276
277QuicCryptoStream* TestSession::GetCryptoStream() {
278  return crypto_stream_;
279}
280
281MockSendAlgorithm::MockSendAlgorithm() {
282}
283
284MockSendAlgorithm::~MockSendAlgorithm() {
285}
286
287namespace {
288
289string HexDumpWithMarks(const char* data, int length,
290                        const bool* marks, int mark_length) {
291  static const char kHexChars[] = "0123456789abcdef";
292  static const int kColumns = 4;
293
294  const int kSizeLimit = 1024;
295  if (length > kSizeLimit || mark_length > kSizeLimit) {
296    LOG(ERROR) << "Only dumping first " << kSizeLimit << " bytes.";
297    length = min(length, kSizeLimit);
298    mark_length = min(mark_length, kSizeLimit);
299  }
300
301  string hex;
302  for (const char* row = data; length > 0;
303       row += kColumns, length -= kColumns) {
304    for (const char *p = row; p < row + 4; ++p) {
305      if (p < row + length) {
306        const bool mark =
307            (marks && (p - data) < mark_length && marks[p - data]);
308        hex += mark ? '*' : ' ';
309        hex += kHexChars[(*p & 0xf0) >> 4];
310        hex += kHexChars[*p & 0x0f];
311        hex += mark ? '*' : ' ';
312      } else {
313        hex += "    ";
314      }
315    }
316    hex = hex + "  ";
317
318    for (const char *p = row; p < row + 4 && p < row + length; ++p)
319      hex += (*p >= 0x20 && *p <= 0x7f) ? (*p) : '.';
320
321    hex = hex + '\n';
322  }
323  return hex;
324}
325
326}  // namespace
327
328void CompareCharArraysWithHexError(
329    const string& description,
330    const char* actual,
331    const int actual_len,
332    const char* expected,
333    const int expected_len) {
334  const int min_len = min(actual_len, expected_len);
335  const int max_len = max(actual_len, expected_len);
336  scoped_ptr<bool[]> marks(new bool[max_len]);
337  bool identical = (actual_len == expected_len);
338  for (int i = 0; i < min_len; ++i) {
339    if (actual[i] != expected[i]) {
340      marks[i] = true;
341      identical = false;
342    } else {
343      marks[i] = false;
344    }
345  }
346  for (int i = min_len; i < max_len; ++i) {
347    marks[i] = true;
348  }
349  if (identical) return;
350  ADD_FAILURE()
351      << "Description:\n"
352      << description
353      << "\n\nExpected:\n"
354      << HexDumpWithMarks(expected, expected_len, marks.get(), max_len)
355      << "\nActual:\n"
356      << HexDumpWithMarks(actual, actual_len, marks.get(), max_len);
357}
358
359void CompareQuicDataWithHexError(
360    const string& description,
361    QuicData* actual,
362    QuicData* expected) {
363  CompareCharArraysWithHexError(
364      description,
365      actual->data(), actual->length(),
366      expected->data(), expected->length());
367}
368
369static QuicPacket* ConstructPacketFromHandshakeMessage(
370    QuicGuid guid,
371    const CryptoHandshakeMessage& message,
372    bool should_include_version) {
373  CryptoFramer crypto_framer;
374  scoped_ptr<QuicData> data(crypto_framer.ConstructHandshakeMessage(message));
375  QuicFramer quic_framer(QuicVersionMax(), QuicTime::Zero(), false);
376
377  QuicPacketHeader header;
378  header.public_header.guid = guid;
379  header.public_header.reset_flag = false;
380  header.public_header.version_flag = should_include_version;
381  header.packet_sequence_number = 1;
382  header.entropy_flag = false;
383  header.entropy_hash = 0;
384  header.fec_flag = false;
385  header.fec_group = 0;
386
387  QuicStreamFrame stream_frame(kCryptoStreamId, false, 0,
388                               data->AsStringPiece());
389
390  QuicFrame frame(&stream_frame);
391  QuicFrames frames;
392  frames.push_back(frame);
393  return quic_framer.BuildUnsizedDataPacket(header, frames).packet;
394}
395
396QuicPacket* ConstructHandshakePacket(QuicGuid guid, QuicTag tag) {
397  CryptoHandshakeMessage message;
398  message.set_tag(tag);
399  return ConstructPacketFromHandshakeMessage(guid, message, false);
400}
401
402size_t GetPacketLengthForOneStream(QuicVersion version,
403                                   bool include_version,
404                                   InFecGroup is_in_fec_group,
405                                   size_t* payload_length) {
406  *payload_length = 1;
407  const size_t stream_length =
408      NullEncrypter().GetCiphertextSize(*payload_length) +
409      QuicPacketCreator::StreamFramePacketOverhead(
410          version, PACKET_8BYTE_GUID, include_version,
411          PACKET_6BYTE_SEQUENCE_NUMBER, is_in_fec_group);
412  const size_t ack_length = NullEncrypter().GetCiphertextSize(
413      QuicFramer::GetMinAckFrameSize()) +
414      GetPacketHeaderSize(PACKET_8BYTE_GUID, include_version,
415                          PACKET_6BYTE_SEQUENCE_NUMBER, is_in_fec_group);
416  if (stream_length < ack_length) {
417    *payload_length = 1 + ack_length - stream_length;
418  }
419
420  return NullEncrypter().GetCiphertextSize(*payload_length) +
421      QuicPacketCreator::StreamFramePacketOverhead(
422          version, PACKET_8BYTE_GUID, include_version,
423          PACKET_6BYTE_SEQUENCE_NUMBER, is_in_fec_group);
424}
425
426// Size in bytes of the stream frame fields for an arbitrary StreamID and
427// offset and the last frame in a packet.
428size_t GetMinStreamFrameSize(QuicVersion version) {
429  return kQuicFrameTypeSize + kQuicMaxStreamIdSize + kQuicMaxStreamOffsetSize;
430}
431
432QuicPacketEntropyHash TestEntropyCalculator::EntropyHash(
433    QuicPacketSequenceNumber sequence_number) const {
434  return 1u;
435}
436
437QuicConfig DefaultQuicConfig() {
438  QuicConfig config;
439  config.SetDefaults();
440  return config;
441}
442
443bool TestDecompressorVisitor::OnDecompressedData(StringPiece data) {
444  data.AppendToString(&data_);
445  return true;
446}
447
448void TestDecompressorVisitor::OnDecompressionError() {
449  error_ = true;
450}
451
452}  // namespace test
453}  // namespace net
454