1/*
2 *  Copyright 2011 The WebRTC Project Authors. All rights reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11
12#include <algorithm>
13#include <set>
14#include <string>
15
16#include "webrtc/base/gunit.h"
17#include "webrtc/base/helpers.h"
18#include "webrtc/base/scoped_ptr.h"
19#include "webrtc/base/ssladapter.h"
20#include "webrtc/base/sslconfig.h"
21#include "webrtc/base/sslidentity.h"
22#include "webrtc/base/sslstreamadapter.h"
23#include "webrtc/base/stream.h"
24#include "webrtc/test/testsupport/gtest_disable.h"
25
26static const int kBlockSize = 4096;
27static const char kAES_CM_HMAC_SHA1_80[] = "AES_CM_128_HMAC_SHA1_80";
28static const char kAES_CM_HMAC_SHA1_32[] = "AES_CM_128_HMAC_SHA1_32";
29static const char kExporterLabel[] = "label";
30static const unsigned char kExporterContext[] = "context";
31static int kExporterContextLen = sizeof(kExporterContext);
32
33static const char kRSA_PRIVATE_KEY_PEM[] =
34    "-----BEGIN RSA PRIVATE KEY-----\n"
35    "MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBAMYRkbhmI7kVA/rM\n"
36    "czsZ+6JDhDvnkF+vn6yCAGuRPV03zuRqZtDy4N4to7PZu9PjqrRl7nDMXrG3YG9y\n"
37    "rlIAZ72KjcKKFAJxQyAKLCIdawKRyp8RdK3LEySWEZb0AV58IadqPZDTNHHRX8dz\n"
38    "5aTSMsbbkZ+C/OzTnbiMqLL/vg6jAgMBAAECgYAvgOs4FJcgvp+TuREx7YtiYVsH\n"
39    "mwQPTum2z/8VzWGwR8BBHBvIpVe1MbD/Y4seyI2aco/7UaisatSgJhsU46/9Y4fq\n"
40    "2TwXH9QANf4at4d9n/R6rzwpAJOpgwZgKvdQjkfrKTtgLV+/dawvpxUYkRH4JZM1\n"
41    "CVGukMfKNrSVH4Ap4QJBAOJmGV1ASPnB4r4nc99at7JuIJmd7fmuVUwUgYi4XgaR\n"
42    "WhScBsgYwZ/JoywdyZJgnbcrTDuVcWG56B3vXbhdpMsCQQDf9zeJrjnPZ3Cqm79y\n"
43    "kdqANep0uwZciiNiWxsQrCHztywOvbFhdp8iYVFG9EK8DMY41Y5TxUwsHD+67zao\n"
44    "ZNqJAkEA1suLUP/GvL8IwuRneQd2tWDqqRQ/Td3qq03hP7e77XtF/buya3Ghclo5\n"
45    "54czUR89QyVfJEC6278nzA7n2h1uVQJAcG6mztNL6ja/dKZjYZye2CY44QjSlLo0\n"
46    "MTgTSjdfg/28fFn2Jjtqf9Pi/X+50LWI/RcYMC2no606wRk9kyOuIQJBAK6VSAim\n"
47    "1pOEjsYQn0X5KEIrz1G3bfCbB848Ime3U2/FWlCHMr6ch8kCZ5d1WUeJD3LbwMNG\n"
48    "UCXiYxSsu20QNVw=\n"
49    "-----END RSA PRIVATE KEY-----\n";
50
51static const char kCERT_PEM[] =
52    "-----BEGIN CERTIFICATE-----\n"
53    "MIIBmTCCAQKgAwIBAgIEbzBSAjANBgkqhkiG9w0BAQsFADARMQ8wDQYDVQQDEwZX\n"
54    "ZWJSVEMwHhcNMTQwMTAyMTgyNDQ3WhcNMTQwMjAxMTgyNDQ3WjARMQ8wDQYDVQQD\n"
55    "EwZXZWJSVEMwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAMYRkbhmI7kVA/rM\n"
56    "czsZ+6JDhDvnkF+vn6yCAGuRPV03zuRqZtDy4N4to7PZu9PjqrRl7nDMXrG3YG9y\n"
57    "rlIAZ72KjcKKFAJxQyAKLCIdawKRyp8RdK3LEySWEZb0AV58IadqPZDTNHHRX8dz\n"
58    "5aTSMsbbkZ+C/OzTnbiMqLL/vg6jAgMBAAEwDQYJKoZIhvcNAQELBQADgYEAUflI\n"
59    "VUe5Krqf5RVa5C3u/UTAOAUJBiDS3VANTCLBxjuMsvqOG0WvaYWP3HYPgrz0jXK2\n"
60    "LJE/mGw3MyFHEqi81jh95J+ypl6xKW6Rm8jKLR87gUvCaVYn/Z4/P3AqcQTB7wOv\n"
61    "UD0A8qfhfDM+LK6rPAnCsVN0NRDY3jvd6rzix9M=\n"
62    "-----END CERTIFICATE-----\n";
63
64#define MAYBE_SKIP_TEST(feature)                    \
65  if (!(rtc::SSLStreamAdapter::feature())) {  \
66    LOG(LS_INFO) << "Feature disabled... skipping"; \
67    return;                                         \
68  }
69
70class SSLStreamAdapterTestBase;
71
72class SSLDummyStream : public rtc::StreamInterface,
73                       public sigslot::has_slots<> {
74 public:
75  explicit SSLDummyStream(SSLStreamAdapterTestBase *test,
76                          const std::string &side,
77                          rtc::FifoBuffer *in,
78                          rtc::FifoBuffer *out) :
79      test_(test),
80      side_(side),
81      in_(in),
82      out_(out),
83      first_packet_(true) {
84    in_->SignalEvent.connect(this, &SSLDummyStream::OnEventIn);
85    out_->SignalEvent.connect(this, &SSLDummyStream::OnEventOut);
86  }
87
88  virtual rtc::StreamState GetState() const { return rtc::SS_OPEN; }
89
90  virtual rtc::StreamResult Read(void* buffer, size_t buffer_len,
91                                       size_t* read, int* error) {
92    rtc::StreamResult r;
93
94    r = in_->Read(buffer, buffer_len, read, error);
95    if (r == rtc::SR_BLOCK)
96      return rtc::SR_BLOCK;
97    if (r == rtc::SR_EOS)
98      return rtc::SR_EOS;
99
100    if (r != rtc::SR_SUCCESS) {
101      ADD_FAILURE();
102      return rtc::SR_ERROR;
103    }
104
105    return rtc::SR_SUCCESS;
106  }
107
108  // Catch readability events on in and pass them up.
109  virtual void OnEventIn(rtc::StreamInterface *stream, int sig,
110                         int err) {
111    int mask = (rtc::SE_READ | rtc::SE_CLOSE);
112
113    if (sig & mask) {
114      LOG(LS_INFO) << "SSLDummyStream::OnEvent side=" << side_ <<  " sig="
115        << sig << " forwarding upward";
116      PostEvent(sig & mask, 0);
117    }
118  }
119
120  // Catch writeability events on out and pass them up.
121  virtual void OnEventOut(rtc::StreamInterface *stream, int sig,
122                          int err) {
123    if (sig & rtc::SE_WRITE) {
124      LOG(LS_INFO) << "SSLDummyStream::OnEvent side=" << side_ <<  " sig="
125        << sig << " forwarding upward";
126
127      PostEvent(sig & rtc::SE_WRITE, 0);
128    }
129  }
130
131  // Write to the outgoing FifoBuffer
132  rtc::StreamResult WriteData(const void* data, size_t data_len,
133                                    size_t* written, int* error) {
134    return out_->Write(data, data_len, written, error);
135  }
136
137  // Defined later
138  virtual rtc::StreamResult Write(const void* data, size_t data_len,
139                                        size_t* written, int* error);
140
141  virtual void Close() {
142    LOG(LS_INFO) << "Closing outbound stream";
143    out_->Close();
144  }
145
146 private:
147  SSLStreamAdapterTestBase *test_;
148  const std::string side_;
149  rtc::FifoBuffer *in_;
150  rtc::FifoBuffer *out_;
151  bool first_packet_;
152};
153
154static const int kFifoBufferSize = 4096;
155
156class SSLStreamAdapterTestBase : public testing::Test,
157                                 public sigslot::has_slots<> {
158 public:
159  SSLStreamAdapterTestBase(const std::string& client_cert_pem,
160                           const std::string& client_private_key_pem,
161                           bool dtls) :
162      client_buffer_(kFifoBufferSize), server_buffer_(kFifoBufferSize),
163      client_stream_(
164          new SSLDummyStream(this, "c2s", &client_buffer_, &server_buffer_)),
165      server_stream_(
166          new SSLDummyStream(this, "s2c", &server_buffer_, &client_buffer_)),
167      client_ssl_(rtc::SSLStreamAdapter::Create(client_stream_)),
168      server_ssl_(rtc::SSLStreamAdapter::Create(server_stream_)),
169      client_identity_(NULL), server_identity_(NULL),
170      delay_(0), mtu_(1460), loss_(0), lose_first_packet_(false),
171      damage_(false), dtls_(dtls),
172      handshake_wait_(5000), identities_set_(false) {
173    // Set use of the test RNG to get predictable loss patterns.
174    rtc::SetRandomTestMode(true);
175
176    // Set up the slots
177    client_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent);
178    server_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent);
179
180    if (!client_cert_pem.empty() && !client_private_key_pem.empty()) {
181      client_identity_ = rtc::SSLIdentity::FromPEMStrings(
182          client_private_key_pem, client_cert_pem);
183    } else {
184      client_identity_ = rtc::SSLIdentity::Generate("client");
185    }
186    server_identity_ = rtc::SSLIdentity::Generate("server");
187
188    client_ssl_->SetIdentity(client_identity_);
189    server_ssl_->SetIdentity(server_identity_);
190  }
191
192  ~SSLStreamAdapterTestBase() {
193    // Put it back for the next test.
194    rtc::SetRandomTestMode(false);
195  }
196
197  static void SetUpTestCase() {
198    rtc::InitializeSSL();
199  }
200
201  static void TearDownTestCase() {
202    rtc::CleanupSSL();
203  }
204
205  // Recreate the client/server identities with the specified validity period.
206  // |not_before| and |not_after| are offsets from the current time in number
207  // of seconds.
208  void ResetIdentitiesWithValidity(int not_before, int not_after) {
209    client_stream_ =
210        new SSLDummyStream(this, "c2s", &client_buffer_, &server_buffer_);
211    server_stream_ =
212        new SSLDummyStream(this, "s2c", &server_buffer_, &client_buffer_);
213
214    client_ssl_.reset(rtc::SSLStreamAdapter::Create(client_stream_));
215    server_ssl_.reset(rtc::SSLStreamAdapter::Create(server_stream_));
216
217    client_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent);
218    server_ssl_->SignalEvent.connect(this, &SSLStreamAdapterTestBase::OnEvent);
219
220    rtc::SSLIdentityParams client_params;
221    client_params.common_name = "client";
222    client_params.not_before = not_before;
223    client_params.not_after = not_after;
224    client_identity_ = rtc::SSLIdentity::GenerateForTest(client_params);
225
226    rtc::SSLIdentityParams server_params;
227    server_params.common_name = "server";
228    server_params.not_before = not_before;
229    server_params.not_after = not_after;
230    server_identity_ = rtc::SSLIdentity::GenerateForTest(server_params);
231
232    client_ssl_->SetIdentity(client_identity_);
233    server_ssl_->SetIdentity(server_identity_);
234  }
235
236  virtual void OnEvent(rtc::StreamInterface *stream, int sig, int err) {
237    LOG(LS_INFO) << "SSLStreamAdapterTestBase::OnEvent sig=" << sig;
238
239    if (sig & rtc::SE_READ) {
240      ReadData(stream);
241    }
242
243    if ((stream == client_ssl_.get()) && (sig & rtc::SE_WRITE)) {
244      WriteData();
245    }
246  }
247
248  void SetPeerIdentitiesByDigest(bool correct) {
249    unsigned char digest[20];
250    size_t digest_len;
251    bool rv;
252
253    LOG(LS_INFO) << "Setting peer identities by digest";
254
255    rv = server_identity_->certificate().ComputeDigest(rtc::DIGEST_SHA_1,
256                                                       digest, 20,
257                                                       &digest_len);
258    ASSERT_TRUE(rv);
259    if (!correct) {
260      LOG(LS_INFO) << "Setting bogus digest for server cert";
261      digest[0]++;
262    }
263    rv = client_ssl_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, digest,
264                                               digest_len);
265    ASSERT_TRUE(rv);
266
267
268    rv = client_identity_->certificate().ComputeDigest(rtc::DIGEST_SHA_1,
269                                                       digest, 20, &digest_len);
270    ASSERT_TRUE(rv);
271    if (!correct) {
272      LOG(LS_INFO) << "Setting bogus digest for client cert";
273      digest[0]++;
274    }
275    rv = server_ssl_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, digest,
276                                               digest_len);
277    ASSERT_TRUE(rv);
278
279    identities_set_ = true;
280  }
281
282  void TestHandshake(bool expect_success = true) {
283    server_ssl_->SetMode(dtls_ ? rtc::SSL_MODE_DTLS :
284                         rtc::SSL_MODE_TLS);
285    client_ssl_->SetMode(dtls_ ? rtc::SSL_MODE_DTLS :
286                         rtc::SSL_MODE_TLS);
287
288    if (!dtls_) {
289      // Make sure we simulate a reliable network for TLS.
290      // This is just a check to make sure that people don't write wrong
291      // tests.
292      ASSERT((mtu_ == 1460) && (loss_ == 0) && (lose_first_packet_ == 0));
293    }
294
295    if (!identities_set_)
296      SetPeerIdentitiesByDigest(true);
297
298    // Start the handshake
299    int rv;
300
301    server_ssl_->SetServerRole();
302    rv = server_ssl_->StartSSLWithPeer();
303    ASSERT_EQ(0, rv);
304
305    rv = client_ssl_->StartSSLWithPeer();
306    ASSERT_EQ(0, rv);
307
308    // Now run the handshake
309    if (expect_success) {
310      EXPECT_TRUE_WAIT((client_ssl_->GetState() == rtc::SS_OPEN)
311                       && (server_ssl_->GetState() == rtc::SS_OPEN),
312                       handshake_wait_);
313    } else {
314      EXPECT_TRUE_WAIT(client_ssl_->GetState() == rtc::SS_CLOSED,
315                       handshake_wait_);
316    }
317  }
318
319  rtc::StreamResult DataWritten(SSLDummyStream *from, const void *data,
320                                      size_t data_len, size_t *written,
321                                      int *error) {
322    // Randomly drop loss_ percent of packets
323    if (rtc::CreateRandomId() % 100 < static_cast<uint32>(loss_)) {
324      LOG(LS_INFO) << "Randomly dropping packet, size=" << data_len;
325      *written = data_len;
326      return rtc::SR_SUCCESS;
327    }
328    if (dtls_ && (data_len > mtu_)) {
329      LOG(LS_INFO) << "Dropping packet > mtu, size=" << data_len;
330      *written = data_len;
331      return rtc::SR_SUCCESS;
332    }
333
334    // Optionally damage application data (type 23). Note that we don't damage
335    // handshake packets and we damage the last byte to keep the header
336    // intact but break the MAC.
337    if (damage_ && (*static_cast<const unsigned char *>(data) == 23)) {
338      std::vector<char> buf(data_len);
339
340      LOG(LS_INFO) << "Damaging packet";
341
342      memcpy(&buf[0], data, data_len);
343      buf[data_len - 1]++;
344
345      return from->WriteData(&buf[0], data_len, written, error);
346    }
347
348    return from->WriteData(data, data_len, written, error);
349  }
350
351  void SetDelay(int delay) {
352    delay_ = delay;
353  }
354  int GetDelay() { return delay_; }
355
356  void SetLoseFirstPacket(bool lose) {
357    lose_first_packet_ = lose;
358  }
359  bool GetLoseFirstPacket() { return lose_first_packet_; }
360
361  void SetLoss(int percent) {
362    loss_ = percent;
363  }
364
365  void SetDamage() {
366    damage_ = true;
367  }
368
369  void SetMtu(size_t mtu) {
370    mtu_ = mtu;
371  }
372
373  void SetHandshakeWait(int wait) {
374    handshake_wait_ = wait;
375  }
376
377  void SetDtlsSrtpCiphers(const std::vector<std::string> &ciphers,
378    bool client) {
379    if (client)
380      client_ssl_->SetDtlsSrtpCiphers(ciphers);
381    else
382      server_ssl_->SetDtlsSrtpCiphers(ciphers);
383  }
384
385  bool GetDtlsSrtpCipher(bool client, std::string *retval) {
386    if (client)
387      return client_ssl_->GetDtlsSrtpCipher(retval);
388    else
389      return server_ssl_->GetDtlsSrtpCipher(retval);
390  }
391
392  bool GetPeerCertificate(bool client, rtc::SSLCertificate** cert) {
393    if (client)
394      return client_ssl_->GetPeerCertificate(cert);
395    else
396      return server_ssl_->GetPeerCertificate(cert);
397  }
398
399  bool ExportKeyingMaterial(const char *label,
400                            const unsigned char *context,
401                            size_t context_len,
402                            bool use_context,
403                            bool client,
404                            unsigned char *result,
405                            size_t result_len) {
406    if (client)
407      return client_ssl_->ExportKeyingMaterial(label,
408                                               context, context_len,
409                                               use_context,
410                                               result, result_len);
411    else
412      return server_ssl_->ExportKeyingMaterial(label,
413                                               context, context_len,
414                                               use_context,
415                                               result, result_len);
416  }
417
418  // To be implemented by subclasses.
419  virtual void WriteData() = 0;
420  virtual void ReadData(rtc::StreamInterface *stream) = 0;
421  virtual void TestTransfer(int size) = 0;
422
423 protected:
424  rtc::FifoBuffer client_buffer_;
425  rtc::FifoBuffer server_buffer_;
426  SSLDummyStream *client_stream_;  // freed by client_ssl_ destructor
427  SSLDummyStream *server_stream_;  // freed by server_ssl_ destructor
428  rtc::scoped_ptr<rtc::SSLStreamAdapter> client_ssl_;
429  rtc::scoped_ptr<rtc::SSLStreamAdapter> server_ssl_;
430  rtc::SSLIdentity *client_identity_;  // freed by client_ssl_ destructor
431  rtc::SSLIdentity *server_identity_;  // freed by server_ssl_ destructor
432  int delay_;
433  size_t mtu_;
434  int loss_;
435  bool lose_first_packet_;
436  bool damage_;
437  bool dtls_;
438  int handshake_wait_;
439  bool identities_set_;
440};
441
442class SSLStreamAdapterTestTLS : public SSLStreamAdapterTestBase {
443 public:
444  SSLStreamAdapterTestTLS() :
445      SSLStreamAdapterTestBase("", "", false) {
446  };
447
448  // Test data transfer for TLS
449  virtual void TestTransfer(int size) {
450    LOG(LS_INFO) << "Starting transfer test with " << size << " bytes";
451    // Create some dummy data to send.
452    size_t received;
453
454    send_stream_.ReserveSize(size);
455    for (int i = 0; i < size; ++i) {
456      char ch = static_cast<char>(i);
457      send_stream_.Write(&ch, 1, NULL, NULL);
458    }
459    send_stream_.Rewind();
460
461    // Prepare the receive stream.
462    recv_stream_.ReserveSize(size);
463
464    // Start sending
465    WriteData();
466
467    // Wait for the client to close
468    EXPECT_TRUE_WAIT(server_ssl_->GetState() == rtc::SS_CLOSED, 10000);
469
470    // Now check the data
471    recv_stream_.GetSize(&received);
472
473    EXPECT_EQ(static_cast<size_t>(size), received);
474    EXPECT_EQ(0, memcmp(send_stream_.GetBuffer(),
475                        recv_stream_.GetBuffer(), size));
476  }
477
478  void WriteData() {
479    size_t position, tosend, size;
480    rtc::StreamResult rv;
481    size_t sent;
482    char block[kBlockSize];
483
484    send_stream_.GetSize(&size);
485    if (!size)
486      return;
487
488    for (;;) {
489      send_stream_.GetPosition(&position);
490      if (send_stream_.Read(block, sizeof(block), &tosend, NULL) !=
491          rtc::SR_EOS) {
492        rv = client_ssl_->Write(block, tosend, &sent, 0);
493
494        if (rv == rtc::SR_SUCCESS) {
495          send_stream_.SetPosition(position + sent);
496          LOG(LS_VERBOSE) << "Sent: " << position + sent;
497        } else if (rv == rtc::SR_BLOCK) {
498          LOG(LS_VERBOSE) << "Blocked...";
499          send_stream_.SetPosition(position);
500          break;
501        } else {
502          ADD_FAILURE();
503          break;
504        }
505      } else {
506        // Now close
507        LOG(LS_INFO) << "Wrote " << position << " bytes. Closing";
508        client_ssl_->Close();
509        break;
510      }
511    }
512  };
513
514  virtual void ReadData(rtc::StreamInterface *stream) {
515    char buffer[1600];
516    size_t bread;
517    int err2;
518    rtc::StreamResult r;
519
520    for (;;) {
521      r = stream->Read(buffer, sizeof(buffer), &bread, &err2);
522
523      if (r == rtc::SR_ERROR || r == rtc::SR_EOS) {
524        // Unfortunately, errors are the way that the stream adapter
525        // signals close in OpenSSL
526        stream->Close();
527        return;
528      }
529
530      if (r == rtc::SR_BLOCK)
531        break;
532
533      ASSERT_EQ(rtc::SR_SUCCESS, r);
534      LOG(LS_INFO) << "Read " << bread;
535
536      recv_stream_.Write(buffer, bread, NULL, NULL);
537    }
538  }
539
540 private:
541  rtc::MemoryStream send_stream_;
542  rtc::MemoryStream recv_stream_;
543};
544
545class SSLStreamAdapterTestDTLS : public SSLStreamAdapterTestBase {
546 public:
547  SSLStreamAdapterTestDTLS() :
548      SSLStreamAdapterTestBase("", "", true),
549      packet_size_(1000), count_(0), sent_(0) {
550  }
551
552  SSLStreamAdapterTestDTLS(const std::string& cert_pem,
553                           const std::string& private_key_pem) :
554      SSLStreamAdapterTestBase(cert_pem, private_key_pem, true),
555      packet_size_(1000), count_(0), sent_(0) {
556  }
557
558  virtual void WriteData() {
559    unsigned char *packet = new unsigned char[1600];
560
561    do {
562      memset(packet, sent_ & 0xff, packet_size_);
563      *(reinterpret_cast<uint32_t *>(packet)) = sent_;
564
565      size_t sent;
566      int rv = client_ssl_->Write(packet, packet_size_, &sent, 0);
567      if (rv == rtc::SR_SUCCESS) {
568        LOG(LS_VERBOSE) << "Sent: " << sent_;
569        sent_++;
570      } else if (rv == rtc::SR_BLOCK) {
571        LOG(LS_VERBOSE) << "Blocked...";
572        break;
573      } else {
574        ADD_FAILURE();
575        break;
576      }
577    } while (sent_ < count_);
578
579    delete [] packet;
580  }
581
582  virtual void ReadData(rtc::StreamInterface *stream) {
583    unsigned char buffer[2000];
584    size_t bread;
585    int err2;
586    rtc::StreamResult r;
587
588    for (;;) {
589      r = stream->Read(buffer, 2000, &bread, &err2);
590
591      if (r == rtc::SR_ERROR) {
592        // Unfortunately, errors are the way that the stream adapter
593        // signals close right now
594        stream->Close();
595        return;
596      }
597
598      if (r == rtc::SR_BLOCK)
599        break;
600
601      ASSERT_EQ(rtc::SR_SUCCESS, r);
602      LOG(LS_INFO) << "Read " << bread;
603
604      // Now parse the datagram
605      ASSERT_EQ(packet_size_, bread);
606      unsigned char* ptr_to_buffer = buffer;
607      uint32_t packet_num = *(reinterpret_cast<uint32_t *>(ptr_to_buffer));
608
609      for (size_t i = 4; i < packet_size_; i++) {
610        ASSERT_EQ((packet_num & 0xff), buffer[i]);
611      }
612      received_.insert(packet_num);
613    }
614  }
615
616  virtual void TestTransfer(int count) {
617    count_ = count;
618
619    WriteData();
620
621    EXPECT_TRUE_WAIT(sent_ == count_, 10000);
622    LOG(LS_INFO) << "sent_ == " << sent_;
623
624    if (damage_) {
625      WAIT(false, 2000);
626      EXPECT_EQ(0U, received_.size());
627    } else if (loss_ == 0) {
628        EXPECT_EQ_WAIT(static_cast<size_t>(sent_), received_.size(), 1000);
629    } else {
630      LOG(LS_INFO) << "Sent " << sent_ << " packets; received " <<
631          received_.size();
632    }
633  };
634
635 private:
636  size_t packet_size_;
637  int count_;
638  int sent_;
639  std::set<int> received_;
640};
641
642
643rtc::StreamResult SSLDummyStream::Write(const void* data, size_t data_len,
644                                              size_t* written, int* error) {
645  *written = data_len;
646
647  LOG(LS_INFO) << "Writing to loopback " << data_len;
648
649  if (first_packet_) {
650    first_packet_ = false;
651    if (test_->GetLoseFirstPacket()) {
652      LOG(LS_INFO) << "Losing initial packet of length " << data_len;
653      return rtc::SR_SUCCESS;
654    }
655  }
656
657  return test_->DataWritten(this, data, data_len, written, error);
658
659  return rtc::SR_SUCCESS;
660};
661
662class SSLStreamAdapterTestDTLSFromPEMStrings : public SSLStreamAdapterTestDTLS {
663 public:
664  SSLStreamAdapterTestDTLSFromPEMStrings() :
665      SSLStreamAdapterTestDTLS(kCERT_PEM, kRSA_PRIVATE_KEY_PEM) {
666  }
667};
668
669// Basic tests: TLS
670
671// Test that we cannot read/write if we have not yet handshaked.
672// This test only applies to NSS because OpenSSL has passthrough
673// semantics for I/O before the handshake is started.
674#if SSL_USE_NSS
675TEST_F(SSLStreamAdapterTestTLS, TestNoReadWriteBeforeConnect) {
676  rtc::StreamResult rv;
677  char block[kBlockSize];
678  size_t dummy;
679
680  rv = client_ssl_->Write(block, sizeof(block), &dummy, NULL);
681  ASSERT_EQ(rtc::SR_BLOCK, rv);
682
683  rv = client_ssl_->Read(block, sizeof(block), &dummy, NULL);
684  ASSERT_EQ(rtc::SR_BLOCK, rv);
685}
686#endif
687
688
689// Test that we can make a handshake work
690TEST_F(SSLStreamAdapterTestTLS, TestTLSConnect) {
691  TestHandshake();
692};
693
694// Test transfer -- trivial
695TEST_F(SSLStreamAdapterTestTLS, TestTLSTransfer) {
696  TestHandshake();
697  TestTransfer(100000);
698};
699
700// Test read-write after close.
701TEST_F(SSLStreamAdapterTestTLS, ReadWriteAfterClose) {
702  TestHandshake();
703  TestTransfer(100000);
704  client_ssl_->Close();
705
706  rtc::StreamResult rv;
707  char block[kBlockSize];
708  size_t dummy;
709
710  // It's an error to write after closed.
711  rv = client_ssl_->Write(block, sizeof(block), &dummy, NULL);
712  ASSERT_EQ(rtc::SR_ERROR, rv);
713
714  // But after closed read gives you EOS.
715  rv = client_ssl_->Read(block, sizeof(block), &dummy, NULL);
716  ASSERT_EQ(rtc::SR_EOS, rv);
717};
718
719// Test a handshake with a bogus peer digest
720TEST_F(SSLStreamAdapterTestTLS, TestTLSBogusDigest) {
721  SetPeerIdentitiesByDigest(false);
722  TestHandshake(false);
723};
724
725// Test moving a bunch of data
726
727// Basic tests: DTLS
728// Test that we can make a handshake work
729TEST_F(SSLStreamAdapterTestDTLS, TestDTLSConnect) {
730  MAYBE_SKIP_TEST(HaveDtls);
731  TestHandshake();
732};
733
734// Test that we can make a handshake work if the first packet in
735// each direction is lost. This gives us predictable loss
736// rather than having to tune random
737TEST_F(SSLStreamAdapterTestDTLS, TestDTLSConnectWithLostFirstPacket) {
738  MAYBE_SKIP_TEST(HaveDtls);
739  SetLoseFirstPacket(true);
740  TestHandshake();
741};
742
743// Test a handshake with loss and delay
744TEST_F(SSLStreamAdapterTestDTLS,
745       TestDTLSConnectWithLostFirstPacketDelay2s) {
746  MAYBE_SKIP_TEST(HaveDtls);
747  SetLoseFirstPacket(true);
748  SetDelay(2000);
749  SetHandshakeWait(20000);
750  TestHandshake();
751};
752
753// Test a handshake with small MTU
754TEST_F(SSLStreamAdapterTestDTLS, DISABLED_ON_MAC(TestDTLSConnectWithSmallMtu)) {
755  MAYBE_SKIP_TEST(HaveDtls);
756  SetMtu(700);
757  SetHandshakeWait(20000);
758  TestHandshake();
759};
760
761// Test transfer -- trivial
762TEST_F(SSLStreamAdapterTestDTLS, TestDTLSTransfer) {
763  MAYBE_SKIP_TEST(HaveDtls);
764  TestHandshake();
765  TestTransfer(100);
766};
767
768TEST_F(SSLStreamAdapterTestDTLS, TestDTLSTransferWithLoss) {
769  MAYBE_SKIP_TEST(HaveDtls);
770  TestHandshake();
771  SetLoss(10);
772  TestTransfer(100);
773};
774
775TEST_F(SSLStreamAdapterTestDTLS, TestDTLSTransferWithDamage) {
776  MAYBE_SKIP_TEST(HaveDtls);
777  SetDamage();  // Must be called first because first packet
778                // write happens at end of handshake.
779  TestHandshake();
780  TestTransfer(100);
781};
782
783// Test DTLS-SRTP with all high ciphers
784TEST_F(SSLStreamAdapterTestDTLS, TestDTLSSrtpHigh) {
785  MAYBE_SKIP_TEST(HaveDtlsSrtp);
786  std::vector<std::string> high;
787  high.push_back(kAES_CM_HMAC_SHA1_80);
788  SetDtlsSrtpCiphers(high, true);
789  SetDtlsSrtpCiphers(high, false);
790  TestHandshake();
791
792  std::string client_cipher;
793  ASSERT_TRUE(GetDtlsSrtpCipher(true, &client_cipher));
794  std::string server_cipher;
795  ASSERT_TRUE(GetDtlsSrtpCipher(false, &server_cipher));
796
797  ASSERT_EQ(client_cipher, server_cipher);
798  ASSERT_EQ(client_cipher, kAES_CM_HMAC_SHA1_80);
799};
800
801// Test DTLS-SRTP with all low ciphers
802TEST_F(SSLStreamAdapterTestDTLS, TestDTLSSrtpLow) {
803  MAYBE_SKIP_TEST(HaveDtlsSrtp);
804  std::vector<std::string> low;
805  low.push_back(kAES_CM_HMAC_SHA1_32);
806  SetDtlsSrtpCiphers(low, true);
807  SetDtlsSrtpCiphers(low, false);
808  TestHandshake();
809
810  std::string client_cipher;
811  ASSERT_TRUE(GetDtlsSrtpCipher(true, &client_cipher));
812  std::string server_cipher;
813  ASSERT_TRUE(GetDtlsSrtpCipher(false, &server_cipher));
814
815  ASSERT_EQ(client_cipher, server_cipher);
816  ASSERT_EQ(client_cipher, kAES_CM_HMAC_SHA1_32);
817};
818
819
820// Test DTLS-SRTP with a mismatch -- should not converge
821TEST_F(SSLStreamAdapterTestDTLS, TestDTLSSrtpHighLow) {
822  MAYBE_SKIP_TEST(HaveDtlsSrtp);
823  std::vector<std::string> high;
824  high.push_back(kAES_CM_HMAC_SHA1_80);
825  std::vector<std::string> low;
826  low.push_back(kAES_CM_HMAC_SHA1_32);
827  SetDtlsSrtpCiphers(high, true);
828  SetDtlsSrtpCiphers(low, false);
829  TestHandshake();
830
831  std::string client_cipher;
832  ASSERT_FALSE(GetDtlsSrtpCipher(true, &client_cipher));
833  std::string server_cipher;
834  ASSERT_FALSE(GetDtlsSrtpCipher(false, &server_cipher));
835};
836
837// Test DTLS-SRTP with each side being mixed -- should select high
838TEST_F(SSLStreamAdapterTestDTLS, TestDTLSSrtpMixed) {
839  MAYBE_SKIP_TEST(HaveDtlsSrtp);
840  std::vector<std::string> mixed;
841  mixed.push_back(kAES_CM_HMAC_SHA1_80);
842  mixed.push_back(kAES_CM_HMAC_SHA1_32);
843  SetDtlsSrtpCiphers(mixed, true);
844  SetDtlsSrtpCiphers(mixed, false);
845  TestHandshake();
846
847  std::string client_cipher;
848  ASSERT_TRUE(GetDtlsSrtpCipher(true, &client_cipher));
849  std::string server_cipher;
850  ASSERT_TRUE(GetDtlsSrtpCipher(false, &server_cipher));
851
852  ASSERT_EQ(client_cipher, server_cipher);
853  ASSERT_EQ(client_cipher, kAES_CM_HMAC_SHA1_80);
854};
855
856// Test an exporter
857TEST_F(SSLStreamAdapterTestDTLS, TestDTLSExporter) {
858  MAYBE_SKIP_TEST(HaveExporter);
859  TestHandshake();
860  unsigned char client_out[20];
861  unsigned char server_out[20];
862
863  bool result;
864  result = ExportKeyingMaterial(kExporterLabel,
865                                kExporterContext, kExporterContextLen,
866                                true, true,
867                                client_out, sizeof(client_out));
868  ASSERT_TRUE(result);
869
870  result = ExportKeyingMaterial(kExporterLabel,
871                                kExporterContext, kExporterContextLen,
872                                true, false,
873                                server_out, sizeof(server_out));
874  ASSERT_TRUE(result);
875
876  ASSERT_TRUE(!memcmp(client_out, server_out, sizeof(client_out)));
877}
878
879// Test not yet valid certificates are not rejected.
880TEST_F(SSLStreamAdapterTestDTLS, TestCertNotYetValid) {
881  MAYBE_SKIP_TEST(HaveDtls);
882  long one_day = 60 * 60 * 24;
883  // Make the certificates not valid until one day later.
884  ResetIdentitiesWithValidity(one_day, one_day);
885  TestHandshake();
886}
887
888// Test expired certificates are not rejected.
889TEST_F(SSLStreamAdapterTestDTLS, TestCertExpired) {
890  MAYBE_SKIP_TEST(HaveDtls);
891  long one_day = 60 * 60 * 24;
892  // Make the certificates already expired.
893  ResetIdentitiesWithValidity(-one_day, -one_day);
894  TestHandshake();
895}
896
897// Test data transfer using certs created from strings.
898TEST_F(SSLStreamAdapterTestDTLSFromPEMStrings, TestTransfer) {
899  MAYBE_SKIP_TEST(HaveDtls);
900  TestHandshake();
901  TestTransfer(100);
902}
903
904// Test getting the remote certificate.
905TEST_F(SSLStreamAdapterTestDTLSFromPEMStrings, TestDTLSGetPeerCertificate) {
906  MAYBE_SKIP_TEST(HaveDtls);
907
908  // Peer certificates haven't been received yet.
909  rtc::scoped_ptr<rtc::SSLCertificate> client_peer_cert;
910  ASSERT_FALSE(GetPeerCertificate(true, client_peer_cert.accept()));
911  ASSERT_FALSE(client_peer_cert != NULL);
912
913  rtc::scoped_ptr<rtc::SSLCertificate> server_peer_cert;
914  ASSERT_FALSE(GetPeerCertificate(false, server_peer_cert.accept()));
915  ASSERT_FALSE(server_peer_cert != NULL);
916
917  TestHandshake();
918
919  // The client should have a peer certificate after the handshake.
920  ASSERT_TRUE(GetPeerCertificate(true, client_peer_cert.accept()));
921  ASSERT_TRUE(client_peer_cert != NULL);
922
923  // It's not kCERT_PEM.
924  std::string client_peer_string = client_peer_cert->ToPEMString();
925  ASSERT_NE(kCERT_PEM, client_peer_string);
926
927  // It must not have a chain, because the test certs are self-signed.
928  rtc::SSLCertChain* client_peer_chain;
929  ASSERT_FALSE(client_peer_cert->GetChain(&client_peer_chain));
930
931  // The server should have a peer certificate after the handshake.
932  ASSERT_TRUE(GetPeerCertificate(false, server_peer_cert.accept()));
933  ASSERT_TRUE(server_peer_cert != NULL);
934
935  // It's kCERT_PEM
936  ASSERT_EQ(kCERT_PEM, server_peer_cert->ToPEMString());
937
938  // It must not have a chain, because the test certs are self-signed.
939  rtc::SSLCertChain* server_peer_chain;
940  ASSERT_FALSE(server_peer_cert->GetChain(&server_peer_chain));
941}
942