1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5// This test suite uses SSLClientSocket to test the implementation of 6// SSLServerSocket. In order to establish connections between the sockets 7// we need two additional classes: 8// 1. FakeSocket 9// Connects SSL socket to FakeDataChannel. This class is just a stub. 10// 11// 2. FakeDataChannel 12// Implements the actual exchange of data between two FakeSockets. 13// 14// Implementations of these two classes are included in this file. 15 16#include "net/socket/ssl_server_socket.h" 17 18#include <stdlib.h> 19 20#include <queue> 21 22#include "base/compiler_specific.h" 23#include "base/files/file_path.h" 24#include "base/files/file_util.h" 25#include "base/message_loop/message_loop.h" 26#include "base/path_service.h" 27#include "crypto/nss_util.h" 28#include "crypto/rsa_private_key.h" 29#include "net/base/address_list.h" 30#include "net/base/completion_callback.h" 31#include "net/base/host_port_pair.h" 32#include "net/base/io_buffer.h" 33#include "net/base/ip_endpoint.h" 34#include "net/base/net_errors.h" 35#include "net/base/net_log.h" 36#include "net/base/test_data_directory.h" 37#include "net/cert/cert_status_flags.h" 38#include "net/cert/mock_cert_verifier.h" 39#include "net/cert/x509_certificate.h" 40#include "net/http/transport_security_state.h" 41#include "net/socket/client_socket_factory.h" 42#include "net/socket/socket_test_util.h" 43#include "net/socket/ssl_client_socket.h" 44#include "net/socket/stream_socket.h" 45#include "net/ssl/ssl_config_service.h" 46#include "net/ssl/ssl_info.h" 47#include "net/test/cert_test_util.h" 48#include "testing/gtest/include/gtest/gtest.h" 49#include "testing/platform_test.h" 50 51namespace net { 52 53namespace { 54 55class FakeDataChannel { 56 public: 57 FakeDataChannel() 58 : read_buf_len_(0), 59 closed_(false), 60 write_called_after_close_(false), 61 weak_factory_(this) { 62 } 63 64 int Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { 65 DCHECK(read_callback_.is_null()); 66 DCHECK(!read_buf_.get()); 67 if (closed_) 68 return 0; 69 if (data_.empty()) { 70 read_callback_ = callback; 71 read_buf_ = buf; 72 read_buf_len_ = buf_len; 73 return ERR_IO_PENDING; 74 } 75 return PropogateData(buf, buf_len); 76 } 77 78 int Write(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { 79 DCHECK(write_callback_.is_null()); 80 if (closed_) { 81 if (write_called_after_close_) 82 return ERR_CONNECTION_RESET; 83 write_called_after_close_ = true; 84 write_callback_ = callback; 85 base::MessageLoop::current()->PostTask( 86 FROM_HERE, base::Bind(&FakeDataChannel::DoWriteCallback, 87 weak_factory_.GetWeakPtr())); 88 return ERR_IO_PENDING; 89 } 90 // This function returns synchronously, so make a copy of the buffer. 91 data_.push(new DrainableIOBuffer( 92 new StringIOBuffer(std::string(buf->data(), buf_len)), 93 buf_len)); 94 base::MessageLoop::current()->PostTask( 95 FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback, 96 weak_factory_.GetWeakPtr())); 97 return buf_len; 98 } 99 100 // Closes the FakeDataChannel. After Close() is called, Read() returns 0, 101 // indicating EOF, and Write() fails with ERR_CONNECTION_RESET. Note that 102 // after the FakeDataChannel is closed, the first Write() call completes 103 // asynchronously, which is necessary to reproduce bug 127822. 104 void Close() { 105 closed_ = true; 106 } 107 108 private: 109 void DoReadCallback() { 110 if (read_callback_.is_null() || data_.empty()) 111 return; 112 113 int copied = PropogateData(read_buf_, read_buf_len_); 114 CompletionCallback callback = read_callback_; 115 read_callback_.Reset(); 116 read_buf_ = NULL; 117 read_buf_len_ = 0; 118 callback.Run(copied); 119 } 120 121 void DoWriteCallback() { 122 if (write_callback_.is_null()) 123 return; 124 125 CompletionCallback callback = write_callback_; 126 write_callback_.Reset(); 127 callback.Run(ERR_CONNECTION_RESET); 128 } 129 130 int PropogateData(scoped_refptr<IOBuffer> read_buf, int read_buf_len) { 131 scoped_refptr<DrainableIOBuffer> buf = data_.front(); 132 int copied = std::min(buf->BytesRemaining(), read_buf_len); 133 memcpy(read_buf->data(), buf->data(), copied); 134 buf->DidConsume(copied); 135 136 if (!buf->BytesRemaining()) 137 data_.pop(); 138 return copied; 139 } 140 141 CompletionCallback read_callback_; 142 scoped_refptr<IOBuffer> read_buf_; 143 int read_buf_len_; 144 145 CompletionCallback write_callback_; 146 147 std::queue<scoped_refptr<DrainableIOBuffer> > data_; 148 149 // True if Close() has been called. 150 bool closed_; 151 152 // Controls the completion of Write() after the FakeDataChannel is closed. 153 // After the FakeDataChannel is closed, the first Write() call completes 154 // asynchronously. 155 bool write_called_after_close_; 156 157 base::WeakPtrFactory<FakeDataChannel> weak_factory_; 158 159 DISALLOW_COPY_AND_ASSIGN(FakeDataChannel); 160}; 161 162class FakeSocket : public StreamSocket { 163 public: 164 FakeSocket(FakeDataChannel* incoming_channel, 165 FakeDataChannel* outgoing_channel) 166 : incoming_(incoming_channel), 167 outgoing_(outgoing_channel) { 168 } 169 170 virtual ~FakeSocket() { 171 } 172 173 virtual int Read(IOBuffer* buf, int buf_len, 174 const CompletionCallback& callback) OVERRIDE { 175 // Read random number of bytes. 176 buf_len = rand() % buf_len + 1; 177 return incoming_->Read(buf, buf_len, callback); 178 } 179 180 virtual int Write(IOBuffer* buf, int buf_len, 181 const CompletionCallback& callback) OVERRIDE { 182 // Write random number of bytes. 183 buf_len = rand() % buf_len + 1; 184 return outgoing_->Write(buf, buf_len, callback); 185 } 186 187 virtual int SetReceiveBufferSize(int32 size) OVERRIDE { 188 return OK; 189 } 190 191 virtual int SetSendBufferSize(int32 size) OVERRIDE { 192 return OK; 193 } 194 195 virtual int Connect(const CompletionCallback& callback) OVERRIDE { 196 return OK; 197 } 198 199 virtual void Disconnect() OVERRIDE { 200 incoming_->Close(); 201 outgoing_->Close(); 202 } 203 204 virtual bool IsConnected() const OVERRIDE { 205 return true; 206 } 207 208 virtual bool IsConnectedAndIdle() const OVERRIDE { 209 return true; 210 } 211 212 virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE { 213 IPAddressNumber ip_address(kIPv4AddressSize); 214 *address = IPEndPoint(ip_address, 0 /*port*/); 215 return OK; 216 } 217 218 virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE { 219 IPAddressNumber ip_address(4); 220 *address = IPEndPoint(ip_address, 0); 221 return OK; 222 } 223 224 virtual const BoundNetLog& NetLog() const OVERRIDE { 225 return net_log_; 226 } 227 228 virtual void SetSubresourceSpeculation() OVERRIDE {} 229 virtual void SetOmniboxSpeculation() OVERRIDE {} 230 231 virtual bool WasEverUsed() const OVERRIDE { 232 return true; 233 } 234 235 virtual bool UsingTCPFastOpen() const OVERRIDE { 236 return false; 237 } 238 239 240 virtual bool WasNpnNegotiated() const OVERRIDE { 241 return false; 242 } 243 244 virtual NextProto GetNegotiatedProtocol() const OVERRIDE { 245 return kProtoUnknown; 246 } 247 248 virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE { 249 return false; 250 } 251 252 private: 253 BoundNetLog net_log_; 254 FakeDataChannel* incoming_; 255 FakeDataChannel* outgoing_; 256 257 DISALLOW_COPY_AND_ASSIGN(FakeSocket); 258}; 259 260} // namespace 261 262// Verify the correctness of the test helper classes first. 263TEST(FakeSocketTest, DataTransfer) { 264 // Establish channels between two sockets. 265 FakeDataChannel channel_1; 266 FakeDataChannel channel_2; 267 FakeSocket client(&channel_1, &channel_2); 268 FakeSocket server(&channel_2, &channel_1); 269 270 const char kTestData[] = "testing123"; 271 const int kTestDataSize = strlen(kTestData); 272 const int kReadBufSize = 1024; 273 scoped_refptr<IOBuffer> write_buf = new StringIOBuffer(kTestData); 274 scoped_refptr<IOBuffer> read_buf = new IOBuffer(kReadBufSize); 275 276 // Write then read. 277 int written = 278 server.Write(write_buf.get(), kTestDataSize, CompletionCallback()); 279 EXPECT_GT(written, 0); 280 EXPECT_LE(written, kTestDataSize); 281 282 int read = client.Read(read_buf.get(), kReadBufSize, CompletionCallback()); 283 EXPECT_GT(read, 0); 284 EXPECT_LE(read, written); 285 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read)); 286 287 // Read then write. 288 TestCompletionCallback callback; 289 EXPECT_EQ(ERR_IO_PENDING, 290 server.Read(read_buf.get(), kReadBufSize, callback.callback())); 291 292 written = client.Write(write_buf.get(), kTestDataSize, CompletionCallback()); 293 EXPECT_GT(written, 0); 294 EXPECT_LE(written, kTestDataSize); 295 296 read = callback.WaitForResult(); 297 EXPECT_GT(read, 0); 298 EXPECT_LE(read, written); 299 EXPECT_EQ(0, memcmp(kTestData, read_buf->data(), read)); 300} 301 302class SSLServerSocketTest : public PlatformTest { 303 public: 304 SSLServerSocketTest() 305 : socket_factory_(ClientSocketFactory::GetDefaultFactory()), 306 cert_verifier_(new MockCertVerifier()), 307 transport_security_state_(new TransportSecurityState) { 308 cert_verifier_->set_default_result(CERT_STATUS_AUTHORITY_INVALID); 309 } 310 311 protected: 312 void Initialize() { 313 scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle); 314 client_connection->SetSocket( 315 scoped_ptr<StreamSocket>(new FakeSocket(&channel_1_, &channel_2_))); 316 scoped_ptr<StreamSocket> server_socket( 317 new FakeSocket(&channel_2_, &channel_1_)); 318 319 base::FilePath certs_dir(GetTestCertsDirectory()); 320 321 base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); 322 std::string cert_der; 323 ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der)); 324 325 scoped_refptr<X509Certificate> cert = 326 X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size()); 327 328 base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); 329 std::string key_string; 330 ASSERT_TRUE(base::ReadFileToString(key_path, &key_string)); 331 std::vector<uint8> key_vector( 332 reinterpret_cast<const uint8*>(key_string.data()), 333 reinterpret_cast<const uint8*>(key_string.data() + 334 key_string.length())); 335 336 scoped_ptr<crypto::RSAPrivateKey> private_key( 337 crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); 338 339 SSLConfig ssl_config; 340 ssl_config.false_start_enabled = false; 341 ssl_config.channel_id_enabled = false; 342 343 // Certificate provided by the host doesn't need authority. 344 SSLConfig::CertAndStatus cert_and_status; 345 cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID; 346 cert_and_status.der_cert = cert_der; 347 ssl_config.allowed_bad_certs.push_back(cert_and_status); 348 349 HostPortPair host_and_pair("unittest", 0); 350 SSLClientSocketContext context; 351 context.cert_verifier = cert_verifier_.get(); 352 context.transport_security_state = transport_security_state_.get(); 353 client_socket_ = 354 socket_factory_->CreateSSLClientSocket( 355 client_connection.Pass(), host_and_pair, ssl_config, context); 356 server_socket_ = CreateSSLServerSocket( 357 server_socket.Pass(), 358 cert.get(), private_key.get(), SSLConfig()); 359 } 360 361 FakeDataChannel channel_1_; 362 FakeDataChannel channel_2_; 363 scoped_ptr<SSLClientSocket> client_socket_; 364 scoped_ptr<SSLServerSocket> server_socket_; 365 ClientSocketFactory* socket_factory_; 366 scoped_ptr<MockCertVerifier> cert_verifier_; 367 scoped_ptr<TransportSecurityState> transport_security_state_; 368}; 369 370// This test only executes creation of client and server sockets. This is to 371// test that creation of sockets doesn't crash and have minimal code to run 372// under valgrind in order to help debugging memory problems. 373TEST_F(SSLServerSocketTest, Initialize) { 374 Initialize(); 375} 376 377// This test executes Connect() on SSLClientSocket and Handshake() on 378// SSLServerSocket to make sure handshaking between the two sockets is 379// completed successfully. 380TEST_F(SSLServerSocketTest, Handshake) { 381 Initialize(); 382 383 TestCompletionCallback connect_callback; 384 TestCompletionCallback handshake_callback; 385 386 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 387 EXPECT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); 388 389 int client_ret = client_socket_->Connect(connect_callback.callback()); 390 EXPECT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); 391 392 if (client_ret == ERR_IO_PENDING) { 393 EXPECT_EQ(OK, connect_callback.WaitForResult()); 394 } 395 if (server_ret == ERR_IO_PENDING) { 396 EXPECT_EQ(OK, handshake_callback.WaitForResult()); 397 } 398 399 // Make sure the cert status is expected. 400 SSLInfo ssl_info; 401 client_socket_->GetSSLInfo(&ssl_info); 402 EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status); 403} 404 405TEST_F(SSLServerSocketTest, DataTransfer) { 406 Initialize(); 407 408 TestCompletionCallback connect_callback; 409 TestCompletionCallback handshake_callback; 410 411 // Establish connection. 412 int client_ret = client_socket_->Connect(connect_callback.callback()); 413 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); 414 415 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 416 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); 417 418 client_ret = connect_callback.GetResult(client_ret); 419 ASSERT_EQ(OK, client_ret); 420 server_ret = handshake_callback.GetResult(server_ret); 421 ASSERT_EQ(OK, server_ret); 422 423 const int kReadBufSize = 1024; 424 scoped_refptr<StringIOBuffer> write_buf = 425 new StringIOBuffer("testing123"); 426 scoped_refptr<DrainableIOBuffer> read_buf = 427 new DrainableIOBuffer(new IOBuffer(kReadBufSize), kReadBufSize); 428 429 // Write then read. 430 TestCompletionCallback write_callback; 431 TestCompletionCallback read_callback; 432 server_ret = server_socket_->Write( 433 write_buf.get(), write_buf->size(), write_callback.callback()); 434 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING); 435 client_ret = client_socket_->Read( 436 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 437 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING); 438 439 server_ret = write_callback.GetResult(server_ret); 440 EXPECT_GT(server_ret, 0); 441 client_ret = read_callback.GetResult(client_ret); 442 ASSERT_GT(client_ret, 0); 443 444 read_buf->DidConsume(client_ret); 445 while (read_buf->BytesConsumed() < write_buf->size()) { 446 client_ret = client_socket_->Read( 447 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 448 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING); 449 client_ret = read_callback.GetResult(client_ret); 450 ASSERT_GT(client_ret, 0); 451 read_buf->DidConsume(client_ret); 452 } 453 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed()); 454 read_buf->SetOffset(0); 455 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); 456 457 // Read then write. 458 write_buf = new StringIOBuffer("hello123"); 459 server_ret = server_socket_->Read( 460 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 461 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING); 462 client_ret = client_socket_->Write( 463 write_buf.get(), write_buf->size(), write_callback.callback()); 464 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING); 465 466 server_ret = read_callback.GetResult(server_ret); 467 ASSERT_GT(server_ret, 0); 468 client_ret = write_callback.GetResult(client_ret); 469 EXPECT_GT(client_ret, 0); 470 471 read_buf->DidConsume(server_ret); 472 while (read_buf->BytesConsumed() < write_buf->size()) { 473 server_ret = server_socket_->Read( 474 read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); 475 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING); 476 server_ret = read_callback.GetResult(server_ret); 477 ASSERT_GT(server_ret, 0); 478 read_buf->DidConsume(server_ret); 479 } 480 EXPECT_EQ(write_buf->size(), read_buf->BytesConsumed()); 481 read_buf->SetOffset(0); 482 EXPECT_EQ(0, memcmp(write_buf->data(), read_buf->data(), write_buf->size())); 483} 484 485// A regression test for bug 127822 (http://crbug.com/127822). 486// If the server closes the connection after the handshake is finished, 487// the client's Write() call should not cause an infinite loop. 488// NOTE: this is a test for SSLClientSocket rather than SSLServerSocket. 489TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) { 490 Initialize(); 491 492 TestCompletionCallback connect_callback; 493 TestCompletionCallback handshake_callback; 494 495 // Establish connection. 496 int client_ret = client_socket_->Connect(connect_callback.callback()); 497 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); 498 499 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 500 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); 501 502 client_ret = connect_callback.GetResult(client_ret); 503 ASSERT_EQ(OK, client_ret); 504 server_ret = handshake_callback.GetResult(server_ret); 505 ASSERT_EQ(OK, server_ret); 506 507 scoped_refptr<StringIOBuffer> write_buf = new StringIOBuffer("testing123"); 508 509 // The server closes the connection. The server needs to write some 510 // data first so that the client's Read() calls from the transport 511 // socket won't return ERR_IO_PENDING. This ensures that the client 512 // will call Read() on the transport socket again. 513 TestCompletionCallback write_callback; 514 515 server_ret = server_socket_->Write( 516 write_buf.get(), write_buf->size(), write_callback.callback()); 517 EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING); 518 519 server_ret = write_callback.GetResult(server_ret); 520 EXPECT_GT(server_ret, 0); 521 522 server_socket_->Disconnect(); 523 524 // The client writes some data. This should not cause an infinite loop. 525 client_ret = client_socket_->Write( 526 write_buf.get(), write_buf->size(), write_callback.callback()); 527 EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING); 528 529 client_ret = write_callback.GetResult(client_ret); 530 EXPECT_GT(client_ret, 0); 531 532 base::MessageLoop::current()->PostDelayedTask( 533 FROM_HERE, base::MessageLoop::QuitClosure(), 534 base::TimeDelta::FromMilliseconds(10)); 535 base::MessageLoop::current()->Run(); 536} 537 538// This test executes ExportKeyingMaterial() on the client and server sockets, 539// after connecting them, and verifies that the results match. 540// This test will fail if False Start is enabled (see crbug.com/90208). 541TEST_F(SSLServerSocketTest, ExportKeyingMaterial) { 542 Initialize(); 543 544 TestCompletionCallback connect_callback; 545 TestCompletionCallback handshake_callback; 546 547 int client_ret = client_socket_->Connect(connect_callback.callback()); 548 ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); 549 550 int server_ret = server_socket_->Handshake(handshake_callback.callback()); 551 ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); 552 553 if (client_ret == ERR_IO_PENDING) { 554 ASSERT_EQ(OK, connect_callback.WaitForResult()); 555 } 556 if (server_ret == ERR_IO_PENDING) { 557 ASSERT_EQ(OK, handshake_callback.WaitForResult()); 558 } 559 560 const int kKeyingMaterialSize = 32; 561 const char* kKeyingLabel = "EXPERIMENTAL-server-socket-test"; 562 const char* kKeyingContext = ""; 563 unsigned char server_out[kKeyingMaterialSize]; 564 int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel, 565 false, kKeyingContext, 566 server_out, sizeof(server_out)); 567 ASSERT_EQ(OK, rv); 568 569 unsigned char client_out[kKeyingMaterialSize]; 570 rv = client_socket_->ExportKeyingMaterial(kKeyingLabel, 571 false, kKeyingContext, 572 client_out, sizeof(client_out)); 573 ASSERT_EQ(OK, rv); 574 EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out))); 575 576 const char* kKeyingLabelBad = "EXPERIMENTAL-server-socket-test-bad"; 577 unsigned char client_bad[kKeyingMaterialSize]; 578 rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad, 579 false, kKeyingContext, 580 client_bad, sizeof(client_bad)); 581 ASSERT_EQ(rv, OK); 582 EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out))); 583} 584 585} // namespace net 586