1/*
2 *  Copyright 2014 The WebRTC Project Authors. All rights reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include <string>
12
13#include "webrtc/base/gunit.h"
14#include "webrtc/base/ipaddress.h"
15#include "webrtc/base/socketstream.h"
16#include "webrtc/base/ssladapter.h"
17#include "webrtc/base/sslstreamadapter.h"
18#include "webrtc/base/sslidentity.h"
19#include "webrtc/base/stream.h"
20#include "webrtc/base/virtualsocketserver.h"
21
22static const int kTimeout = 5000;
23
24static rtc::AsyncSocket* CreateSocket(const rtc::SSLMode& ssl_mode) {
25  rtc::SocketAddress address(rtc::IPAddress(INADDR_ANY), 0);
26
27  rtc::AsyncSocket* socket = rtc::Thread::Current()->
28      socketserver()->CreateAsyncSocket(
29      address.family(), (ssl_mode == rtc::SSL_MODE_DTLS) ?
30      SOCK_DGRAM : SOCK_STREAM);
31  socket->Bind(address);
32
33  return socket;
34}
35
36static std::string GetSSLProtocolName(const rtc::SSLMode& ssl_mode) {
37  return (ssl_mode == rtc::SSL_MODE_DTLS) ? "DTLS" : "TLS";
38}
39
40class SSLAdapterTestDummyClient : public sigslot::has_slots<> {
41 public:
42  explicit SSLAdapterTestDummyClient(const rtc::SSLMode& ssl_mode)
43      : ssl_mode_(ssl_mode) {
44    rtc::AsyncSocket* socket = CreateSocket(ssl_mode_);
45
46    ssl_adapter_.reset(rtc::SSLAdapter::Create(socket));
47
48    ssl_adapter_->SetMode(ssl_mode_);
49
50    // Ignore any certificate errors for the purpose of testing.
51    // Note: We do this only because we don't have a real certificate.
52    // NEVER USE THIS IN PRODUCTION CODE!
53    ssl_adapter_->set_ignore_bad_cert(true);
54
55    ssl_adapter_->SignalReadEvent.connect(this,
56        &SSLAdapterTestDummyClient::OnSSLAdapterReadEvent);
57    ssl_adapter_->SignalCloseEvent.connect(this,
58        &SSLAdapterTestDummyClient::OnSSLAdapterCloseEvent);
59  }
60
61  rtc::SocketAddress GetAddress() const {
62    return ssl_adapter_->GetLocalAddress();
63  }
64
65  rtc::AsyncSocket::ConnState GetState() const {
66    return ssl_adapter_->GetState();
67  }
68
69  const std::string& GetReceivedData() const {
70    return data_;
71  }
72
73  int Connect(const std::string& hostname, const rtc::SocketAddress& address) {
74    LOG(LS_INFO) << "Initiating connection with " << address;
75
76    int rv = ssl_adapter_->Connect(address);
77
78    if (rv == 0) {
79      LOG(LS_INFO) << "Starting " << GetSSLProtocolName(ssl_mode_)
80          << " handshake with " << hostname;
81
82      if (ssl_adapter_->StartSSL(hostname.c_str(), false) != 0) {
83        return -1;
84      }
85    }
86
87    return rv;
88  }
89
90  int Close() {
91    return ssl_adapter_->Close();
92  }
93
94  int Send(const std::string& message) {
95    LOG(LS_INFO) << "Client sending '" << message << "'";
96
97    return ssl_adapter_->Send(message.data(), message.length());
98  }
99
100  void OnSSLAdapterReadEvent(rtc::AsyncSocket* socket) {
101    char buffer[4096] = "";
102
103    // Read data received from the server and store it in our internal buffer.
104    int read = socket->Recv(buffer, sizeof(buffer) - 1);
105    if (read != -1) {
106      buffer[read] = '\0';
107
108      LOG(LS_INFO) << "Client received '" << buffer << "'";
109
110      data_ += buffer;
111    }
112  }
113
114  void OnSSLAdapterCloseEvent(rtc::AsyncSocket* socket, int error) {
115    // OpenSSLAdapter signals handshake failure with a close event, but without
116    // closing the socket! Let's close the socket here. This way GetState() can
117    // return CS_CLOSED after failure.
118    if (socket->GetState() != rtc::AsyncSocket::CS_CLOSED) {
119      socket->Close();
120    }
121  }
122
123 private:
124  const rtc::SSLMode ssl_mode_;
125
126  rtc::scoped_ptr<rtc::SSLAdapter> ssl_adapter_;
127
128  std::string data_;
129};
130
131class SSLAdapterTestDummyServer : public sigslot::has_slots<> {
132 public:
133  explicit SSLAdapterTestDummyServer(const rtc::SSLMode& ssl_mode,
134                                     const rtc::KeyParams& key_params)
135      : ssl_mode_(ssl_mode) {
136    // Generate a key pair and a certificate for this host.
137    ssl_identity_.reset(rtc::SSLIdentity::Generate(GetHostname(), key_params));
138
139    server_socket_.reset(CreateSocket(ssl_mode_));
140
141    if (ssl_mode_ == rtc::SSL_MODE_TLS) {
142      server_socket_->SignalReadEvent.connect(this,
143          &SSLAdapterTestDummyServer::OnServerSocketReadEvent);
144
145      server_socket_->Listen(1);
146    }
147
148    LOG(LS_INFO) << ((ssl_mode_ == rtc::SSL_MODE_DTLS) ? "UDP" : "TCP")
149        << " server listening on " << server_socket_->GetLocalAddress();
150  }
151
152  rtc::SocketAddress GetAddress() const {
153    return server_socket_->GetLocalAddress();
154  }
155
156  std::string GetHostname() const {
157    // Since we don't have a real certificate anyway, the value here doesn't
158    // really matter.
159    return "example.com";
160  }
161
162  const std::string& GetReceivedData() const {
163    return data_;
164  }
165
166  int Send(const std::string& message) {
167    if (ssl_stream_adapter_ == NULL
168        || ssl_stream_adapter_->GetState() != rtc::SS_OPEN) {
169      // No connection yet.
170      return -1;
171    }
172
173    LOG(LS_INFO) << "Server sending '" << message << "'";
174
175    size_t written;
176    int error;
177
178    rtc::StreamResult r = ssl_stream_adapter_->Write(message.data(),
179        message.length(), &written, &error);
180    if (r == rtc::SR_SUCCESS) {
181      return written;
182    } else {
183      return -1;
184    }
185  }
186
187  void AcceptConnection(const rtc::SocketAddress& address) {
188    // Only a single connection is supported.
189    ASSERT_TRUE(ssl_stream_adapter_ == NULL);
190
191    // This is only for DTLS.
192    ASSERT_EQ(rtc::SSL_MODE_DTLS, ssl_mode_);
193
194    // Transfer ownership of the socket to the SSLStreamAdapter object.
195    rtc::AsyncSocket* socket = server_socket_.release();
196
197    socket->Connect(address);
198
199    DoHandshake(socket);
200  }
201
202  void OnServerSocketReadEvent(rtc::AsyncSocket* socket) {
203    // Only a single connection is supported.
204    ASSERT_TRUE(ssl_stream_adapter_ == NULL);
205
206    DoHandshake(server_socket_->Accept(NULL));
207  }
208
209  void OnSSLStreamAdapterEvent(rtc::StreamInterface* stream, int sig, int err) {
210    if (sig & rtc::SE_READ) {
211      char buffer[4096] = "";
212
213      size_t read;
214      int error;
215
216      // Read data received from the client and store it in our internal
217      // buffer.
218      rtc::StreamResult r = stream->Read(buffer,
219          sizeof(buffer) - 1, &read, &error);
220      if (r == rtc::SR_SUCCESS) {
221        buffer[read] = '\0';
222
223        LOG(LS_INFO) << "Server received '" << buffer << "'";
224
225        data_ += buffer;
226      }
227    }
228  }
229
230 private:
231  void DoHandshake(rtc::AsyncSocket* socket) {
232    rtc::SocketStream* stream = new rtc::SocketStream(socket);
233
234    ssl_stream_adapter_.reset(rtc::SSLStreamAdapter::Create(stream));
235
236    ssl_stream_adapter_->SetMode(ssl_mode_);
237    ssl_stream_adapter_->SetServerRole();
238
239    // SSLStreamAdapter is normally used for peer-to-peer communication, but
240    // here we're testing communication between a client and a server
241    // (e.g. a WebRTC-based application and an RFC 5766 TURN server), where
242    // clients are not required to provide a certificate during handshake.
243    // Accordingly, we must disable client authentication here.
244    ssl_stream_adapter_->set_client_auth_enabled(false);
245
246    ssl_stream_adapter_->SetIdentity(ssl_identity_->GetReference());
247
248    // Set a bogus peer certificate digest.
249    unsigned char digest[20];
250    size_t digest_len = sizeof(digest);
251    ssl_stream_adapter_->SetPeerCertificateDigest(rtc::DIGEST_SHA_1, digest,
252        digest_len);
253
254    ssl_stream_adapter_->StartSSLWithPeer();
255
256    ssl_stream_adapter_->SignalEvent.connect(this,
257        &SSLAdapterTestDummyServer::OnSSLStreamAdapterEvent);
258  }
259
260  const rtc::SSLMode ssl_mode_;
261
262  rtc::scoped_ptr<rtc::AsyncSocket> server_socket_;
263  rtc::scoped_ptr<rtc::SSLStreamAdapter> ssl_stream_adapter_;
264
265  rtc::scoped_ptr<rtc::SSLIdentity> ssl_identity_;
266
267  std::string data_;
268};
269
270class SSLAdapterTestBase : public testing::Test,
271                           public sigslot::has_slots<> {
272 public:
273  explicit SSLAdapterTestBase(const rtc::SSLMode& ssl_mode,
274                              const rtc::KeyParams& key_params)
275      : ssl_mode_(ssl_mode),
276        ss_scope_(new rtc::VirtualSocketServer(NULL)),
277        server_(new SSLAdapterTestDummyServer(ssl_mode_, key_params)),
278        client_(new SSLAdapterTestDummyClient(ssl_mode_)),
279        handshake_wait_(kTimeout) {}
280
281  void SetHandshakeWait(int wait) {
282    handshake_wait_ = wait;
283  }
284
285  void TestHandshake(bool expect_success) {
286    int rv;
287
288    // The initial state is CS_CLOSED
289    ASSERT_EQ(rtc::AsyncSocket::CS_CLOSED, client_->GetState());
290
291    rv = client_->Connect(server_->GetHostname(), server_->GetAddress());
292    ASSERT_EQ(0, rv);
293
294    // Now the state should be CS_CONNECTING
295    ASSERT_EQ(rtc::AsyncSocket::CS_CONNECTING, client_->GetState());
296
297    if (ssl_mode_ == rtc::SSL_MODE_DTLS) {
298      // For DTLS, call AcceptConnection() with the client's address.
299      server_->AcceptConnection(client_->GetAddress());
300    }
301
302    if (expect_success) {
303      // If expecting success, the client should end up in the CS_CONNECTED
304      // state after handshake.
305      EXPECT_EQ_WAIT(rtc::AsyncSocket::CS_CONNECTED, client_->GetState(),
306          handshake_wait_);
307
308      LOG(LS_INFO) << GetSSLProtocolName(ssl_mode_) << " handshake complete.";
309
310    } else {
311      // On handshake failure the client should end up in the CS_CLOSED state.
312      EXPECT_EQ_WAIT(rtc::AsyncSocket::CS_CLOSED, client_->GetState(),
313          handshake_wait_);
314
315      LOG(LS_INFO) << GetSSLProtocolName(ssl_mode_) << " handshake failed.";
316    }
317  }
318
319  void TestTransfer(const std::string& message) {
320    int rv;
321
322    rv = client_->Send(message);
323    ASSERT_EQ(static_cast<int>(message.length()), rv);
324
325    // The server should have received the client's message.
326    EXPECT_EQ_WAIT(message, server_->GetReceivedData(), kTimeout);
327
328    rv = server_->Send(message);
329    ASSERT_EQ(static_cast<int>(message.length()), rv);
330
331    // The client should have received the server's message.
332    EXPECT_EQ_WAIT(message, client_->GetReceivedData(), kTimeout);
333
334    LOG(LS_INFO) << "Transfer complete.";
335  }
336
337 private:
338  const rtc::SSLMode ssl_mode_;
339
340  const rtc::SocketServerScope ss_scope_;
341
342  rtc::scoped_ptr<SSLAdapterTestDummyServer> server_;
343  rtc::scoped_ptr<SSLAdapterTestDummyClient> client_;
344
345  int handshake_wait_;
346};
347
348class SSLAdapterTestTLS_RSA : public SSLAdapterTestBase {
349 public:
350  SSLAdapterTestTLS_RSA()
351      : SSLAdapterTestBase(rtc::SSL_MODE_TLS, rtc::KeyParams::RSA()) {}
352};
353
354class SSLAdapterTestTLS_ECDSA : public SSLAdapterTestBase {
355 public:
356  SSLAdapterTestTLS_ECDSA()
357      : SSLAdapterTestBase(rtc::SSL_MODE_TLS, rtc::KeyParams::ECDSA()) {}
358};
359
360class SSLAdapterTestDTLS_RSA : public SSLAdapterTestBase {
361 public:
362  SSLAdapterTestDTLS_RSA()
363      : SSLAdapterTestBase(rtc::SSL_MODE_DTLS, rtc::KeyParams::RSA()) {}
364};
365
366class SSLAdapterTestDTLS_ECDSA : public SSLAdapterTestBase {
367 public:
368  SSLAdapterTestDTLS_ECDSA()
369      : SSLAdapterTestBase(rtc::SSL_MODE_DTLS, rtc::KeyParams::ECDSA()) {}
370};
371
372#if SSL_USE_OPENSSL
373
374// Basic tests: TLS
375
376// Test that handshake works, using RSA
377TEST_F(SSLAdapterTestTLS_RSA, TestTLSConnect) {
378  TestHandshake(true);
379}
380
381// Test that handshake works, using ECDSA
382TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSConnect) {
383  TestHandshake(true);
384}
385
386// Test transfer between client and server, using RSA
387TEST_F(SSLAdapterTestTLS_RSA, TestTLSTransfer) {
388  TestHandshake(true);
389  TestTransfer("Hello, world!");
390}
391
392// Test transfer between client and server, using ECDSA
393TEST_F(SSLAdapterTestTLS_ECDSA, TestTLSTransfer) {
394  TestHandshake(true);
395  TestTransfer("Hello, world!");
396}
397
398// Basic tests: DTLS
399
400// Test that handshake works, using RSA
401TEST_F(SSLAdapterTestDTLS_RSA, TestDTLSConnect) {
402  TestHandshake(true);
403}
404
405// Test that handshake works, using ECDSA
406TEST_F(SSLAdapterTestDTLS_ECDSA, TestDTLSConnect) {
407  TestHandshake(true);
408}
409
410// Test transfer between client and server, using RSA
411TEST_F(SSLAdapterTestDTLS_RSA, TestDTLSTransfer) {
412  TestHandshake(true);
413  TestTransfer("Hello, world!");
414}
415
416// Test transfer between client and server, using ECDSA
417TEST_F(SSLAdapterTestDTLS_ECDSA, TestDTLSTransfer) {
418  TestHandshake(true);
419  TestTransfer("Hello, world!");
420}
421
422#endif  // SSL_USE_OPENSSL
423