1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5// This test suite uses SSLClientSocket to test the implementation of 6// SSLServerSocket. In order to establish connections between the sockets 7// we need two additional classes: 8// 1. FakeSocket 9// Connects SSL socket to FakeDataChannel. This class is just a stub. 10// 11// 2. FakeDataChannel 12// Implements the actual exchange of data between two FakeSockets. 13// 14// Implementations of these two classes are included in this file. 15 16#include "net/socket/ssl_server_socket.h" 17 18#include <stdlib.h> 19 20#include <queue> 21 22#include "base/compiler_specific.h" 23#include "base/file_util.h" 24#include "base/files/file_path.h" 25#include "base/message_loop/message_loop.h" 26#include "base/path_service.h" 27#include "crypto/nss_util.h" 28#include "crypto/rsa_private_key.h" 29#include "net/base/address_list.h" 30#include "net/base/completion_callback.h" 31#include "net/base/host_port_pair.h" 32#include "net/base/io_buffer.h" 33#include "net/base/ip_endpoint.h" 34#include "net/base/net_errors.h" 35#include "net/base/net_log.h" 36#include "net/base/test_data_directory.h" 37#include "net/cert/cert_status_flags.h" 38#include "net/cert/mock_cert_verifier.h" 39#include "net/cert/x509_certificate.h" 40#include "net/http/transport_security_state.h" 41#include "net/socket/client_socket_factory.h" 42#include "net/socket/socket_test_util.h" 43#include "net/socket/ssl_client_socket.h" 44#include "net/socket/stream_socket.h" 45#include "net/ssl/ssl_config_service.h" 46#include "net/ssl/ssl_info.h" 47#include "net/test/cert_test_util.h" 48#include "testing/gtest/include/gtest/gtest.h" 49#include "testing/platform_test.h" 50 51namespace net { 52 53namespace { 54 55class FakeDataChannel { 56 public: 57 FakeDataChannel() 58 : read_buf_len_(0), 59 closed_(false), 60 write_called_after_close_(false), 61 weak_factory_(this) { 62 } 63 64 int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { 65 if (closed_) 66 return 0; 67 if (data_.empty()) { 68 read_callback_ = callback; 69 read_buf_ = buf; 70 read_buf_len_ = buf_len; 71 return net::ERR_IO_PENDING; 72 } 73 return PropogateData(buf, buf_len); 74 } 75 76 int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { 77 if (closed_) { 78 if (write_called_after_close_) 79 return net::ERR_CONNECTION_RESET; 80 write_called_after_close_ = true; 81 write_callback_ = callback; 82 base::MessageLoop::current()->PostTask( 83 FROM_HERE, base::Bind(&FakeDataChannel::DoWriteCallback, 84 weak_factory_.GetWeakPtr())); 85 return net::ERR_IO_PENDING; 86 } 87 data_.push(new net::DrainableIOBuffer(buf, buf_len)); 88 base::MessageLoop::current()->PostTask( 89 FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback, 90 weak_factory_.GetWeakPtr())); 91 return buf_len; 92 } 93 94 // Closes the FakeDataChannel. After Close() is called, Read() returns 0, 95 // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that 96 // after the FakeDataChannel is closed, the first Write() call completes 97 // asynchronously, which is necessary to reproduce bug 127822. 98 void Close() { 99 closed_ = true; 100 } 101 102 private: 103 void DoReadCallback() { 104 if (read_callback_.is_null() || data_.empty()) 105 return; 106 107 int copied = PropogateData(read_buf_, read_buf_len_); 108 CompletionCallback callback = read_callback_; 109 read_callback_.Reset(); 110 read_buf_ = NULL; 111 read_buf_len_ = 0; 112 callback.Run(copied); 113 } 114 115 void DoWriteCallback() { 116 if (write_callback_.is_null()) 117 return; 118 119 CompletionCallback callback = write_callback_; 120 write_callback_.Reset(); 121 callback.Run(net::ERR_CONNECTION_RESET); 122 } 123 124 int PropogateData(scoped_refptr<net::IOBuffer> read_buf, int read_buf_len) { 125 scoped_refptr<net::DrainableIOBuffer> buf = data_.front(); 126 int copied = std::min(buf->BytesRemaining(), read_buf_len); 127 memcpy(read_buf->data(), buf->data(), copied); 128 buf->DidConsume(copied); 129 130 if (!buf->BytesRemaining()) 131 data_.pop(); 132 return copied; 133 } 134 135 CompletionCallback read_callback_; 136 scoped_refptr<net::IOBuffer> read_buf_; 137 int read_buf_len_; 138 139 CompletionCallback write_callback_; 140 141 std::queue<scoped_refptr<net::DrainableIOBuffer> > data_; 142 143 // True if Close() has been called. 144 bool closed_; 145 146 // Controls the completion of Write() after the FakeDataChannel is closed. 147 // After the FakeDataChannel is closed, the first Write() call completes 148 // asynchronously. 149 bool write_called_after_close_; 150 151 base::WeakPtrFactory<FakeDataChannel> weak_factory_; 152 153 DISALLOW_COPY_AND_ASSIGN(FakeDataChannel); 154}; 155 156class FakeSocket : public StreamSocket { 157 public: 158 FakeSocket(FakeDataChannel* incoming_channel, 159 FakeDataChannel* outgoing_channel) 160 : incoming_(incoming_channel), 161 outgoing_(outgoing_channel) { 162 } 163 164 virtual ~FakeSocket() { 165 } 166 167 virtual int Read(IOBuffer* buf, int buf_len, 168 const CompletionCallback& callback) OVERRIDE { 169 // Read random number of bytes. 170 buf_len = rand() % buf_len + 1; 171 return incoming_->Read(buf, buf_len, callback); 172 } 173 174 virtual int Write(IOBuffer* buf, int buf_len, 175 const CompletionCallback& callback) OVERRIDE { 176 // Write random number of bytes. 177 buf_len = rand() % buf_len + 1; 178 return outgoing_->Write(buf, buf_len, callback); 179 } 180 181 virtual int SetReceiveBufferSize(int32 size) OVERRIDE { 182 return net::OK; 183 } 184 185 virtual int SetSendBufferSize(int32 size) OVERRIDE { 186 return net::OK; 187 } 188 189 virtual int Connect(const CompletionCallback& callback) OVERRIDE { 190 return net::OK; 191 } 192 193 virtual void Disconnect() OVERRIDE { 194 incoming_->Close(); 195 outgoing_->Close(); 196 } 197 198 virtual bool IsConnected() const OVERRIDE { 199 return true; 200 } 201 202 virtual bool IsConnectedAndIdle() const OVERRIDE { 203 return true; 204 } 205 206 virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { 207 net::IPAddressNumber ip_address(net::kIPv4AddressSize); 208 *address = net::IPEndPoint(ip_address, 0 /*port*/); 209 return net::OK; 210 } 211 212 virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { 213 net::IPAddressNumber ip_address(4); 214 *address = net::IPEndPoint(ip_address, 0); 215 return net::OK; 216 } 217 218 virtual const BoundNetLog& NetLog() const OVERRIDE { 219 return net_log_; 220 } 221 222 virtual void SetSubresourceSpeculation() OVERRIDE {} 223 virtual void SetOmniboxSpeculation() OVERRIDE {} 224 225 virtual bool WasEverUsed() const OVERRIDE { 226 return true; 227 } 228 229 virtual bool UsingTCPFastOpen() const OVERRIDE { 230 return false; 231 } 232 233 234 virtual bool WasNpnNegotiated() const OVERRIDE { 235 return false; 236 } 237 238 virtual NextProto GetNegotiatedProtocol() const OVERRIDE { 239 return kProtoUnknown; 240 } 241 242 virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { 243 return false; 244 } 245 246 private: 247 net::BoundNetLog net_log_; 248 FakeDataChannel* incoming_; 249 FakeDataChannel* outgoing_; 250 251 DISALLOW_COPY_AND_ASSIGN(FakeSocket); 252}; 253 254} // namespace 255 256// Verify the correctness of the test helper classes first. 257TEST(FakeSocketTest, DataTransfer) { 258 // Establish channels between two sockets. 259 FakeDataChannel channel_1; 260 FakeDataChannel channel_2; 261 FakeSocket client(&channel_1, &channel_2); 262 FakeSocket server(&channel_2, &channel_1); 263 264 const char kTestData[] = "testing123"; 265 const int kTestDataSize = strlen(kTestData); 266 const int kReadBufSize = 1024; 267 scoped_refptr<net::IOBuffer> write_buf = new net::StringIOBuffer(kTestData); 268 scoped_refptr<net::IOBuffer> read_buf = new net::IOBuffer(kReadBufSize); 269 270 // Write then read. 271 int written = 272 server.Write(write_buf.get(), kTestDataSize, CompletionCallback()); 273 EXPECT_GT(written, 0); 274 EXPECT_LE(written, kTestDataSize); 275 276 int read = client.Read(read_buf.get(), kReadBufSize, CompletionCallback()); 277 EXPECT_GT(read, 0); 278 EXPECT_LE(read, written); 279 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read)); 280 281 // Read then write. 282 TestCompletionCallback callback; 283 EXPECT_EQ(net::ERR_IO_PENDING, 284 server.Read(read_buf.get(), kReadBufSize, callback.callback())); 285 286 written = client.Write(write_buf.get(), kTestDataSize, CompletionCallback()); 287 EXPECT_GT(written, 0); 288 EXPECT_LE(written, kTestDataSize); 289 290 read = callback.WaitForResult(); 291 EXPECT_GT(read, 0); 292 EXPECT_LE(read, written); 293 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read)); 294} 295 296class SSLServerSocketTest : public PlatformTest { 297 public: 298 SSLServerSocketTest() 299 : socket_factory_(net::ClientSocketFactory::GetDefaultFactory()), 300 cert_verifier_(new MockCertVerifier()), 301 transport_security_state_(new TransportSecurityState) { 302 cert_verifier_->set_default_result(net::CERT_STATUS_AUTHORITY_INVALID); 303 } 304 305 protected: 306 void Initialize() { 307 scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle); 308 client_connection->SetSocket( 309 scoped_ptr<StreamSocket>(new FakeSocket(&channel_1_, &channel_2_))); 310 scoped_ptr<StreamSocket> server_socket( 311 new FakeSocket(&channel_2_, &channel_1_)); 312 313 base::FilePath certs_dir(GetTestCertsDirectory()); 314 315 base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); 316 std::string cert_der; 317 ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der)); 318 319 scoped_refptr<net::X509Certificate> cert = 320 X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size()); 321 322 base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); 323 std::string key_string; 324 ASSERT_TRUE(base::ReadFileToString(key_path, &key_string)); 325 std::vector<uint8> key_vector( 326 reinterpret_cast<const uint8*>(key_string.data()), 327 reinterpret_cast<const uint8*>(key_string.data() + 328 key_string.length())); 329 330 scoped_ptr<crypto::RSAPrivateKey> private_key( 331 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); 332 333 net::SSLConfig ssl_config; 334 ssl_config.false_start_enabled = false; 335 ssl_config.channel_id_enabled = false; 336 337 // Certificate provided by the host doesn't need authority. 338 net::SSLConfig::CertAndStatus cert_and_status; 339 cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID; 340 cert_and_status.der_cert = cert_der; 341 ssl_config.allowed_bad_certs.push_back(cert_and_status); 342 343 net::HostPortPair host_and_pair("unittest", 0); 344 net::SSLClientSocketContext context; 345 context.cert_verifier = cert_verifier_.get(); 346 context.transport_security_state = transport_security_state_.get(); 347 client_socket_ = 348 socket_factory_->CreateSSLClientSocket( 349 client_connection.Pass(), host_and_pair, ssl_config, context); 350 server_socket_ = net::CreateSSLServerSocket( 351 server_socket.Pass(), 352 cert.get(), private_key.get(), net::SSLConfig()); 353 } 354 355 FakeDataChannel channel_1_; 356 FakeDataChannel channel_2_; 357 scoped_ptr<net::SSLClientSocket> client_socket_; 358 scoped_ptr<net::SSLServerSocket> server_socket_; 359 net::ClientSocketFactory* socket_factory_; 360 scoped_ptr<net::MockCertVerifier> cert_verifier_; 361 scoped_ptr<net::TransportSecurityState> transport_security_state_; 362}; 363 364// This test only executes creation of client and server sockets. This is to 365// test that creation of sockets doesn't crash and have minimal code to run 366// under valgrind in order to help debugging memory problems. 367TEST_F(SSLServerSocketTest, Initialize) { 368 Initialize(); 369} 370 371// This test executes Connect() on SSLClientSocket and Handshake() on 372// SSLServerSocket to make sure handshaking between the two sockets is 373// completed successfully. 374TEST_F(SSLServerSocketTest, Handshake) { 375 Initialize(); 376 377 TestCompletionCallback connect_callback; 378 TestCompletionCallback handshake_callback; 379 380 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 381 EXPECT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 382 383 int client_ret = client_socket_->Connect(connect_callback.callback()); 384 EXPECT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 385 386 if (client_ret == net::ERR_IO_PENDING) { 387 EXPECT_EQ(net::OK, connect_callback.WaitForResult()); 388 } 389 if (server_ret == net::ERR_IO_PENDING) { 390 EXPECT_EQ(net::OK, handshake_callback.WaitForResult()); 391 } 392 393 // Make sure the cert status is expected. 394 SSLInfo ssl_info; 395 client_socket_->GetSSLInfo(&ssl_info); 396 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status); 397} 398 399TEST_F(SSLServerSocketTest, DataTransfer) { 400 Initialize(); 401 402 TestCompletionCallback connect_callback; 403 TestCompletionCallback handshake_callback; 404 405 // Establish connection. 406 int client_ret = client_socket_->Connect(connect_callback.callback()); 407 ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 408 409 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 410 ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 411 412 client_ret = connect_callback.GetResult(client_ret); 413 ASSERT_EQ(net::OK, client_ret); 414 server_ret = handshake_callback.GetResult(server_ret); 415 ASSERT_EQ(net::OK, server_ret); 416 417 const int kReadBufSize = 1024; 418 scoped_refptr<net::StringIOBuffer> write_buf = 419 new net::StringIOBuffer("testing123"); 420 scoped_refptr<net::DrainableIOBuffer> read_buf = 421 new net::DrainableIOBuffer(new net::IOBuffer(kReadBufSize), 422 kReadBufSize); 423 424 // Write then read. 425 TestCompletionCallback write_callback; 426 TestCompletionCallback read_callback; 427 server_ret = server_socket_->Write( 428 write_buf.get(), write_buf->size(), write_callback.callback()); 429 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 430 client_ret = client_socket_->Read( 431 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 432 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 433 434 server_ret = write_callback.GetResult(server_ret); 435 EXPECT_GT(server_ret, 0); 436 client_ret = read_callback.GetResult(client_ret); 437 ASSERT_GT(client_ret, 0); 438 439 read_buf->DidConsume(client_ret); 440 while (read_buf->BytesConsumed() < write_buf->size()) { 441 client_ret = client_socket_->Read( 442 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 443 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 444 client_ret = read_callback.GetResult(client_ret); 445 ASSERT_GT(client_ret, 0); 446 read_buf->DidConsume(client_ret); 447 } 448 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed()); 449 read_buf->SetOffset(0); 450 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); 451 452 // Read then write. 453 write_buf = new net::StringIOBuffer("hello123"); 454 server_ret = server_socket_->Read( 455 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 456 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 457 client_ret = client_socket_->Write( 458 write_buf.get(), write_buf->size(), write_callback.callback()); 459 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 460 461 server_ret = read_callback.GetResult(server_ret); 462 ASSERT_GT(server_ret, 0); 463 client_ret = write_callback.GetResult(client_ret); 464 EXPECT_GT(client_ret, 0); 465 466 read_buf->DidConsume(server_ret); 467 while (read_buf->BytesConsumed() < write_buf->size()) { 468 server_ret = server_socket_->Read( 469 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 470 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 471 server_ret = read_callback.GetResult(server_ret); 472 ASSERT_GT(server_ret, 0); 473 read_buf->DidConsume(server_ret); 474 } 475 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed()); 476 read_buf->SetOffset(0); 477 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); 478} 479 480// Flaky on Android: http://crbug.com/381147 481#if defined(OS_ANDROID) 482#define MAYBE_ClientWriteAfterServerClose DISABLED_ClientWriteAfterServerClose 483#else 484#define MAYBE_ClientWriteAfterServerClose ClientWriteAfterServerClose 485#endif 486// A regression test for bug 127822 (http://crbug.com/127822). 487// If the server closes the connection after the handshake is finished, 488// the client's Write() call should not cause an infinite loop. 489// NOTE: this is a test for SSLClientSocket rather than SSLServerSocket. 490TEST_F(SSLServerSocketTest, MAYBE_ClientWriteAfterServerClose) { 491 Initialize(); 492 493 TestCompletionCallback connect_callback; 494 TestCompletionCallback handshake_callback; 495 496 // Establish connection. 497 int client_ret = client_socket_->Connect(connect_callback.callback()); 498 ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 499 500 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 501 ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 502 503 client_ret = connect_callback.GetResult(client_ret); 504 ASSERT_EQ(net::OK, client_ret); 505 server_ret = handshake_callback.GetResult(server_ret); 506 ASSERT_EQ(net::OK, server_ret); 507 508 scoped_refptr<net::StringIOBuffer> write_buf = 509 new net::StringIOBuffer("testing123"); 510 511 // The server closes the connection. The server needs to write some 512 // data first so that the client's Read() calls from the transport 513 // socket won't return ERR_IO_PENDING. This ensures that the client 514 // will call Read() on the transport socket again. 515 TestCompletionCallback write_callback; 516 517 server_ret = server_socket_->Write( 518 write_buf.get(), write_buf->size(), write_callback.callback()); 519 EXPECT_TRUE(server_ret > 0 || server_ret == net::ERR_IO_PENDING); 520 521 server_ret = write_callback.GetResult(server_ret); 522 EXPECT_GT(server_ret, 0); 523 524 server_socket_->Disconnect(); 525 526 // The client writes some data. This should not cause an infinite loop. 527 client_ret = client_socket_->Write( 528 write_buf.get(), write_buf->size(), write_callback.callback()); 529 EXPECT_TRUE(client_ret > 0 || client_ret == net::ERR_IO_PENDING); 530 531 client_ret = write_callback.GetResult(client_ret); 532 EXPECT_GT(client_ret, 0); 533 534 base::MessageLoop::current()->PostDelayedTask( 535 FROM_HERE, base::MessageLoop::QuitClosure(), 536 base::TimeDelta::FromMilliseconds(10)); 537 base::MessageLoop::current()->Run(); 538} 539 540// This test executes ExportKeyingMaterial() on the client and server sockets, 541// after connecting them, and verifies that the results match. 542// This test will fail if False Start is enabled (see crbug.com/90208). 543TEST_F(SSLServerSocketTest, ExportKeyingMaterial) { 544 Initialize(); 545 546 TestCompletionCallback connect_callback; 547 TestCompletionCallback handshake_callback; 548 549 int client_ret = client_socket_->Connect(connect_callback.callback()); 550 ASSERT_TRUE(client_ret == net::OK || client_ret == net::ERR_IO_PENDING); 551 552 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 553 ASSERT_TRUE(server_ret == net::OK || server_ret == net::ERR_IO_PENDING); 554 555 if (client_ret == net::ERR_IO_PENDING) { 556 ASSERT_EQ(net::OK, connect_callback.WaitForResult()); 557 } 558 if (server_ret == net::ERR_IO_PENDING) { 559 ASSERT_EQ(net::OK, handshake_callback.WaitForResult()); 560 } 561 562 const int kKeyingMaterialSize = 32; 563 const char* kKeyingLabel = "EXPERIMENTAL-server-socket-test"; 564 const char* kKeyingContext = ""; 565 unsigned char server_out[kKeyingMaterialSize]; 566 int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel, 567 false, kKeyingContext, 568 server_out, sizeof(server_out)); 569 ASSERT_EQ(net::OK, rv); 570 571 unsigned char client_out[kKeyingMaterialSize]; 572 rv = client_socket_->ExportKeyingMaterial(kKeyingLabel, 573 false, kKeyingContext, 574 client_out, sizeof(client_out)); 575 ASSERT_EQ(net::OK, rv); 576 EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out))); 577 578 const char* kKeyingLabelBad = "EXPERIMENTAL-server-socket-test-bad"; 579 unsigned char client_bad[kKeyingMaterialSize]; 580 rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad, 581 false, kKeyingContext, 582 client_bad, sizeof(client_bad)); 583 ASSERT_EQ(rv, net::OK); 584 EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out))); 585} 586 587} // namespace net 588