1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// This test suite uses SSLClientSocket to test the implementation of
6// SSLServerSocket. In order to establish connections between the sockets
7// we need two additional classes:
8// 1. FakeSocket
9//    Connects SSL socket to FakeDataChannel. This class is just a stub.
10//
11// 2. FakeDataChannel
12//    Implements the actual exchange of data between two FakeSockets.
13//
14// Implementations of these two classes are included in this file.
15
16#include "net/socket/ssl_server_socket.h"
17
18#include <queue>
19
20#include "base/file_path.h"
21#include "base/file_util.h"
22#include "base/path_service.h"
23#include "crypto/nss_util.h"
24#include "crypto/rsa_private_key.h"
25#include "net/base/address_list.h"
26#include "net/base/cert_status_flags.h"
27#include "net/base/cert_verifier.h"
28#include "net/base/host_port_pair.h"
29#include "net/base/io_buffer.h"
30#include "net/base/ip_endpoint.h"
31#include "net/base/net_errors.h"
32#include "net/base/net_log.h"
33#include "net/base/ssl_config_service.h"
34#include "net/base/x509_certificate.h"
35#include "net/socket/client_socket.h"
36#include "net/socket/client_socket_factory.h"
37#include "net/socket/socket_test_util.h"
38#include "net/socket/ssl_client_socket.h"
39#include "testing/gtest/include/gtest/gtest.h"
40#include "testing/platform_test.h"
41
42namespace net {
43
44namespace {
45
46class FakeDataChannel {
47 public:
48  FakeDataChannel() : read_callback_(NULL), read_buf_len_(0) {
49  }
50
51  virtual int Read(IOBuffer* buf, int buf_len,
52                   CompletionCallback* callback) {
53    if (data_.empty()) {
54      read_callback_ = callback;
55      read_buf_ = buf;
56      read_buf_len_ = buf_len;
57      return net::ERR_IO_PENDING;
58    }
59    return PropogateData(buf, buf_len);
60  }
61
62  virtual int Write(IOBuffer* buf, int buf_len,
63                    CompletionCallback* callback) {
64    data_.push(new net::DrainableIOBuffer(buf, buf_len));
65    DoReadCallback();
66    return buf_len;
67  }
68
69 private:
70  void DoReadCallback() {
71    if (!read_callback_)
72      return;
73
74    int copied = PropogateData(read_buf_, read_buf_len_);
75    net::CompletionCallback* callback = read_callback_;
76    read_callback_ = NULL;
77    read_buf_ = NULL;
78    read_buf_len_ = 0;
79    callback->Run(copied);
80  }
81
82  int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) {
83    scoped_refptr<net::DrainableIOBuffer> buf = data_.front();
84    int copied = std::min(buf->BytesRemaining(), read_buf_len);
85    memcpy(read_buf->data(), buf->data(), copied);
86    buf->DidConsume(copied);
87
88    if (!buf->BytesRemaining())
89      data_.pop();
90    return copied;
91  }
92
93  net::CompletionCallback* read_callback_;
94  scoped_refptr<net::IOBuffer> read_buf_;
95  int read_buf_len_;
96
97  std::queue<scoped_refptr<net::DrainableIOBuffer> > data_;
98
99  DISALLOW_COPY_AND_ASSIGN(FakeDataChannel);
100};
101
102class FakeSocket : public ClientSocket {
103 public:
104  FakeSocket(FakeDataChannel* incoming_channel,
105             FakeDataChannel* outgoing_channel)
106      : incoming_(incoming_channel),
107        outgoing_(outgoing_channel) {
108  }
109
110  virtual ~FakeSocket() {
111
112  }
113
114  virtual int Read(IOBuffer* buf, int buf_len,
115                   CompletionCallback* callback) {
116    return incoming_->Read(buf, buf_len, callback);
117  }
118
119  virtual int Write(IOBuffer* buf, int buf_len,
120                    CompletionCallback* callback) {
121    return outgoing_->Write(buf, buf_len, callback);
122  }
123
124  virtual bool SetReceiveBufferSize(int32 size) {
125    return true;
126  }
127
128  virtual bool SetSendBufferSize(int32 size) {
129    return true;
130  }
131
132  virtual int Connect(CompletionCallback* callback) {
133    return net::OK;
134  }
135
136  virtual void Disconnect() {}
137
138  virtual bool IsConnected() const {
139    return true;
140  }
141
142  virtual bool IsConnectedAndIdle() const {
143    return true;
144  }
145
146  virtual int GetPeerAddress(AddressList* address) const {
147    net::IPAddressNumber ip_address(4);
148    *address = net::AddressList(ip_address, 0, false);
149    return net::OK;
150  }
151
152  virtual int GetLocalAddress(IPEndPoint* address) const {
153    net::IPAddressNumber ip_address(4);
154    *address = net::IPEndPoint(ip_address, 0);
155    return net::OK;
156  }
157
158  virtual const BoundNetLog& NetLog() const {
159    return net_log_;
160  }
161
162  virtual void SetSubresourceSpeculation() {}
163  virtual void SetOmniboxSpeculation() {}
164
165  virtual bool WasEverUsed() const {
166    return true;
167  }
168
169  virtual bool UsingTCPFastOpen() const {
170    return false;
171  }
172
173 private:
174  net::BoundNetLog net_log_;
175  FakeDataChannel* incoming_;
176  FakeDataChannel* outgoing_;
177
178  DISALLOW_COPY_AND_ASSIGN(FakeSocket);
179};
180
181}  // namespace
182
183// Verify the correctness of the test helper classes first.
184TEST(FakeSocketTest, DataTransfer) {
185  // Establish channels between two sockets.
186  FakeDataChannel channel_1;
187  FakeDataChannel channel_2;
188  FakeSocket client(&channel_1, &channel_2);
189  FakeSocket server(&channel_2, &channel_1);
190
191  const char kTestData[] = "testing123";
192  const int kTestDataSize = strlen(kTestData);
193  const int kReadBufSize = 1024;
194  scoped_refptr<net::IOBuffer> write_buf = new net::StringIOBuffer(kTestData);
195  scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize);
196
197  // Write then read.
198  EXPECT_EQ(kTestDataSize, server.Write(write_buf, kTestDataSize, NULL));
199  EXPECT_EQ(kTestDataSize, client.Read(read_buf, kReadBufSize, NULL));
200  EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), kTestDataSize));
201
202  // Read then write.
203  TestCompletionCallback callback;
204  EXPECT_EQ(net::ERR_IO_PENDING,
205            server.Read(read_buf, kReadBufSize, &callback));
206  EXPECT_EQ(kTestDataSize, client.Write(write_buf, kTestDataSize, NULL));
207  EXPECT_EQ(kTestDataSize, callback.WaitForResult());
208  EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), kTestDataSize));
209}
210
211class SSLServerSocketTest : public PlatformTest {
212 public:
213  SSLServerSocketTest()
214      : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()) {
215  }
216
217 protected:
218  void Initialize() {
219    FakeSocket* fake_client_socket = new FakeSocket(&channel_1_, &channel_2_);
220    FakeSocket* fake_server_socket = new FakeSocket(&channel_2_, &channel_1_);
221
222    FilePath certs_dir;
223    PathService::Get(base::DIR_SOURCE_ROOT, &certs_dir);
224    certs_dir = certs_dir.AppendASCII("net");
225    certs_dir = certs_dir.AppendASCII("data");
226    certs_dir = certs_dir.AppendASCII("ssl");
227    certs_dir = certs_dir.AppendASCII("certificates");
228
229    FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der");
230    std::string cert_der;
231    ASSERT_TRUE(file_util::ReadFileToString(cert_path, &cert_der));
232
233    scoped_refptr<net::X509Certificate> cert =
234        X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size());
235
236    FilePath key_path = certs_dir.AppendASCII("unittest.key.bin");
237    std::string key_string;
238    ASSERT_TRUE(file_util::ReadFileToString(key_path, &key_string));
239    std::vector<uint8> key_vector(
240        reinterpret_cast<const uint8*>(key_string.data()),
241        reinterpret_cast<const uint8*>(key_string.data() +
242                                       key_string.length()));
243
244    scoped_ptr<crypto::RSAPrivateKey> private_key(
245        crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector));
246
247    net::SSLConfig ssl_config;
248    ssl_config.false_start_enabled = false;
249    ssl_config.ssl3_enabled = true;
250    ssl_config.tls1_enabled = true;
251
252    // Certificate provided by the host doesn't need authority.
253    net::SSLConfig::CertAndStatus cert_and_status;
254    cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID;
255    cert_and_status.cert = cert;
256    ssl_config.allowed_bad_certs.push_back(cert_and_status);
257
258    net::HostPortPair host_and_pair("unittest", 0);
259    client_socket_.reset(
260        socket_factory_->CreateSSLClientSocket(
261            fake_client_socket, host_and_pair, ssl_config, NULL,
262            &cert_verifier_));
263    server_socket_.reset(net::CreateSSLServerSocket(fake_server_socket,
264                                                    cert, private_key.get(),
265                                                    net::SSLConfig()));
266  }
267
268  FakeDataChannel channel_1_;
269  FakeDataChannel channel_2_;
270  scoped_ptr<net::SSLClientSocket> client_socket_;
271  scoped_ptr<net::SSLServerSocket> server_socket_;
272  net::ClientSocketFactory* socket_factory_;
273  net::CertVerifier cert_verifier_;
274};
275
276// SSLServerSocket is only implemented using NSS.
277#if defined(USE_NSS) || defined(OS_WIN) || defined(OS_MACOSX)
278
279// This test only executes creation of client and server sockets. This is to
280// test that creation of sockets doesn't crash and have minimal code to run
281// under valgrind in order to help debugging memory problems.
282TEST_F(SSLServerSocketTest, Initialize) {
283  Initialize();
284}
285
286// This test executes Connect() of SSLClientSocket and Accept() of
287// SSLServerSocket to make sure handshaking between the two sockets are
288// completed successfully.
289TEST_F(SSLServerSocketTest, Handshake) {
290  Initialize();
291
292  TestCompletionCallback connect_callback;
293  TestCompletionCallback accept_callback;
294
295  int server_ret = server_socket_->Accept(&accept_callback);
296  EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
297
298  int client_ret = client_socket_->Connect(&connect_callback);
299  EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
300
301  if (client_ret == net::ERR_IO_PENDING) {
302    EXPECT_EQ(net::OK, connect_callback.WaitForResult());
303  }
304  if (server_ret == net::ERR_IO_PENDING) {
305    EXPECT_EQ(net::OK, accept_callback.WaitForResult());
306  }
307}
308
309TEST_F(SSLServerSocketTest, DataTransfer) {
310  Initialize();
311
312  TestCompletionCallback connect_callback;
313  TestCompletionCallback accept_callback;
314
315  // Establish connection.
316  int client_ret = client_socket_->Connect(&connect_callback);
317  ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING);
318
319  int server_ret = server_socket_->Accept(&accept_callback);
320  ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING);
321
322  if (client_ret == net::ERR_IO_PENDING) {
323    ASSERT_EQ(net::OK, connect_callback.WaitForResult());
324  }
325  if (server_ret == net::ERR_IO_PENDING) {
326    ASSERT_EQ(net::OK, accept_callback.WaitForResult());
327  }
328
329  const int kReadBufSize = 1024;
330  scoped_refptr<net::StringIOBuffer> write_buf =
331      new net::StringIOBuffer("testing123");
332  scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize);
333
334  // Write then read.
335  TestCompletionCallback write_callback;
336  TestCompletionCallback read_callback;
337  server_ret = server_socket_->Write(write_buf, write_buf->size(),
338                                     &write_callback);
339  EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
340  client_ret = client_socket_->Read(read_buf, kReadBufSize, &read_callback);
341  EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
342
343  if (server_ret == net::ERR_IO_PENDING) {
344    EXPECT_GT(write_callback.WaitForResult(), 0);
345  }
346  if (client_ret == net::ERR_IO_PENDING) {
347    EXPECT_GT(read_callback.WaitForResult(), 0);
348  }
349  EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
350
351  // Read then write.
352  write_buf = new net::StringIOBuffer("hello123");
353  server_ret = server_socket_->Read(read_buf, kReadBufSize, &read_callback);
354  EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING);
355  client_ret = client_socket_->Write(write_buf, write_buf->size(),
356                                     &write_callback);
357  EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING);
358
359  if (server_ret == net::ERR_IO_PENDING) {
360    EXPECT_GT(read_callback.WaitForResult(), 0);
361  }
362  if (client_ret == net::ERR_IO_PENDING) {
363    EXPECT_GT(write_callback.WaitForResult(), 0);
364  }
365  EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size()));
366}
367#endif
368
369}  // namespace net
370