1// Copyright (c) 2011 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5// This test suite uses SSLClientSocket to test the implementation of 6// SSLServerSocket. In order to establish connections between the sockets 7// we need two additional classes: 8// 1. FakeSocket 9// Connects SSL socket to FakeDataChannel. This class is just a stub. 10// 11// 2. FakeDataChannel 12// Implements the actual exchange of data between two FakeSockets. 13// 14// Implementations of these two classes are included in this file. 15 16#include "net/socket/ssl_server_socket.h" 17 18#include <queue> 19 20#include "base/file_path.h" 21#include "base/file_util.h" 22#include "base/path_service.h" 23#include "crypto/nss_util.h" 24#include "crypto/rsa_private_key.h" 25#include "net/base/address_list.h" 26#include "net/base/cert_status_flags.h" 27#include "net/base/cert_verifier.h" 28#include "net/base/host_port_pair.h" 29#include "net/base/io_buffer.h" 30#include "net/base/ip_endpoint.h" 31#include "net/base/net_errors.h" 32#include "net/base/net_log.h" 33#include "net/base/ssl_config_service.h" 34#include "net/base/x509_certificate.h" 35#include "net/socket/client_socket.h" 36#include "net/socket/client_socket_factory.h" 37#include "net/socket/socket_test_util.h" 38#include "net/socket/ssl_client_socket.h" 39#include "testing/gtest/include/gtest/gtest.h" 40#include "testing/platform_test.h" 41 42namespace net { 43 44namespace { 45 46class FakeDataChannel { 47 public: 48 FakeDataChannel() : read_callback_(NULL), read_buf_len_(0) { 49 } 50 51 virtual int Read(IOBuffer* buf, int buf_len, 52 CompletionCallback* callback) { 53 if (data_.empty()) { 54 read_callback_ = callback; 55 read_buf_ = buf; 56 read_buf_len_ = buf_len; 57 return net::ERR_IO_PENDING; 58 } 59 return PropogateData(buf, buf_len); 60 } 61 62 virtual int Write(IOBuffer* buf, int buf_len, 63 CompletionCallback* callback) { 64 data_.push(new net::DrainableIOBuffer(buf, buf_len)); 65 DoReadCallback(); 66 return buf_len; 67 } 68 69 private: 70 void DoReadCallback() { 71 if (!read_callback_) 72 return; 73 74 int copied = PropogateData(read_buf_, read_buf_len_); 75 net::CompletionCallback* callback = read_callback_; 76 read_callback_ = NULL; 77 read_buf_ = NULL; 78 read_buf_len_ = 0; 79 callback->Run(copied); 80 } 81 82 int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) { 83 scoped_refptr<net::DrainableIOBuffer> buf = data_.front(); 84 int copied = std::min(buf->BytesRemaining(), read_buf_len); 85 memcpy(read_buf->data(), buf->data(), copied); 86 buf->DidConsume(copied); 87 88 if (!buf->BytesRemaining()) 89 data_.pop(); 90 return copied; 91 } 92 93 net::CompletionCallback* read_callback_; 94 scoped_refptr<net::IOBuffer> read_buf_; 95 int read_buf_len_; 96 97 std::queue<scoped_refptr<net::DrainableIOBuffer> > data_; 98 99 DISALLOW_COPY_AND_ASSIGN(FakeDataChannel); 100}; 101 102class FakeSocket : public ClientSocket { 103 public: 104 FakeSocket(FakeDataChannel* incoming_channel, 105 FakeDataChannel* outgoing_channel) 106 : incoming_(incoming_channel), 107 outgoing_(outgoing_channel) { 108 } 109 110 virtual ~FakeSocket() { 111 112 } 113 114 virtual int Read(IOBuffer* buf, int buf_len, 115 CompletionCallback* callback) { 116 return incoming_->Read(buf, buf_len, callback); 117 } 118 119 virtual int Write(IOBuffer* buf, int buf_len, 120 CompletionCallback* callback) { 121 return outgoing_->Write(buf, buf_len, callback); 122 } 123 124 virtual bool SetReceiveBufferSize(int32 size) { 125 return true; 126 } 127 128 virtual bool SetSendBufferSize(int32 size) { 129 return true; 130 } 131 132 virtual int Connect(CompletionCallback* callback) { 133 return net::OK; 134 } 135 136 virtual void Disconnect() {} 137 138 virtual bool IsConnected() const { 139 return true; 140 } 141 142 virtual bool IsConnectedAndIdle() const { 143 return true; 144 } 145 146 virtual int GetPeerAddress(AddressList* address) const { 147 net::IPAddressNumber ip_address(4); 148 *address = net::AddressList(ip_address, 0, false); 149 return net::OK; 150 } 151 152 virtual int GetLocalAddress(IPEndPoint* address) const { 153 net::IPAddressNumber ip_address(4); 154 *address = net::IPEndPoint(ip_address, 0); 155 return net::OK; 156 } 157 158 virtual const BoundNetLog& NetLog() const { 159 return net_log_; 160 } 161 162 virtual void SetSubresourceSpeculation() {} 163 virtual void SetOmniboxSpeculation() {} 164 165 virtual bool WasEverUsed() const { 166 return true; 167 } 168 169 virtual bool UsingTCPFastOpen() const { 170 return false; 171 } 172 173 private: 174 net::BoundNetLog net_log_; 175 FakeDataChannel* incoming_; 176 FakeDataChannel* outgoing_; 177 178 DISALLOW_COPY_AND_ASSIGN(FakeSocket); 179}; 180 181} // namespace 182 183// Verify the correctness of the test helper classes first. 184TEST(FakeSocketTest, DataTransfer) { 185 // Establish channels between two sockets. 186 FakeDataChannel channel_1; 187 FakeDataChannel channel_2; 188 FakeSocket client(&channel_1, &channel_2); 189 FakeSocket server(&channel_2, &channel_1); 190 191 const char kTestData[] = "testing123"; 192 const int kTestDataSize = strlen(kTestData); 193 const int kReadBufSize = 1024; 194 scoped_refptr<net::IOBuffer> write_buf = new net::StringIOBuffer(kTestData); 195 scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize); 196 197 // Write then read. 198 EXPECT_EQ(kTestDataSize, server.Write(write_buf, kTestDataSize, NULL)); 199 EXPECT_EQ(kTestDataSize, client.Read(read_buf, kReadBufSize, NULL)); 200 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), kTestDataSize)); 201 202 // Read then write. 203 TestCompletionCallback callback; 204 EXPECT_EQ(net::ERR_IO_PENDING, 205 server.Read(read_buf, kReadBufSize, &callback)); 206 EXPECT_EQ(kTestDataSize, client.Write(write_buf, kTestDataSize, NULL)); 207 EXPECT_EQ(kTestDataSize, callback.WaitForResult()); 208 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), kTestDataSize)); 209} 210 211class SSLServerSocketTest : public PlatformTest { 212 public: 213 SSLServerSocketTest() 214 : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()) { 215 } 216 217 protected: 218 void Initialize() { 219 FakeSocket* fake_client_socket = new FakeSocket(&channel_1_, &channel_2_); 220 FakeSocket* fake_server_socket = new FakeSocket(&channel_2_, &channel_1_); 221 222 FilePath certs_dir; 223 PathService::Get(base::DIR_SOURCE_ROOT, &certs_dir); 224 certs_dir = certs_dir.AppendASCII("net"); 225 certs_dir = certs_dir.AppendASCII("data"); 226 certs_dir = certs_dir.AppendASCII("ssl"); 227 certs_dir = certs_dir.AppendASCII("certificates"); 228 229 FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); 230 std::string cert_der; 231 ASSERT_TRUE(file_util::ReadFileToString(cert_path, &cert_der)); 232 233 scoped_refptr<net::X509Certificate> cert = 234 X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size()); 235 236 FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); 237 std::string key_string; 238 ASSERT_TRUE(file_util::ReadFileToString(key_path, &key_string)); 239 std::vector<uint8> key_vector( 240 reinterpret_cast<const uint8*>(key_string.data()), 241 reinterpret_cast<const uint8*>(key_string.data() + 242 key_string.length())); 243 244 scoped_ptr<crypto::RSAPrivateKey> private_key( 245 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); 246 247 net::SSLConfig ssl_config; 248 ssl_config.false_start_enabled = false; 249 ssl_config.ssl3_enabled = true; 250 ssl_config.tls1_enabled = true; 251 252 // Certificate provided by the host doesn't need authority. 253 net::SSLConfig::CertAndStatus cert_and_status; 254 cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID; 255 cert_and_status.cert = cert; 256 ssl_config.allowed_bad_certs.push_back(cert_and_status); 257 258 net::HostPortPair host_and_pair("unittest", 0); 259 client_socket_.reset( 260 socket_factory_->CreateSSLClientSocket( 261 fake_client_socket, host_and_pair, ssl_config, NULL, 262 &cert_verifier_)); 263 server_socket_.reset(net::CreateSSLServerSocket(fake_server_socket, 264 cert, private_key.get(), 265 net::SSLConfig())); 266 } 267 268 FakeDataChannel channel_1_; 269 FakeDataChannel channel_2_; 270 scoped_ptr<net::SSLClientSocket> client_socket_; 271 scoped_ptr<net::SSLServerSocket> server_socket_; 272 net::ClientSocketFactory* socket_factory_; 273 net::CertVerifier cert_verifier_; 274}; 275 276// SSLServerSocket is only implemented using NSS. 277#if defined(USE_NSS) || defined(OS_WIN) || defined(OS_MACOSX) 278 279// This test only executes creation of client and server sockets. This is to 280// test that creation of sockets doesn't crash and have minimal code to run 281// under valgrind in order to help debugging memory problems. 282TEST_F(SSLServerSocketTest, Initialize) { 283 Initialize(); 284} 285 286// This test executes Connect() of SSLClientSocket and Accept() of 287// SSLServerSocket to make sure handshaking between the two sockets are 288// completed successfully. 289TEST_F(SSLServerSocketTest, Handshake) { 290 Initialize(); 291 292 TestCompletionCallback connect_callback; 293 TestCompletionCallback accept_callback; 294 295 int server_ret = server_socket_->Accept(&accept_callback); 296 EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 297 298 int client_ret = client_socket_->Connect(&connect_callback); 299 EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 300 301 if (client_ret == net::ERR_IO_PENDING) { 302 EXPECT_EQ(net::OK, connect_callback.WaitForResult()); 303 } 304 if (server_ret == net::ERR_IO_PENDING) { 305 EXPECT_EQ(net::OK, accept_callback.WaitForResult()); 306 } 307} 308 309TEST_F(SSLServerSocketTest, DataTransfer) { 310 Initialize(); 311 312 TestCompletionCallback connect_callback; 313 TestCompletionCallback accept_callback; 314 315 // Establish connection. 316 int client_ret = client_socket_->Connect(&connect_callback); 317 ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 318 319 int server_ret = server_socket_->Accept(&accept_callback); 320 ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 321 322 if (client_ret == net::ERR_IO_PENDING) { 323 ASSERT_EQ(net::OK, connect_callback.WaitForResult()); 324 } 325 if (server_ret == net::ERR_IO_PENDING) { 326 ASSERT_EQ(net::OK, accept_callback.WaitForResult()); 327 } 328 329 const int kReadBufSize = 1024; 330 scoped_refptr<net::StringIOBuffer> write_buf = 331 new net::StringIOBuffer("testing123"); 332 scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize); 333 334 // Write then read. 335 TestCompletionCallback write_callback; 336 TestCompletionCallback read_callback; 337 server_ret = server_socket_->Write(write_buf, write_buf->size(), 338 &write_callback); 339 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 340 client_ret = client_socket_->Read(read_buf, kReadBufSize, &read_callback); 341 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 342 343 if (server_ret == net::ERR_IO_PENDING) { 344 EXPECT_GT(write_callback.WaitForResult(), 0); 345 } 346 if (client_ret == net::ERR_IO_PENDING) { 347 EXPECT_GT(read_callback.WaitForResult(), 0); 348 } 349 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); 350 351 // Read then write. 352 write_buf = new net::StringIOBuffer("hello123"); 353 server_ret = server_socket_->Read(read_buf, kReadBufSize, &read_callback); 354 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 355 client_ret = client_socket_->Write(write_buf, write_buf->size(), 356 &write_callback); 357 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 358 359 if (server_ret == net::ERR_IO_PENDING) { 360 EXPECT_GT(read_callback.WaitForResult(), 0); 361 } 362 if (client_ret == net::ERR_IO_PENDING) { 363 EXPECT_GT(write_callback.WaitForResult(), 0); 364 } 365 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); 366} 367#endif 368 369} // namespace net 370