1/*
2 *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include <vector>
12
13#if HAVE_CONFIG_H
14#include "config.h"
15#endif  // HAVE_CONFIG_H
16
17#if HAVE_NSS_SSL_H
18
19#include "webrtc/base/nssstreamadapter.h"
20
21#include "keyhi.h"
22#include "nspr.h"
23#include "nss.h"
24#include "pk11pub.h"
25#include "secerr.h"
26
27#ifdef NSS_SSL_RELATIVE_PATH
28#include "ssl.h"
29#include "sslerr.h"
30#include "sslproto.h"
31#else
32#include "net/third_party/nss/ssl/ssl.h"
33#include "net/third_party/nss/ssl/sslerr.h"
34#include "net/third_party/nss/ssl/sslproto.h"
35#endif
36
37#include "webrtc/base/nssidentity.h"
38#include "webrtc/base/safe_conversions.h"
39#include "webrtc/base/thread.h"
40
41namespace rtc {
42
43PRDescIdentity NSSStreamAdapter::nspr_layer_identity = PR_INVALID_IO_LAYER;
44
45#define UNIMPLEMENTED \
46  PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0); \
47  LOG(LS_ERROR) \
48  << "Call to unimplemented function "<< __FUNCTION__; ASSERT(false)
49
50#ifdef SRTP_AES128_CM_HMAC_SHA1_80
51#define HAVE_DTLS_SRTP
52#endif
53
54#ifdef HAVE_DTLS_SRTP
55// SRTP cipher suite table
56struct SrtpCipherMapEntry {
57  const char* external_name;
58  PRUint16 cipher_id;
59};
60
61// This isn't elegant, but it's better than an external reference
62static const SrtpCipherMapEntry kSrtpCipherMap[] = {
63  {"AES_CM_128_HMAC_SHA1_80", SRTP_AES128_CM_HMAC_SHA1_80 },
64  {"AES_CM_128_HMAC_SHA1_32", SRTP_AES128_CM_HMAC_SHA1_32 },
65  {NULL, 0}
66};
67#endif
68
69
70// Implementation of NSPR methods
71static PRStatus StreamClose(PRFileDesc *socket) {
72  ASSERT(!socket->lower);
73  socket->dtor(socket);
74  return PR_SUCCESS;
75}
76
77static PRInt32 StreamRead(PRFileDesc *socket, void *buf, PRInt32 length) {
78  StreamInterface *stream = reinterpret_cast<StreamInterface *>(socket->secret);
79  size_t read;
80  int error;
81  StreamResult result = stream->Read(buf, length, &read, &error);
82  if (result == SR_SUCCESS) {
83    return checked_cast<PRInt32>(read);
84  }
85
86  if (result == SR_EOS) {
87    return 0;
88  }
89
90  if (result == SR_BLOCK) {
91    PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
92    return -1;
93  }
94
95  PR_SetError(PR_UNKNOWN_ERROR, error);
96  return -1;
97}
98
99static PRInt32 StreamWrite(PRFileDesc *socket, const void *buf,
100                           PRInt32 length) {
101  StreamInterface *stream = reinterpret_cast<StreamInterface *>(socket->secret);
102  size_t written;
103  int error;
104  StreamResult result = stream->Write(buf, length, &written, &error);
105  if (result == SR_SUCCESS) {
106    return checked_cast<PRInt32>(written);
107  }
108
109  if (result == SR_BLOCK) {
110    LOG(LS_INFO) <<
111        "NSSStreamAdapter: write to underlying transport would block";
112    PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
113    return -1;
114  }
115
116  LOG(LS_ERROR) << "Write error";
117  PR_SetError(PR_UNKNOWN_ERROR, error);
118  return -1;
119}
120
121static PRInt32 StreamAvailable(PRFileDesc *socket) {
122  UNIMPLEMENTED;
123  return -1;
124}
125
126PRInt64 StreamAvailable64(PRFileDesc *socket) {
127  UNIMPLEMENTED;
128  return -1;
129}
130
131static PRStatus StreamSync(PRFileDesc *socket) {
132  UNIMPLEMENTED;
133  return PR_FAILURE;
134}
135
136static PROffset32 StreamSeek(PRFileDesc *socket, PROffset32 offset,
137                             PRSeekWhence how) {
138  UNIMPLEMENTED;
139  return -1;
140}
141
142static PROffset64 StreamSeek64(PRFileDesc *socket, PROffset64 offset,
143                               PRSeekWhence how) {
144  UNIMPLEMENTED;
145  return -1;
146}
147
148static PRStatus StreamFileInfo(PRFileDesc *socket, PRFileInfo *info) {
149  UNIMPLEMENTED;
150  return PR_FAILURE;
151}
152
153static PRStatus StreamFileInfo64(PRFileDesc *socket, PRFileInfo64 *info) {
154  UNIMPLEMENTED;
155  return PR_FAILURE;
156}
157
158static PRInt32 StreamWritev(PRFileDesc *socket, const PRIOVec *iov,
159                     PRInt32 iov_size, PRIntervalTime timeout) {
160  UNIMPLEMENTED;
161  return -1;
162}
163
164static PRStatus StreamConnect(PRFileDesc *socket, const PRNetAddr *addr,
165                              PRIntervalTime timeout) {
166  UNIMPLEMENTED;
167  return PR_FAILURE;
168}
169
170static PRFileDesc *StreamAccept(PRFileDesc *sd, PRNetAddr *addr,
171                                PRIntervalTime timeout) {
172  UNIMPLEMENTED;
173  return NULL;
174}
175
176static PRStatus StreamBind(PRFileDesc *socket, const PRNetAddr *addr) {
177  UNIMPLEMENTED;
178  return PR_FAILURE;
179}
180
181static PRStatus StreamListen(PRFileDesc *socket, PRIntn depth) {
182  UNIMPLEMENTED;
183  return PR_FAILURE;
184}
185
186static PRStatus StreamShutdown(PRFileDesc *socket, PRIntn how) {
187  UNIMPLEMENTED;
188  return PR_FAILURE;
189}
190
191// Note: this is always nonblocking and ignores the timeout.
192// TODO(ekr@rtfm.com): In future verify that the socket is
193// actually in non-blocking mode.
194// This function does not support peek.
195static PRInt32 StreamRecv(PRFileDesc *socket, void *buf, PRInt32 amount,
196                   PRIntn flags, PRIntervalTime to) {
197  ASSERT(flags == 0);
198
199  if (flags != 0) {
200    PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
201    return -1;
202  }
203
204  return StreamRead(socket, buf, amount);
205}
206
207// Note: this is always nonblocking and assumes a zero timeout.
208// This function does not support peek.
209static PRInt32 StreamSend(PRFileDesc *socket, const void *buf,
210                          PRInt32 amount, PRIntn flags,
211                          PRIntervalTime to) {
212  ASSERT(flags == 0);
213
214  return StreamWrite(socket, buf, amount);
215}
216
217static PRInt32 StreamRecvfrom(PRFileDesc *socket, void *buf,
218                              PRInt32 amount, PRIntn flags,
219                              PRNetAddr *addr, PRIntervalTime to) {
220  UNIMPLEMENTED;
221  return -1;
222}
223
224static PRInt32 StreamSendto(PRFileDesc *socket, const void *buf,
225                            PRInt32 amount, PRIntn flags,
226                            const PRNetAddr *addr, PRIntervalTime to) {
227  UNIMPLEMENTED;
228  return -1;
229}
230
231static PRInt16 StreamPoll(PRFileDesc *socket, PRInt16 in_flags,
232                          PRInt16 *out_flags) {
233  UNIMPLEMENTED;
234  return -1;
235}
236
237static PRInt32 StreamAcceptRead(PRFileDesc *sd, PRFileDesc **nd,
238                                PRNetAddr **raddr,
239                                void *buf, PRInt32 amount, PRIntervalTime t) {
240  UNIMPLEMENTED;
241  return -1;
242}
243
244static PRInt32 StreamTransmitFile(PRFileDesc *sd, PRFileDesc *socket,
245                                  const void *headers, PRInt32 hlen,
246                                  PRTransmitFileFlags flags, PRIntervalTime t) {
247  UNIMPLEMENTED;
248  return -1;
249}
250
251static PRStatus StreamGetPeerName(PRFileDesc *socket, PRNetAddr *addr) {
252  // TODO(ekr@rtfm.com): Modify to return unique names for each channel
253  // somehow, as opposed to always the same static address. The current
254  // implementation messes up the session cache, which is why it's off
255  // elsewhere
256  addr->inet.family = PR_AF_INET;
257  addr->inet.port = 0;
258  addr->inet.ip = 0;
259
260  return PR_SUCCESS;
261}
262
263static PRStatus StreamGetSockName(PRFileDesc *socket, PRNetAddr *addr) {
264  UNIMPLEMENTED;
265  return PR_FAILURE;
266}
267
268static PRStatus StreamGetSockOption(PRFileDesc *socket, PRSocketOptionData *opt) {
269  switch (opt->option) {
270    case PR_SockOpt_Nonblocking:
271      opt->value.non_blocking = PR_TRUE;
272      return PR_SUCCESS;
273    default:
274      UNIMPLEMENTED;
275      break;
276  }
277
278  return PR_FAILURE;
279}
280
281// Imitate setting socket options. These are mostly noops.
282static PRStatus StreamSetSockOption(PRFileDesc *socket,
283                                    const PRSocketOptionData *opt) {
284  switch (opt->option) {
285    case PR_SockOpt_Nonblocking:
286      return PR_SUCCESS;
287    case PR_SockOpt_NoDelay:
288      return PR_SUCCESS;
289    default:
290      UNIMPLEMENTED;
291      break;
292  }
293
294  return PR_FAILURE;
295}
296
297static PRInt32 StreamSendfile(PRFileDesc *out, PRSendFileData *in,
298                              PRTransmitFileFlags flags, PRIntervalTime to) {
299  UNIMPLEMENTED;
300  return -1;
301}
302
303static PRStatus StreamConnectContinue(PRFileDesc *socket, PRInt16 flags) {
304  UNIMPLEMENTED;
305  return PR_FAILURE;
306}
307
308static PRIntn StreamReserved(PRFileDesc *socket) {
309  UNIMPLEMENTED;
310  return -1;
311}
312
313static const struct PRIOMethods nss_methods = {
314  PR_DESC_LAYERED,
315  StreamClose,
316  StreamRead,
317  StreamWrite,
318  StreamAvailable,
319  StreamAvailable64,
320  StreamSync,
321  StreamSeek,
322  StreamSeek64,
323  StreamFileInfo,
324  StreamFileInfo64,
325  StreamWritev,
326  StreamConnect,
327  StreamAccept,
328  StreamBind,
329  StreamListen,
330  StreamShutdown,
331  StreamRecv,
332  StreamSend,
333  StreamRecvfrom,
334  StreamSendto,
335  StreamPoll,
336  StreamAcceptRead,
337  StreamTransmitFile,
338  StreamGetSockName,
339  StreamGetPeerName,
340  StreamReserved,
341  StreamReserved,
342  StreamGetSockOption,
343  StreamSetSockOption,
344  StreamSendfile,
345  StreamConnectContinue,
346  StreamReserved,
347  StreamReserved,
348  StreamReserved,
349  StreamReserved
350};
351
352NSSStreamAdapter::NSSStreamAdapter(StreamInterface *stream)
353    : SSLStreamAdapterHelper(stream),
354      ssl_fd_(NULL),
355      cert_ok_(false) {
356}
357
358bool NSSStreamAdapter::Init() {
359  if (nspr_layer_identity == PR_INVALID_IO_LAYER) {
360    nspr_layer_identity = PR_GetUniqueIdentity("nssstreamadapter");
361  }
362  PRFileDesc *pr_fd = PR_CreateIOLayerStub(nspr_layer_identity, &nss_methods);
363  if (!pr_fd)
364    return false;
365  pr_fd->secret = reinterpret_cast<PRFilePrivate *>(stream());
366
367  PRFileDesc *ssl_fd;
368  if (ssl_mode_ == SSL_MODE_DTLS) {
369    ssl_fd = DTLS_ImportFD(NULL, pr_fd);
370  } else {
371    ssl_fd = SSL_ImportFD(NULL, pr_fd);
372  }
373  ASSERT(ssl_fd != NULL);  // This should never happen
374  if (!ssl_fd) {
375    PR_Close(pr_fd);
376    return false;
377  }
378
379  SECStatus rv;
380  // Turn on security.
381  rv = SSL_OptionSet(ssl_fd, SSL_SECURITY, PR_TRUE);
382  if (rv != SECSuccess) {
383    LOG(LS_ERROR) << "Error enabling security on SSL Socket";
384    return false;
385  }
386
387  // Disable SSLv2.
388  rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_SSL2, PR_FALSE);
389  if (rv != SECSuccess) {
390    LOG(LS_ERROR) << "Error disabling SSL2";
391    return false;
392  }
393
394  // Disable caching.
395  // TODO(ekr@rtfm.com): restore this when I have the caching
396  // identity set.
397  rv = SSL_OptionSet(ssl_fd, SSL_NO_CACHE, PR_TRUE);
398  if (rv != SECSuccess) {
399    LOG(LS_ERROR) << "Error disabling cache";
400    return false;
401  }
402
403  // Disable session tickets.
404  rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_SESSION_TICKETS, PR_FALSE);
405  if (rv != SECSuccess) {
406    LOG(LS_ERROR) << "Error enabling tickets";
407    return false;
408  }
409
410  // Disable renegotiation.
411  rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_RENEGOTIATION,
412                     SSL_RENEGOTIATE_NEVER);
413  if (rv != SECSuccess) {
414    LOG(LS_ERROR) << "Error disabling renegotiation";
415    return false;
416  }
417
418  // Disable false start.
419  rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_FALSE_START, PR_FALSE);
420  if (rv != SECSuccess) {
421    LOG(LS_ERROR) << "Error disabling false start";
422    return false;
423  }
424
425  ssl_fd_ = ssl_fd;
426
427  return true;
428}
429
430NSSStreamAdapter::~NSSStreamAdapter() {
431  if (ssl_fd_)
432    PR_Close(ssl_fd_);
433};
434
435
436int NSSStreamAdapter::BeginSSL() {
437  SECStatus rv;
438
439  if (!Init()) {
440    Error("Init", -1, false);
441    return -1;
442  }
443
444  ASSERT(state_ == SSL_CONNECTING);
445  // The underlying stream has been opened. If we are in peer-to-peer mode
446  // then a peer certificate must have been specified by now.
447  ASSERT(!ssl_server_name_.empty() ||
448         peer_certificate_.get() != NULL ||
449         !peer_certificate_digest_algorithm_.empty());
450  LOG(LS_INFO) << "BeginSSL: "
451               << (!ssl_server_name_.empty() ? ssl_server_name_ :
452                                               "with peer");
453
454  if (role_ == SSL_CLIENT) {
455    LOG(LS_INFO) << "BeginSSL: as client";
456
457    rv = SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook,
458                                   this);
459    if (rv != SECSuccess) {
460      Error("BeginSSL", -1, false);
461      return -1;
462    }
463  } else {
464    LOG(LS_INFO) << "BeginSSL: as server";
465    NSSIdentity *identity;
466
467    if (identity_.get()) {
468      identity = static_cast<NSSIdentity *>(identity_.get());
469    } else {
470      LOG(LS_ERROR) << "Can't be an SSL server without an identity";
471      Error("BeginSSL", -1, false);
472      return -1;
473    }
474    rv = SSL_ConfigSecureServer(ssl_fd_, identity->certificate().certificate(),
475                                identity->keypair()->privkey(),
476                                kt_rsa);
477    if (rv != SECSuccess) {
478      Error("BeginSSL", -1, false);
479      return -1;
480    }
481
482    // Insist on a certificate from the client
483    rv = SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE);
484    if (rv != SECSuccess) {
485      Error("BeginSSL", -1, false);
486      return -1;
487    }
488
489    // TODO(juberti): Check for client_auth_enabled()
490
491    rv = SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE, PR_TRUE);
492    if (rv != SECSuccess) {
493      Error("BeginSSL", -1, false);
494      return -1;
495    }
496  }
497
498  // Set the version range.
499  SSLVersionRange vrange;
500  vrange.min =  (ssl_mode_ == SSL_MODE_DTLS) ?
501      SSL_LIBRARY_VERSION_TLS_1_1 :
502      SSL_LIBRARY_VERSION_TLS_1_0;
503  vrange.max = SSL_LIBRARY_VERSION_TLS_1_1;
504
505  rv = SSL_VersionRangeSet(ssl_fd_, &vrange);
506  if (rv != SECSuccess) {
507    Error("BeginSSL", -1, false);
508    return -1;
509  }
510
511  // SRTP
512#ifdef HAVE_DTLS_SRTP
513  if (!srtp_ciphers_.empty()) {
514    rv = SSL_SetSRTPCiphers(
515        ssl_fd_, &srtp_ciphers_[0],
516        checked_cast<unsigned int>(srtp_ciphers_.size()));
517    if (rv != SECSuccess) {
518      Error("BeginSSL", -1, false);
519      return -1;
520    }
521  }
522#endif
523
524  // Certificate validation
525  rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this);
526  if (rv != SECSuccess) {
527    Error("BeginSSL", -1, false);
528    return -1;
529  }
530
531  // Now start the handshake
532  rv = SSL_ResetHandshake(ssl_fd_, role_ == SSL_SERVER ? PR_TRUE : PR_FALSE);
533  if (rv != SECSuccess) {
534    Error("BeginSSL", -1, false);
535    return -1;
536  }
537
538  return ContinueSSL();
539}
540
541int NSSStreamAdapter::ContinueSSL() {
542  LOG(LS_INFO) << "ContinueSSL";
543  ASSERT(state_ == SSL_CONNECTING);
544
545  // Clear the DTLS timer
546  Thread::Current()->Clear(this, MSG_DTLS_TIMEOUT);
547
548  SECStatus rv = SSL_ForceHandshake(ssl_fd_);
549
550  if (rv == SECSuccess) {
551    LOG(LS_INFO) << "Handshake complete";
552
553    ASSERT(cert_ok_);
554    if (!cert_ok_) {
555      Error("ContinueSSL", -1, true);
556      return -1;
557    }
558
559    state_ = SSL_CONNECTED;
560    StreamAdapterInterface::OnEvent(stream(), SE_OPEN|SE_READ|SE_WRITE, 0);
561    return 0;
562  }
563
564  PRInt32 err = PR_GetError();
565  switch (err) {
566    case SSL_ERROR_RX_MALFORMED_HANDSHAKE:
567      if (ssl_mode_ != SSL_MODE_DTLS) {
568        Error("ContinueSSL", -1, true);
569        return -1;
570      } else {
571        LOG(LS_INFO) << "Malformed DTLS message. Ignoring.";
572        // Fall through
573      }
574    case PR_WOULD_BLOCK_ERROR:
575      LOG(LS_INFO) << "Would have blocked";
576      if (ssl_mode_ == SSL_MODE_DTLS) {
577        PRIntervalTime timeout;
578
579        SECStatus rv = DTLS_GetHandshakeTimeout(ssl_fd_, &timeout);
580        if (rv == SECSuccess) {
581          LOG(LS_INFO) << "Timeout is " << timeout << " ms";
582          Thread::Current()->PostDelayed(PR_IntervalToMilliseconds(timeout),
583                                         this, MSG_DTLS_TIMEOUT, 0);
584        }
585      }
586
587      return 0;
588    default:
589      LOG(LS_INFO) << "Error " << err;
590      break;
591  }
592
593  Error("ContinueSSL", -1, true);
594  return -1;
595}
596
597void NSSStreamAdapter::Cleanup() {
598  if (state_ != SSL_ERROR) {
599    state_ = SSL_CLOSED;
600  }
601
602  if (ssl_fd_) {
603    PR_Close(ssl_fd_);
604    ssl_fd_ = NULL;
605  }
606
607  identity_.reset();
608  peer_certificate_.reset();
609
610  Thread::Current()->Clear(this, MSG_DTLS_TIMEOUT);
611}
612
613StreamResult NSSStreamAdapter::Read(void* data, size_t data_len,
614                                    size_t* read, int* error) {
615  // SSL_CONNECTED sanity check.
616  switch (state_) {
617    case SSL_NONE:
618    case SSL_WAIT:
619    case SSL_CONNECTING:
620      return SR_BLOCK;
621
622    case SSL_CONNECTED:
623      break;
624
625    case SSL_CLOSED:
626      return SR_EOS;
627
628    case SSL_ERROR:
629    default:
630      if (error)
631        *error = ssl_error_code_;
632      return SR_ERROR;
633  }
634
635  PRInt32 rv = PR_Read(ssl_fd_, data, checked_cast<PRInt32>(data_len));
636
637  if (rv == 0) {
638    return SR_EOS;
639  }
640
641  // Error
642  if (rv < 0) {
643    PRInt32 err = PR_GetError();
644
645    switch (err) {
646      case PR_WOULD_BLOCK_ERROR:
647        return SR_BLOCK;
648      default:
649        Error("Read", -1, false);
650        *error = err;  // libjingle semantics are that this is impl-specific
651        return SR_ERROR;
652    }
653  }
654
655  // Success
656  *read = rv;
657
658  return SR_SUCCESS;
659}
660
661StreamResult NSSStreamAdapter::Write(const void* data, size_t data_len,
662                                     size_t* written, int* error) {
663  // SSL_CONNECTED sanity check.
664  switch (state_) {
665    case SSL_NONE:
666    case SSL_WAIT:
667    case SSL_CONNECTING:
668      return SR_BLOCK;
669
670    case SSL_CONNECTED:
671      break;
672
673    case SSL_ERROR:
674    case SSL_CLOSED:
675    default:
676      if (error)
677        *error = ssl_error_code_;
678      return SR_ERROR;
679  }
680
681  PRInt32 rv = PR_Write(ssl_fd_, data, checked_cast<PRInt32>(data_len));
682
683  // Error
684  if (rv < 0) {
685    PRInt32 err = PR_GetError();
686
687    switch (err) {
688      case PR_WOULD_BLOCK_ERROR:
689        return SR_BLOCK;
690      default:
691        Error("Write", -1, false);
692        *error = err;  // libjingle semantics are that this is impl-specific
693        return SR_ERROR;
694    }
695  }
696
697  // Success
698  *written = rv;
699
700  return SR_SUCCESS;
701}
702
703void NSSStreamAdapter::OnEvent(StreamInterface* stream, int events,
704                               int err) {
705  int events_to_signal = 0;
706  int signal_error = 0;
707  ASSERT(stream == this->stream());
708  if ((events & SE_OPEN)) {
709    LOG(LS_INFO) << "NSSStreamAdapter::OnEvent SE_OPEN";
710    if (state_ != SSL_WAIT) {
711      ASSERT(state_ == SSL_NONE);
712      events_to_signal |= SE_OPEN;
713    } else {
714      state_ = SSL_CONNECTING;
715      if (int err = BeginSSL()) {
716        Error("BeginSSL", err, true);
717        return;
718      }
719    }
720  }
721  if ((events & (SE_READ|SE_WRITE))) {
722    LOG(LS_INFO) << "NSSStreamAdapter::OnEvent"
723                 << ((events & SE_READ) ? " SE_READ" : "")
724                 << ((events & SE_WRITE) ? " SE_WRITE" : "");
725    if (state_ == SSL_NONE) {
726      events_to_signal |= events & (SE_READ|SE_WRITE);
727    } else if (state_ == SSL_CONNECTING) {
728      if (int err = ContinueSSL()) {
729        Error("ContinueSSL", err, true);
730        return;
731      }
732    } else if (state_ == SSL_CONNECTED) {
733      if (events & SE_WRITE) {
734        LOG(LS_INFO) << " -- onStreamWriteable";
735        events_to_signal |= SE_WRITE;
736      }
737      if (events & SE_READ) {
738        LOG(LS_INFO) << " -- onStreamReadable";
739        events_to_signal |= SE_READ;
740      }
741    }
742  }
743  if ((events & SE_CLOSE)) {
744    LOG(LS_INFO) << "NSSStreamAdapter::OnEvent(SE_CLOSE, " << err << ")";
745    Cleanup();
746    events_to_signal |= SE_CLOSE;
747    // SE_CLOSE is the only event that uses the final parameter to OnEvent().
748    ASSERT(signal_error == 0);
749    signal_error = err;
750  }
751  if (events_to_signal)
752    StreamAdapterInterface::OnEvent(stream, events_to_signal, signal_error);
753}
754
755void NSSStreamAdapter::OnMessage(Message* msg) {
756  // Process our own messages and then pass others to the superclass
757  if (MSG_DTLS_TIMEOUT == msg->message_id) {
758    LOG(LS_INFO) << "DTLS timeout expired";
759    ContinueSSL();
760  } else {
761    StreamInterface::OnMessage(msg);
762  }
763}
764
765// Certificate verification callback. Called to check any certificate
766SECStatus NSSStreamAdapter::AuthCertificateHook(void *arg,
767                                                PRFileDesc *fd,
768                                                PRBool checksig,
769                                                PRBool isServer) {
770  LOG(LS_INFO) << "NSSStreamAdapter::AuthCertificateHook";
771  // SSL_PeerCertificate returns a pointer that is owned by the caller, and
772  // the NSSCertificate constructor copies its argument, so |raw_peer_cert|
773  // must be destroyed in this function.
774  CERTCertificate* raw_peer_cert = SSL_PeerCertificate(fd);
775  NSSCertificate peer_cert(raw_peer_cert);
776  CERT_DestroyCertificate(raw_peer_cert);
777
778  NSSStreamAdapter *stream = reinterpret_cast<NSSStreamAdapter *>(arg);
779  stream->cert_ok_ = false;
780
781  // Read the peer's certificate chain.
782  CERTCertList* cert_list = SSL_PeerCertificateChain(fd);
783  ASSERT(cert_list != NULL);
784
785  // If the peer provided multiple certificates, check that they form a valid
786  // chain as defined by RFC 5246 Section 7.4.2: "Each following certificate
787  // MUST directly certify the one preceding it.".  This check does NOT
788  // verify other requirements, such as whether the chain reaches a trusted
789  // root, self-signed certificates have valid signatures, certificates are not
790  // expired, etc.
791  // Even if the chain is valid, the leaf certificate must still match a
792  // provided certificate or digest.
793  if (!NSSCertificate::IsValidChain(cert_list)) {
794    CERT_DestroyCertList(cert_list);
795    PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
796    return SECFailure;
797  }
798
799  if (stream->peer_certificate_.get()) {
800    LOG(LS_INFO) << "Checking against specified certificate";
801
802    // The peer certificate was specified
803    if (reinterpret_cast<NSSCertificate *>(stream->peer_certificate_.get())->
804        Equals(&peer_cert)) {
805      LOG(LS_INFO) << "Accepted peer certificate";
806      stream->cert_ok_ = true;
807    }
808  } else if (!stream->peer_certificate_digest_algorithm_.empty()) {
809    LOG(LS_INFO) << "Checking against specified digest";
810    // The peer certificate digest was specified
811    unsigned char digest[64];  // Maximum size
812    size_t digest_length;
813
814    if (!peer_cert.ComputeDigest(
815            stream->peer_certificate_digest_algorithm_,
816            digest, sizeof(digest), &digest_length)) {
817      LOG(LS_ERROR) << "Digest computation failed";
818    } else {
819      Buffer computed_digest(digest, digest_length);
820      if (computed_digest == stream->peer_certificate_digest_value_) {
821        LOG(LS_INFO) << "Accepted peer certificate";
822        stream->cert_ok_ = true;
823      }
824    }
825  } else {
826    // Other modes, but we haven't implemented yet
827    // TODO(ekr@rtfm.com): Implement real certificate validation
828    UNIMPLEMENTED;
829  }
830
831  if (!stream->cert_ok_ && stream->ignore_bad_cert()) {
832    LOG(LS_WARNING) << "Ignoring cert error while verifying cert chain";
833    stream->cert_ok_ = true;
834  }
835
836  if (stream->cert_ok_)
837    stream->peer_certificate_.reset(new NSSCertificate(cert_list));
838
839  CERT_DestroyCertList(cert_list);
840
841  if (stream->cert_ok_)
842    return SECSuccess;
843
844  PORT_SetError(SEC_ERROR_UNTRUSTED_CERT);
845  return SECFailure;
846}
847
848
849SECStatus NSSStreamAdapter::GetClientAuthDataHook(void *arg, PRFileDesc *fd,
850                                                  CERTDistNames *caNames,
851                                                  CERTCertificate **pRetCert,
852                                                  SECKEYPrivateKey **pRetKey) {
853  LOG(LS_INFO) << "Client cert requested";
854  NSSStreamAdapter *stream = reinterpret_cast<NSSStreamAdapter *>(arg);
855
856  if (!stream->identity_.get()) {
857    LOG(LS_ERROR) << "No identity available";
858    return SECFailure;
859  }
860
861  NSSIdentity *identity = static_cast<NSSIdentity *>(stream->identity_.get());
862  // Destroyed internally by NSS
863  *pRetCert = CERT_DupCertificate(identity->certificate().certificate());
864  *pRetKey = SECKEY_CopyPrivateKey(identity->keypair()->privkey());
865
866  return SECSuccess;
867}
868
869// RFC 5705 Key Exporter
870bool NSSStreamAdapter::ExportKeyingMaterial(const std::string& label,
871                                            const uint8* context,
872                                            size_t context_len,
873                                            bool use_context,
874                                            uint8* result,
875                                            size_t result_len) {
876  SECStatus rv = SSL_ExportKeyingMaterial(
877      ssl_fd_,
878      label.c_str(),
879      checked_cast<unsigned int>(label.size()),
880      use_context,
881      context,
882      checked_cast<unsigned int>(context_len),
883      result,
884      checked_cast<unsigned int>(result_len));
885
886  return rv == SECSuccess;
887}
888
889bool NSSStreamAdapter::SetDtlsSrtpCiphers(
890    const std::vector<std::string>& ciphers) {
891#ifdef HAVE_DTLS_SRTP
892  std::vector<PRUint16> internal_ciphers;
893  if (state_ != SSL_NONE)
894    return false;
895
896  for (std::vector<std::string>::const_iterator cipher = ciphers.begin();
897       cipher != ciphers.end(); ++cipher) {
898    bool found = false;
899    for (const SrtpCipherMapEntry *entry = kSrtpCipherMap; entry->cipher_id;
900         ++entry) {
901      if (*cipher == entry->external_name) {
902        found = true;
903        internal_ciphers.push_back(entry->cipher_id);
904        break;
905      }
906    }
907
908    if (!found) {
909      LOG(LS_ERROR) << "Could not find cipher: " << *cipher;
910      return false;
911    }
912  }
913
914  if (internal_ciphers.empty())
915    return false;
916
917  srtp_ciphers_ = internal_ciphers;
918
919  return true;
920#else
921  return false;
922#endif
923}
924
925bool NSSStreamAdapter::GetDtlsSrtpCipher(std::string* cipher) {
926#ifdef HAVE_DTLS_SRTP
927  ASSERT(state_ == SSL_CONNECTED);
928  if (state_ != SSL_CONNECTED)
929    return false;
930
931  PRUint16 selected_cipher;
932
933  SECStatus rv = SSL_GetSRTPCipher(ssl_fd_, &selected_cipher);
934  if (rv == SECFailure)
935    return false;
936
937  for (const SrtpCipherMapEntry *entry = kSrtpCipherMap;
938       entry->cipher_id; ++entry) {
939    if (selected_cipher == entry->cipher_id) {
940      *cipher = entry->external_name;
941      return true;
942    }
943  }
944
945  ASSERT(false);  // This should never happen
946#endif
947  return false;
948}
949
950
951bool NSSContext::initialized;
952NSSContext *NSSContext::global_nss_context;
953
954// Static initialization and shutdown
955NSSContext *NSSContext::Instance() {
956  if (!global_nss_context) {
957    scoped_ptr<NSSContext> new_ctx(new NSSContext());
958    new_ctx->slot_ = PK11_GetInternalSlot();
959    if (new_ctx->slot_)
960      global_nss_context = new_ctx.release();
961  }
962  return global_nss_context;
963}
964
965
966
967bool NSSContext::InitializeSSL(VerificationCallback callback) {
968  ASSERT(!callback);
969
970  if (!initialized) {
971    SECStatus rv;
972
973    rv = NSS_NoDB_Init(NULL);
974    if (rv != SECSuccess) {
975      LOG(LS_ERROR) << "Couldn't initialize NSS error=" <<
976          PORT_GetError();
977      return false;
978    }
979
980    NSS_SetDomesticPolicy();
981
982    initialized = true;
983  }
984
985  return true;
986}
987
988bool NSSContext::InitializeSSLThread() {
989  // Not needed
990  return true;
991}
992
993bool NSSContext::CleanupSSL() {
994  // Not needed
995  return true;
996}
997
998bool NSSStreamAdapter::HaveDtls() {
999  return true;
1000}
1001
1002bool NSSStreamAdapter::HaveDtlsSrtp() {
1003#ifdef HAVE_DTLS_SRTP
1004  return true;
1005#else
1006  return false;
1007#endif
1008}
1009
1010bool NSSStreamAdapter::HaveExporter() {
1011  return true;
1012}
1013
1014}  // namespace rtc
1015
1016#endif  // HAVE_NSS_SSL_H
1017