1/*
2 *  Copyright 2015 The WebRTC Project Authors. All rights reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "webrtc/p2p/base/transportcontroller.h"
12
13#include <algorithm>
14
15#include "webrtc/base/bind.h"
16#include "webrtc/base/checks.h"
17#include "webrtc/base/thread.h"
18#include "webrtc/p2p/base/dtlstransport.h"
19#include "webrtc/p2p/base/p2ptransport.h"
20#include "webrtc/p2p/base/port.h"
21
22namespace cricket {
23
24enum {
25  MSG_ICECONNECTIONSTATE,
26  MSG_RECEIVING,
27  MSG_ICEGATHERINGSTATE,
28  MSG_CANDIDATESGATHERED,
29};
30
31struct CandidatesData : public rtc::MessageData {
32  CandidatesData(const std::string& transport_name,
33                 const Candidates& candidates)
34      : transport_name(transport_name), candidates(candidates) {}
35
36  std::string transport_name;
37  Candidates candidates;
38};
39
40TransportController::TransportController(rtc::Thread* signaling_thread,
41                                         rtc::Thread* worker_thread,
42                                         PortAllocator* port_allocator)
43    : signaling_thread_(signaling_thread),
44      worker_thread_(worker_thread),
45      port_allocator_(port_allocator) {}
46
47TransportController::~TransportController() {
48  worker_thread_->Invoke<void>(
49      rtc::Bind(&TransportController::DestroyAllTransports_w, this));
50  signaling_thread_->Clear(this);
51}
52
53bool TransportController::SetSslMaxProtocolVersion(
54    rtc::SSLProtocolVersion version) {
55  return worker_thread_->Invoke<bool>(rtc::Bind(
56      &TransportController::SetSslMaxProtocolVersion_w, this, version));
57}
58
59void TransportController::SetIceConfig(const IceConfig& config) {
60  worker_thread_->Invoke<void>(
61      rtc::Bind(&TransportController::SetIceConfig_w, this, config));
62}
63
64void TransportController::SetIceRole(IceRole ice_role) {
65  worker_thread_->Invoke<void>(
66      rtc::Bind(&TransportController::SetIceRole_w, this, ice_role));
67}
68
69bool TransportController::GetSslRole(const std::string& transport_name,
70                                     rtc::SSLRole* role) {
71  return worker_thread_->Invoke<bool>(rtc::Bind(
72      &TransportController::GetSslRole_w, this, transport_name, role));
73}
74
75bool TransportController::SetLocalCertificate(
76    const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
77  return worker_thread_->Invoke<bool>(rtc::Bind(
78      &TransportController::SetLocalCertificate_w, this, certificate));
79}
80
81bool TransportController::GetLocalCertificate(
82    const std::string& transport_name,
83    rtc::scoped_refptr<rtc::RTCCertificate>* certificate) {
84  return worker_thread_->Invoke<bool>(
85      rtc::Bind(&TransportController::GetLocalCertificate_w, this,
86                transport_name, certificate));
87}
88
89bool TransportController::GetRemoteSSLCertificate(
90    const std::string& transport_name,
91    rtc::SSLCertificate** cert) {
92  return worker_thread_->Invoke<bool>(
93      rtc::Bind(&TransportController::GetRemoteSSLCertificate_w, this,
94                transport_name, cert));
95}
96
97bool TransportController::SetLocalTransportDescription(
98    const std::string& transport_name,
99    const TransportDescription& tdesc,
100    ContentAction action,
101    std::string* err) {
102  return worker_thread_->Invoke<bool>(
103      rtc::Bind(&TransportController::SetLocalTransportDescription_w, this,
104                transport_name, tdesc, action, err));
105}
106
107bool TransportController::SetRemoteTransportDescription(
108    const std::string& transport_name,
109    const TransportDescription& tdesc,
110    ContentAction action,
111    std::string* err) {
112  return worker_thread_->Invoke<bool>(
113      rtc::Bind(&TransportController::SetRemoteTransportDescription_w, this,
114                transport_name, tdesc, action, err));
115}
116
117void TransportController::MaybeStartGathering() {
118  worker_thread_->Invoke<void>(
119      rtc::Bind(&TransportController::MaybeStartGathering_w, this));
120}
121
122bool TransportController::AddRemoteCandidates(const std::string& transport_name,
123                                              const Candidates& candidates,
124                                              std::string* err) {
125  return worker_thread_->Invoke<bool>(
126      rtc::Bind(&TransportController::AddRemoteCandidates_w, this,
127                transport_name, candidates, err));
128}
129
130bool TransportController::ReadyForRemoteCandidates(
131    const std::string& transport_name) {
132  return worker_thread_->Invoke<bool>(rtc::Bind(
133      &TransportController::ReadyForRemoteCandidates_w, this, transport_name));
134}
135
136bool TransportController::GetStats(const std::string& transport_name,
137                                   TransportStats* stats) {
138  return worker_thread_->Invoke<bool>(
139      rtc::Bind(&TransportController::GetStats_w, this, transport_name, stats));
140}
141
142TransportChannel* TransportController::CreateTransportChannel_w(
143    const std::string& transport_name,
144    int component) {
145  RTC_DCHECK(worker_thread_->IsCurrent());
146
147  auto it = FindChannel_w(transport_name, component);
148  if (it != channels_.end()) {
149    // Channel already exists; increment reference count and return.
150    it->AddRef();
151    return it->get();
152  }
153
154  // Need to create a new channel.
155  Transport* transport = GetOrCreateTransport_w(transport_name);
156  TransportChannelImpl* channel = transport->CreateChannel(component);
157  channel->SignalWritableState.connect(
158      this, &TransportController::OnChannelWritableState_w);
159  channel->SignalReceivingState.connect(
160      this, &TransportController::OnChannelReceivingState_w);
161  channel->SignalGatheringState.connect(
162      this, &TransportController::OnChannelGatheringState_w);
163  channel->SignalCandidateGathered.connect(
164      this, &TransportController::OnChannelCandidateGathered_w);
165  channel->SignalRoleConflict.connect(
166      this, &TransportController::OnChannelRoleConflict_w);
167  channel->SignalConnectionRemoved.connect(
168      this, &TransportController::OnChannelConnectionRemoved_w);
169  channels_.insert(channels_.end(), RefCountedChannel(channel))->AddRef();
170  // Adding a channel could cause aggregate state to change.
171  UpdateAggregateStates_w();
172  return channel;
173}
174
175void TransportController::DestroyTransportChannel_w(
176    const std::string& transport_name,
177    int component) {
178  RTC_DCHECK(worker_thread_->IsCurrent());
179
180  auto it = FindChannel_w(transport_name, component);
181  if (it == channels_.end()) {
182    LOG(LS_WARNING) << "Attempting to delete " << transport_name
183                    << " TransportChannel " << component
184                    << ", which doesn't exist.";
185    return;
186  }
187
188  it->DecRef();
189  if (it->ref() > 0) {
190    return;
191  }
192
193  channels_.erase(it);
194  Transport* transport = GetTransport_w(transport_name);
195  transport->DestroyChannel(component);
196  // Just as we create a Transport when its first channel is created,
197  // we delete it when its last channel is deleted.
198  if (!transport->HasChannels()) {
199    DestroyTransport_w(transport_name);
200  }
201  // Removing a channel could cause aggregate state to change.
202  UpdateAggregateStates_w();
203}
204
205const rtc::scoped_refptr<rtc::RTCCertificate>&
206TransportController::certificate_for_testing() {
207  return certificate_;
208}
209
210Transport* TransportController::CreateTransport_w(
211    const std::string& transport_name) {
212  RTC_DCHECK(worker_thread_->IsCurrent());
213
214  Transport* transport = new DtlsTransport<P2PTransport>(
215      transport_name, port_allocator(), certificate_);
216  return transport;
217}
218
219Transport* TransportController::GetTransport_w(
220    const std::string& transport_name) {
221  RTC_DCHECK(worker_thread_->IsCurrent());
222
223  auto iter = transports_.find(transport_name);
224  return (iter != transports_.end()) ? iter->second : nullptr;
225}
226
227void TransportController::OnMessage(rtc::Message* pmsg) {
228  RTC_DCHECK(signaling_thread_->IsCurrent());
229
230  switch (pmsg->message_id) {
231    case MSG_ICECONNECTIONSTATE: {
232      rtc::TypedMessageData<IceConnectionState>* data =
233          static_cast<rtc::TypedMessageData<IceConnectionState>*>(pmsg->pdata);
234      SignalConnectionState(data->data());
235      delete data;
236      break;
237    }
238    case MSG_RECEIVING: {
239      rtc::TypedMessageData<bool>* data =
240          static_cast<rtc::TypedMessageData<bool>*>(pmsg->pdata);
241      SignalReceiving(data->data());
242      delete data;
243      break;
244    }
245    case MSG_ICEGATHERINGSTATE: {
246      rtc::TypedMessageData<IceGatheringState>* data =
247          static_cast<rtc::TypedMessageData<IceGatheringState>*>(pmsg->pdata);
248      SignalGatheringState(data->data());
249      delete data;
250      break;
251    }
252    case MSG_CANDIDATESGATHERED: {
253      CandidatesData* data = static_cast<CandidatesData*>(pmsg->pdata);
254      SignalCandidatesGathered(data->transport_name, data->candidates);
255      delete data;
256      break;
257    }
258    default:
259      ASSERT(false);
260  }
261}
262
263std::vector<TransportController::RefCountedChannel>::iterator
264TransportController::FindChannel_w(const std::string& transport_name,
265                                   int component) {
266  return std::find_if(
267      channels_.begin(), channels_.end(),
268      [transport_name, component](const RefCountedChannel& channel) {
269        return channel->transport_name() == transport_name &&
270               channel->component() == component;
271      });
272}
273
274Transport* TransportController::GetOrCreateTransport_w(
275    const std::string& transport_name) {
276  RTC_DCHECK(worker_thread_->IsCurrent());
277
278  Transport* transport = GetTransport_w(transport_name);
279  if (transport) {
280    return transport;
281  }
282
283  transport = CreateTransport_w(transport_name);
284  // The stuff below happens outside of CreateTransport_w so that unit tests
285  // can override CreateTransport_w to return a different type of transport.
286  transport->SetSslMaxProtocolVersion(ssl_max_version_);
287  transport->SetIceConfig(ice_config_);
288  transport->SetIceRole(ice_role_);
289  transport->SetIceTiebreaker(ice_tiebreaker_);
290  if (certificate_) {
291    transport->SetLocalCertificate(certificate_);
292  }
293  transports_[transport_name] = transport;
294
295  return transport;
296}
297
298void TransportController::DestroyTransport_w(
299    const std::string& transport_name) {
300  RTC_DCHECK(worker_thread_->IsCurrent());
301
302  auto iter = transports_.find(transport_name);
303  if (iter != transports_.end()) {
304    delete iter->second;
305    transports_.erase(transport_name);
306  }
307}
308
309void TransportController::DestroyAllTransports_w() {
310  RTC_DCHECK(worker_thread_->IsCurrent());
311
312  for (const auto& kv : transports_) {
313    delete kv.second;
314  }
315  transports_.clear();
316}
317
318bool TransportController::SetSslMaxProtocolVersion_w(
319    rtc::SSLProtocolVersion version) {
320  RTC_DCHECK(worker_thread_->IsCurrent());
321
322  // Max SSL version can only be set before transports are created.
323  if (!transports_.empty()) {
324    return false;
325  }
326
327  ssl_max_version_ = version;
328  return true;
329}
330
331void TransportController::SetIceConfig_w(const IceConfig& config) {
332  RTC_DCHECK(worker_thread_->IsCurrent());
333  ice_config_ = config;
334  for (const auto& kv : transports_) {
335    kv.second->SetIceConfig(ice_config_);
336  }
337}
338
339void TransportController::SetIceRole_w(IceRole ice_role) {
340  RTC_DCHECK(worker_thread_->IsCurrent());
341  ice_role_ = ice_role;
342  for (const auto& kv : transports_) {
343    kv.second->SetIceRole(ice_role_);
344  }
345}
346
347bool TransportController::GetSslRole_w(const std::string& transport_name,
348                                       rtc::SSLRole* role) {
349  RTC_DCHECK(worker_thread()->IsCurrent());
350
351  Transport* t = GetTransport_w(transport_name);
352  if (!t) {
353    return false;
354  }
355
356  return t->GetSslRole(role);
357}
358
359bool TransportController::SetLocalCertificate_w(
360    const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) {
361  RTC_DCHECK(worker_thread_->IsCurrent());
362
363  if (certificate_) {
364    return false;
365  }
366  if (!certificate) {
367    return false;
368  }
369  certificate_ = certificate;
370
371  for (const auto& kv : transports_) {
372    kv.second->SetLocalCertificate(certificate_);
373  }
374  return true;
375}
376
377bool TransportController::GetLocalCertificate_w(
378    const std::string& transport_name,
379    rtc::scoped_refptr<rtc::RTCCertificate>* certificate) {
380  RTC_DCHECK(worker_thread_->IsCurrent());
381
382  Transport* t = GetTransport_w(transport_name);
383  if (!t) {
384    return false;
385  }
386
387  return t->GetLocalCertificate(certificate);
388}
389
390bool TransportController::GetRemoteSSLCertificate_w(
391    const std::string& transport_name,
392    rtc::SSLCertificate** cert) {
393  RTC_DCHECK(worker_thread_->IsCurrent());
394
395  Transport* t = GetTransport_w(transport_name);
396  if (!t) {
397    return false;
398  }
399
400  return t->GetRemoteSSLCertificate(cert);
401}
402
403bool TransportController::SetLocalTransportDescription_w(
404    const std::string& transport_name,
405    const TransportDescription& tdesc,
406    ContentAction action,
407    std::string* err) {
408  RTC_DCHECK(worker_thread()->IsCurrent());
409
410  Transport* transport = GetTransport_w(transport_name);
411  if (!transport) {
412    // If we didn't find a transport, that's not an error;
413    // it could have been deleted as a result of bundling.
414    // TODO(deadbeef): Make callers smarter so they won't attempt to set a
415    // description on a deleted transport.
416    return true;
417  }
418
419  return transport->SetLocalTransportDescription(tdesc, action, err);
420}
421
422bool TransportController::SetRemoteTransportDescription_w(
423    const std::string& transport_name,
424    const TransportDescription& tdesc,
425    ContentAction action,
426    std::string* err) {
427  RTC_DCHECK(worker_thread()->IsCurrent());
428
429  Transport* transport = GetTransport_w(transport_name);
430  if (!transport) {
431    // If we didn't find a transport, that's not an error;
432    // it could have been deleted as a result of bundling.
433    // TODO(deadbeef): Make callers smarter so they won't attempt to set a
434    // description on a deleted transport.
435    return true;
436  }
437
438  return transport->SetRemoteTransportDescription(tdesc, action, err);
439}
440
441void TransportController::MaybeStartGathering_w() {
442  for (const auto& kv : transports_) {
443    kv.second->MaybeStartGathering();
444  }
445}
446
447bool TransportController::AddRemoteCandidates_w(
448    const std::string& transport_name,
449    const Candidates& candidates,
450    std::string* err) {
451  RTC_DCHECK(worker_thread()->IsCurrent());
452
453  Transport* transport = GetTransport_w(transport_name);
454  if (!transport) {
455    // If we didn't find a transport, that's not an error;
456    // it could have been deleted as a result of bundling.
457    return true;
458  }
459
460  return transport->AddRemoteCandidates(candidates, err);
461}
462
463bool TransportController::ReadyForRemoteCandidates_w(
464    const std::string& transport_name) {
465  RTC_DCHECK(worker_thread()->IsCurrent());
466
467  Transport* transport = GetTransport_w(transport_name);
468  if (!transport) {
469    return false;
470  }
471  return transport->ready_for_remote_candidates();
472}
473
474bool TransportController::GetStats_w(const std::string& transport_name,
475                                     TransportStats* stats) {
476  RTC_DCHECK(worker_thread()->IsCurrent());
477
478  Transport* transport = GetTransport_w(transport_name);
479  if (!transport) {
480    return false;
481  }
482  return transport->GetStats(stats);
483}
484
485void TransportController::OnChannelWritableState_w(TransportChannel* channel) {
486  RTC_DCHECK(worker_thread_->IsCurrent());
487  LOG(LS_INFO) << channel->transport_name() << " TransportChannel "
488               << channel->component() << " writability changed to "
489               << channel->writable() << ".";
490  UpdateAggregateStates_w();
491}
492
493void TransportController::OnChannelReceivingState_w(TransportChannel* channel) {
494  RTC_DCHECK(worker_thread_->IsCurrent());
495  UpdateAggregateStates_w();
496}
497
498void TransportController::OnChannelGatheringState_w(
499    TransportChannelImpl* channel) {
500  RTC_DCHECK(worker_thread_->IsCurrent());
501  UpdateAggregateStates_w();
502}
503
504void TransportController::OnChannelCandidateGathered_w(
505    TransportChannelImpl* channel,
506    const Candidate& candidate) {
507  RTC_DCHECK(worker_thread_->IsCurrent());
508
509  // We should never signal peer-reflexive candidates.
510  if (candidate.type() == PRFLX_PORT_TYPE) {
511    RTC_DCHECK(false);
512    return;
513  }
514  std::vector<Candidate> candidates;
515  candidates.push_back(candidate);
516  CandidatesData* data =
517      new CandidatesData(channel->transport_name(), candidates);
518  signaling_thread_->Post(this, MSG_CANDIDATESGATHERED, data);
519}
520
521void TransportController::OnChannelRoleConflict_w(
522    TransportChannelImpl* channel) {
523  RTC_DCHECK(worker_thread_->IsCurrent());
524
525  if (ice_role_switch_) {
526    LOG(LS_WARNING)
527        << "Repeat of role conflict signal from TransportChannelImpl.";
528    return;
529  }
530
531  ice_role_switch_ = true;
532  IceRole reversed_role = (ice_role_ == ICEROLE_CONTROLLING)
533                              ? ICEROLE_CONTROLLED
534                              : ICEROLE_CONTROLLING;
535  for (const auto& kv : transports_) {
536    kv.second->SetIceRole(reversed_role);
537  }
538}
539
540void TransportController::OnChannelConnectionRemoved_w(
541    TransportChannelImpl* channel) {
542  RTC_DCHECK(worker_thread_->IsCurrent());
543  LOG(LS_INFO) << channel->transport_name() << " TransportChannel "
544               << channel->component()
545               << " connection removed. Check if state is complete.";
546  UpdateAggregateStates_w();
547}
548
549void TransportController::UpdateAggregateStates_w() {
550  RTC_DCHECK(worker_thread_->IsCurrent());
551
552  IceConnectionState new_connection_state = kIceConnectionConnecting;
553  IceGatheringState new_gathering_state = kIceGatheringNew;
554  bool any_receiving = false;
555  bool any_failed = false;
556  bool all_connected = !channels_.empty();
557  bool all_completed = !channels_.empty();
558  bool any_gathering = false;
559  bool all_done_gathering = !channels_.empty();
560  for (const auto& channel : channels_) {
561    any_receiving = any_receiving || channel->receiving();
562    any_failed = any_failed ||
563                 channel->GetState() == TransportChannelState::STATE_FAILED;
564    all_connected = all_connected && channel->writable();
565    all_completed =
566        all_completed && channel->writable() &&
567        channel->GetState() == TransportChannelState::STATE_COMPLETED &&
568        channel->GetIceRole() == ICEROLE_CONTROLLING &&
569        channel->gathering_state() == kIceGatheringComplete;
570    any_gathering =
571        any_gathering || channel->gathering_state() != kIceGatheringNew;
572    all_done_gathering = all_done_gathering &&
573                         channel->gathering_state() == kIceGatheringComplete;
574  }
575
576  if (any_failed) {
577    new_connection_state = kIceConnectionFailed;
578  } else if (all_completed) {
579    new_connection_state = kIceConnectionCompleted;
580  } else if (all_connected) {
581    new_connection_state = kIceConnectionConnected;
582  }
583  if (connection_state_ != new_connection_state) {
584    connection_state_ = new_connection_state;
585    signaling_thread_->Post(
586        this, MSG_ICECONNECTIONSTATE,
587        new rtc::TypedMessageData<IceConnectionState>(new_connection_state));
588  }
589
590  if (receiving_ != any_receiving) {
591    receiving_ = any_receiving;
592    signaling_thread_->Post(this, MSG_RECEIVING,
593                            new rtc::TypedMessageData<bool>(any_receiving));
594  }
595
596  if (all_done_gathering) {
597    new_gathering_state = kIceGatheringComplete;
598  } else if (any_gathering) {
599    new_gathering_state = kIceGatheringGathering;
600  }
601  if (gathering_state_ != new_gathering_state) {
602    gathering_state_ = new_gathering_state;
603    signaling_thread_->Post(
604        this, MSG_ICEGATHERINGSTATE,
605        new rtc::TypedMessageData<IceGatheringState>(new_gathering_state));
606  }
607}
608
609}  // namespace cricket
610