crypto_test_utils.cc revision 5d1f7b1de12d16ceb2c938c56701a3e8bfa558f7
1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/quic/test_tools/crypto_test_utils.h"
6
7#include "net/quic/crypto/channel_id.h"
8#include "net/quic/crypto/common_cert_set.h"
9#include "net/quic/crypto/crypto_handshake.h"
10#include "net/quic/crypto/quic_crypto_server_config.h"
11#include "net/quic/crypto/quic_decrypter.h"
12#include "net/quic/crypto/quic_encrypter.h"
13#include "net/quic/crypto/quic_random.h"
14#include "net/quic/quic_clock.h"
15#include "net/quic/quic_crypto_client_stream.h"
16#include "net/quic/quic_crypto_server_stream.h"
17#include "net/quic/quic_crypto_stream.h"
18#include "net/quic/test_tools/quic_connection_peer.h"
19#include "net/quic/test_tools/quic_test_utils.h"
20#include "net/quic/test_tools/simple_quic_framer.h"
21
22using base::StringPiece;
23using std::make_pair;
24using std::pair;
25using std::string;
26using std::vector;
27
28namespace net {
29namespace test {
30
31namespace {
32
33// CryptoFramerVisitor is a framer visitor that records handshake messages.
34class CryptoFramerVisitor : public CryptoFramerVisitorInterface {
35 public:
36  CryptoFramerVisitor()
37      : error_(false) {
38  }
39
40  virtual void OnError(CryptoFramer* framer) OVERRIDE {
41    error_ = true;
42  }
43
44  virtual void OnHandshakeMessage(
45      const CryptoHandshakeMessage& message) OVERRIDE {
46    messages_.push_back(message);
47  }
48
49  bool error() const {
50    return error_;
51  }
52
53  const vector<CryptoHandshakeMessage>& messages() const {
54    return messages_;
55  }
56
57 private:
58  bool error_;
59  vector<CryptoHandshakeMessage> messages_;
60};
61
62// MovePackets parses crypto handshake messages from packet number
63// |*inout_packet_index| through to the last packet and has |dest_stream|
64// process them. |*inout_packet_index| is updated with an index one greater
65// than the last packet processed.
66void MovePackets(PacketSavingConnection* source_conn,
67                 size_t *inout_packet_index,
68                 QuicCryptoStream* dest_stream,
69                 PacketSavingConnection* dest_conn) {
70  SimpleQuicFramer framer(source_conn->supported_versions());
71  CryptoFramer crypto_framer;
72  CryptoFramerVisitor crypto_visitor;
73
74  // In order to properly test the code we need to perform encryption and
75  // decryption so that the crypters latch when expected. The crypters are in
76  // |dest_conn|, but we don't want to try and use them there. Instead we swap
77  // them into |framer|, perform the decryption with them, and then swap them
78  // back.
79  QuicConnectionPeer::SwapCrypters(dest_conn, framer.framer());
80
81  crypto_framer.set_visitor(&crypto_visitor);
82
83  size_t index = *inout_packet_index;
84  for (; index < source_conn->encrypted_packets_.size(); index++) {
85    ASSERT_TRUE(framer.ProcessPacket(*source_conn->encrypted_packets_[index]));
86    for (vector<QuicStreamFrame>::const_iterator
87         i =  framer.stream_frames().begin();
88         i != framer.stream_frames().end(); ++i) {
89      scoped_ptr<string> frame_data(i->GetDataAsString());
90      ASSERT_TRUE(crypto_framer.ProcessInput(*frame_data));
91      ASSERT_FALSE(crypto_visitor.error());
92    }
93  }
94  *inout_packet_index = index;
95
96  QuicConnectionPeer::SwapCrypters(dest_conn, framer.framer());
97
98  ASSERT_EQ(0u, crypto_framer.InputBytesRemaining());
99
100  for (vector<CryptoHandshakeMessage>::const_iterator
101       i = crypto_visitor.messages().begin();
102       i != crypto_visitor.messages().end(); ++i) {
103    dest_stream->OnHandshakeMessage(*i);
104  }
105}
106
107// HexChar parses |c| as a hex character. If valid, it sets |*value| to the
108// value of the hex character and returns true. Otherwise it returns false.
109bool HexChar(char c, uint8* value) {
110  if (c >= '0' && c <= '9') {
111    *value = c - '0';
112    return true;
113  }
114  if (c >= 'a' && c <= 'f') {
115    *value = c - 'a' + 10;
116    return true;
117  }
118  if (c >= 'A' && c <= 'F') {
119    *value = c - 'A' + 10;
120    return true;
121  }
122  return false;
123}
124
125}  // anonymous namespace
126
127CryptoTestUtils::FakeClientOptions::FakeClientOptions()
128    : dont_verify_certs(false),
129      channel_id_enabled(false) {
130}
131
132// static
133int CryptoTestUtils::HandshakeWithFakeServer(
134    PacketSavingConnection* client_conn,
135    QuicCryptoClientStream* client) {
136  PacketSavingConnection* server_conn =
137      new PacketSavingConnection(true, client_conn->supported_versions());
138  TestSession server_session(server_conn, DefaultQuicConfig());
139
140  QuicCryptoServerConfig crypto_config(QuicCryptoServerConfig::TESTING,
141                                       QuicRandom::GetInstance());
142  SetupCryptoServerConfigForTest(
143      server_session.connection()->clock(),
144      server_session.connection()->random_generator(),
145      server_session.config(), &crypto_config);
146
147  QuicCryptoServerStream server(crypto_config, &server_session);
148  server_session.SetCryptoStream(&server);
149
150  // The client's handshake must have been started already.
151  CHECK_NE(0u, client_conn->packets_.size());
152
153  CommunicateHandshakeMessages(client_conn, client, server_conn, &server);
154
155  CompareClientAndServerKeys(client, &server);
156
157  return client->num_sent_client_hellos();
158}
159
160// static
161int CryptoTestUtils::HandshakeWithFakeClient(
162    PacketSavingConnection* server_conn,
163    QuicCryptoServerStream* server,
164    const FakeClientOptions& options) {
165  PacketSavingConnection* client_conn = new PacketSavingConnection(false);
166  TestSession client_session(client_conn, DefaultQuicConfig());
167  QuicCryptoClientConfig crypto_config;
168
169  client_session.config()->SetDefaults();
170  crypto_config.SetDefaults();
171  // TODO(rtenneti): Enable testing of ProofVerifier.
172  // if (!options.dont_verify_certs) {
173  //   crypto_config.SetProofVerifier(ProofVerifierForTesting());
174  // }
175  if (options.channel_id_enabled) {
176    crypto_config.SetChannelIDSigner(ChannelIDSignerForTesting());
177  }
178  QuicCryptoClientStream client("test.example.com", &client_session,
179                                &crypto_config);
180  client_session.SetCryptoStream(&client);
181
182  CHECK(client.CryptoConnect());
183  CHECK_EQ(1u, client_conn->packets_.size());
184
185  CommunicateHandshakeMessages(client_conn, &client, server_conn, server);
186
187  CompareClientAndServerKeys(&client, server);
188
189  if (options.channel_id_enabled) {
190    EXPECT_EQ(crypto_config.channel_id_signer()->GetKeyForHostname(
191                  "test.example.com"),
192              server->crypto_negotiated_params().channel_id);
193  }
194
195  return client.num_sent_client_hellos();
196}
197
198// static
199void CryptoTestUtils::SetupCryptoServerConfigForTest(
200    const QuicClock* clock,
201    QuicRandom* rand,
202    QuicConfig* config,
203    QuicCryptoServerConfig* crypto_config) {
204  config->SetDefaults();
205  QuicCryptoServerConfig::ConfigOptions options;
206  options.channel_id_enabled = true;
207  scoped_ptr<CryptoHandshakeMessage> scfg(
208      crypto_config->AddDefaultConfig(rand, clock, options));
209}
210
211// static
212void CryptoTestUtils::CommunicateHandshakeMessages(
213    PacketSavingConnection* a_conn,
214    QuicCryptoStream* a,
215    PacketSavingConnection* b_conn,
216    QuicCryptoStream* b) {
217  size_t a_i = 0, b_i = 0;
218  while (!a->handshake_confirmed()) {
219    ASSERT_GT(a_conn->packets_.size(), a_i);
220    LOG(INFO) << "Processing " << a_conn->packets_.size() - a_i
221              << " packets a->b";
222    MovePackets(a_conn, &a_i, b, b_conn);
223
224    ASSERT_GT(b_conn->packets_.size(), b_i);
225    LOG(INFO) << "Processing " << b_conn->packets_.size() - b_i
226              << " packets b->a";
227    if (b_conn->packets_.size() - b_i == 2) {
228      LOG(INFO) << "here";
229    }
230    MovePackets(b_conn, &b_i, a, a_conn);
231  }
232}
233
234pair<size_t, size_t> CryptoTestUtils::AdvanceHandshake(
235    PacketSavingConnection* a_conn,
236    QuicCryptoStream* a,
237    size_t a_i,
238    PacketSavingConnection* b_conn,
239    QuicCryptoStream* b,
240    size_t b_i) {
241  LOG(INFO) << "Processing " << a_conn->packets_.size() - a_i
242            << " packets a->b";
243  MovePackets(a_conn, &a_i, b, b_conn);
244
245  LOG(INFO) << "Processing " << b_conn->packets_.size() - b_i
246            << " packets b->a";
247  if (b_conn->packets_.size() - b_i == 2) {
248    LOG(INFO) << "here";
249  }
250  MovePackets(b_conn, &b_i, a, a_conn);
251
252  return make_pair(a_i, b_i);
253}
254
255// static
256string CryptoTestUtils::GetValueForTag(const CryptoHandshakeMessage& message,
257                                       QuicTag tag) {
258  QuicTagValueMap::const_iterator it = message.tag_value_map().find(tag);
259  if (it == message.tag_value_map().end()) {
260    return string();
261  }
262  return it->second;
263}
264
265class MockCommonCertSets : public CommonCertSets {
266 public:
267  MockCommonCertSets(StringPiece cert, uint64 hash, uint32 index)
268      : cert_(cert.as_string()),
269        hash_(hash),
270        index_(index) {
271  }
272
273  virtual StringPiece GetCommonHashes() const OVERRIDE {
274    CHECK(false) << "not implemented";
275    return StringPiece();
276  }
277
278  virtual StringPiece GetCert(uint64 hash, uint32 index) const OVERRIDE {
279    if (hash == hash_ && index == index_) {
280      return cert_;
281    }
282    return StringPiece();
283  }
284
285  virtual bool MatchCert(StringPiece cert,
286                         StringPiece common_set_hashes,
287                         uint64* out_hash,
288                         uint32* out_index) const OVERRIDE {
289    if (cert != cert_) {
290      return false;
291    }
292
293    if (common_set_hashes.size() % sizeof(uint64) != 0) {
294      return false;
295    }
296    bool client_has_set = false;
297    for (size_t i = 0; i < common_set_hashes.size(); i += sizeof(uint64)) {
298      uint64 hash;
299      memcpy(&hash, common_set_hashes.data() + i, sizeof(hash));
300      if (hash == hash_) {
301        client_has_set = true;
302        break;
303      }
304    }
305
306    if (!client_has_set) {
307      return false;
308    }
309
310    *out_hash = hash_;
311    *out_index = index_;
312    return true;
313  }
314
315 private:
316  const string cert_;
317  const uint64 hash_;
318  const uint32 index_;
319};
320
321CommonCertSets* CryptoTestUtils::MockCommonCertSets(StringPiece cert,
322                                                    uint64 hash,
323                                                    uint32 index) {
324  return new class MockCommonCertSets(cert, hash, index);
325}
326
327void CryptoTestUtils::CompareClientAndServerKeys(
328    QuicCryptoClientStream* client,
329    QuicCryptoServerStream* server) {
330  const QuicEncrypter* client_encrypter(
331      client->session()->connection()->encrypter(ENCRYPTION_INITIAL));
332  const QuicDecrypter* client_decrypter(
333      client->session()->connection()->decrypter());
334  const QuicEncrypter* client_forward_secure_encrypter(
335      client->session()->connection()->encrypter(ENCRYPTION_FORWARD_SECURE));
336  const QuicDecrypter* client_forward_secure_decrypter(
337      client->session()->connection()->alternative_decrypter());
338  const QuicEncrypter* server_encrypter(
339      server->session()->connection()->encrypter(ENCRYPTION_INITIAL));
340  const QuicDecrypter* server_decrypter(
341      server->session()->connection()->decrypter());
342  const QuicEncrypter* server_forward_secure_encrypter(
343      server->session()->connection()->encrypter(ENCRYPTION_FORWARD_SECURE));
344  const QuicDecrypter* server_forward_secure_decrypter(
345      server->session()->connection()->alternative_decrypter());
346
347  StringPiece client_encrypter_key = client_encrypter->GetKey();
348  StringPiece client_encrypter_iv = client_encrypter->GetNoncePrefix();
349  StringPiece client_decrypter_key = client_decrypter->GetKey();
350  StringPiece client_decrypter_iv = client_decrypter->GetNoncePrefix();
351  StringPiece client_forward_secure_encrypter_key =
352      client_forward_secure_encrypter->GetKey();
353  StringPiece client_forward_secure_encrypter_iv =
354      client_forward_secure_encrypter->GetNoncePrefix();
355  StringPiece client_forward_secure_decrypter_key =
356      client_forward_secure_decrypter->GetKey();
357  StringPiece client_forward_secure_decrypter_iv =
358      client_forward_secure_decrypter->GetNoncePrefix();
359  StringPiece server_encrypter_key = server_encrypter->GetKey();
360  StringPiece server_encrypter_iv = server_encrypter->GetNoncePrefix();
361  StringPiece server_decrypter_key = server_decrypter->GetKey();
362  StringPiece server_decrypter_iv = server_decrypter->GetNoncePrefix();
363  StringPiece server_forward_secure_encrypter_key =
364      server_forward_secure_encrypter->GetKey();
365  StringPiece server_forward_secure_encrypter_iv =
366      server_forward_secure_encrypter->GetNoncePrefix();
367  StringPiece server_forward_secure_decrypter_key =
368      server_forward_secure_decrypter->GetKey();
369  StringPiece server_forward_secure_decrypter_iv =
370      server_forward_secure_decrypter->GetNoncePrefix();
371
372  CompareCharArraysWithHexError("client write key",
373                                client_encrypter_key.data(),
374                                client_encrypter_key.length(),
375                                server_decrypter_key.data(),
376                                server_decrypter_key.length());
377  CompareCharArraysWithHexError("client write IV",
378                                client_encrypter_iv.data(),
379                                client_encrypter_iv.length(),
380                                server_decrypter_iv.data(),
381                                server_decrypter_iv.length());
382  CompareCharArraysWithHexError("server write key",
383                                server_encrypter_key.data(),
384                                server_encrypter_key.length(),
385                                client_decrypter_key.data(),
386                                client_decrypter_key.length());
387  CompareCharArraysWithHexError("server write IV",
388                                server_encrypter_iv.data(),
389                                server_encrypter_iv.length(),
390                                client_decrypter_iv.data(),
391                                client_decrypter_iv.length());
392  CompareCharArraysWithHexError("client forward secure write key",
393                                client_forward_secure_encrypter_key.data(),
394                                client_forward_secure_encrypter_key.length(),
395                                server_forward_secure_decrypter_key.data(),
396                                server_forward_secure_decrypter_key.length());
397  CompareCharArraysWithHexError("client forward secure write IV",
398                                client_forward_secure_encrypter_iv.data(),
399                                client_forward_secure_encrypter_iv.length(),
400                                server_forward_secure_decrypter_iv.data(),
401                                server_forward_secure_decrypter_iv.length());
402  CompareCharArraysWithHexError("server forward secure write key",
403                                server_forward_secure_encrypter_key.data(),
404                                server_forward_secure_encrypter_key.length(),
405                                client_forward_secure_decrypter_key.data(),
406                                client_forward_secure_decrypter_key.length());
407  CompareCharArraysWithHexError("server forward secure write IV",
408                                server_forward_secure_encrypter_iv.data(),
409                                server_forward_secure_encrypter_iv.length(),
410                                client_forward_secure_decrypter_iv.data(),
411                                client_forward_secure_decrypter_iv.length());
412}
413
414// static
415QuicTag CryptoTestUtils::ParseTag(const char* tagstr) {
416  const size_t len = strlen(tagstr);
417  CHECK_NE(0u, len);
418
419  QuicTag tag = 0;
420
421  if (tagstr[0] == '#') {
422    CHECK_EQ(static_cast<size_t>(1 + 2*4), len);
423    tagstr++;
424
425    for (size_t i = 0; i < 8; i++) {
426      tag <<= 4;
427
428      uint8 v = 0;
429      CHECK(HexChar(tagstr[i], &v));
430      tag |= v;
431    }
432
433    return tag;
434  }
435
436  CHECK_LE(len, 4u);
437  for (size_t i = 0; i < 4; i++) {
438    tag >>= 8;
439    if (i < len) {
440      tag |= static_cast<uint32>(tagstr[i]) << 24;
441    }
442  }
443
444  return tag;
445}
446
447// static
448CryptoHandshakeMessage CryptoTestUtils::Message(const char* message_tag, ...) {
449  va_list ap;
450  va_start(ap, message_tag);
451
452  CryptoHandshakeMessage message = BuildMessage(message_tag, ap);
453  va_end(ap);
454  return message;
455}
456
457// static
458CryptoHandshakeMessage CryptoTestUtils::BuildMessage(const char* message_tag,
459                                                     va_list ap) {
460  CryptoHandshakeMessage msg;
461  msg.set_tag(ParseTag(message_tag));
462
463  for (;;) {
464    const char* tagstr = va_arg(ap, const char*);
465    if (tagstr == NULL) {
466      break;
467    }
468
469    if (tagstr[0] == '$') {
470      // Special value.
471      const char* const special = tagstr + 1;
472      if (strcmp(special, "padding") == 0) {
473        const int min_bytes = va_arg(ap, int);
474        msg.set_minimum_size(min_bytes);
475      } else {
476        CHECK(false) << "Unknown special value: " << special;
477      }
478
479      continue;
480    }
481
482    const QuicTag tag = ParseTag(tagstr);
483    const char* valuestr = va_arg(ap, const char*);
484
485    size_t len = strlen(valuestr);
486    if (len > 0 && valuestr[0] == '#') {
487      valuestr++;
488      len--;
489
490      CHECK(len % 2 == 0);
491      scoped_ptr<uint8[]> buf(new uint8[len/2]);
492
493      for (size_t i = 0; i < len/2; i++) {
494        uint8 v = 0;
495        CHECK(HexChar(valuestr[i*2], &v));
496        buf[i] = v << 4;
497        CHECK(HexChar(valuestr[i*2 + 1], &v));
498        buf[i] |= v;
499      }
500
501      msg.SetStringPiece(
502          tag, StringPiece(reinterpret_cast<char*>(buf.get()), len/2));
503      continue;
504    }
505
506    msg.SetStringPiece(tag, valuestr);
507  }
508
509  // The CryptoHandshakeMessage needs to be serialized and parsed to ensure
510  // that any padding is included.
511  scoped_ptr<QuicData> bytes(CryptoFramer::ConstructHandshakeMessage(msg));
512  scoped_ptr<CryptoHandshakeMessage> parsed(
513      CryptoFramer::ParseMessage(bytes->AsStringPiece()));
514  CHECK(parsed.get());
515
516  return *parsed;
517}
518
519}  // namespace test
520}  // namespace net
521