1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "net/quic/test_tools/quic_test_utils.h" 6 7#include "base/stl_util.h" 8#include "net/quic/crypto/crypto_framer.h" 9#include "net/quic/crypto/crypto_handshake.h" 10#include "net/quic/crypto/crypto_utils.h" 11#include "net/quic/crypto/null_encrypter.h" 12#include "net/quic/crypto/quic_decrypter.h" 13#include "net/quic/crypto/quic_encrypter.h" 14#include "net/quic/quic_framer.h" 15#include "net/quic/quic_packet_creator.h" 16#include "net/spdy/spdy_frame_builder.h" 17 18using base::StringPiece; 19using std::max; 20using std::min; 21using std::string; 22using testing::_; 23 24namespace net { 25namespace test { 26namespace { 27 28// No-op alarm implementation used by MockHelper. 29class TestAlarm : public QuicAlarm { 30 public: 31 explicit TestAlarm(QuicAlarm::Delegate* delegate) 32 : QuicAlarm(delegate) { 33 } 34 35 virtual void SetImpl() OVERRIDE {} 36 virtual void CancelImpl() OVERRIDE {} 37}; 38 39} // namespace 40 41MockFramerVisitor::MockFramerVisitor() { 42 // By default, we want to accept packets. 43 ON_CALL(*this, OnProtocolVersionMismatch(_)) 44 .WillByDefault(testing::Return(false)); 45 46 // By default, we want to accept packets. 47 ON_CALL(*this, OnPacketHeader(_)) 48 .WillByDefault(testing::Return(true)); 49 50 ON_CALL(*this, OnStreamFrame(_)) 51 .WillByDefault(testing::Return(true)); 52 53 ON_CALL(*this, OnAckFrame(_)) 54 .WillByDefault(testing::Return(true)); 55 56 ON_CALL(*this, OnCongestionFeedbackFrame(_)) 57 .WillByDefault(testing::Return(true)); 58 59 ON_CALL(*this, OnRstStreamFrame(_)) 60 .WillByDefault(testing::Return(true)); 61 62 ON_CALL(*this, OnConnectionCloseFrame(_)) 63 .WillByDefault(testing::Return(true)); 64 65 ON_CALL(*this, OnGoAwayFrame(_)) 66 .WillByDefault(testing::Return(true)); 67} 68 69MockFramerVisitor::~MockFramerVisitor() { 70} 71 72bool NoOpFramerVisitor::OnProtocolVersionMismatch(QuicVersion version) { 73 return false; 74} 75 76bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader& header) { 77 return true; 78} 79 80bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame& frame) { 81 return true; 82} 83 84bool NoOpFramerVisitor::OnAckFrame(const QuicAckFrame& frame) { 85 return true; 86} 87 88bool NoOpFramerVisitor::OnCongestionFeedbackFrame( 89 const QuicCongestionFeedbackFrame& frame) { 90 return true; 91} 92 93bool NoOpFramerVisitor::OnRstStreamFrame( 94 const QuicRstStreamFrame& frame) { 95 return true; 96} 97 98bool NoOpFramerVisitor::OnConnectionCloseFrame( 99 const QuicConnectionCloseFrame& frame) { 100 return true; 101} 102 103bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame& frame) { 104 return true; 105} 106 107FramerVisitorCapturingFrames::FramerVisitorCapturingFrames() : frame_count_(0) { 108} 109 110FramerVisitorCapturingFrames::~FramerVisitorCapturingFrames() { 111} 112 113bool FramerVisitorCapturingFrames::OnPacketHeader( 114 const QuicPacketHeader& header) { 115 header_ = header; 116 frame_count_ = 0; 117 return true; 118} 119 120bool FramerVisitorCapturingFrames::OnStreamFrame(const QuicStreamFrame& frame) { 121 // TODO(ianswett): Own the underlying string, so it will not exist outside 122 // this callback. 123 stream_frames_.push_back(frame); 124 ++frame_count_; 125 return true; 126} 127 128bool FramerVisitorCapturingFrames::OnAckFrame(const QuicAckFrame& frame) { 129 ack_.reset(new QuicAckFrame(frame)); 130 ++frame_count_; 131 return true; 132} 133 134bool FramerVisitorCapturingFrames::OnCongestionFeedbackFrame( 135 const QuicCongestionFeedbackFrame& frame) { 136 feedback_.reset(new QuicCongestionFeedbackFrame(frame)); 137 ++frame_count_; 138 return true; 139} 140 141bool FramerVisitorCapturingFrames::OnRstStreamFrame( 142 const QuicRstStreamFrame& frame) { 143 rst_.reset(new QuicRstStreamFrame(frame)); 144 ++frame_count_; 145 return true; 146} 147 148bool FramerVisitorCapturingFrames::OnConnectionCloseFrame( 149 const QuicConnectionCloseFrame& frame) { 150 close_.reset(new QuicConnectionCloseFrame(frame)); 151 ++frame_count_; 152 return true; 153} 154 155bool FramerVisitorCapturingFrames::OnGoAwayFrame(const QuicGoAwayFrame& frame) { 156 goaway_.reset(new QuicGoAwayFrame(frame)); 157 ++frame_count_; 158 return true; 159} 160 161void FramerVisitorCapturingFrames::OnVersionNegotiationPacket( 162 const QuicVersionNegotiationPacket& packet) { 163 version_negotiation_packet_.reset(new QuicVersionNegotiationPacket(packet)); 164 frame_count_ = 0; 165} 166 167FramerVisitorCapturingPublicReset::FramerVisitorCapturingPublicReset() { 168} 169 170FramerVisitorCapturingPublicReset::~FramerVisitorCapturingPublicReset() { 171} 172 173void FramerVisitorCapturingPublicReset::OnPublicResetPacket( 174 const QuicPublicResetPacket& public_reset) { 175 public_reset_packet_ = public_reset; 176} 177 178MockConnectionVisitor::MockConnectionVisitor() { 179} 180 181MockConnectionVisitor::~MockConnectionVisitor() { 182} 183 184MockHelper::MockHelper() { 185} 186 187MockHelper::~MockHelper() { 188} 189 190const QuicClock* MockHelper::GetClock() const { 191 return &clock_; 192} 193 194QuicRandom* MockHelper::GetRandomGenerator() { 195 return &random_generator_; 196} 197 198QuicAlarm* MockHelper::CreateAlarm(QuicAlarm::Delegate* delegate) { 199 return new TestAlarm(delegate); 200} 201 202void MockHelper::AdvanceTime(QuicTime::Delta delta) { 203 clock_.AdvanceTime(delta); 204} 205 206MockConnection::MockConnection(QuicGuid guid, 207 IPEndPoint address, 208 bool is_server) 209 : QuicConnection(guid, address, new testing::NiceMock<MockHelper>(), 210 is_server, QuicVersionMax()), 211 has_mock_helper_(true) { 212} 213 214MockConnection::MockConnection(QuicGuid guid, 215 IPEndPoint address, 216 QuicConnectionHelperInterface* helper, 217 bool is_server) 218 : QuicConnection(guid, address, helper, is_server, QuicVersionMax()), 219 has_mock_helper_(false) { 220} 221 222MockConnection::~MockConnection() { 223} 224 225void MockConnection::AdvanceTime(QuicTime::Delta delta) { 226 CHECK(has_mock_helper_) << "Cannot advance time unless a MockClock is being" 227 " used"; 228 static_cast<MockHelper*>(helper())->AdvanceTime(delta); 229} 230 231PacketSavingConnection::PacketSavingConnection(QuicGuid guid, 232 IPEndPoint address, 233 bool is_server) 234 : MockConnection(guid, address, is_server) { 235} 236 237PacketSavingConnection::~PacketSavingConnection() { 238 STLDeleteElements(&packets_); 239 STLDeleteElements(&encrypted_packets_); 240} 241 242bool PacketSavingConnection::SendOrQueuePacket( 243 EncryptionLevel level, 244 QuicPacketSequenceNumber sequence_number, 245 QuicPacket* packet, 246 QuicPacketEntropyHash entropy_hash, 247 HasRetransmittableData retransmittable) { 248 packets_.push_back(packet); 249 QuicEncryptedPacket* encrypted = 250 framer_.EncryptPacket(level, sequence_number, *packet); 251 encrypted_packets_.push_back(encrypted); 252 return true; 253} 254 255MockSession::MockSession(QuicConnection* connection, bool is_server) 256 : QuicSession(connection, DefaultQuicConfig(), is_server) { 257 ON_CALL(*this, WriteData(_, _, _, _)) 258 .WillByDefault(testing::Return(QuicConsumedData(0, false))); 259} 260 261MockSession::~MockSession() { 262} 263 264TestSession::TestSession(QuicConnection* connection, 265 const QuicConfig& config, 266 bool is_server) 267 : QuicSession(connection, config, is_server), 268 crypto_stream_(NULL) { 269} 270 271TestSession::~TestSession() {} 272 273void TestSession::SetCryptoStream(QuicCryptoStream* stream) { 274 crypto_stream_ = stream; 275} 276 277QuicCryptoStream* TestSession::GetCryptoStream() { 278 return crypto_stream_; 279} 280 281MockSendAlgorithm::MockSendAlgorithm() { 282} 283 284MockSendAlgorithm::~MockSendAlgorithm() { 285} 286 287namespace { 288 289string HexDumpWithMarks(const char* data, int length, 290 const bool* marks, int mark_length) { 291 static const char kHexChars[] = "0123456789abcdef"; 292 static const int kColumns = 4; 293 294 const int kSizeLimit = 1024; 295 if (length > kSizeLimit || mark_length > kSizeLimit) { 296 LOG(ERROR) << "Only dumping first " << kSizeLimit << " bytes."; 297 length = min(length, kSizeLimit); 298 mark_length = min(mark_length, kSizeLimit); 299 } 300 301 string hex; 302 for (const char* row = data; length > 0; 303 row += kColumns, length -= kColumns) { 304 for (const char *p = row; p < row + 4; ++p) { 305 if (p < row + length) { 306 const bool mark = 307 (marks && (p - data) < mark_length && marks[p - data]); 308 hex += mark ? '*' : ' '; 309 hex += kHexChars[(*p & 0xf0) >> 4]; 310 hex += kHexChars[*p & 0x0f]; 311 hex += mark ? '*' : ' '; 312 } else { 313 hex += " "; 314 } 315 } 316 hex = hex + " "; 317 318 for (const char *p = row; p < row + 4 && p < row + length; ++p) 319 hex += (*p >= 0x20 && *p <= 0x7f) ? (*p) : '.'; 320 321 hex = hex + '\n'; 322 } 323 return hex; 324} 325 326} // namespace 327 328void CompareCharArraysWithHexError( 329 const string& description, 330 const char* actual, 331 const int actual_len, 332 const char* expected, 333 const int expected_len) { 334 const int min_len = min(actual_len, expected_len); 335 const int max_len = max(actual_len, expected_len); 336 scoped_ptr<bool[]> marks(new bool[max_len]); 337 bool identical = (actual_len == expected_len); 338 for (int i = 0; i < min_len; ++i) { 339 if (actual[i] != expected[i]) { 340 marks[i] = true; 341 identical = false; 342 } else { 343 marks[i] = false; 344 } 345 } 346 for (int i = min_len; i < max_len; ++i) { 347 marks[i] = true; 348 } 349 if (identical) return; 350 ADD_FAILURE() 351 << "Description:\n" 352 << description 353 << "\n\nExpected:\n" 354 << HexDumpWithMarks(expected, expected_len, marks.get(), max_len) 355 << "\nActual:\n" 356 << HexDumpWithMarks(actual, actual_len, marks.get(), max_len); 357} 358 359void CompareQuicDataWithHexError( 360 const string& description, 361 QuicData* actual, 362 QuicData* expected) { 363 CompareCharArraysWithHexError( 364 description, 365 actual->data(), actual->length(), 366 expected->data(), expected->length()); 367} 368 369static QuicPacket* ConstructPacketFromHandshakeMessage( 370 QuicGuid guid, 371 const CryptoHandshakeMessage& message, 372 bool should_include_version) { 373 CryptoFramer crypto_framer; 374 scoped_ptr<QuicData> data(crypto_framer.ConstructHandshakeMessage(message)); 375 QuicFramer quic_framer(QuicVersionMax(), QuicTime::Zero(), false); 376 377 QuicPacketHeader header; 378 header.public_header.guid = guid; 379 header.public_header.reset_flag = false; 380 header.public_header.version_flag = should_include_version; 381 header.packet_sequence_number = 1; 382 header.entropy_flag = false; 383 header.entropy_hash = 0; 384 header.fec_flag = false; 385 header.fec_group = 0; 386 387 QuicStreamFrame stream_frame(kCryptoStreamId, false, 0, 388 data->AsStringPiece()); 389 390 QuicFrame frame(&stream_frame); 391 QuicFrames frames; 392 frames.push_back(frame); 393 return quic_framer.BuildUnsizedDataPacket(header, frames).packet; 394} 395 396QuicPacket* ConstructHandshakePacket(QuicGuid guid, QuicTag tag) { 397 CryptoHandshakeMessage message; 398 message.set_tag(tag); 399 return ConstructPacketFromHandshakeMessage(guid, message, false); 400} 401 402size_t GetPacketLengthForOneStream(QuicVersion version, 403 bool include_version, 404 InFecGroup is_in_fec_group, 405 size_t* payload_length) { 406 *payload_length = 1; 407 const size_t stream_length = 408 NullEncrypter().GetCiphertextSize(*payload_length) + 409 QuicPacketCreator::StreamFramePacketOverhead( 410 version, PACKET_8BYTE_GUID, include_version, 411 PACKET_6BYTE_SEQUENCE_NUMBER, is_in_fec_group); 412 const size_t ack_length = NullEncrypter().GetCiphertextSize( 413 QuicFramer::GetMinAckFrameSize()) + 414 GetPacketHeaderSize(PACKET_8BYTE_GUID, include_version, 415 PACKET_6BYTE_SEQUENCE_NUMBER, is_in_fec_group); 416 if (stream_length < ack_length) { 417 *payload_length = 1 + ack_length - stream_length; 418 } 419 420 return NullEncrypter().GetCiphertextSize(*payload_length) + 421 QuicPacketCreator::StreamFramePacketOverhead( 422 version, PACKET_8BYTE_GUID, include_version, 423 PACKET_6BYTE_SEQUENCE_NUMBER, is_in_fec_group); 424} 425 426// Size in bytes of the stream frame fields for an arbitrary StreamID and 427// offset and the last frame in a packet. 428size_t GetMinStreamFrameSize(QuicVersion version) { 429 return kQuicFrameTypeSize + kQuicMaxStreamIdSize + kQuicMaxStreamOffsetSize; 430} 431 432QuicPacketEntropyHash TestEntropyCalculator::EntropyHash( 433 QuicPacketSequenceNumber sequence_number) const { 434 return 1u; 435} 436 437QuicConfig DefaultQuicConfig() { 438 QuicConfig config; 439 config.SetDefaults(); 440 return config; 441} 442 443bool TestDecompressorVisitor::OnDecompressedData(StringPiece data) { 444 data.AppendToString(&data_); 445 return true; 446} 447 448void TestDecompressorVisitor::OnDecompressionError() { 449 error_ = true; 450} 451 452} // namespace test 453} // namespace net 454