1/*
2 *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include <vector>
12
13#if HAVE_CONFIG_H
14#include "config.h"
15#endif  // HAVE_CONFIG_H
16
17#if HAVE_NSS_SSL_H
18
19#include "webrtc/base/nssstreamadapter.h"
20
21#include "keyhi.h"
22#include "nspr.h"
23#include "nss.h"
24#include "pk11pub.h"
25#include "secerr.h"
26
27#ifdef NSS_SSL_RELATIVE_PATH
28#include "ssl.h"
29#include "sslerr.h"
30#include "sslproto.h"
31#else
32#include "net/third_party/nss/ssl/ssl.h"
33#include "net/third_party/nss/ssl/sslerr.h"
34#include "net/third_party/nss/ssl/sslproto.h"
35#endif
36
37#include "webrtc/base/nssidentity.h"
38#include "webrtc/base/safe_conversions.h"
39#include "webrtc/base/thread.h"
40
41namespace rtc {
42
43PRDescIdentity NSSStreamAdapter::nspr_layer_identity = PR_INVALID_IO_LAYER;
44
45#define UNIMPLEMENTED \
46  PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0); \
47  LOG(LS_ERROR) \
48  << "Call to unimplemented function "<< __FUNCTION__; ASSERT(false)
49
50#ifdef SRTP_AES128_CM_HMAC_SHA1_80
51#define HAVE_DTLS_SRTP
52#endif
53
54#ifdef HAVE_DTLS_SRTP
55// SRTP cipher suite table
56struct SrtpCipherMapEntry {
57  const char* external_name;
58  PRUint16 cipher_id;
59};
60
61// This isn't elegant, but it's better than an external reference
62static const SrtpCipherMapEntry kSrtpCipherMap[] = {
63  {"AES_CM_128_HMAC_SHA1_80", SRTP_AES128_CM_HMAC_SHA1_80 },
64  {"AES_CM_128_HMAC_SHA1_32", SRTP_AES128_CM_HMAC_SHA1_32 },
65  {NULL, 0}
66};
67#endif
68
69
70// Implementation of NSPR methods
71static PRStatus StreamClose(PRFileDesc *socket) {
72  ASSERT(!socket->lower);
73  socket->dtor(socket);
74  return PR_SUCCESS;
75}
76
77static PRInt32 StreamRead(PRFileDesc *socket, void *buf, PRInt32 length) {
78  StreamInterface *stream = reinterpret_cast<StreamInterface *>(socket->secret);
79  size_t read;
80  int error;
81  StreamResult result = stream->Read(buf, length, &read, &error);
82  if (result == SR_SUCCESS) {
83    return checked_cast<PRInt32>(read);
84  }
85
86  if (result == SR_EOS) {
87    return 0;
88  }
89
90  if (result == SR_BLOCK) {
91    PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
92    return -1;
93  }
94
95  PR_SetError(PR_UNKNOWN_ERROR, error);
96  return -1;
97}
98
99static PRInt32 StreamWrite(PRFileDesc *socket, const void *buf,
100                           PRInt32 length) {
101  StreamInterface *stream = reinterpret_cast<StreamInterface *>(socket->secret);
102  size_t written;
103  int error;
104  StreamResult result = stream->Write(buf, length, &written, &error);
105  if (result == SR_SUCCESS) {
106    return checked_cast<PRInt32>(written);
107  }
108
109  if (result == SR_BLOCK) {
110    LOG(LS_INFO) <<
111        "NSSStreamAdapter: write to underlying transport would block";
112    PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
113    return -1;
114  }
115
116  LOG(LS_ERROR) << "Write error";
117  PR_SetError(PR_UNKNOWN_ERROR, error);
118  return -1;
119}
120
121static PRInt32 StreamAvailable(PRFileDesc *socket) {
122  UNIMPLEMENTED;
123  return -1;
124}
125
126PRInt64 StreamAvailable64(PRFileDesc *socket) {
127  UNIMPLEMENTED;
128  return -1;
129}
130
131static PRStatus StreamSync(PRFileDesc *socket) {
132  UNIMPLEMENTED;
133  return PR_FAILURE;
134}
135
136static PROffset32 StreamSeek(PRFileDesc *socket, PROffset32 offset,
137                             PRSeekWhence how) {
138  UNIMPLEMENTED;
139  return -1;
140}
141
142static PROffset64 StreamSeek64(PRFileDesc *socket, PROffset64 offset,
143                               PRSeekWhence how) {
144  UNIMPLEMENTED;
145  return -1;
146}
147
148static PRStatus StreamFileInfo(PRFileDesc *socket, PRFileInfo *info) {
149  UNIMPLEMENTED;
150  return PR_FAILURE;
151}
152
153static PRStatus StreamFileInfo64(PRFileDesc *socket, PRFileInfo64 *info) {
154  UNIMPLEMENTED;
155  return PR_FAILURE;
156}
157
158static PRInt32 StreamWritev(PRFileDesc *socket, const PRIOVec *iov,
159                     PRInt32 iov_size, PRIntervalTime timeout) {
160  UNIMPLEMENTED;
161  return -1;
162}
163
164static PRStatus StreamConnect(PRFileDesc *socket, const PRNetAddr *addr,
165                              PRIntervalTime timeout) {
166  UNIMPLEMENTED;
167  return PR_FAILURE;
168}
169
170static PRFileDesc *StreamAccept(PRFileDesc *sd, PRNetAddr *addr,
171                                PRIntervalTime timeout) {
172  UNIMPLEMENTED;
173  return NULL;
174}
175
176static PRStatus StreamBind(PRFileDesc *socket, const PRNetAddr *addr) {
177  UNIMPLEMENTED;
178  return PR_FAILURE;
179}
180
181static PRStatus StreamListen(PRFileDesc *socket, PRIntn depth) {
182  UNIMPLEMENTED;
183  return PR_FAILURE;
184}
185
186static PRStatus StreamShutdown(PRFileDesc *socket, PRIntn how) {
187  UNIMPLEMENTED;
188  return PR_FAILURE;
189}
190
191// Note: this is always nonblocking and ignores the timeout.
192// TODO(ekr@rtfm.com): In future verify that the socket is
193// actually in non-blocking mode.
194// This function does not support peek.
195static PRInt32 StreamRecv(PRFileDesc *socket, void *buf, PRInt32 amount,
196                   PRIntn flags, PRIntervalTime to) {
197  ASSERT(flags == 0);
198
199  if (flags != 0) {
200    PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
201    return -1;
202  }
203
204  return StreamRead(socket, buf, amount);
205}
206
207// Note: this is always nonblocking and assumes a zero timeout.
208// This function does not support peek.
209static PRInt32 StreamSend(PRFileDesc *socket, const void *buf,
210                          PRInt32 amount, PRIntn flags,
211                          PRIntervalTime to) {
212  ASSERT(flags == 0);
213
214  return StreamWrite(socket, buf, amount);
215}
216
217static PRInt32 StreamRecvfrom(PRFileDesc *socket, void *buf,
218                              PRInt32 amount, PRIntn flags,
219                              PRNetAddr *addr, PRIntervalTime to) {
220  UNIMPLEMENTED;
221  return -1;
222}
223
224static PRInt32 StreamSendto(PRFileDesc *socket, const void *buf,
225                            PRInt32 amount, PRIntn flags,
226                            const PRNetAddr *addr, PRIntervalTime to) {
227  UNIMPLEMENTED;
228  return -1;
229}
230
231static PRInt16 StreamPoll(PRFileDesc *socket, PRInt16 in_flags,
232                          PRInt16 *out_flags) {
233  UNIMPLEMENTED;
234  return -1;
235}
236
237static PRInt32 StreamAcceptRead(PRFileDesc *sd, PRFileDesc **nd,
238                                PRNetAddr **raddr,
239                                void *buf, PRInt32 amount, PRIntervalTime t) {
240  UNIMPLEMENTED;
241  return -1;
242}
243
244static PRInt32 StreamTransmitFile(PRFileDesc *sd, PRFileDesc *socket,
245                                  const void *headers, PRInt32 hlen,
246                                  PRTransmitFileFlags flags, PRIntervalTime t) {
247  UNIMPLEMENTED;
248  return -1;
249}
250
251static PRStatus StreamGetPeerName(PRFileDesc *socket, PRNetAddr *addr) {
252  // TODO(ekr@rtfm.com): Modify to return unique names for each channel
253  // somehow, as opposed to always the same static address. The current
254  // implementation messes up the session cache, which is why it's off
255  // elsewhere
256  addr->inet.family = PR_AF_INET;
257  addr->inet.port = 0;
258  addr->inet.ip = 0;
259
260  return PR_SUCCESS;
261}
262
263static PRStatus StreamGetSockName(PRFileDesc *socket, PRNetAddr *addr) {
264  UNIMPLEMENTED;
265  return PR_FAILURE;
266}
267
268static PRStatus StreamGetSockOption(PRFileDesc *socket, PRSocketOptionData *opt) {
269  switch (opt->option) {
270    case PR_SockOpt_Nonblocking:
271      opt->value.non_blocking = PR_TRUE;
272      return PR_SUCCESS;
273    default:
274      UNIMPLEMENTED;
275      break;
276  }
277
278  return PR_FAILURE;
279}
280
281// Imitate setting socket options. These are mostly noops.
282static PRStatus StreamSetSockOption(PRFileDesc *socket,
283                                    const PRSocketOptionData *opt) {
284  switch (opt->option) {
285    case PR_SockOpt_Nonblocking:
286      return PR_SUCCESS;
287    case PR_SockOpt_NoDelay:
288      return PR_SUCCESS;
289    default:
290      UNIMPLEMENTED;
291      break;
292  }
293
294  return PR_FAILURE;
295}
296
297static PRInt32 StreamSendfile(PRFileDesc *out, PRSendFileData *in,
298                              PRTransmitFileFlags flags, PRIntervalTime to) {
299  UNIMPLEMENTED;
300  return -1;
301}
302
303static PRStatus StreamConnectContinue(PRFileDesc *socket, PRInt16 flags) {
304  UNIMPLEMENTED;
305  return PR_FAILURE;
306}
307
308static PRIntn StreamReserved(PRFileDesc *socket) {
309  UNIMPLEMENTED;
310  return -1;
311}
312
313static const struct PRIOMethods nss_methods = {
314  PR_DESC_LAYERED,
315  StreamClose,
316  StreamRead,
317  StreamWrite,
318  StreamAvailable,
319  StreamAvailable64,
320  StreamSync,
321  StreamSeek,
322  StreamSeek64,
323  StreamFileInfo,
324  StreamFileInfo64,
325  StreamWritev,
326  StreamConnect,
327  StreamAccept,
328  StreamBind,
329  StreamListen,
330  StreamShutdown,
331  StreamRecv,
332  StreamSend,
333  StreamRecvfrom,
334  StreamSendto,
335  StreamPoll,
336  StreamAcceptRead,
337  StreamTransmitFile,
338  StreamGetSockName,
339  StreamGetPeerName,
340  StreamReserved,
341  StreamReserved,
342  StreamGetSockOption,
343  StreamSetSockOption,
344  StreamSendfile,
345  StreamConnectContinue,
346  StreamReserved,
347  StreamReserved,
348  StreamReserved,
349  StreamReserved
350};
351
352NSSStreamAdapter::NSSStreamAdapter(StreamInterface *stream)
353    : SSLStreamAdapterHelper(stream),
354      ssl_fd_(NULL),
355      cert_ok_(false) {
356}
357
358bool NSSStreamAdapter::Init() {
359  if (nspr_layer_identity == PR_INVALID_IO_LAYER) {
360    nspr_layer_identity = PR_GetUniqueIdentity("nssstreamadapter");
361  }
362  PRFileDesc *pr_fd = PR_CreateIOLayerStub(nspr_layer_identity, &nss_methods);
363  if (!pr_fd)
364    return false;
365  pr_fd->secret = reinterpret_cast<PRFilePrivate *>(stream());
366
367  PRFileDesc *ssl_fd;
368  if (ssl_mode_ == SSL_MODE_DTLS) {
369    ssl_fd = DTLS_ImportFD(NULL, pr_fd);
370  } else {
371    ssl_fd = SSL_ImportFD(NULL, pr_fd);
372  }
373  ASSERT(ssl_fd != NULL);  // This should never happen
374  if (!ssl_fd) {
375    PR_Close(pr_fd);
376    return false;
377  }
378
379  SECStatus rv;
380  // Turn on security.
381  rv = SSL_OptionSet(ssl_fd, SSL_SECURITY, PR_TRUE);
382  if (rv != SECSuccess) {
383    LOG(LS_ERROR) << "Error enabling security on SSL Socket";
384    return false;
385  }
386
387  // Disable SSLv2.
388  rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_SSL2, PR_FALSE);
389  if (rv != SECSuccess) {
390    LOG(LS_ERROR) << "Error disabling SSL2";
391    return false;
392  }
393
394  // Disable caching.
395  // TODO(ekr@rtfm.com): restore this when I have the caching
396  // identity set.
397  rv = SSL_OptionSet(ssl_fd, SSL_NO_CACHE, PR_TRUE);
398  if (rv != SECSuccess) {
399    LOG(LS_ERROR) << "Error disabling cache";
400    return false;
401  }
402
403  // Disable session tickets.
404  rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_SESSION_TICKETS, PR_FALSE);
405  if (rv != SECSuccess) {
406    LOG(LS_ERROR) << "Error enabling tickets";
407    return false;
408  }
409
410  // Disable renegotiation.
411  rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_RENEGOTIATION,
412                     SSL_RENEGOTIATE_NEVER);
413  if (rv != SECSuccess) {
414    LOG(LS_ERROR) << "Error disabling renegotiation";
415    return false;
416  }
417
418  // Disable false start.
419  rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_FALSE_START, PR_FALSE);
420  if (rv != SECSuccess) {
421    LOG(LS_ERROR) << "Error disabling false start";
422    return false;
423  }
424
425  ssl_fd_ = ssl_fd;
426
427  return true;
428}
429
430NSSStreamAdapter::~NSSStreamAdapter() {
431  if (ssl_fd_)
432    PR_Close(ssl_fd_);
433};
434
435
436int NSSStreamAdapter::BeginSSL() {
437  SECStatus rv;
438
439  if (!Init()) {
440    Error("Init", -1, false);
441    return -1;
442  }
443
444  ASSERT(state_ == SSL_CONNECTING);
445  // The underlying stream has been opened. If we are in peer-to-peer mode
446  // then a peer certificate must have been specified by now.
447  ASSERT(!ssl_server_name_.empty() ||
448         peer_certificate_.get() != NULL ||
449         !peer_certificate_digest_algorithm_.empty());
450  LOG(LS_INFO) << "BeginSSL: "
451               << (!ssl_server_name_.empty() ? ssl_server_name_ :
452                                               "with peer");
453
454  if (role_ == SSL_CLIENT) {
455    LOG(LS_INFO) << "BeginSSL: as client";
456
457    rv = SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook,
458                                   this);
459    if (rv != SECSuccess) {
460      Error("BeginSSL", -1, false);
461      return -1;
462    }
463  } else {
464    LOG(LS_INFO) << "BeginSSL: as server";
465    NSSIdentity *identity;
466
467    if (identity_.get()) {
468      identity = static_cast<NSSIdentity *>(identity_.get());
469    } else {
470      LOG(LS_ERROR) << "Can't be an SSL server without an identity";
471      Error("BeginSSL", -1, false);
472      return -1;
473    }
474    rv = SSL_ConfigSecureServer(ssl_fd_, identity->certificate().certificate(),
475                                identity->keypair()->privkey(),
476                                kt_rsa);
477    if (rv != SECSuccess) {
478      Error("BeginSSL", -1, false);
479      return -1;
480    }
481
482    // Insist on a certificate from the client
483    rv = SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE);
484    if (rv != SECSuccess) {
485      Error("BeginSSL", -1, false);
486      return -1;
487    }
488
489    rv = SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE, PR_TRUE);
490    if (rv != SECSuccess) {
491      Error("BeginSSL", -1, false);
492      return -1;
493    }
494  }
495
496  // Set the version range.
497  SSLVersionRange vrange;
498  vrange.min =  (ssl_mode_ == SSL_MODE_DTLS) ?
499      SSL_LIBRARY_VERSION_TLS_1_1 :
500      SSL_LIBRARY_VERSION_TLS_1_0;
501  vrange.max = SSL_LIBRARY_VERSION_TLS_1_1;
502
503  rv = SSL_VersionRangeSet(ssl_fd_, &vrange);
504  if (rv != SECSuccess) {
505    Error("BeginSSL", -1, false);
506    return -1;
507  }
508
509  // SRTP
510#ifdef HAVE_DTLS_SRTP
511  if (!srtp_ciphers_.empty()) {
512    rv = SSL_SetSRTPCiphers(
513        ssl_fd_, &srtp_ciphers_[0],
514        checked_cast<unsigned int>(srtp_ciphers_.size()));
515    if (rv != SECSuccess) {
516      Error("BeginSSL", -1, false);
517      return -1;
518    }
519  }
520#endif
521
522  // Certificate validation
523  rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this);
524  if (rv != SECSuccess) {
525    Error("BeginSSL", -1, false);
526    return -1;
527  }
528
529  // Now start the handshake
530  rv = SSL_ResetHandshake(ssl_fd_, role_ == SSL_SERVER ? PR_TRUE : PR_FALSE);
531  if (rv != SECSuccess) {
532    Error("BeginSSL", -1, false);
533    return -1;
534  }
535
536  return ContinueSSL();
537}
538
539int NSSStreamAdapter::ContinueSSL() {
540  LOG(LS_INFO) << "ContinueSSL";
541  ASSERT(state_ == SSL_CONNECTING);
542
543  // Clear the DTLS timer
544  Thread::Current()->Clear(this, MSG_DTLS_TIMEOUT);
545
546  SECStatus rv = SSL_ForceHandshake(ssl_fd_);
547
548  if (rv == SECSuccess) {
549    LOG(LS_INFO) << "Handshake complete";
550
551    ASSERT(cert_ok_);
552    if (!cert_ok_) {
553      Error("ContinueSSL", -1, true);
554      return -1;
555    }
556
557    state_ = SSL_CONNECTED;
558    StreamAdapterInterface::OnEvent(stream(), SE_OPEN|SE_READ|SE_WRITE, 0);
559    return 0;
560  }
561
562  PRInt32 err = PR_GetError();
563  switch (err) {
564    case SSL_ERROR_RX_MALFORMED_HANDSHAKE:
565      if (ssl_mode_ != SSL_MODE_DTLS) {
566        Error("ContinueSSL", -1, true);
567        return -1;
568      } else {
569        LOG(LS_INFO) << "Malformed DTLS message. Ignoring.";
570        // Fall through
571      }
572    case PR_WOULD_BLOCK_ERROR:
573      LOG(LS_INFO) << "Would have blocked";
574      if (ssl_mode_ == SSL_MODE_DTLS) {
575        PRIntervalTime timeout;
576
577        SECStatus rv = DTLS_GetHandshakeTimeout(ssl_fd_, &timeout);
578        if (rv == SECSuccess) {
579          LOG(LS_INFO) << "Timeout is " << timeout << " ms";
580          Thread::Current()->PostDelayed(PR_IntervalToMilliseconds(timeout),
581                                         this, MSG_DTLS_TIMEOUT, 0);
582        }
583      }
584
585      return 0;
586    default:
587      LOG(LS_INFO) << "Error " << err;
588      break;
589  }
590
591  Error("ContinueSSL", -1, true);
592  return -1;
593}
594
595void NSSStreamAdapter::Cleanup() {
596  if (state_ != SSL_ERROR) {
597    state_ = SSL_CLOSED;
598  }
599
600  if (ssl_fd_) {
601    PR_Close(ssl_fd_);
602    ssl_fd_ = NULL;
603  }
604
605  identity_.reset();
606  peer_certificate_.reset();
607
608  Thread::Current()->Clear(this, MSG_DTLS_TIMEOUT);
609}
610
611StreamResult NSSStreamAdapter::Read(void* data, size_t data_len,
612                                    size_t* read, int* error) {
613  // SSL_CONNECTED sanity check.
614  switch (state_) {
615    case SSL_NONE:
616    case SSL_WAIT:
617    case SSL_CONNECTING:
618      return SR_BLOCK;
619
620    case SSL_CONNECTED:
621      break;
622
623    case SSL_CLOSED:
624      return SR_EOS;
625
626    case SSL_ERROR:
627    default:
628      if (error)
629        *error = ssl_error_code_;
630      return SR_ERROR;
631  }
632
633  PRInt32 rv = PR_Read(ssl_fd_, data, checked_cast<PRInt32>(data_len));
634
635  if (rv == 0) {
636    return SR_EOS;
637  }
638
639  // Error
640  if (rv < 0) {
641    PRInt32 err = PR_GetError();
642
643    switch (err) {
644      case PR_WOULD_BLOCK_ERROR:
645        return SR_BLOCK;
646      default:
647        Error("Read", -1, false);
648        *error = err;  // libjingle semantics are that this is impl-specific
649        return SR_ERROR;
650    }
651  }
652
653  // Success
654  *read = rv;
655
656  return SR_SUCCESS;
657}
658
659StreamResult NSSStreamAdapter::Write(const void* data, size_t data_len,
660                                     size_t* written, int* error) {
661  // SSL_CONNECTED sanity check.
662  switch (state_) {
663    case SSL_NONE:
664    case SSL_WAIT:
665    case SSL_CONNECTING:
666      return SR_BLOCK;
667
668    case SSL_CONNECTED:
669      break;
670
671    case SSL_ERROR:
672    case SSL_CLOSED:
673    default:
674      if (error)
675        *error = ssl_error_code_;
676      return SR_ERROR;
677  }
678
679  PRInt32 rv = PR_Write(ssl_fd_, data, checked_cast<PRInt32>(data_len));
680
681  // Error
682  if (rv < 0) {
683    PRInt32 err = PR_GetError();
684
685    switch (err) {
686      case PR_WOULD_BLOCK_ERROR:
687        return SR_BLOCK;
688      default:
689        Error("Write", -1, false);
690        *error = err;  // libjingle semantics are that this is impl-specific
691        return SR_ERROR;
692    }
693  }
694
695  // Success
696  *written = rv;
697
698  return SR_SUCCESS;
699}
700
701void NSSStreamAdapter::OnEvent(StreamInterface* stream, int events,
702                               int err) {
703  int events_to_signal = 0;
704  int signal_error = 0;
705  ASSERT(stream == this->stream());
706  if ((events & SE_OPEN)) {
707    LOG(LS_INFO) << "NSSStreamAdapter::OnEvent SE_OPEN";
708    if (state_ != SSL_WAIT) {
709      ASSERT(state_ == SSL_NONE);
710      events_to_signal |= SE_OPEN;
711    } else {
712      state_ = SSL_CONNECTING;
713      if (int err = BeginSSL()) {
714        Error("BeginSSL", err, true);
715        return;
716      }
717    }
718  }
719  if ((events & (SE_READ|SE_WRITE))) {
720    LOG(LS_INFO) << "NSSStreamAdapter::OnEvent"
721                 << ((events & SE_READ) ? " SE_READ" : "")
722                 << ((events & SE_WRITE) ? " SE_WRITE" : "");
723    if (state_ == SSL_NONE) {
724      events_to_signal |= events & (SE_READ|SE_WRITE);
725    } else if (state_ == SSL_CONNECTING) {
726      if (int err = ContinueSSL()) {
727        Error("ContinueSSL", err, true);
728        return;
729      }
730    } else if (state_ == SSL_CONNECTED) {
731      if (events & SE_WRITE) {
732        LOG(LS_INFO) << " -- onStreamWriteable";
733        events_to_signal |= SE_WRITE;
734      }
735      if (events & SE_READ) {
736        LOG(LS_INFO) << " -- onStreamReadable";
737        events_to_signal |= SE_READ;
738      }
739    }
740  }
741  if ((events & SE_CLOSE)) {
742    LOG(LS_INFO) << "NSSStreamAdapter::OnEvent(SE_CLOSE, " << err << ")";
743    Cleanup();
744    events_to_signal |= SE_CLOSE;
745    // SE_CLOSE is the only event that uses the final parameter to OnEvent().
746    ASSERT(signal_error == 0);
747    signal_error = err;
748  }
749  if (events_to_signal)
750    StreamAdapterInterface::OnEvent(stream, events_to_signal, signal_error);
751}
752
753void NSSStreamAdapter::OnMessage(Message* msg) {
754  // Process our own messages and then pass others to the superclass
755  if (MSG_DTLS_TIMEOUT == msg->message_id) {
756    LOG(LS_INFO) << "DTLS timeout expired";
757    ContinueSSL();
758  } else {
759    StreamInterface::OnMessage(msg);
760  }
761}
762
763// Certificate verification callback. Called to check any certificate
764SECStatus NSSStreamAdapter::AuthCertificateHook(void *arg,
765                                                PRFileDesc *fd,
766                                                PRBool checksig,
767                                                PRBool isServer) {
768  LOG(LS_INFO) << "NSSStreamAdapter::AuthCertificateHook";
769  // SSL_PeerCertificate returns a pointer that is owned by the caller, and
770  // the NSSCertificate constructor copies its argument, so |raw_peer_cert|
771  // must be destroyed in this function.
772  CERTCertificate* raw_peer_cert = SSL_PeerCertificate(fd);
773  NSSCertificate peer_cert(raw_peer_cert);
774  CERT_DestroyCertificate(raw_peer_cert);
775
776  NSSStreamAdapter *stream = reinterpret_cast<NSSStreamAdapter *>(arg);
777  stream->cert_ok_ = false;
778
779  // Read the peer's certificate chain.
780  CERTCertList* cert_list = SSL_PeerCertificateChain(fd);
781  ASSERT(cert_list != NULL);
782
783  // If the peer provided multiple certificates, check that they form a valid
784  // chain as defined by RFC 5246 Section 7.4.2: "Each following certificate
785  // MUST directly certify the one preceding it.".  This check does NOT
786  // verify other requirements, such as whether the chain reaches a trusted
787  // root, self-signed certificates have valid signatures, certificates are not
788  // expired, etc.
789  // Even if the chain is valid, the leaf certificate must still match a
790  // provided certificate or digest.
791  if (!NSSCertificate::IsValidChain(cert_list)) {
792    CERT_DestroyCertList(cert_list);
793    PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
794    return SECFailure;
795  }
796
797  if (stream->peer_certificate_.get()) {
798    LOG(LS_INFO) << "Checking against specified certificate";
799
800    // The peer certificate was specified
801    if (reinterpret_cast<NSSCertificate *>(stream->peer_certificate_.get())->
802        Equals(&peer_cert)) {
803      LOG(LS_INFO) << "Accepted peer certificate";
804      stream->cert_ok_ = true;
805    }
806  } else if (!stream->peer_certificate_digest_algorithm_.empty()) {
807    LOG(LS_INFO) << "Checking against specified digest";
808    // The peer certificate digest was specified
809    unsigned char digest[64];  // Maximum size
810    size_t digest_length;
811
812    if (!peer_cert.ComputeDigest(
813            stream->peer_certificate_digest_algorithm_,
814            digest, sizeof(digest), &digest_length)) {
815      LOG(LS_ERROR) << "Digest computation failed";
816    } else {
817      Buffer computed_digest(digest, digest_length);
818      if (computed_digest == stream->peer_certificate_digest_value_) {
819        LOG(LS_INFO) << "Accepted peer certificate";
820        stream->cert_ok_ = true;
821      }
822    }
823  } else {
824    // Other modes, but we haven't implemented yet
825    // TODO(ekr@rtfm.com): Implement real certificate validation
826    UNIMPLEMENTED;
827  }
828
829  if (!stream->cert_ok_ && stream->ignore_bad_cert()) {
830    LOG(LS_WARNING) << "Ignoring cert error while verifying cert chain";
831    stream->cert_ok_ = true;
832  }
833
834  if (stream->cert_ok_)
835    stream->peer_certificate_.reset(new NSSCertificate(cert_list));
836
837  CERT_DestroyCertList(cert_list);
838
839  if (stream->cert_ok_)
840    return SECSuccess;
841
842  PORT_SetError(SEC_ERROR_UNTRUSTED_CERT);
843  return SECFailure;
844}
845
846
847SECStatus NSSStreamAdapter::GetClientAuthDataHook(void *arg, PRFileDesc *fd,
848                                                  CERTDistNames *caNames,
849                                                  CERTCertificate **pRetCert,
850                                                  SECKEYPrivateKey **pRetKey) {
851  LOG(LS_INFO) << "Client cert requested";
852  NSSStreamAdapter *stream = reinterpret_cast<NSSStreamAdapter *>(arg);
853
854  if (!stream->identity_.get()) {
855    LOG(LS_ERROR) << "No identity available";
856    return SECFailure;
857  }
858
859  NSSIdentity *identity = static_cast<NSSIdentity *>(stream->identity_.get());
860  // Destroyed internally by NSS
861  *pRetCert = CERT_DupCertificate(identity->certificate().certificate());
862  *pRetKey = SECKEY_CopyPrivateKey(identity->keypair()->privkey());
863
864  return SECSuccess;
865}
866
867// RFC 5705 Key Exporter
868bool NSSStreamAdapter::ExportKeyingMaterial(const std::string& label,
869                                            const uint8* context,
870                                            size_t context_len,
871                                            bool use_context,
872                                            uint8* result,
873                                            size_t result_len) {
874  SECStatus rv = SSL_ExportKeyingMaterial(
875      ssl_fd_,
876      label.c_str(),
877      checked_cast<unsigned int>(label.size()),
878      use_context,
879      context,
880      checked_cast<unsigned int>(context_len),
881      result,
882      checked_cast<unsigned int>(result_len));
883
884  return rv == SECSuccess;
885}
886
887bool NSSStreamAdapter::SetDtlsSrtpCiphers(
888    const std::vector<std::string>& ciphers) {
889#ifdef HAVE_DTLS_SRTP
890  std::vector<PRUint16> internal_ciphers;
891  if (state_ != SSL_NONE)
892    return false;
893
894  for (std::vector<std::string>::const_iterator cipher = ciphers.begin();
895       cipher != ciphers.end(); ++cipher) {
896    bool found = false;
897    for (const SrtpCipherMapEntry *entry = kSrtpCipherMap; entry->cipher_id;
898         ++entry) {
899      if (*cipher == entry->external_name) {
900        found = true;
901        internal_ciphers.push_back(entry->cipher_id);
902        break;
903      }
904    }
905
906    if (!found) {
907      LOG(LS_ERROR) << "Could not find cipher: " << *cipher;
908      return false;
909    }
910  }
911
912  if (internal_ciphers.empty())
913    return false;
914
915  srtp_ciphers_ = internal_ciphers;
916
917  return true;
918#else
919  return false;
920#endif
921}
922
923bool NSSStreamAdapter::GetDtlsSrtpCipher(std::string* cipher) {
924#ifdef HAVE_DTLS_SRTP
925  ASSERT(state_ == SSL_CONNECTED);
926  if (state_ != SSL_CONNECTED)
927    return false;
928
929  PRUint16 selected_cipher;
930
931  SECStatus rv = SSL_GetSRTPCipher(ssl_fd_, &selected_cipher);
932  if (rv == SECFailure)
933    return false;
934
935  for (const SrtpCipherMapEntry *entry = kSrtpCipherMap;
936       entry->cipher_id; ++entry) {
937    if (selected_cipher == entry->cipher_id) {
938      *cipher = entry->external_name;
939      return true;
940    }
941  }
942
943  ASSERT(false);  // This should never happen
944#endif
945  return false;
946}
947
948
949bool NSSContext::initialized;
950NSSContext *NSSContext::global_nss_context;
951
952// Static initialization and shutdown
953NSSContext *NSSContext::Instance() {
954  if (!global_nss_context) {
955    NSSContext *new_ctx = new NSSContext();
956
957    if (!(new_ctx->slot_ = PK11_GetInternalSlot())) {
958      delete new_ctx;
959      goto fail;
960    }
961
962    global_nss_context = new_ctx;
963  }
964
965 fail:
966  return global_nss_context;
967}
968
969
970
971bool NSSContext::InitializeSSL(VerificationCallback callback) {
972  ASSERT(!callback);
973
974  if (!initialized) {
975    SECStatus rv;
976
977    rv = NSS_NoDB_Init(NULL);
978    if (rv != SECSuccess) {
979      LOG(LS_ERROR) << "Couldn't initialize NSS error=" <<
980          PORT_GetError();
981      return false;
982    }
983
984    NSS_SetDomesticPolicy();
985
986    initialized = true;
987  }
988
989  return true;
990}
991
992bool NSSContext::InitializeSSLThread() {
993  // Not needed
994  return true;
995}
996
997bool NSSContext::CleanupSSL() {
998  // Not needed
999  return true;
1000}
1001
1002bool NSSStreamAdapter::HaveDtls() {
1003  return true;
1004}
1005
1006bool NSSStreamAdapter::HaveDtlsSrtp() {
1007#ifdef HAVE_DTLS_SRTP
1008  return true;
1009#else
1010  return false;
1011#endif
1012}
1013
1014bool NSSStreamAdapter::HaveExporter() {
1015  return true;
1016}
1017
1018}  // namespace rtc
1019
1020#endif  // HAVE_NSS_SSL_H
1021