1/*
2 * libjingle
3 * Copyright 2009, Google, Inc.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 *  1. Redistributions of source code must retain the above copyright notice,
9 *     this list of conditions and the following disclaimer.
10 *  2. Redistributions in binary form must reproduce the above copyright notice,
11 *     this list of conditions and the following disclaimer in the documentation
12 *     and/or other materials provided with the distribution.
13 *  3. The name of the author may not be used to endorse or promote products
14 *     derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28#ifndef TALK_P2P_BASE_FAKESESSION_H_
29#define TALK_P2P_BASE_FAKESESSION_H_
30
31#include <map>
32#include <string>
33#include <vector>
34
35#include "talk/p2p/base/session.h"
36#include "talk/p2p/base/transport.h"
37#include "talk/p2p/base/transportchannel.h"
38#include "talk/p2p/base/transportchannelimpl.h"
39#include "webrtc/base/buffer.h"
40#include "webrtc/base/fakesslidentity.h"
41#include "webrtc/base/messagequeue.h"
42#include "webrtc/base/sigslot.h"
43#include "webrtc/base/sslfingerprint.h"
44
45namespace cricket {
46
47class FakeTransport;
48
49struct PacketMessageData : public rtc::MessageData {
50  PacketMessageData(const char* data, size_t len) : packet(data, len) {
51  }
52  rtc::Buffer packet;
53};
54
55// Fake transport channel class, which can be passed to anything that needs a
56// transport channel. Can be informed of another FakeTransportChannel via
57// SetDestination.
58class FakeTransportChannel : public TransportChannelImpl,
59                             public rtc::MessageHandler {
60 public:
61  explicit FakeTransportChannel(Transport* transport,
62                                const std::string& content_name,
63                                int component)
64      : TransportChannelImpl(content_name, component),
65        transport_(transport),
66        dest_(NULL),
67        state_(STATE_INIT),
68        async_(false),
69        identity_(NULL),
70        do_dtls_(false),
71        role_(ICEROLE_UNKNOWN),
72        tiebreaker_(0),
73        ice_proto_(ICEPROTO_HYBRID),
74        remote_ice_mode_(ICEMODE_FULL),
75        dtls_fingerprint_("", NULL, 0),
76        ssl_role_(rtc::SSL_CLIENT),
77        connection_count_(0) {
78  }
79  ~FakeTransportChannel() {
80    Reset();
81  }
82
83  uint64 IceTiebreaker() const { return tiebreaker_; }
84  TransportProtocol protocol() const { return ice_proto_; }
85  IceMode remote_ice_mode() const { return remote_ice_mode_; }
86  const std::string& ice_ufrag() const { return ice_ufrag_; }
87  const std::string& ice_pwd() const { return ice_pwd_; }
88  const std::string& remote_ice_ufrag() const { return remote_ice_ufrag_; }
89  const std::string& remote_ice_pwd() const { return remote_ice_pwd_; }
90  const rtc::SSLFingerprint& dtls_fingerprint() const {
91    return dtls_fingerprint_;
92  }
93
94  void SetAsync(bool async) {
95    async_ = async;
96  }
97
98  virtual Transport* GetTransport() {
99    return transport_;
100  }
101
102  virtual void SetIceRole(IceRole role) { role_ = role; }
103  virtual IceRole GetIceRole() const { return role_; }
104  virtual size_t GetConnectionCount() const { return connection_count_; }
105  virtual void SetIceTiebreaker(uint64 tiebreaker) { tiebreaker_ = tiebreaker; }
106  virtual bool GetIceProtocolType(IceProtocolType* type) const {
107    *type = ice_proto_;
108    return true;
109  }
110  virtual void SetIceProtocolType(IceProtocolType type) { ice_proto_ = type; }
111  virtual void SetIceCredentials(const std::string& ice_ufrag,
112                                 const std::string& ice_pwd) {
113    ice_ufrag_ = ice_ufrag;
114    ice_pwd_ = ice_pwd;
115  }
116  virtual void SetRemoteIceCredentials(const std::string& ice_ufrag,
117                                       const std::string& ice_pwd) {
118    remote_ice_ufrag_ = ice_ufrag;
119    remote_ice_pwd_ = ice_pwd;
120  }
121
122  virtual void SetRemoteIceMode(IceMode mode) { remote_ice_mode_ = mode; }
123  virtual bool SetRemoteFingerprint(const std::string& alg, const uint8* digest,
124                                    size_t digest_len) {
125    dtls_fingerprint_ = rtc::SSLFingerprint(alg, digest, digest_len);
126    return true;
127  }
128  virtual bool SetSslRole(rtc::SSLRole role) {
129    ssl_role_ = role;
130    return true;
131  }
132  virtual bool GetSslRole(rtc::SSLRole* role) const {
133    *role = ssl_role_;
134    return true;
135  }
136
137  virtual void Connect() {
138    if (state_ == STATE_INIT) {
139      state_ = STATE_CONNECTING;
140    }
141  }
142  virtual void Reset() {
143    if (state_ != STATE_INIT) {
144      state_ = STATE_INIT;
145      if (dest_) {
146        dest_->state_ = STATE_INIT;
147        dest_->dest_ = NULL;
148        dest_ = NULL;
149      }
150    }
151  }
152
153  void SetWritable(bool writable) {
154    set_writable(writable);
155  }
156
157  void SetDestination(FakeTransportChannel* dest) {
158    if (state_ == STATE_CONNECTING && dest) {
159      // This simulates the delivery of candidates.
160      dest_ = dest;
161      dest_->dest_ = this;
162      if (identity_ && dest_->identity_) {
163        do_dtls_ = true;
164        dest_->do_dtls_ = true;
165        NegotiateSrtpCiphers();
166      }
167      state_ = STATE_CONNECTED;
168      dest_->state_ = STATE_CONNECTED;
169      set_writable(true);
170      dest_->set_writable(true);
171    } else if (state_ == STATE_CONNECTED && !dest) {
172      // Simulates loss of connectivity, by asymmetrically forgetting dest_.
173      dest_ = NULL;
174      state_ = STATE_CONNECTING;
175      set_writable(false);
176    }
177  }
178
179  void SetConnectionCount(size_t connection_count) {
180    size_t old_connection_count = connection_count_;
181    connection_count_ = connection_count;
182    if (connection_count_ < old_connection_count)
183      SignalConnectionRemoved(this);
184  }
185
186  virtual int SendPacket(const char* data, size_t len,
187                         const rtc::PacketOptions& options, int flags) {
188    if (state_ != STATE_CONNECTED) {
189      return -1;
190    }
191
192    if (flags != PF_SRTP_BYPASS && flags != 0) {
193      return -1;
194    }
195
196    PacketMessageData* packet = new PacketMessageData(data, len);
197    if (async_) {
198      rtc::Thread::Current()->Post(this, 0, packet);
199    } else {
200      rtc::Thread::Current()->Send(this, 0, packet);
201    }
202    return static_cast<int>(len);
203  }
204  virtual int SetOption(rtc::Socket::Option opt, int value) {
205    return true;
206  }
207  virtual int GetError() {
208    return 0;
209  }
210
211  virtual void OnSignalingReady() {
212  }
213  virtual void OnCandidate(const Candidate& candidate) {
214  }
215
216  virtual void OnMessage(rtc::Message* msg) {
217    PacketMessageData* data = static_cast<PacketMessageData*>(
218        msg->pdata);
219    dest_->SignalReadPacket(dest_, data->packet.data(),
220                            data->packet.length(),
221                            rtc::CreatePacketTime(0), 0);
222    delete data;
223  }
224
225  bool SetLocalIdentity(rtc::SSLIdentity* identity) {
226    identity_ = identity;
227    return true;
228  }
229
230
231  void SetRemoteCertificate(rtc::FakeSSLCertificate* cert) {
232    remote_cert_ = cert;
233  }
234
235  virtual bool IsDtlsActive() const {
236    return do_dtls_;
237  }
238
239  virtual bool SetSrtpCiphers(const std::vector<std::string>& ciphers) {
240    srtp_ciphers_ = ciphers;
241    return true;
242  }
243
244  virtual bool GetSrtpCipher(std::string* cipher) {
245    if (!chosen_srtp_cipher_.empty()) {
246      *cipher = chosen_srtp_cipher_;
247      return true;
248    }
249    return false;
250  }
251
252  virtual bool GetLocalIdentity(rtc::SSLIdentity** identity) const {
253    if (!identity_)
254      return false;
255
256    *identity = identity_->GetReference();
257    return true;
258  }
259
260  virtual bool GetRemoteCertificate(rtc::SSLCertificate** cert) const {
261    if (!remote_cert_)
262      return false;
263
264    *cert = remote_cert_->GetReference();
265    return true;
266  }
267
268  virtual bool ExportKeyingMaterial(const std::string& label,
269                                    const uint8* context,
270                                    size_t context_len,
271                                    bool use_context,
272                                    uint8* result,
273                                    size_t result_len) {
274    if (!chosen_srtp_cipher_.empty()) {
275      memset(result, 0xff, result_len);
276      return true;
277    }
278
279    return false;
280  }
281
282  virtual void NegotiateSrtpCiphers() {
283    for (std::vector<std::string>::const_iterator it1 = srtp_ciphers_.begin();
284        it1 != srtp_ciphers_.end(); ++it1) {
285      for (std::vector<std::string>::const_iterator it2 =
286              dest_->srtp_ciphers_.begin();
287          it2 != dest_->srtp_ciphers_.end(); ++it2) {
288        if (*it1 == *it2) {
289          chosen_srtp_cipher_ = *it1;
290          dest_->chosen_srtp_cipher_ = *it2;
291          return;
292        }
293      }
294    }
295  }
296
297  virtual bool GetStats(ConnectionInfos* infos) OVERRIDE {
298    ConnectionInfo info;
299    infos->clear();
300    infos->push_back(info);
301    return true;
302  }
303
304 private:
305  enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED };
306  Transport* transport_;
307  FakeTransportChannel* dest_;
308  State state_;
309  bool async_;
310  rtc::SSLIdentity* identity_;
311  rtc::FakeSSLCertificate* remote_cert_;
312  bool do_dtls_;
313  std::vector<std::string> srtp_ciphers_;
314  std::string chosen_srtp_cipher_;
315  IceRole role_;
316  uint64 tiebreaker_;
317  IceProtocolType ice_proto_;
318  std::string ice_ufrag_;
319  std::string ice_pwd_;
320  std::string remote_ice_ufrag_;
321  std::string remote_ice_pwd_;
322  IceMode remote_ice_mode_;
323  rtc::SSLFingerprint dtls_fingerprint_;
324  rtc::SSLRole ssl_role_;
325  size_t connection_count_;
326};
327
328// Fake transport class, which can be passed to anything that needs a Transport.
329// Can be informed of another FakeTransport via SetDestination (low-tech way
330// of doing candidates)
331class FakeTransport : public Transport {
332 public:
333  typedef std::map<int, FakeTransportChannel*> ChannelMap;
334  FakeTransport(rtc::Thread* signaling_thread,
335                rtc::Thread* worker_thread,
336                const std::string& content_name,
337                PortAllocator* alllocator = NULL)
338      : Transport(signaling_thread, worker_thread,
339                  content_name, "test_type", NULL),
340      dest_(NULL),
341      async_(false),
342      identity_(NULL) {
343  }
344  ~FakeTransport() {
345    DestroyAllChannels();
346  }
347
348  const ChannelMap& channels() const { return channels_; }
349
350  void SetAsync(bool async) { async_ = async; }
351  void SetDestination(FakeTransport* dest) {
352    dest_ = dest;
353    for (ChannelMap::iterator it = channels_.begin(); it != channels_.end();
354         ++it) {
355      it->second->SetLocalIdentity(identity_);
356      SetChannelDestination(it->first, it->second);
357    }
358  }
359
360  void SetWritable(bool writable) {
361    for (ChannelMap::iterator it = channels_.begin(); it != channels_.end();
362         ++it) {
363      it->second->SetWritable(writable);
364    }
365  }
366
367  void set_identity(rtc::SSLIdentity* identity) {
368    identity_ = identity;
369  }
370
371  using Transport::local_description;
372  using Transport::remote_description;
373
374 protected:
375  virtual TransportChannelImpl* CreateTransportChannel(int component) {
376    if (channels_.find(component) != channels_.end()) {
377      return NULL;
378    }
379    FakeTransportChannel* channel =
380        new FakeTransportChannel(this, content_name(), component);
381    channel->SetAsync(async_);
382    SetChannelDestination(component, channel);
383    channels_[component] = channel;
384    return channel;
385  }
386  virtual void DestroyTransportChannel(TransportChannelImpl* channel) {
387    channels_.erase(channel->component());
388    delete channel;
389  }
390  virtual void SetIdentity_w(rtc::SSLIdentity* identity) {
391    identity_ = identity;
392  }
393  virtual bool GetIdentity_w(rtc::SSLIdentity** identity) {
394    if (!identity_)
395      return false;
396
397    *identity = identity_->GetReference();
398    return true;
399  }
400
401 private:
402  FakeTransportChannel* GetFakeChannel(int component) {
403    ChannelMap::iterator it = channels_.find(component);
404    return (it != channels_.end()) ? it->second : NULL;
405  }
406  void SetChannelDestination(int component,
407                             FakeTransportChannel* channel) {
408    FakeTransportChannel* dest_channel = NULL;
409    if (dest_) {
410      dest_channel = dest_->GetFakeChannel(component);
411      if (dest_channel) {
412        dest_channel->SetLocalIdentity(dest_->identity_);
413      }
414    }
415    channel->SetDestination(dest_channel);
416  }
417
418  // Note, this is distinct from the Channel map owned by Transport.
419  // This map just tracks the FakeTransportChannels created by this class.
420  ChannelMap channels_;
421  FakeTransport* dest_;
422  bool async_;
423  rtc::SSLIdentity* identity_;
424};
425
426// Fake session class, which can be passed into a BaseChannel object for
427// test purposes. Can be connected to other FakeSessions via Connect().
428class FakeSession : public BaseSession {
429 public:
430  explicit FakeSession()
431      : BaseSession(rtc::Thread::Current(),
432                    rtc::Thread::Current(),
433                    NULL, "", "", true),
434      fail_create_channel_(false) {
435  }
436  explicit FakeSession(bool initiator)
437      : BaseSession(rtc::Thread::Current(),
438                    rtc::Thread::Current(),
439                    NULL, "", "", initiator),
440      fail_create_channel_(false) {
441  }
442  FakeSession(rtc::Thread* worker_thread, bool initiator)
443      : BaseSession(rtc::Thread::Current(),
444                    worker_thread,
445                    NULL, "", "", initiator),
446      fail_create_channel_(false) {
447  }
448
449  FakeTransport* GetTransport(const std::string& content_name) {
450    return static_cast<FakeTransport*>(
451        BaseSession::GetTransport(content_name));
452  }
453
454  void Connect(FakeSession* dest) {
455    // Simulate the exchange of candidates.
456    CompleteNegotiation();
457    dest->CompleteNegotiation();
458    for (TransportMap::const_iterator it = transport_proxies().begin();
459        it != transport_proxies().end(); ++it) {
460      static_cast<FakeTransport*>(it->second->impl())->SetDestination(
461          dest->GetTransport(it->first));
462    }
463  }
464
465  virtual TransportChannel* CreateChannel(
466      const std::string& content_name,
467      const std::string& channel_name,
468      int component) {
469    if (fail_create_channel_) {
470      return NULL;
471    }
472    return BaseSession::CreateChannel(content_name, channel_name, component);
473  }
474
475  void set_fail_channel_creation(bool fail_channel_creation) {
476    fail_create_channel_ = fail_channel_creation;
477  }
478
479  // TODO: Hoist this into Session when we re-work the Session code.
480  void set_ssl_identity(rtc::SSLIdentity* identity) {
481    for (TransportMap::const_iterator it = transport_proxies().begin();
482        it != transport_proxies().end(); ++it) {
483      // We know that we have a FakeTransport*
484
485      static_cast<FakeTransport*>(it->second->impl())->set_identity
486          (identity);
487    }
488  }
489
490 protected:
491  virtual Transport* CreateTransport(const std::string& content_name) {
492    return new FakeTransport(signaling_thread(), worker_thread(), content_name);
493  }
494
495  void CompleteNegotiation() {
496    for (TransportMap::const_iterator it = transport_proxies().begin();
497        it != transport_proxies().end(); ++it) {
498      it->second->CompleteNegotiation();
499      it->second->ConnectChannels();
500    }
501  }
502
503 private:
504  bool fail_create_channel_;
505};
506
507}  // namespace cricket
508
509#endif  // TALK_P2P_BASE_FAKESESSION_H_
510