ssl_server_socket_unittest.cc revision 0f1bc08d4cfcc34181b0b5cbf065c40f687bf740
1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5// This test suite uses SSLClientSocket to test the implementation of 6// SSLServerSocket. In order to establish connections between the sockets 7// we need two additional classes: 8// 1. FakeSocket 9// Connects SSL socket to FakeDataChannel. This class is just a stub. 10// 11// 2. FakeDataChannel 12// Implements the actual exchange of data between two FakeSockets. 13// 14// Implementations of these two classes are included in this file. 15 16#include "net/socket/ssl_server_socket.h" 17 18#include <stdlib.h> 19 20#include <queue> 21 22#include "base/compiler_specific.h" 23#include "base/file_util.h" 24#include "base/files/file_path.h" 25#include "base/message_loop/message_loop.h" 26#include "base/path_service.h" 27#include "crypto/nss_util.h" 28#include "crypto/rsa_private_key.h" 29#include "net/base/address_list.h" 30#include "net/base/completion_callback.h" 31#include "net/base/host_port_pair.h" 32#include "net/base/io_buffer.h" 33#include "net/base/ip_endpoint.h" 34#include "net/base/net_errors.h" 35#include "net/base/net_log.h" 36#include "net/base/test_data_directory.h" 37#include "net/cert/cert_status_flags.h" 38#include "net/cert/mock_cert_verifier.h" 39#include "net/cert/x509_certificate.h" 40#include "net/http/transport_security_state.h" 41#include "net/socket/client_socket_factory.h" 42#include "net/socket/socket_test_util.h" 43#include "net/socket/ssl_client_socket.h" 44#include "net/socket/stream_socket.h" 45#include "net/ssl/ssl_config_service.h" 46#include "net/ssl/ssl_info.h" 47#include "net/test/cert_test_util.h" 48#include "testing/gtest/include/gtest/gtest.h" 49#include "testing/platform_test.h" 50 51namespace net { 52 53namespace { 54 55class FakeDataChannel { 56 public: 57 FakeDataChannel() 58 : read_buf_len_(0), 59 weak_factory_(this), 60 closed_(false), 61 write_called_after_close_(false) { 62 } 63 64 int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { 65 if (closed_) 66 return 0; 67 if (data_.empty()) { 68 read_callback_ = callback; 69 read_buf_ = buf; 70 read_buf_len_ = buf_len; 71 return net::ERR_IO_PENDING; 72 } 73 return PropogateData(buf, buf_len); 74 } 75 76 int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { 77 if (closed_) { 78 if (write_called_after_close_) 79 return net::ERR_CONNECTION_RESET; 80 write_called_after_close_ = true; 81 write_callback_ = callback; 82 base::MessageLoop::current()->PostTask( 83 FROM_HERE, base::Bind(&FakeDataChannel::DoWriteCallback, 84 weak_factory_.GetWeakPtr())); 85 return net::ERR_IO_PENDING; 86 } 87 data_.push(new net::DrainableIOBuffer(buf, buf_len)); 88 base::MessageLoop::current()->PostTask( 89 FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback, 90 weak_factory_.GetWeakPtr())); 91 return buf_len; 92 } 93 94 // Closes the FakeDataChannel. After Close() is called, Read() returns 0, 95 // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that 96 // after the FakeDataChannel is closed, the first Write() call completes 97 // asynchronously, which is necessary to reproduce bug 127822. 98 void Close() { 99 closed_ = true; 100 } 101 102 private: 103 void DoReadCallback() { 104 if (read_callback_.is_null() || data_.empty()) 105 return; 106 107 int copied = PropogateData(read_buf_, read_buf_len_); 108 CompletionCallback callback = read_callback_; 109 read_callback_.Reset(); 110 read_buf_ = NULL; 111 read_buf_len_ = 0; 112 callback.Run(copied); 113 } 114 115 void DoWriteCallback() { 116 if (write_callback_.is_null()) 117 return; 118 119 CompletionCallback callback = write_callback_; 120 write_callback_.Reset(); 121 callback.Run(net::ERR_CONNECTION_RESET); 122 } 123 124 int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) { 125 scoped_refptr<net::DrainableIOBuffer> buf = data_.front(); 126 int copied = std::min(buf->BytesRemaining(), read_buf_len); 127 memcpy(read_buf->data(), buf->data(), copied); 128 buf->DidConsume(copied); 129 130 if (!buf->BytesRemaining()) 131 data_.pop(); 132 return copied; 133 } 134 135 CompletionCallback read_callback_; 136 scoped_refptr<net::IOBuffer> read_buf_; 137 int read_buf_len_; 138 139 CompletionCallback write_callback_; 140 141 std::queue<scoped_refptr<net::DrainableIOBuffer> > data_; 142 143 base::WeakPtrFactory<FakeDataChannel> weak_factory_; 144 145 // True if Close() has been called. 146 bool closed_; 147 148 // Controls the completion of Write() after the FakeDataChannel is closed. 149 // After the FakeDataChannel is closed, the first Write() call completes 150 // asynchronously. 151 bool write_called_after_close_; 152 153 DISALLOW_COPY_AND_ASSIGN(FakeDataChannel); 154}; 155 156class FakeSocket : public StreamSocket { 157 public: 158 FakeSocket(FakeDataChannel* incoming_channel, 159 FakeDataChannel* outgoing_channel) 160 : incoming_(incoming_channel), 161 outgoing_(outgoing_channel) { 162 } 163 164 virtual ~FakeSocket() { 165 } 166 167 virtual int Read(IOBuffer* buf, int buf_len, 168 const CompletionCallback& callback) OVERRIDE { 169 // Read random number of bytes. 170 buf_len = rand() % buf_len + 1; 171 return incoming_->Read(buf, buf_len, callback); 172 } 173 174 virtual int Write(IOBuffer* buf, int buf_len, 175 const CompletionCallback& callback) OVERRIDE { 176 // Write random number of bytes. 177 buf_len = rand() % buf_len + 1; 178 return outgoing_->Write(buf, buf_len, callback); 179 } 180 181 virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { 182 return true; 183 } 184 185 virtual bool SetSendBufferSize(int32 size) OVERRIDE { 186 return true; 187 } 188 189 virtual int Connect(const CompletionCallback& callback) OVERRIDE { 190 return net::OK; 191 } 192 193 virtual void Disconnect() OVERRIDE { 194 incoming_->Close(); 195 outgoing_->Close(); 196 } 197 198 virtual bool IsConnected() const OVERRIDE { 199 return true; 200 } 201 202 virtual bool IsConnectedAndIdle() const OVERRIDE { 203 return true; 204 } 205 206 virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { 207 net::IPAddressNumber ip_address(net::kIPv4AddressSize); 208 *address = net::IPEndPoint(ip_address, 0 /*port*/); 209 return net::OK; 210 } 211 212 virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { 213 net::IPAddressNumber ip_address(4); 214 *address = net::IPEndPoint(ip_address, 0); 215 return net::OK; 216 } 217 218 virtual const BoundNetLog& NetLog() const OVERRIDE { 219 return net_log_; 220 } 221 222 virtual void SetSubresourceSpeculation() OVERRIDE {} 223 virtual void SetOmniboxSpeculation() OVERRIDE {} 224 225 virtual bool WasEverUsed() const OVERRIDE { 226 return true; 227 } 228 229 virtual bool UsingTCPFastOpen() const OVERRIDE { 230 return false; 231 } 232 233 234 virtual bool WasNpnNegotiated() const OVERRIDE { 235 return false; 236 } 237 238 virtual NextProto GetNegotiatedProtocol() const OVERRIDE { 239 return kProtoUnknown; 240 } 241 242 virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { 243 return false; 244 } 245 246 private: 247 net::BoundNetLog net_log_; 248 FakeDataChannel* incoming_; 249 FakeDataChannel* outgoing_; 250 251 DISALLOW_COPY_AND_ASSIGN(FakeSocket); 252}; 253 254} // namespace 255 256// Verify the correctness of the test helper classes first. 257TEST(FakeSocketTest, DataTransfer) { 258 // Establish channels between two sockets. 259 FakeDataChannel channel_1; 260 FakeDataChannel channel_2; 261 FakeSocket client(&channel_1, &channel_2); 262 FakeSocket server(&channel_2, &channel_1); 263 264 const char kTestData[] = "testing123"; 265 const int kTestDataSize = strlen(kTestData); 266 const int kReadBufSize = 1024; 267 scoped_refptr<net::IOBuffer> write_buf = new net::StringIOBuffer(kTestData); 268 scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize); 269 270 // Write then read. 271 int written = 272 server.Write(write_buf.get(), kTestDataSize, CompletionCallback()); 273 EXPECT_GT(written, 0); 274 EXPECT_LE(written, kTestDataSize); 275 276 int read = client.Read(read_buf.get(), kReadBufSize, CompletionCallback()); 277 EXPECT_GT(read, 0); 278 EXPECT_LE(read, written); 279 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read)); 280 281 // Read then write. 282 TestCompletionCallback callback; 283 EXPECT_EQ(net::ERR_IO_PENDING, 284 server.Read(read_buf.get(), kReadBufSize, callback.callback())); 285 286 written = client.Write(write_buf.get(), kTestDataSize, CompletionCallback()); 287 EXPECT_GT(written, 0); 288 EXPECT_LE(written, kTestDataSize); 289 290 read = callback.WaitForResult(); 291 EXPECT_GT(read, 0); 292 EXPECT_LE(read, written); 293 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read)); 294} 295 296class SSLServerSocketTest : public PlatformTest { 297 public: 298 SSLServerSocketTest() 299 : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()), 300 cert_verifier_(new MockCertVerifier()), 301 transport_security_state_(new TransportSecurityState) { 302 cert_verifier_->set_default_result(net::CERT_STATUS_AUTHORITY_INVALID); 303 } 304 305 protected: 306 void Initialize() { 307 scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle); 308 client_connection->SetSocket( 309 scoped_ptr<StreamSocket>(new FakeSocket(&channel_1_, &channel_2_))); 310 scoped_ptr<StreamSocket> server_socket( 311 new FakeSocket(&channel_2_, &channel_1_)); 312 313 base::FilePath certs_dir(GetTestCertsDirectory()); 314 315 base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); 316 std::string cert_der; 317 ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der)); 318 319 scoped_refptr<net::X509Certificate> cert = 320 X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size()); 321 322 base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); 323 std::string key_string; 324 ASSERT_TRUE(base::ReadFileToString(key_path, &key_string)); 325 std::vector<uint8> key_vector( 326 reinterpret_cast<const uint8*>(key_string.data()), 327 reinterpret_cast<const uint8*>(key_string.data() + 328 key_string.length())); 329 330 scoped_ptr<crypto::RSAPrivateKey> private_key( 331 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); 332 333 net::SSLConfig ssl_config; 334 ssl_config.cached_info_enabled = false; 335 ssl_config.false_start_enabled = false; 336 ssl_config.channel_id_enabled = false; 337 338 // Certificate provided by the host doesn't need authority. 339 net::SSLConfig::CertAndStatus cert_and_status; 340 cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID; 341 cert_and_status.der_cert = cert_der; 342 ssl_config.allowed_bad_certs.push_back(cert_and_status); 343 344 net::HostPortPair host_and_pair("unittest", 0); 345 net::SSLClientSocketContext context; 346 context.cert_verifier = cert_verifier_.get(); 347 context.transport_security_state = transport_security_state_.get(); 348 client_socket_ = 349 socket_factory_->CreateSSLClientSocket( 350 client_connection.Pass(), host_and_pair, ssl_config, context); 351 server_socket_ = net::CreateSSLServerSocket( 352 server_socket.Pass(), 353 cert.get(), private_key.get(), net::SSLConfig()); 354 } 355 356 FakeDataChannel channel_1_; 357 FakeDataChannel channel_2_; 358 scoped_ptr<net::SSLClientSocket> client_socket_; 359 scoped_ptr<net::SSLServerSocket> server_socket_; 360 net::ClientSocketFactory* socket_factory_; 361 scoped_ptr<net::MockCertVerifier> cert_verifier_; 362 scoped_ptr<net::TransportSecurityState> transport_security_state_; 363}; 364 365// SSLServerSocket is only implemented using NSS. 366#if defined(USE_NSS) || defined(OS_WIN) || defined(OS_MACOSX) 367 368// This test only executes creation of client and server sockets. This is to 369// test that creation of sockets doesn't crash and have minimal code to run 370// under valgrind in order to help debugging memory problems. 371TEST_F(SSLServerSocketTest, Initialize) { 372 Initialize(); 373} 374 375// This test executes Connect() on SSLClientSocket and Handshake() on 376// SSLServerSocket to make sure handshaking between the two sockets is 377// completed successfully. 378TEST_F(SSLServerSocketTest, Handshake) { 379 Initialize(); 380 381 TestCompletionCallback connect_callback; 382 TestCompletionCallback handshake_callback; 383 384 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 385 EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 386 387 int client_ret = client_socket_->Connect(connect_callback.callback()); 388 EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 389 390 if (client_ret == net::ERR_IO_PENDING) { 391 EXPECT_EQ(net::OK, connect_callback.WaitForResult()); 392 } 393 if (server_ret == net::ERR_IO_PENDING) { 394 EXPECT_EQ(net::OK, handshake_callback.WaitForResult()); 395 } 396 397 // Make sure the cert status is expected. 398 SSLInfo ssl_info; 399 client_socket_->GetSSLInfo(&ssl_info); 400 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status); 401} 402 403TEST_F(SSLServerSocketTest, DataTransfer) { 404 Initialize(); 405 406 TestCompletionCallback connect_callback; 407 TestCompletionCallback handshake_callback; 408 409 // Establish connection. 410 int client_ret = client_socket_->Connect(connect_callback.callback()); 411 ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 412 413 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 414 ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 415 416 client_ret = connect_callback.GetResult(client_ret); 417 ASSERT_EQ(net::OK, client_ret); 418 server_ret = handshake_callback.GetResult(server_ret); 419 ASSERT_EQ(net::OK, server_ret); 420 421 const int kReadBufSize = 1024; 422 scoped_refptr<net::StringIOBuffer> write_buf = 423 new net::StringIOBuffer("testing123"); 424 scoped_refptr<net::DrainableIOBuffer> read_buf = 425 new net::DrainableIOBuffer(new net::IOBuffer(kReadBufSize), 426 kReadBufSize); 427 428 // Write then read. 429 TestCompletionCallback write_callback; 430 TestCompletionCallback read_callback; 431 server_ret = server_socket_->Write( 432 write_buf.get(), write_buf->size(), write_callback.callback()); 433 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 434 client_ret = client_socket_->Read( 435 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 436 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 437 438 server_ret = write_callback.GetResult(server_ret); 439 EXPECT_GT(server_ret, 0); 440 client_ret = read_callback.GetResult(client_ret); 441 ASSERT_GT(client_ret, 0); 442 443 read_buf->DidConsume(client_ret); 444 while (read_buf->BytesConsumed() < write_buf->size()) { 445 client_ret = client_socket_->Read( 446 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 447 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 448 client_ret = read_callback.GetResult(client_ret); 449 ASSERT_GT(client_ret, 0); 450 read_buf->DidConsume(client_ret); 451 } 452 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed()); 453 read_buf->SetOffset(0); 454 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); 455 456 // Read then write. 457 write_buf = new net::StringIOBuffer("hello123"); 458 server_ret = server_socket_->Read( 459 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 460 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 461 client_ret = client_socket_->Write( 462 write_buf.get(), write_buf->size(), write_callback.callback()); 463 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 464 465 server_ret = read_callback.GetResult(server_ret); 466 ASSERT_GT(server_ret, 0); 467 client_ret = write_callback.GetResult(client_ret); 468 EXPECT_GT(client_ret, 0); 469 470 read_buf->DidConsume(server_ret); 471 while (read_buf->BytesConsumed() < write_buf->size()) { 472 server_ret = server_socket_->Read( 473 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 474 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 475 server_ret = read_callback.GetResult(server_ret); 476 ASSERT_GT(server_ret, 0); 477 read_buf->DidConsume(server_ret); 478 } 479 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed()); 480 read_buf->SetOffset(0); 481 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); 482} 483 484// A regression test for bug 127822 (http://crbug.com/127822). 485// If the server closes the connection after the handshake is finished, 486// the client's Write() call should not cause an infinite loop. 487// NOTE: this is a test for SSLClientSocket rather than SSLServerSocket. 488TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) { 489 Initialize(); 490 491 TestCompletionCallback connect_callback; 492 TestCompletionCallback handshake_callback; 493 494 // Establish connection. 495 int client_ret = client_socket_->Connect(connect_callback.callback()); 496 ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 497 498 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 499 ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 500 501 client_ret = connect_callback.GetResult(client_ret); 502 ASSERT_EQ(net::OK, client_ret); 503 server_ret = handshake_callback.GetResult(server_ret); 504 ASSERT_EQ(net::OK, server_ret); 505 506 scoped_refptr<net::StringIOBuffer> write_buf = 507 new net::StringIOBuffer("testing123"); 508 509 // The server closes the connection. The server needs to write some 510 // data first so that the client's Read() calls from the transport 511 // socket won't return ERR_IO_PENDING. This ensures that the client 512 // will call Read() on the transport socket again. 513 TestCompletionCallback write_callback; 514 515 server_ret = server_socket_->Write( 516 write_buf.get(), write_buf->size(), write_callback.callback()); 517 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 518 519 server_ret = write_callback.GetResult(server_ret); 520 EXPECT_GT(server_ret, 0); 521 522 server_socket_->Disconnect(); 523 524 // The client writes some data. This should not cause an infinite loop. 525 client_ret = client_socket_->Write( 526 write_buf.get(), write_buf->size(), write_callback.callback()); 527 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 528 529 client_ret = write_callback.GetResult(client_ret); 530 EXPECT_GT(client_ret, 0); 531 532 base::MessageLoop::current()->PostDelayedTask( 533 FROM_HERE, base::MessageLoop::QuitClosure(), 534 base::TimeDelta::FromMilliseconds(10)); 535 base::MessageLoop::current()->Run(); 536} 537 538// This test executes ExportKeyingMaterial() on the client and server sockets, 539// after connecting them, and verifies that the results match. 540// This test will fail if False Start is enabled (see crbug.com/90208). 541TEST_F(SSLServerSocketTest, ExportKeyingMaterial) { 542 Initialize(); 543 544 TestCompletionCallback connect_callback; 545 TestCompletionCallback handshake_callback; 546 547 int client_ret = client_socket_->Connect(connect_callback.callback()); 548 ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 549 550 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 551 ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 552 553 if (client_ret == net::ERR_IO_PENDING) { 554 ASSERT_EQ(net::OK, connect_callback.WaitForResult()); 555 } 556 if (server_ret == net::ERR_IO_PENDING) { 557 ASSERT_EQ(net::OK, handshake_callback.WaitForResult()); 558 } 559 560 const int kKeyingMaterialSize = 32; 561 const char* kKeyingLabel = "EXPERIMENTAL-server-socket-test"; 562 const char* kKeyingContext = ""; 563 unsigned char server_out[kKeyingMaterialSize]; 564 int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel, 565 false, kKeyingContext, 566 server_out, sizeof(server_out)); 567 ASSERT_EQ(net::OK, rv); 568 569 unsigned char client_out[kKeyingMaterialSize]; 570 rv = client_socket_->ExportKeyingMaterial(kKeyingLabel, 571 false, kKeyingContext, 572 client_out, sizeof(client_out)); 573 ASSERT_EQ(net::OK, rv); 574 EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out))); 575 576 const char* kKeyingLabelBad = "EXPERIMENTAL-server-socket-test-bad"; 577 unsigned char client_bad[kKeyingMaterialSize]; 578 rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad, 579 false, kKeyingContext, 580 client_bad, sizeof(client_bad)); 581 ASSERT_EQ(rv, net::OK); 582 EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out))); 583} 584#endif 585 586} // namespace net 587