faketransportcontroller.h revision 27dc29b0df23eed5034f28d4d5f66ea0bb425d6c
1/* 2 * Copyright 2009 The WebRTC Project Authors. All rights reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11#ifndef WEBRTC_P2P_BASE_FAKETRANSPORTCONTROLLER_H_ 12#define WEBRTC_P2P_BASE_FAKETRANSPORTCONTROLLER_H_ 13 14#include <map> 15#include <string> 16#include <vector> 17 18#include "webrtc/p2p/base/transport.h" 19#include "webrtc/p2p/base/transportchannel.h" 20#include "webrtc/p2p/base/transportcontroller.h" 21#include "webrtc/p2p/base/transportchannelimpl.h" 22#include "webrtc/base/bind.h" 23#include "webrtc/base/buffer.h" 24#include "webrtc/base/fakesslidentity.h" 25#include "webrtc/base/messagequeue.h" 26#include "webrtc/base/sigslot.h" 27#include "webrtc/base/sslfingerprint.h" 28#include "webrtc/base/thread.h" 29 30namespace cricket { 31 32class FakeTransport; 33 34struct PacketMessageData : public rtc::MessageData { 35 PacketMessageData(const char* data, size_t len) : packet(data, len) {} 36 rtc::Buffer packet; 37}; 38 39// Fake transport channel class, which can be passed to anything that needs a 40// transport channel. Can be informed of another FakeTransportChannel via 41// SetDestination. 42// TODO(hbos): Move implementation to .cc file, this and other classes in file. 43class FakeTransportChannel : public TransportChannelImpl, 44 public rtc::MessageHandler { 45 public: 46 explicit FakeTransportChannel(Transport* transport, 47 const std::string& name, 48 int component) 49 : TransportChannelImpl(name, component), 50 transport_(transport), 51 dtls_fingerprint_("", nullptr, 0) {} 52 ~FakeTransportChannel() { Reset(); } 53 54 uint64 IceTiebreaker() const { return tiebreaker_; } 55 IceMode remote_ice_mode() const { return remote_ice_mode_; } 56 const std::string& ice_ufrag() const { return ice_ufrag_; } 57 const std::string& ice_pwd() const { return ice_pwd_; } 58 const std::string& remote_ice_ufrag() const { return remote_ice_ufrag_; } 59 const std::string& remote_ice_pwd() const { return remote_ice_pwd_; } 60 const rtc::SSLFingerprint& dtls_fingerprint() const { 61 return dtls_fingerprint_; 62 } 63 64 // If async, will send packets by "Post"-ing to message queue instead of 65 // synchronously "Send"-ing. 66 void SetAsync(bool async) { async_ = async; } 67 68 Transport* GetTransport() override { return transport_; } 69 70 TransportChannelState GetState() const override { 71 if (connection_count_ == 0) { 72 return had_connection_ ? TransportChannelState::STATE_FAILED 73 : TransportChannelState::STATE_INIT; 74 } 75 76 if (connection_count_ == 1) { 77 return TransportChannelState::STATE_COMPLETED; 78 } 79 80 return TransportChannelState::STATE_CONNECTING; 81 } 82 83 void SetIceRole(IceRole role) override { role_ = role; } 84 IceRole GetIceRole() const override { return role_; } 85 void SetIceTiebreaker(uint64 tiebreaker) override { 86 tiebreaker_ = tiebreaker; 87 } 88 void SetIceCredentials(const std::string& ice_ufrag, 89 const std::string& ice_pwd) override { 90 ice_ufrag_ = ice_ufrag; 91 ice_pwd_ = ice_pwd; 92 } 93 void SetRemoteIceCredentials(const std::string& ice_ufrag, 94 const std::string& ice_pwd) override { 95 remote_ice_ufrag_ = ice_ufrag; 96 remote_ice_pwd_ = ice_pwd; 97 } 98 99 void SetRemoteIceMode(IceMode mode) override { remote_ice_mode_ = mode; } 100 bool SetRemoteFingerprint(const std::string& alg, 101 const uint8* digest, 102 size_t digest_len) override { 103 dtls_fingerprint_ = rtc::SSLFingerprint(alg, digest, digest_len); 104 return true; 105 } 106 bool SetSslRole(rtc::SSLRole role) override { 107 ssl_role_ = role; 108 return true; 109 } 110 bool GetSslRole(rtc::SSLRole* role) const override { 111 *role = ssl_role_; 112 return true; 113 } 114 115 void Connect() override { 116 if (state_ == STATE_INIT) { 117 state_ = STATE_CONNECTING; 118 } 119 } 120 121 void MaybeStartGathering() override { 122 if (gathering_state_ == kIceGatheringNew) { 123 gathering_state_ = kIceGatheringGathering; 124 SignalGatheringState(this); 125 } 126 } 127 128 IceGatheringState gathering_state() const override { 129 return gathering_state_; 130 } 131 132 void Reset() { 133 if (state_ != STATE_INIT) { 134 state_ = STATE_INIT; 135 if (dest_) { 136 dest_->state_ = STATE_INIT; 137 dest_->dest_ = nullptr; 138 dest_ = nullptr; 139 } 140 } 141 } 142 143 void SetWritable(bool writable) { set_writable(writable); } 144 145 void SetDestination(FakeTransportChannel* dest) { 146 if (state_ == STATE_CONNECTING && dest) { 147 // This simulates the delivery of candidates. 148 dest_ = dest; 149 dest_->dest_ = this; 150 if (local_cert_ && dest_->local_cert_) { 151 do_dtls_ = true; 152 dest_->do_dtls_ = true; 153 NegotiateSrtpCiphers(); 154 } 155 state_ = STATE_CONNECTED; 156 dest_->state_ = STATE_CONNECTED; 157 set_writable(true); 158 dest_->set_writable(true); 159 } else if (state_ == STATE_CONNECTED && !dest) { 160 // Simulates loss of connectivity, by asymmetrically forgetting dest_. 161 dest_ = nullptr; 162 state_ = STATE_CONNECTING; 163 set_writable(false); 164 } 165 } 166 167 void SetConnectionCount(size_t connection_count) { 168 size_t old_connection_count = connection_count_; 169 connection_count_ = connection_count; 170 if (connection_count) 171 had_connection_ = true; 172 if (connection_count_ < old_connection_count) 173 SignalConnectionRemoved(this); 174 } 175 176 void SetCandidatesGatheringComplete() { 177 if (gathering_state_ != kIceGatheringComplete) { 178 gathering_state_ = kIceGatheringComplete; 179 SignalGatheringState(this); 180 } 181 } 182 183 void SetReceiving(bool receiving) { set_receiving(receiving); } 184 185 void SetIceConfig(const IceConfig& config) override { 186 receiving_timeout_ = config.receiving_timeout_ms; 187 gather_continually_ = config.gather_continually; 188 } 189 190 int receiving_timeout() const { return receiving_timeout_; } 191 bool gather_continually() const { return gather_continually_; } 192 193 int SendPacket(const char* data, 194 size_t len, 195 const rtc::PacketOptions& options, 196 int flags) override { 197 if (state_ != STATE_CONNECTED) { 198 return -1; 199 } 200 201 if (flags != PF_SRTP_BYPASS && flags != 0) { 202 return -1; 203 } 204 205 PacketMessageData* packet = new PacketMessageData(data, len); 206 if (async_) { 207 rtc::Thread::Current()->Post(this, 0, packet); 208 } else { 209 rtc::Thread::Current()->Send(this, 0, packet); 210 } 211 return static_cast<int>(len); 212 } 213 int SetOption(rtc::Socket::Option opt, int value) override { return true; } 214 bool GetOption(rtc::Socket::Option opt, int* value) override { return true; } 215 int GetError() override { return 0; } 216 217 void AddRemoteCandidate(const Candidate& candidate) override { 218 remote_candidates_.push_back(candidate); 219 } 220 const Candidates& remote_candidates() const { return remote_candidates_; } 221 222 void OnMessage(rtc::Message* msg) override { 223 PacketMessageData* data = static_cast<PacketMessageData*>(msg->pdata); 224 dest_->SignalReadPacket(dest_, data->packet.data<char>(), 225 data->packet.size(), rtc::CreatePacketTime(0), 0); 226 delete data; 227 } 228 229 bool SetLocalCertificate( 230 const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) { 231 local_cert_ = certificate; 232 return true; 233 } 234 235 void SetRemoteSSLCertificate(rtc::FakeSSLCertificate* cert) { 236 remote_cert_ = cert; 237 } 238 239 bool IsDtlsActive() const override { return do_dtls_; } 240 241 bool SetSrtpCiphers(const std::vector<std::string>& ciphers) override { 242 srtp_ciphers_ = ciphers; 243 return true; 244 } 245 246 bool GetSrtpCipher(std::string* cipher) override { 247 if (!chosen_srtp_cipher_.empty()) { 248 *cipher = chosen_srtp_cipher_; 249 return true; 250 } 251 return false; 252 } 253 254 bool GetSslCipher(std::string* cipher) override { return false; } 255 256 rtc::scoped_refptr<rtc::RTCCertificate> GetLocalCertificate() const { 257 return local_cert_; 258 } 259 260 bool GetRemoteSSLCertificate(rtc::SSLCertificate** cert) const override { 261 if (!remote_cert_) 262 return false; 263 264 *cert = remote_cert_->GetReference(); 265 return true; 266 } 267 268 bool ExportKeyingMaterial(const std::string& label, 269 const uint8* context, 270 size_t context_len, 271 bool use_context, 272 uint8* result, 273 size_t result_len) override { 274 if (!chosen_srtp_cipher_.empty()) { 275 memset(result, 0xff, result_len); 276 return true; 277 } 278 279 return false; 280 } 281 282 void NegotiateSrtpCiphers() { 283 for (std::vector<std::string>::const_iterator it1 = srtp_ciphers_.begin(); 284 it1 != srtp_ciphers_.end(); ++it1) { 285 for (std::vector<std::string>::const_iterator it2 = 286 dest_->srtp_ciphers_.begin(); 287 it2 != dest_->srtp_ciphers_.end(); ++it2) { 288 if (*it1 == *it2) { 289 chosen_srtp_cipher_ = *it1; 290 dest_->chosen_srtp_cipher_ = *it2; 291 return; 292 } 293 } 294 } 295 } 296 297 bool GetStats(ConnectionInfos* infos) override { 298 ConnectionInfo info; 299 infos->clear(); 300 infos->push_back(info); 301 return true; 302 } 303 304 void set_ssl_max_protocol_version(rtc::SSLProtocolVersion version) { 305 ssl_max_version_ = version; 306 } 307 rtc::SSLProtocolVersion ssl_max_protocol_version() const { 308 return ssl_max_version_; 309 } 310 311 private: 312 enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED }; 313 Transport* transport_; 314 FakeTransportChannel* dest_ = nullptr; 315 State state_ = STATE_INIT; 316 bool async_ = false; 317 Candidates remote_candidates_; 318 rtc::scoped_refptr<rtc::RTCCertificate> local_cert_; 319 rtc::FakeSSLCertificate* remote_cert_ = nullptr; 320 bool do_dtls_ = false; 321 std::vector<std::string> srtp_ciphers_; 322 std::string chosen_srtp_cipher_; 323 int receiving_timeout_ = -1; 324 bool gather_continually_ = false; 325 IceRole role_ = ICEROLE_UNKNOWN; 326 uint64 tiebreaker_ = 0; 327 std::string ice_ufrag_; 328 std::string ice_pwd_; 329 std::string remote_ice_ufrag_; 330 std::string remote_ice_pwd_; 331 IceMode remote_ice_mode_ = ICEMODE_FULL; 332 rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_10; 333 rtc::SSLFingerprint dtls_fingerprint_; 334 rtc::SSLRole ssl_role_ = rtc::SSL_CLIENT; 335 size_t connection_count_ = 0; 336 IceGatheringState gathering_state_ = kIceGatheringNew; 337 bool had_connection_ = false; 338}; 339 340// Fake transport class, which can be passed to anything that needs a Transport. 341// Can be informed of another FakeTransport via SetDestination (low-tech way 342// of doing candidates) 343class FakeTransport : public Transport { 344 public: 345 typedef std::map<int, FakeTransportChannel*> ChannelMap; 346 347 explicit FakeTransport(const std::string& name) : Transport(name, nullptr) {} 348 349 // Note that we only have a constructor with the allocator parameter so it can 350 // be wrapped by a DtlsTransport. 351 FakeTransport(const std::string& name, PortAllocator* allocator) 352 : Transport(name, nullptr) {} 353 354 ~FakeTransport() { DestroyAllChannels(); } 355 356 const ChannelMap& channels() const { return channels_; } 357 358 // If async, will send packets by "Post"-ing to message queue instead of 359 // synchronously "Send"-ing. 360 void SetAsync(bool async) { async_ = async; } 361 void SetDestination(FakeTransport* dest) { 362 dest_ = dest; 363 for (const auto& kv : channels_) { 364 kv.second->SetLocalCertificate(certificate_); 365 SetChannelDestination(kv.first, kv.second); 366 } 367 } 368 369 void SetWritable(bool writable) { 370 for (const auto& kv : channels_) { 371 kv.second->SetWritable(writable); 372 } 373 } 374 375 void SetLocalCertificate( 376 const rtc::scoped_refptr<rtc::RTCCertificate>& certificate) override { 377 certificate_ = certificate; 378 } 379 bool GetLocalCertificate( 380 rtc::scoped_refptr<rtc::RTCCertificate>* certificate) override { 381 if (!certificate_) 382 return false; 383 384 *certificate = certificate_; 385 return true; 386 } 387 388 bool GetSslRole(rtc::SSLRole* role) const override { 389 if (channels_.empty()) { 390 return false; 391 } 392 return channels_.begin()->second->GetSslRole(role); 393 } 394 395 bool SetSslMaxProtocolVersion(rtc::SSLProtocolVersion version) override { 396 ssl_max_version_ = version; 397 for (const auto& kv : channels_) { 398 kv.second->set_ssl_max_protocol_version(ssl_max_version_); 399 } 400 return true; 401 } 402 rtc::SSLProtocolVersion ssl_max_protocol_version() const { 403 return ssl_max_version_; 404 } 405 406 using Transport::local_description; 407 using Transport::remote_description; 408 409 protected: 410 TransportChannelImpl* CreateTransportChannel(int component) override { 411 if (channels_.find(component) != channels_.end()) { 412 return nullptr; 413 } 414 FakeTransportChannel* channel = 415 new FakeTransportChannel(this, name(), component); 416 channel->set_ssl_max_protocol_version(ssl_max_version_); 417 channel->SetAsync(async_); 418 SetChannelDestination(component, channel); 419 channels_[component] = channel; 420 return channel; 421 } 422 423 void DestroyTransportChannel(TransportChannelImpl* channel) override { 424 channels_.erase(channel->component()); 425 delete channel; 426 } 427 428 private: 429 FakeTransportChannel* GetFakeChannel(int component) { 430 auto it = channels_.find(component); 431 return (it != channels_.end()) ? it->second : nullptr; 432 } 433 434 void SetChannelDestination(int component, FakeTransportChannel* channel) { 435 FakeTransportChannel* dest_channel = nullptr; 436 if (dest_) { 437 dest_channel = dest_->GetFakeChannel(component); 438 if (dest_channel) { 439 dest_channel->SetLocalCertificate(dest_->certificate_); 440 } 441 } 442 channel->SetDestination(dest_channel); 443 } 444 445 // Note, this is distinct from the Channel map owned by Transport. 446 // This map just tracks the FakeTransportChannels created by this class. 447 // It's mainly needed so that we can access a FakeTransportChannel directly, 448 // even if wrapped by a DtlsTransportChannelWrapper. 449 ChannelMap channels_; 450 FakeTransport* dest_ = nullptr; 451 bool async_ = false; 452 rtc::scoped_refptr<rtc::RTCCertificate> certificate_; 453 rtc::SSLProtocolVersion ssl_max_version_ = rtc::SSL_PROTOCOL_DTLS_10; 454}; 455 456// Fake TransportController class, which can be passed into a BaseChannel object 457// for test purposes. Can be connected to other FakeTransportControllers via 458// Connect(). 459// 460// This fake is unusual in that for the most part, it's implemented with the 461// real TransportController code, but with fake TransportChannels underneath. 462class FakeTransportController : public TransportController { 463 public: 464 FakeTransportController() 465 : TransportController(rtc::Thread::Current(), 466 rtc::Thread::Current(), 467 nullptr), 468 fail_create_channel_(false) {} 469 470 explicit FakeTransportController(IceRole role) 471 : TransportController(rtc::Thread::Current(), 472 rtc::Thread::Current(), 473 nullptr), 474 fail_create_channel_(false) { 475 SetIceRole(role); 476 } 477 478 explicit FakeTransportController(rtc::Thread* worker_thread) 479 : TransportController(rtc::Thread::Current(), worker_thread, nullptr), 480 fail_create_channel_(false) {} 481 482 FakeTransportController(rtc::Thread* worker_thread, IceRole role) 483 : TransportController(rtc::Thread::Current(), worker_thread, nullptr), 484 fail_create_channel_(false) { 485 SetIceRole(role); 486 } 487 488 FakeTransport* GetTransport_w(const std::string& transport_name) { 489 return static_cast<FakeTransport*>( 490 TransportController::GetTransport_w(transport_name)); 491 } 492 493 void Connect(FakeTransportController* dest) { 494 worker_thread()->Invoke<void>( 495 rtc::Bind(&FakeTransportController::Connect_w, this, dest)); 496 } 497 498 TransportChannel* CreateTransportChannel_w(const std::string& transport_name, 499 int component) override { 500 if (fail_create_channel_) { 501 return nullptr; 502 } 503 return TransportController::CreateTransportChannel_w(transport_name, 504 component); 505 } 506 507 void set_fail_channel_creation(bool fail_channel_creation) { 508 fail_create_channel_ = fail_channel_creation; 509 } 510 511 protected: 512 Transport* CreateTransport_w(const std::string& transport_name) override { 513 return new FakeTransport(transport_name); 514 } 515 516 void Connect_w(FakeTransportController* dest) { 517 // Simulate the exchange of candidates. 518 ConnectChannels_w(); 519 dest->ConnectChannels_w(); 520 for (auto& kv : transports()) { 521 FakeTransport* transport = static_cast<FakeTransport*>(kv.second); 522 transport->SetDestination(dest->GetTransport_w(kv.first)); 523 } 524 } 525 526 void ConnectChannels_w() { 527 for (auto& kv : transports()) { 528 FakeTransport* transport = static_cast<FakeTransport*>(kv.second); 529 transport->ConnectChannels(); 530 transport->MaybeStartGathering(); 531 } 532 } 533 534 private: 535 bool fail_create_channel_; 536}; 537 538} // namespace cricket 539 540#endif // WEBRTC_P2P_BASE_FAKETRANSPORTCONTROLLER_H_ 541