1/*
2 * libjingle
3 * Copyright 2011, Google Inc.
4 * Copyright 2011, RTFM, Inc.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions are met:
8 *
9 *  1. Redistributions of source code must retain the above copyright notice,
10 *     this list of conditions and the following disclaimer.
11 *  2. Redistributions in binary form must reproduce the above copyright notice,
12 *     this list of conditions and the following disclaimer in the documentation
13 *     and/or other materials provided with the distribution.
14 *  3. The name of the author may not be used to endorse or promote products
15 *     derived from this software without specific prior written permission.
16 *
17 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
18 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
19 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
20 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
23 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
24 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
25 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
26 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 */
28
29#include "talk/p2p/base/dtlstransportchannel.h"
30
31#include "talk/base/buffer.h"
32#include "talk/base/dscp.h"
33#include "talk/base/messagequeue.h"
34#include "talk/base/stream.h"
35#include "talk/base/sslstreamadapter.h"
36#include "talk/base/thread.h"
37#include "talk/p2p/base/common.h"
38
39namespace cricket {
40
41// We don't pull the RTP constants from rtputils.h, to avoid a layer violation.
42static const size_t kDtlsRecordHeaderLen = 13;
43static const size_t kMaxDtlsPacketLen = 2048;
44static const size_t kMinRtpPacketLen = 12;
45static const size_t kDefaultVideoAndDataCryptos = 1;
46
47static bool IsDtlsPacket(const char* data, size_t len) {
48  const uint8* u = reinterpret_cast<const uint8*>(data);
49  return (len >= kDtlsRecordHeaderLen && (u[0] > 19 && u[0] < 64));
50}
51static bool IsRtpPacket(const char* data, size_t len) {
52  const uint8* u = reinterpret_cast<const uint8*>(data);
53  return (len >= kMinRtpPacketLen && (u[0] & 0xC0) == 0x80);
54}
55
56talk_base::StreamResult StreamInterfaceChannel::Read(void* buffer,
57                                                     size_t buffer_len,
58                                                     size_t* read,
59                                                     int* error) {
60  if (state_ == talk_base::SS_CLOSED)
61    return talk_base::SR_EOS;
62  if (state_ == talk_base::SS_OPENING)
63    return talk_base::SR_BLOCK;
64
65  return fifo_.Read(buffer, buffer_len, read, error);
66}
67
68talk_base::StreamResult StreamInterfaceChannel::Write(const void* data,
69                                                      size_t data_len,
70                                                      size_t* written,
71                                                      int* error) {
72  // Always succeeds, since this is an unreliable transport anyway.
73  // TODO: Should this block if channel_'s temporarily unwritable?
74  channel_->SendPacket(
75      static_cast<const char*>(data), data_len, talk_base::DSCP_NO_CHANGE);
76  if (written) {
77    *written = data_len;
78  }
79  return talk_base::SR_SUCCESS;
80}
81
82bool StreamInterfaceChannel::OnPacketReceived(const char* data, size_t size) {
83  // We force a read event here to ensure that we don't overflow our FIFO.
84  // Under high packet rate this can occur if we wait for the FIFO to post its
85  // own SE_READ.
86  bool ret = (fifo_.WriteAll(data, size, NULL, NULL) == talk_base::SR_SUCCESS);
87  if (ret) {
88    SignalEvent(this, talk_base::SE_READ, 0);
89  }
90  return ret;
91}
92
93void StreamInterfaceChannel::OnEvent(talk_base::StreamInterface* stream,
94                                     int sig, int err) {
95  SignalEvent(this, sig, err);
96}
97
98DtlsTransportChannelWrapper::DtlsTransportChannelWrapper(
99                                           Transport* transport,
100                                           TransportChannelImpl* channel)
101    : TransportChannelImpl(channel->content_name(), channel->component()),
102      transport_(transport),
103      worker_thread_(talk_base::Thread::Current()),
104      channel_(channel),
105      downward_(NULL),
106      dtls_state_(STATE_NONE),
107      local_identity_(NULL),
108      ssl_role_(talk_base::SSL_CLIENT) {
109  channel_->SignalReadableState.connect(this,
110      &DtlsTransportChannelWrapper::OnReadableState);
111  channel_->SignalWritableState.connect(this,
112      &DtlsTransportChannelWrapper::OnWritableState);
113  channel_->SignalReadPacket.connect(this,
114      &DtlsTransportChannelWrapper::OnReadPacket);
115  channel_->SignalReadyToSend.connect(this,
116      &DtlsTransportChannelWrapper::OnReadyToSend);
117  channel_->SignalRequestSignaling.connect(this,
118      &DtlsTransportChannelWrapper::OnRequestSignaling);
119  channel_->SignalCandidateReady.connect(this,
120      &DtlsTransportChannelWrapper::OnCandidateReady);
121  channel_->SignalCandidatesAllocationDone.connect(this,
122      &DtlsTransportChannelWrapper::OnCandidatesAllocationDone);
123  channel_->SignalRoleConflict.connect(this,
124      &DtlsTransportChannelWrapper::OnRoleConflict);
125  channel_->SignalRouteChange.connect(this,
126      &DtlsTransportChannelWrapper::OnRouteChange);
127}
128
129DtlsTransportChannelWrapper::~DtlsTransportChannelWrapper() {
130}
131
132void DtlsTransportChannelWrapper::Connect() {
133  // We should only get a single call to Connect.
134  ASSERT(dtls_state_ == STATE_NONE ||
135         dtls_state_ == STATE_OFFERED ||
136         dtls_state_ == STATE_ACCEPTED);
137  channel_->Connect();
138}
139
140void DtlsTransportChannelWrapper::Reset() {
141  channel_->Reset();
142  set_writable(false);
143  set_readable(false);
144
145  // Re-call SetupDtls()
146  if (!SetupDtls()) {
147    LOG_J(LS_ERROR, this) << "Error re-initializing DTLS";
148    dtls_state_ = STATE_CLOSED;
149    return;
150  }
151
152  dtls_state_ = STATE_ACCEPTED;
153}
154
155bool DtlsTransportChannelWrapper::SetLocalIdentity(
156    talk_base::SSLIdentity* identity) {
157  if (dtls_state_ == STATE_OPEN && identity == local_identity_) {
158    return true;
159  }
160
161  // TODO(ekr@rtfm.com): Forbid this if Connect() has been called.
162  if (dtls_state_ != STATE_NONE) {
163    LOG_J(LS_ERROR, this) << "Can't set DTLS local identity in this state";
164    return false;
165  }
166
167  if (identity) {
168    local_identity_ = identity;
169    dtls_state_ = STATE_OFFERED;
170  } else {
171    LOG_J(LS_INFO, this) << "NULL DTLS identity supplied. Not doing DTLS";
172  }
173
174  return true;
175}
176
177bool DtlsTransportChannelWrapper::GetLocalIdentity(
178    talk_base::SSLIdentity** identity) const {
179  if (!local_identity_)
180    return false;
181
182  *identity = local_identity_->GetReference();
183  return true;
184}
185
186bool DtlsTransportChannelWrapper::SetSslRole(talk_base::SSLRole role) {
187  if (dtls_state_ == STATE_OPEN) {
188    if (ssl_role_ != role) {
189      LOG(LS_ERROR) << "SSL Role can't be reversed after the session is setup.";
190      return false;
191    }
192    return true;
193  }
194
195  ssl_role_ = role;
196  return true;
197}
198
199bool DtlsTransportChannelWrapper::GetSslRole(talk_base::SSLRole* role) const {
200  *role = ssl_role_;
201  return true;
202}
203
204bool DtlsTransportChannelWrapper::SetRemoteFingerprint(
205    const std::string& digest_alg,
206    const uint8* digest,
207    size_t digest_len) {
208
209  talk_base::Buffer remote_fingerprint_value(digest, digest_len);
210
211  if ((dtls_state_ == STATE_OPEN) &&
212      (remote_fingerprint_value_ == remote_fingerprint_value)) {
213    return true;
214  }
215
216  // Allow SetRemoteFingerprint with a NULL digest even if SetLocalIdentity
217  // hasn't been called.
218  if (dtls_state_ > STATE_OFFERED ||
219      (dtls_state_ == STATE_NONE && !digest_alg.empty())) {
220    LOG_J(LS_ERROR, this) << "Can't set DTLS remote settings in this state.";
221    return false;
222  }
223
224  if (digest_alg.empty()) {
225    LOG_J(LS_INFO, this) << "Other side didn't support DTLS.";
226    dtls_state_ = STATE_NONE;
227    return true;
228  }
229
230  // At this point we know we are doing DTLS
231  remote_fingerprint_value.TransferTo(&remote_fingerprint_value_);
232  remote_fingerprint_algorithm_ = digest_alg;
233
234  if (!SetupDtls()) {
235    dtls_state_ = STATE_CLOSED;
236    return false;
237  }
238
239  dtls_state_ = STATE_ACCEPTED;
240  return true;
241}
242
243bool DtlsTransportChannelWrapper::GetRemoteCertificate(
244    talk_base::SSLCertificate** cert) const {
245  if (!dtls_)
246    return false;
247
248  return dtls_->GetPeerCertificate(cert);
249}
250
251bool DtlsTransportChannelWrapper::SetupDtls() {
252  StreamInterfaceChannel* downward =
253      new StreamInterfaceChannel(worker_thread_, channel_);
254
255  dtls_.reset(talk_base::SSLStreamAdapter::Create(downward));
256  if (!dtls_) {
257    LOG_J(LS_ERROR, this) << "Failed to create DTLS adapter.";
258    delete downward;
259    return false;
260  }
261
262  downward_ = downward;
263
264  dtls_->SetIdentity(local_identity_->GetReference());
265  dtls_->SetMode(talk_base::SSL_MODE_DTLS);
266  dtls_->SetServerRole(ssl_role_);
267  dtls_->SignalEvent.connect(this, &DtlsTransportChannelWrapper::OnDtlsEvent);
268  if (!dtls_->SetPeerCertificateDigest(
269          remote_fingerprint_algorithm_,
270          reinterpret_cast<unsigned char *>(remote_fingerprint_value_.data()),
271          remote_fingerprint_value_.length())) {
272    LOG_J(LS_ERROR, this) << "Couldn't set DTLS certificate digest.";
273    return false;
274  }
275
276  // Set up DTLS-SRTP, if it's been enabled.
277  if (!srtp_ciphers_.empty()) {
278    if (!dtls_->SetDtlsSrtpCiphers(srtp_ciphers_)) {
279      LOG_J(LS_ERROR, this) << "Couldn't set DTLS-SRTP ciphers.";
280      return false;
281    }
282  } else {
283    LOG_J(LS_INFO, this) << "Not using DTLS.";
284  }
285
286  LOG_J(LS_INFO, this) << "DTLS setup complete.";
287  return true;
288}
289
290bool DtlsTransportChannelWrapper::SetSrtpCiphers(
291    const std::vector<std::string>& ciphers) {
292  if (srtp_ciphers_ == ciphers)
293    return true;
294
295  if (dtls_state_ == STATE_OPEN) {
296    // We don't support DTLS renegotiation currently. If new set of srtp ciphers
297    // are different than what's being used currently, we will not use it.
298    // So for now, let's be happy (or sad) with a warning message.
299    std::string current_srtp_cipher;
300    if (!dtls_->GetDtlsSrtpCipher(&current_srtp_cipher)) {
301      LOG(LS_ERROR) << "Failed to get the current SRTP cipher for DTLS channel";
302      return false;
303    }
304    const std::vector<std::string>::const_iterator iter =
305        std::find(ciphers.begin(), ciphers.end(), current_srtp_cipher);
306    if (iter == ciphers.end()) {
307      std::string requested_str;
308      for (size_t i = 0; i < ciphers.size(); ++i) {
309        requested_str.append(" ");
310        requested_str.append(ciphers[i]);
311        requested_str.append(" ");
312      }
313      LOG(LS_WARNING) << "Ignoring new set of SRTP ciphers, as DTLS "
314                      << "renegotiation is not supported currently "
315                      << "current cipher = " << current_srtp_cipher << " and "
316                      << "requested = " << "[" << requested_str << "]";
317    }
318    return true;
319  }
320
321  if (dtls_state_ != STATE_NONE &&
322      dtls_state_ != STATE_OFFERED &&
323      dtls_state_ != STATE_ACCEPTED) {
324    ASSERT(false);
325    return false;
326  }
327
328  srtp_ciphers_ = ciphers;
329  return true;
330}
331
332bool DtlsTransportChannelWrapper::GetSrtpCipher(std::string* cipher) {
333  if (dtls_state_ != STATE_OPEN) {
334    return false;
335  }
336
337  return dtls_->GetDtlsSrtpCipher(cipher);
338}
339
340
341// Called from upper layers to send a media packet.
342int DtlsTransportChannelWrapper::SendPacket(const char* data, size_t size,
343                                            talk_base::DiffServCodePoint dscp,
344                                            int flags) {
345  int result = -1;
346
347  switch (dtls_state_) {
348    case STATE_OFFERED:
349      // We don't know if we are doing DTLS yet, so we can't send a packet.
350      // TODO(ekr@rtfm.com): assert here?
351      result = -1;
352      break;
353
354    case STATE_STARTED:
355    case STATE_ACCEPTED:
356      // Can't send data until the connection is active
357      result = -1;
358      break;
359
360    case STATE_OPEN:
361      if (flags & PF_SRTP_BYPASS) {
362        ASSERT(!srtp_ciphers_.empty());
363        if (!IsRtpPacket(data, size)) {
364          result = false;
365          break;
366        }
367
368        result = channel_->SendPacket(data, size, dscp);
369      } else {
370        result = (dtls_->WriteAll(data, size, NULL, NULL) ==
371          talk_base::SR_SUCCESS) ? static_cast<int>(size) : -1;
372      }
373      break;
374      // Not doing DTLS.
375    case STATE_NONE:
376      result = channel_->SendPacket(data, size, dscp);
377      break;
378
379    case STATE_CLOSED:  // Can't send anything when we're closed.
380      return -1;
381  }
382
383  return result;
384}
385
386// The state transition logic here is as follows:
387// (1) If we're not doing DTLS-SRTP, then the state is just the
388//     state of the underlying impl()
389// (2) If we're doing DTLS-SRTP:
390//     - Prior to the DTLS handshake, the state is neither readable or
391//       writable
392//     - When the impl goes writable for the first time we
393//       start the DTLS handshake
394//     - Once the DTLS handshake completes, the state is that of the
395//       impl again
396void DtlsTransportChannelWrapper::OnReadableState(TransportChannel* channel) {
397  ASSERT(talk_base::Thread::Current() == worker_thread_);
398  ASSERT(channel == channel_);
399  LOG_J(LS_VERBOSE, this)
400      << "DTLSTransportChannelWrapper: channel readable state changed.";
401
402  if (dtls_state_ == STATE_NONE || dtls_state_ == STATE_OPEN) {
403    set_readable(channel_->readable());
404    // Note: SignalReadableState fired by set_readable.
405  }
406}
407
408void DtlsTransportChannelWrapper::OnWritableState(TransportChannel* channel) {
409  ASSERT(talk_base::Thread::Current() == worker_thread_);
410  ASSERT(channel == channel_);
411  LOG_J(LS_VERBOSE, this)
412      << "DTLSTransportChannelWrapper: channel writable state changed.";
413
414  switch (dtls_state_) {
415    case STATE_NONE:
416    case STATE_OPEN:
417      set_writable(channel_->writable());
418      // Note: SignalWritableState fired by set_writable.
419      break;
420
421    case STATE_OFFERED:
422      // Do nothing
423      break;
424
425    case STATE_ACCEPTED:
426      if (!MaybeStartDtls()) {
427        // This should never happen:
428        // Because we are operating in a nonblocking mode and all
429        // incoming packets come in via OnReadPacket(), which rejects
430        // packets in this state, the incoming queue must be empty. We
431        // ignore write errors, thus any errors must be because of
432        // configuration and therefore are our fault.
433        // Note that in non-debug configurations, failure in
434        // MaybeStartDtls() changes the state to STATE_CLOSED.
435        ASSERT(false);
436      }
437      break;
438
439    case STATE_STARTED:
440      // Do nothing
441      break;
442
443    case STATE_CLOSED:
444      // Should not happen. Do nothing
445      break;
446  }
447}
448
449void DtlsTransportChannelWrapper::OnReadPacket(
450    TransportChannel* channel, const char* data, size_t size,
451    const talk_base::PacketTime& packet_time, int flags) {
452  ASSERT(talk_base::Thread::Current() == worker_thread_);
453  ASSERT(channel == channel_);
454  ASSERT(flags == 0);
455
456  switch (dtls_state_) {
457    case STATE_NONE:
458      // We are not doing DTLS
459      SignalReadPacket(this, data, size, packet_time, 0);
460      break;
461
462    case STATE_OFFERED:
463      // Currently drop the packet, but we might in future
464      // decide to take this as evidence that the other
465      // side is ready to do DTLS and start the handshake
466      // on our end
467      LOG_J(LS_WARNING, this) << "Received packet before we know if we are "
468                              << "doing DTLS or not; dropping.";
469      break;
470
471    case STATE_ACCEPTED:
472      // Drop packets received before DTLS has actually started
473      LOG_J(LS_INFO, this) << "Dropping packet received before DTLS started.";
474      break;
475
476    case STATE_STARTED:
477    case STATE_OPEN:
478      // We should only get DTLS or SRTP packets; STUN's already been demuxed.
479      // Is this potentially a DTLS packet?
480      if (IsDtlsPacket(data, size)) {
481        if (!HandleDtlsPacket(data, size)) {
482          LOG_J(LS_ERROR, this) << "Failed to handle DTLS packet.";
483          return;
484        }
485      } else {
486        // Not a DTLS packet; our handshake should be complete by now.
487        if (dtls_state_ != STATE_OPEN) {
488          LOG_J(LS_ERROR, this) << "Received non-DTLS packet before DTLS "
489                                << "complete.";
490          return;
491        }
492
493        // And it had better be a SRTP packet.
494        if (!IsRtpPacket(data, size)) {
495          LOG_J(LS_ERROR, this) << "Received unexpected non-DTLS packet.";
496          return;
497        }
498
499        // Sanity check.
500        ASSERT(!srtp_ciphers_.empty());
501
502        // Signal this upwards as a bypass packet.
503        SignalReadPacket(this, data, size, packet_time, PF_SRTP_BYPASS);
504      }
505      break;
506    case STATE_CLOSED:
507      // This shouldn't be happening. Drop the packet
508      break;
509  }
510}
511
512void DtlsTransportChannelWrapper::OnReadyToSend(TransportChannel* channel) {
513  if (writable()) {
514    SignalReadyToSend(this);
515  }
516}
517
518void DtlsTransportChannelWrapper::OnDtlsEvent(talk_base::StreamInterface* dtls,
519                                              int sig, int err) {
520  ASSERT(talk_base::Thread::Current() == worker_thread_);
521  ASSERT(dtls == dtls_.get());
522  if (sig & talk_base::SE_OPEN) {
523    // This is the first time.
524    LOG_J(LS_INFO, this) << "DTLS handshake complete.";
525    if (dtls_->GetState() == talk_base::SS_OPEN) {
526      // The check for OPEN shouldn't be necessary but let's make
527      // sure we don't accidentally frob the state if it's closed.
528      dtls_state_ = STATE_OPEN;
529
530      set_readable(true);
531      set_writable(true);
532    }
533  }
534  if (sig & talk_base::SE_READ) {
535    char buf[kMaxDtlsPacketLen];
536    size_t read;
537    if (dtls_->Read(buf, sizeof(buf), &read, NULL) == talk_base::SR_SUCCESS) {
538      SignalReadPacket(this, buf, read, talk_base::CreatePacketTime(0), 0);
539    }
540  }
541  if (sig & talk_base::SE_CLOSE) {
542    ASSERT(sig == talk_base::SE_CLOSE);  // SE_CLOSE should be by itself.
543    if (!err) {
544      LOG_J(LS_INFO, this) << "DTLS channel closed";
545    } else {
546      LOG_J(LS_INFO, this) << "DTLS channel error, code=" << err;
547    }
548
549    set_readable(false);
550    set_writable(false);
551    dtls_state_ = STATE_CLOSED;
552  }
553}
554
555bool DtlsTransportChannelWrapper::MaybeStartDtls() {
556  if (channel_->writable()) {
557    if (dtls_->StartSSLWithPeer()) {
558      LOG_J(LS_ERROR, this) << "Couldn't start DTLS handshake";
559      dtls_state_ = STATE_CLOSED;
560      return false;
561    }
562    LOG_J(LS_INFO, this)
563      << "DtlsTransportChannelWrapper: Started DTLS handshake";
564
565    dtls_state_ = STATE_STARTED;
566  }
567  return true;
568}
569
570// Called from OnReadPacket when a DTLS packet is received.
571bool DtlsTransportChannelWrapper::HandleDtlsPacket(const char* data,
572                                                   size_t size) {
573  // Sanity check we're not passing junk that
574  // just looks like DTLS.
575  const uint8* tmp_data = reinterpret_cast<const uint8* >(data);
576  size_t tmp_size = size;
577  while (tmp_size > 0) {
578    if (tmp_size < kDtlsRecordHeaderLen)
579      return false;  // Too short for the header
580
581    size_t record_len = (tmp_data[11] << 8) | (tmp_data[12]);
582    if ((record_len + kDtlsRecordHeaderLen) > tmp_size)
583      return false;  // Body too short
584
585    tmp_data += record_len + kDtlsRecordHeaderLen;
586    tmp_size -= record_len + kDtlsRecordHeaderLen;
587  }
588
589  // Looks good. Pass to the SIC which ends up being passed to
590  // the DTLS stack.
591  return downward_->OnPacketReceived(data, size);
592}
593
594void DtlsTransportChannelWrapper::OnRequestSignaling(
595    TransportChannelImpl* channel) {
596  ASSERT(channel == channel_);
597  SignalRequestSignaling(this);
598}
599
600void DtlsTransportChannelWrapper::OnCandidateReady(
601    TransportChannelImpl* channel, const Candidate& c) {
602  ASSERT(channel == channel_);
603  SignalCandidateReady(this, c);
604}
605
606void DtlsTransportChannelWrapper::OnCandidatesAllocationDone(
607    TransportChannelImpl* channel) {
608  ASSERT(channel == channel_);
609  SignalCandidatesAllocationDone(this);
610}
611
612void DtlsTransportChannelWrapper::OnRoleConflict(
613    TransportChannelImpl* channel) {
614  ASSERT(channel == channel_);
615  SignalRoleConflict(this);
616}
617
618void DtlsTransportChannelWrapper::OnRouteChange(
619    TransportChannel* channel, const Candidate& candidate) {
620  ASSERT(channel == channel_);
621  SignalRouteChange(this, candidate);
622}
623
624}  // namespace cricket
625