1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/quic/test_tools/crypto_test_utils.h"
6
7#include "net/quic/crypto/channel_id.h"
8#include "net/quic/crypto/common_cert_set.h"
9#include "net/quic/crypto/crypto_handshake.h"
10#include "net/quic/crypto/quic_crypto_server_config.h"
11#include "net/quic/crypto/quic_decrypter.h"
12#include "net/quic/crypto/quic_encrypter.h"
13#include "net/quic/crypto/quic_random.h"
14#include "net/quic/quic_clock.h"
15#include "net/quic/quic_crypto_client_stream.h"
16#include "net/quic/quic_crypto_server_stream.h"
17#include "net/quic/quic_crypto_stream.h"
18#include "net/quic/quic_server_id.h"
19#include "net/quic/test_tools/quic_connection_peer.h"
20#include "net/quic/test_tools/quic_test_utils.h"
21#include "net/quic/test_tools/simple_quic_framer.h"
22
23using base::StringPiece;
24using std::make_pair;
25using std::pair;
26using std::string;
27using std::vector;
28
29namespace net {
30namespace test {
31
32namespace {
33
34const char kServerHostname[] = "test.example.com";
35const uint16 kServerPort = 80;
36
37// CryptoFramerVisitor is a framer visitor that records handshake messages.
38class CryptoFramerVisitor : public CryptoFramerVisitorInterface {
39 public:
40  CryptoFramerVisitor()
41      : error_(false) {
42  }
43
44  virtual void OnError(CryptoFramer* framer) OVERRIDE { error_ = true; }
45
46  virtual void OnHandshakeMessage(
47      const CryptoHandshakeMessage& message) OVERRIDE {
48    messages_.push_back(message);
49  }
50
51  bool error() const {
52    return error_;
53  }
54
55  const vector<CryptoHandshakeMessage>& messages() const {
56    return messages_;
57  }
58
59 private:
60  bool error_;
61  vector<CryptoHandshakeMessage> messages_;
62};
63
64// MovePackets parses crypto handshake messages from packet number
65// |*inout_packet_index| through to the last packet (or until a packet fails to
66// decrypt) and has |dest_stream| process them. |*inout_packet_index| is updated
67// with an index one greater than the last packet processed.
68void MovePackets(PacketSavingConnection* source_conn,
69                 size_t *inout_packet_index,
70                 QuicCryptoStream* dest_stream,
71                 PacketSavingConnection* dest_conn) {
72  SimpleQuicFramer framer(source_conn->supported_versions());
73  CryptoFramer crypto_framer;
74  CryptoFramerVisitor crypto_visitor;
75
76  // In order to properly test the code we need to perform encryption and
77  // decryption so that the crypters latch when expected. The crypters are in
78  // |dest_conn|, but we don't want to try and use them there. Instead we swap
79  // them into |framer|, perform the decryption with them, and then swap them
80  // back.
81  QuicConnectionPeer::SwapCrypters(dest_conn, framer.framer());
82
83  crypto_framer.set_visitor(&crypto_visitor);
84
85  size_t index = *inout_packet_index;
86  for (; index < source_conn->encrypted_packets_.size(); index++) {
87    if (!framer.ProcessPacket(*source_conn->encrypted_packets_[index])) {
88      // The framer will be unable to decrypt forward-secure packets sent after
89      // the handshake is complete. Don't treat them as handshake packets.
90      break;
91    }
92
93    for (vector<QuicStreamFrame>::const_iterator
94         i =  framer.stream_frames().begin();
95         i != framer.stream_frames().end(); ++i) {
96      scoped_ptr<string> frame_data(i->GetDataAsString());
97      ASSERT_TRUE(crypto_framer.ProcessInput(*frame_data));
98      ASSERT_FALSE(crypto_visitor.error());
99    }
100  }
101  *inout_packet_index = index;
102
103  QuicConnectionPeer::SwapCrypters(dest_conn, framer.framer());
104
105  ASSERT_EQ(0u, crypto_framer.InputBytesRemaining());
106
107  for (vector<CryptoHandshakeMessage>::const_iterator
108       i = crypto_visitor.messages().begin();
109       i != crypto_visitor.messages().end(); ++i) {
110    dest_stream->OnHandshakeMessage(*i);
111  }
112}
113
114// HexChar parses |c| as a hex character. If valid, it sets |*value| to the
115// value of the hex character and returns true. Otherwise it returns false.
116bool HexChar(char c, uint8* value) {
117  if (c >= '0' && c <= '9') {
118    *value = c - '0';
119    return true;
120  }
121  if (c >= 'a' && c <= 'f') {
122    *value = c - 'a' + 10;
123    return true;
124  }
125  if (c >= 'A' && c <= 'F') {
126    *value = c - 'A' + 10;
127    return true;
128  }
129  return false;
130}
131
132// A ChannelIDSource that works in asynchronous mode unless the |callback|
133// argument to GetChannelIDKey is NULL.
134class AsyncTestChannelIDSource : public ChannelIDSource,
135                                 public CryptoTestUtils::CallbackSource {
136 public:
137  // Takes ownership of |sync_source|, a synchronous ChannelIDSource.
138  explicit AsyncTestChannelIDSource(ChannelIDSource* sync_source)
139      : sync_source_(sync_source) {}
140  virtual ~AsyncTestChannelIDSource() {}
141
142  // ChannelIDSource implementation.
143  virtual QuicAsyncStatus GetChannelIDKey(
144      const string& hostname,
145      scoped_ptr<ChannelIDKey>* channel_id_key,
146      ChannelIDSourceCallback* callback) OVERRIDE {
147    // Synchronous mode.
148    if (!callback) {
149      return sync_source_->GetChannelIDKey(hostname, channel_id_key, NULL);
150    }
151
152    // Asynchronous mode.
153    QuicAsyncStatus status =
154        sync_source_->GetChannelIDKey(hostname, &channel_id_key_, NULL);
155    if (status != QUIC_SUCCESS) {
156      return QUIC_FAILURE;
157    }
158    callback_.reset(callback);
159    return QUIC_PENDING;
160  }
161
162  // CallbackSource implementation.
163  virtual void RunPendingCallbacks() OVERRIDE {
164    if (callback_.get()) {
165      callback_->Run(&channel_id_key_);
166      callback_.reset();
167    }
168  }
169
170 private:
171  scoped_ptr<ChannelIDSource> sync_source_;
172  scoped_ptr<ChannelIDSourceCallback> callback_;
173  scoped_ptr<ChannelIDKey> channel_id_key_;
174};
175
176}  // anonymous namespace
177
178CryptoTestUtils::FakeClientOptions::FakeClientOptions()
179    : dont_verify_certs(false),
180      channel_id_enabled(false),
181      channel_id_source_async(false) {
182}
183
184// static
185int CryptoTestUtils::HandshakeWithFakeServer(
186    PacketSavingConnection* client_conn,
187    QuicCryptoClientStream* client) {
188  PacketSavingConnection* server_conn =
189      new PacketSavingConnection(true, client_conn->supported_versions());
190  TestSession server_session(server_conn, DefaultQuicConfig());
191  server_session.InitializeSession();
192  QuicCryptoServerConfig crypto_config(QuicCryptoServerConfig::TESTING,
193                                       QuicRandom::GetInstance());
194
195  SetupCryptoServerConfigForTest(
196      server_session.connection()->clock(),
197      server_session.connection()->random_generator(),
198      server_session.config(), &crypto_config);
199
200  QuicCryptoServerStream server(crypto_config, &server_session);
201  server_session.SetCryptoStream(&server);
202
203  // The client's handshake must have been started already.
204  CHECK_NE(0u, client_conn->packets_.size());
205
206  CommunicateHandshakeMessages(client_conn, client, server_conn, &server);
207
208  CompareClientAndServerKeys(client, &server);
209
210  return client->num_sent_client_hellos();
211}
212
213// static
214int CryptoTestUtils::HandshakeWithFakeClient(
215    PacketSavingConnection* server_conn,
216    QuicCryptoServerStream* server,
217    const FakeClientOptions& options) {
218  PacketSavingConnection* client_conn = new PacketSavingConnection(false);
219  TestClientSession client_session(client_conn, DefaultQuicConfig());
220  QuicCryptoClientConfig crypto_config;
221
222  client_session.config()->SetDefaults();
223  crypto_config.SetDefaults();
224  if (!options.dont_verify_certs) {
225    // TODO(wtc): replace this with ProofVerifierForTesting() when we have
226    // a working ProofSourceForTesting().
227    crypto_config.SetProofVerifier(FakeProofVerifierForTesting());
228  }
229  bool is_https = false;
230  AsyncTestChannelIDSource* async_channel_id_source = NULL;
231  if (options.channel_id_enabled) {
232    is_https = true;
233
234    ChannelIDSource* source = ChannelIDSourceForTesting();
235    if (options.channel_id_source_async) {
236      async_channel_id_source = new AsyncTestChannelIDSource(source);
237      source = async_channel_id_source;
238    }
239    crypto_config.SetChannelIDSource(source);
240  }
241  QuicServerId server_id(kServerHostname, kServerPort, is_https,
242                         PRIVACY_MODE_DISABLED);
243  QuicCryptoClientStream client(server_id, &client_session,
244                                ProofVerifyContextForTesting(),
245                                &crypto_config);
246  client_session.SetCryptoStream(&client);
247
248  CHECK(client.CryptoConnect());
249  CHECK_EQ(1u, client_conn->packets_.size());
250
251  CommunicateHandshakeMessagesAndRunCallbacks(
252      client_conn, &client, server_conn, server, async_channel_id_source);
253
254  CompareClientAndServerKeys(&client, server);
255
256  if (options.channel_id_enabled) {
257    scoped_ptr<ChannelIDKey> channel_id_key;
258    QuicAsyncStatus status =
259        crypto_config.channel_id_source()->GetChannelIDKey(kServerHostname,
260                                                           &channel_id_key,
261                                                           NULL);
262    EXPECT_EQ(QUIC_SUCCESS, status);
263    EXPECT_EQ(channel_id_key->SerializeKey(),
264              server->crypto_negotiated_params().channel_id);
265    EXPECT_EQ(options.channel_id_source_async,
266              client.WasChannelIDSourceCallbackRun());
267  }
268
269  return client.num_sent_client_hellos();
270}
271
272// static
273void CryptoTestUtils::SetupCryptoServerConfigForTest(
274    const QuicClock* clock,
275    QuicRandom* rand,
276    QuicConfig* config,
277    QuicCryptoServerConfig* crypto_config) {
278  config->SetDefaults();
279  QuicCryptoServerConfig::ConfigOptions options;
280  options.channel_id_enabled = true;
281  scoped_ptr<CryptoHandshakeMessage> scfg(
282      crypto_config->AddDefaultConfig(rand, clock, options));
283}
284
285// static
286void CryptoTestUtils::CommunicateHandshakeMessages(
287    PacketSavingConnection* a_conn,
288    QuicCryptoStream* a,
289    PacketSavingConnection* b_conn,
290    QuicCryptoStream* b) {
291  CommunicateHandshakeMessagesAndRunCallbacks(a_conn, a, b_conn, b, NULL);
292}
293
294// static
295void CryptoTestUtils::CommunicateHandshakeMessagesAndRunCallbacks(
296    PacketSavingConnection* a_conn,
297    QuicCryptoStream* a,
298    PacketSavingConnection* b_conn,
299    QuicCryptoStream* b,
300    CallbackSource* callback_source) {
301  size_t a_i = 0, b_i = 0;
302  while (!a->handshake_confirmed()) {
303    ASSERT_GT(a_conn->packets_.size(), a_i);
304    LOG(INFO) << "Processing " << a_conn->packets_.size() - a_i
305              << " packets a->b";
306    MovePackets(a_conn, &a_i, b, b_conn);
307    if (callback_source) {
308      callback_source->RunPendingCallbacks();
309    }
310
311    ASSERT_GT(b_conn->packets_.size(), b_i);
312    LOG(INFO) << "Processing " << b_conn->packets_.size() - b_i
313              << " packets b->a";
314    MovePackets(b_conn, &b_i, a, a_conn);
315    if (callback_source) {
316      callback_source->RunPendingCallbacks();
317    }
318  }
319}
320
321// static
322pair<size_t, size_t> CryptoTestUtils::AdvanceHandshake(
323    PacketSavingConnection* a_conn,
324    QuicCryptoStream* a,
325    size_t a_i,
326    PacketSavingConnection* b_conn,
327    QuicCryptoStream* b,
328    size_t b_i) {
329  LOG(INFO) << "Processing " << a_conn->packets_.size() - a_i
330            << " packets a->b";
331  MovePackets(a_conn, &a_i, b, b_conn);
332
333  LOG(INFO) << "Processing " << b_conn->packets_.size() - b_i
334            << " packets b->a";
335  if (b_conn->packets_.size() - b_i == 2) {
336    LOG(INFO) << "here";
337  }
338  MovePackets(b_conn, &b_i, a, a_conn);
339
340  return make_pair(a_i, b_i);
341}
342
343// static
344string CryptoTestUtils::GetValueForTag(const CryptoHandshakeMessage& message,
345                                       QuicTag tag) {
346  QuicTagValueMap::const_iterator it = message.tag_value_map().find(tag);
347  if (it == message.tag_value_map().end()) {
348    return string();
349  }
350  return it->second;
351}
352
353class MockCommonCertSets : public CommonCertSets {
354 public:
355  MockCommonCertSets(StringPiece cert, uint64 hash, uint32 index)
356      : cert_(cert.as_string()),
357        hash_(hash),
358        index_(index) {
359  }
360
361  virtual StringPiece GetCommonHashes() const OVERRIDE {
362    CHECK(false) << "not implemented";
363    return StringPiece();
364  }
365
366  virtual StringPiece GetCert(uint64 hash, uint32 index) const OVERRIDE {
367    if (hash == hash_ && index == index_) {
368      return cert_;
369    }
370    return StringPiece();
371  }
372
373  virtual bool MatchCert(StringPiece cert,
374                         StringPiece common_set_hashes,
375                         uint64* out_hash,
376                         uint32* out_index) const OVERRIDE {
377    if (cert != cert_) {
378      return false;
379    }
380
381    if (common_set_hashes.size() % sizeof(uint64) != 0) {
382      return false;
383    }
384    bool client_has_set = false;
385    for (size_t i = 0; i < common_set_hashes.size(); i += sizeof(uint64)) {
386      uint64 hash;
387      memcpy(&hash, common_set_hashes.data() + i, sizeof(hash));
388      if (hash == hash_) {
389        client_has_set = true;
390        break;
391      }
392    }
393
394    if (!client_has_set) {
395      return false;
396    }
397
398    *out_hash = hash_;
399    *out_index = index_;
400    return true;
401  }
402
403 private:
404  const string cert_;
405  const uint64 hash_;
406  const uint32 index_;
407};
408
409CommonCertSets* CryptoTestUtils::MockCommonCertSets(StringPiece cert,
410                                                    uint64 hash,
411                                                    uint32 index) {
412  return new class MockCommonCertSets(cert, hash, index);
413}
414
415void CryptoTestUtils::CompareClientAndServerKeys(
416    QuicCryptoClientStream* client,
417    QuicCryptoServerStream* server) {
418  const QuicEncrypter* client_encrypter(
419      client->session()->connection()->encrypter(ENCRYPTION_INITIAL));
420  const QuicDecrypter* client_decrypter(
421      client->session()->connection()->decrypter());
422  const QuicEncrypter* client_forward_secure_encrypter(
423      client->session()->connection()->encrypter(ENCRYPTION_FORWARD_SECURE));
424  const QuicDecrypter* client_forward_secure_decrypter(
425      client->session()->connection()->alternative_decrypter());
426  const QuicEncrypter* server_encrypter(
427      server->session()->connection()->encrypter(ENCRYPTION_INITIAL));
428  const QuicDecrypter* server_decrypter(
429      server->session()->connection()->decrypter());
430  const QuicEncrypter* server_forward_secure_encrypter(
431      server->session()->connection()->encrypter(ENCRYPTION_FORWARD_SECURE));
432  const QuicDecrypter* server_forward_secure_decrypter(
433      server->session()->connection()->alternative_decrypter());
434
435  StringPiece client_encrypter_key = client_encrypter->GetKey();
436  StringPiece client_encrypter_iv = client_encrypter->GetNoncePrefix();
437  StringPiece client_decrypter_key = client_decrypter->GetKey();
438  StringPiece client_decrypter_iv = client_decrypter->GetNoncePrefix();
439  StringPiece client_forward_secure_encrypter_key =
440      client_forward_secure_encrypter->GetKey();
441  StringPiece client_forward_secure_encrypter_iv =
442      client_forward_secure_encrypter->GetNoncePrefix();
443  StringPiece client_forward_secure_decrypter_key =
444      client_forward_secure_decrypter->GetKey();
445  StringPiece client_forward_secure_decrypter_iv =
446      client_forward_secure_decrypter->GetNoncePrefix();
447  StringPiece server_encrypter_key = server_encrypter->GetKey();
448  StringPiece server_encrypter_iv = server_encrypter->GetNoncePrefix();
449  StringPiece server_decrypter_key = server_decrypter->GetKey();
450  StringPiece server_decrypter_iv = server_decrypter->GetNoncePrefix();
451  StringPiece server_forward_secure_encrypter_key =
452      server_forward_secure_encrypter->GetKey();
453  StringPiece server_forward_secure_encrypter_iv =
454      server_forward_secure_encrypter->GetNoncePrefix();
455  StringPiece server_forward_secure_decrypter_key =
456      server_forward_secure_decrypter->GetKey();
457  StringPiece server_forward_secure_decrypter_iv =
458      server_forward_secure_decrypter->GetNoncePrefix();
459
460  StringPiece client_subkey_secret =
461      client->crypto_negotiated_params().subkey_secret;
462  StringPiece server_subkey_secret =
463      server->crypto_negotiated_params().subkey_secret;
464
465
466  const char kSampleLabel[] = "label";
467  const char kSampleContext[] = "context";
468  const size_t kSampleOutputLength = 32;
469  string client_key_extraction;
470  string server_key_extraction;
471  EXPECT_TRUE(client->ExportKeyingMaterial(kSampleLabel,
472                                           kSampleContext,
473                                           kSampleOutputLength,
474                                           &client_key_extraction));
475  EXPECT_TRUE(server->ExportKeyingMaterial(kSampleLabel,
476                                           kSampleContext,
477                                           kSampleOutputLength,
478                                           &server_key_extraction));
479
480  CompareCharArraysWithHexError("client write key",
481                                client_encrypter_key.data(),
482                                client_encrypter_key.length(),
483                                server_decrypter_key.data(),
484                                server_decrypter_key.length());
485  CompareCharArraysWithHexError("client write IV",
486                                client_encrypter_iv.data(),
487                                client_encrypter_iv.length(),
488                                server_decrypter_iv.data(),
489                                server_decrypter_iv.length());
490  CompareCharArraysWithHexError("server write key",
491                                server_encrypter_key.data(),
492                                server_encrypter_key.length(),
493                                client_decrypter_key.data(),
494                                client_decrypter_key.length());
495  CompareCharArraysWithHexError("server write IV",
496                                server_encrypter_iv.data(),
497                                server_encrypter_iv.length(),
498                                client_decrypter_iv.data(),
499                                client_decrypter_iv.length());
500  CompareCharArraysWithHexError("client forward secure write key",
501                                client_forward_secure_encrypter_key.data(),
502                                client_forward_secure_encrypter_key.length(),
503                                server_forward_secure_decrypter_key.data(),
504                                server_forward_secure_decrypter_key.length());
505  CompareCharArraysWithHexError("client forward secure write IV",
506                                client_forward_secure_encrypter_iv.data(),
507                                client_forward_secure_encrypter_iv.length(),
508                                server_forward_secure_decrypter_iv.data(),
509                                server_forward_secure_decrypter_iv.length());
510  CompareCharArraysWithHexError("server forward secure write key",
511                                server_forward_secure_encrypter_key.data(),
512                                server_forward_secure_encrypter_key.length(),
513                                client_forward_secure_decrypter_key.data(),
514                                client_forward_secure_decrypter_key.length());
515  CompareCharArraysWithHexError("server forward secure write IV",
516                                server_forward_secure_encrypter_iv.data(),
517                                server_forward_secure_encrypter_iv.length(),
518                                client_forward_secure_decrypter_iv.data(),
519                                client_forward_secure_decrypter_iv.length());
520  CompareCharArraysWithHexError("subkey secret",
521                                client_subkey_secret.data(),
522                                client_subkey_secret.length(),
523                                server_subkey_secret.data(),
524                                server_subkey_secret.length());
525  CompareCharArraysWithHexError("sample key extraction",
526                                client_key_extraction.data(),
527                                client_key_extraction.length(),
528                                server_key_extraction.data(),
529                                server_key_extraction.length());
530}
531
532// static
533QuicTag CryptoTestUtils::ParseTag(const char* tagstr) {
534  const size_t len = strlen(tagstr);
535  CHECK_NE(0u, len);
536
537  QuicTag tag = 0;
538
539  if (tagstr[0] == '#') {
540    CHECK_EQ(static_cast<size_t>(1 + 2*4), len);
541    tagstr++;
542
543    for (size_t i = 0; i < 8; i++) {
544      tag <<= 4;
545
546      uint8 v = 0;
547      CHECK(HexChar(tagstr[i], &v));
548      tag |= v;
549    }
550
551    return tag;
552  }
553
554  CHECK_LE(len, 4u);
555  for (size_t i = 0; i < 4; i++) {
556    tag >>= 8;
557    if (i < len) {
558      tag |= static_cast<uint32>(tagstr[i]) << 24;
559    }
560  }
561
562  return tag;
563}
564
565// static
566CryptoHandshakeMessage CryptoTestUtils::Message(const char* message_tag, ...) {
567  va_list ap;
568  va_start(ap, message_tag);
569
570  CryptoHandshakeMessage message = BuildMessage(message_tag, ap);
571  va_end(ap);
572  return message;
573}
574
575// static
576CryptoHandshakeMessage CryptoTestUtils::BuildMessage(const char* message_tag,
577                                                     va_list ap) {
578  CryptoHandshakeMessage msg;
579  msg.set_tag(ParseTag(message_tag));
580
581  for (;;) {
582    const char* tagstr = va_arg(ap, const char*);
583    if (tagstr == NULL) {
584      break;
585    }
586
587    if (tagstr[0] == '$') {
588      // Special value.
589      const char* const special = tagstr + 1;
590      if (strcmp(special, "padding") == 0) {
591        const int min_bytes = va_arg(ap, int);
592        msg.set_minimum_size(min_bytes);
593      } else {
594        CHECK(false) << "Unknown special value: " << special;
595      }
596
597      continue;
598    }
599
600    const QuicTag tag = ParseTag(tagstr);
601    const char* valuestr = va_arg(ap, const char*);
602
603    size_t len = strlen(valuestr);
604    if (len > 0 && valuestr[0] == '#') {
605      valuestr++;
606      len--;
607
608      CHECK_EQ(0u, len % 2);
609      scoped_ptr<uint8[]> buf(new uint8[len/2]);
610
611      for (size_t i = 0; i < len/2; i++) {
612        uint8 v = 0;
613        CHECK(HexChar(valuestr[i*2], &v));
614        buf[i] = v << 4;
615        CHECK(HexChar(valuestr[i*2 + 1], &v));
616        buf[i] |= v;
617      }
618
619      msg.SetStringPiece(
620          tag, StringPiece(reinterpret_cast<char*>(buf.get()), len/2));
621      continue;
622    }
623
624    msg.SetStringPiece(tag, valuestr);
625  }
626
627  // The CryptoHandshakeMessage needs to be serialized and parsed to ensure
628  // that any padding is included.
629  scoped_ptr<QuicData> bytes(CryptoFramer::ConstructHandshakeMessage(msg));
630  scoped_ptr<CryptoHandshakeMessage> parsed(
631      CryptoFramer::ParseMessage(bytes->AsStringPiece()));
632  CHECK(parsed.get());
633
634  return *parsed;
635}
636
637}  // namespace test
638}  // namespace net
639