1// Copyright 2014 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "extensions/browser/api/socket/tls_socket.h" 6 7#include <deque> 8#include <utility> 9 10#include "base/memory/scoped_ptr.h" 11#include "base/strings/string_piece.h" 12#include "net/base/address_list.h" 13#include "net/base/completion_callback.h" 14#include "net/base/io_buffer.h" 15#include "net/base/net_errors.h" 16#include "net/base/rand_callback.h" 17#include "net/socket/ssl_client_socket.h" 18#include "net/socket/tcp_client_socket.h" 19#include "testing/gmock/include/gmock/gmock.h" 20 21using testing::_; 22using testing::DoAll; 23using testing::Invoke; 24using testing::Gt; 25using testing::Return; 26using testing::SaveArg; 27using testing::WithArgs; 28using base::StringPiece; 29 30namespace extensions { 31 32class MockSSLClientSocket : public net::SSLClientSocket { 33 public: 34 MockSSLClientSocket() {} 35 MOCK_METHOD0(Disconnect, void()); 36 MOCK_METHOD3(Read, 37 int(net::IOBuffer* buf, 38 int buf_len, 39 const net::CompletionCallback& callback)); 40 MOCK_METHOD3(Write, 41 int(net::IOBuffer* buf, 42 int buf_len, 43 const net::CompletionCallback& callback)); 44 MOCK_METHOD1(SetReceiveBufferSize, int(int32)); 45 MOCK_METHOD1(SetSendBufferSize, int(int32)); 46 MOCK_METHOD1(Connect, int(const CompletionCallback&)); 47 MOCK_CONST_METHOD0(IsConnectedAndIdle, bool()); 48 MOCK_CONST_METHOD1(GetPeerAddress, int(net::IPEndPoint*)); 49 MOCK_CONST_METHOD1(GetLocalAddress, int(net::IPEndPoint*)); 50 MOCK_CONST_METHOD0(NetLog, const net::BoundNetLog&()); 51 MOCK_METHOD0(SetSubresourceSpeculation, void()); 52 MOCK_METHOD0(SetOmniboxSpeculation, void()); 53 MOCK_CONST_METHOD0(WasEverUsed, bool()); 54 MOCK_CONST_METHOD0(UsingTCPFastOpen, bool()); 55 MOCK_METHOD1(GetSSLInfo, bool(net::SSLInfo*)); 56 MOCK_METHOD5(ExportKeyingMaterial, 57 int(const StringPiece&, 58 bool, 59 const StringPiece&, 60 unsigned char*, 61 unsigned int)); 62 MOCK_METHOD1(GetTLSUniqueChannelBinding, int(std::string*)); 63 MOCK_CONST_METHOD0(GetSessionCacheKey, std::string()); 64 MOCK_CONST_METHOD0(InSessionCache, bool()); 65 MOCK_METHOD1(SetHandshakeCompletionCallback, void(const base::Closure&)); 66 MOCK_METHOD1(GetSSLCertRequestInfo, void(net::SSLCertRequestInfo*)); 67 MOCK_METHOD1(GetNextProto, 68 net::SSLClientSocket::NextProtoStatus(std::string*)); 69 MOCK_CONST_METHOD0(GetUnverifiedServerCertificateChain, 70 scoped_refptr<net::X509Certificate>()); 71 MOCK_CONST_METHOD0(GetChannelIDService, net::ChannelIDService*()); 72 virtual bool IsConnected() const OVERRIDE { return true; } 73 74 private: 75 DISALLOW_COPY_AND_ASSIGN(MockSSLClientSocket); 76}; 77 78class MockTCPSocket : public net::TCPClientSocket { 79 public: 80 explicit MockTCPSocket(const net::AddressList& address_list) 81 : net::TCPClientSocket(address_list, NULL, net::NetLog::Source()) {} 82 83 MOCK_METHOD3(Read, 84 int(net::IOBuffer* buf, 85 int buf_len, 86 const net::CompletionCallback& callback)); 87 MOCK_METHOD3(Write, 88 int(net::IOBuffer* buf, 89 int buf_len, 90 const net::CompletionCallback& callback)); 91 MOCK_METHOD2(SetKeepAlive, bool(bool enable, int delay)); 92 MOCK_METHOD1(SetNoDelay, bool(bool no_delay)); 93 94 virtual bool IsConnected() const OVERRIDE { return true; } 95 96 private: 97 DISALLOW_COPY_AND_ASSIGN(MockTCPSocket); 98}; 99 100class CompleteHandler { 101 public: 102 CompleteHandler() {} 103 MOCK_METHOD1(OnComplete, void(int result_code)); 104 MOCK_METHOD2(OnReadComplete, 105 void(int result_code, scoped_refptr<net::IOBuffer> io_buffer)); 106 MOCK_METHOD2(OnAccept, void(int, net::TCPClientSocket*)); 107 108 private: 109 DISALLOW_COPY_AND_ASSIGN(CompleteHandler); 110}; 111 112class TLSSocketTest : public ::testing::Test { 113 public: 114 TLSSocketTest() {} 115 116 virtual void SetUp() { 117 net::AddressList address_list; 118 // |ssl_socket_| is owned by |socket_|. TLSSocketTest keeps a pointer to 119 // it to expect invocations from TLSSocket to |ssl_socket_|. 120 scoped_ptr<MockSSLClientSocket> ssl_sock(new MockSSLClientSocket); 121 ssl_socket_ = ssl_sock.get(); 122 socket_.reset(new TLSSocket(ssl_sock.PassAs<net::StreamSocket>(), 123 "test_extension_id")); 124 EXPECT_CALL(*ssl_socket_, Disconnect()).Times(1); 125 }; 126 127 virtual void TearDown() { 128 ssl_socket_ = NULL; 129 socket_.reset(); 130 }; 131 132 protected: 133 MockSSLClientSocket* ssl_socket_; 134 scoped_ptr<TLSSocket> socket_; 135}; 136 137// Verify that a Read() on TLSSocket will pass through into a Read() on 138// |ssl_socket_| and invoke its completion callback. 139TEST_F(TLSSocketTest, TestTLSSocketRead) { 140 CompleteHandler handler; 141 142 EXPECT_CALL(*ssl_socket_, Read(_, _, _)).Times(1); 143 EXPECT_CALL(handler, OnReadComplete(_, _)).Times(1); 144 145 const int count = 512; 146 socket_->Read( 147 count, 148 base::Bind(&CompleteHandler::OnReadComplete, base::Unretained(&handler))); 149} 150 151// Verify that a Write() on a TLSSocket will pass through to Write() 152// invocations on |ssl_socket_|, handling partial writes correctly, and calls 153// the completion callback correctly. 154TEST_F(TLSSocketTest, TestTLSSocketWrite) { 155 CompleteHandler handler; 156 net::CompletionCallback callback; 157 158 EXPECT_CALL(*ssl_socket_, Write(_, _, _)).Times(2).WillRepeatedly( 159 DoAll(SaveArg<2>(&callback), Return(128))); 160 EXPECT_CALL(handler, OnComplete(_)).Times(1); 161 162 scoped_refptr<net::IOBufferWithSize> io_buffer( 163 new net::IOBufferWithSize(256)); 164 socket_->Write( 165 io_buffer.get(), 166 io_buffer->size(), 167 base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler))); 168} 169 170// Simulate a blocked Write, and verify that, when simulating the Write going 171// through, the callback gets invoked. 172TEST_F(TLSSocketTest, TestTLSSocketBlockedWrite) { 173 CompleteHandler handler; 174 net::CompletionCallback callback; 175 176 // Return ERR_IO_PENDING to say the Write()'s blocked. Save the |callback| 177 // Write()'s passed. 178 EXPECT_CALL(*ssl_socket_, Write(_, _, _)).Times(2).WillRepeatedly( 179 DoAll(SaveArg<2>(&callback), Return(net::ERR_IO_PENDING))); 180 181 scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(42)); 182 socket_->Write( 183 io_buffer.get(), 184 io_buffer->size(), 185 base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler))); 186 187 // After the simulated asynchronous writes come back (via calls to 188 // callback.Run()), hander's OnComplete() should get invoked with the total 189 // amount written. 190 EXPECT_CALL(handler, OnComplete(42)).Times(1); 191 callback.Run(40); 192 callback.Run(2); 193} 194 195// Simulate multiple blocked Write()s. 196TEST_F(TLSSocketTest, TestTLSSocketBlockedWriteReentry) { 197 const int kNumIOs = 5; 198 CompleteHandler handlers[kNumIOs]; 199 net::CompletionCallback callback; 200 scoped_refptr<net::IOBufferWithSize> io_buffers[kNumIOs]; 201 202 // The implementation of TLSSocket::Write() is inherited from 203 // Socket::Write(), which implements an internal write queue that wraps 204 // TLSSocket::WriteImpl(). Each call from TLSSocket::WriteImpl() will invoke 205 // |ssl_socket_|'s Write() (mocked here). Save the |callback| (assume they 206 // will all be equivalent), and return ERR_IO_PENDING, to indicate a blocked 207 // request. The mocked SSLClientSocket::Write() will get one request per 208 // TLSSocket::Write() request invoked on |socket_| below. 209 EXPECT_CALL(*ssl_socket_, Write(_, _, _)).Times(kNumIOs).WillRepeatedly( 210 DoAll(SaveArg<2>(&callback), Return(net::ERR_IO_PENDING))); 211 212 // Send out |kNuMIOs| requests, each with a different size. 213 for (int i = 0; i < kNumIOs; i++) { 214 io_buffers[i] = new net::IOBufferWithSize(128 + i * 50); 215 socket_->Write(io_buffers[i].get(), 216 io_buffers[i]->size(), 217 base::Bind(&CompleteHandler::OnComplete, 218 base::Unretained(&handlers[i]))); 219 220 // Set up expectations on all |kNumIOs| handlers. 221 EXPECT_CALL(handlers[i], OnComplete(io_buffers[i]->size())).Times(1); 222 } 223 224 // Finish each pending I/O. This should satisfy the expectations on the 225 // handlers. 226 for (int i = 0; i < kNumIOs; i++) { 227 callback.Run(128 + i * 50); 228 } 229} 230 231typedef std::pair<net::CompletionCallback, int> PendingCallback; 232 233class CallbackList : public std::deque<PendingCallback> { 234 public: 235 void append(const net::CompletionCallback& cb, int arg) { 236 push_back(std::make_pair(cb, arg)); 237 } 238}; 239 240// Simulate Write()s above and below a SSLClientSocket size limit. 241TEST_F(TLSSocketTest, TestTLSSocketLargeWrites) { 242 const int kSizeIncrement = 4096; 243 const int kNumIncrements = 10; 244 const int kFragmentIncrement = 4; 245 const int kSizeLimit = kSizeIncrement * kFragmentIncrement; 246 net::CompletionCallback callback; 247 CompleteHandler handler; 248 scoped_refptr<net::IOBufferWithSize> io_buffers[kNumIncrements]; 249 CallbackList pending_callbacks; 250 size_t total_bytes_requested = 0; 251 size_t total_bytes_written = 0; 252 253 // Some implementations of SSLClientSocket may have write-size limits (e.g, 254 // max 1 TLS record, which is 16k). This test mocks a size limit at 255 // |kSizeIncrement| and calls Write() above and below that limit. It 256 // simulates SSLClientSocket::Write() behavior in only writing up to the size 257 // limit, requiring additional calls for the remaining data to be sent. 258 // Socket::Write() (and supporting methods) execute the additional calls as 259 // needed. This test verifies that this inherited implementation does 260 // properly issue additional calls, and that the total amount returned from 261 // all mocked SSLClientSocket::Write() calls is the same as originally 262 // requested. 263 264 // |ssl_socket_|'s Write() will write at most |kSizeLimit| bytes. The 265 // inherited Socket::Write() will repeatedly call |ssl_socket_|'s Write() 266 // until the entire original request is sent. Socket::Write() will queue any 267 // additional write requests until the current request is complete. A 268 // request is complete when the callback passed to Socket::WriteImpl() is 269 // invoked with an argument equal to the original number of bytes requested 270 // from Socket::Write(). If the callback is invoked with a smaller number, 271 // Socket::WriteImpl() will get repeatedly invoked until the sum of the 272 // callbacks' arguments is equal to the original requested amount. 273 EXPECT_CALL(*ssl_socket_, Write(_, _, _)).WillRepeatedly( 274 DoAll(WithArgs<2, 1>(Invoke(&pending_callbacks, &CallbackList::append)), 275 Return(net::ERR_IO_PENDING))); 276 277 // Observe what comes back from Socket::Write() here. 278 EXPECT_CALL(handler, OnComplete(Gt(0))).Times(kNumIncrements); 279 280 // Send out |kNumIncrements| requests, each with a different size. The 281 // last request is the same size as the first, and the ones in the middle 282 // are monotonically increasing from the first. 283 for (int i = 0; i < kNumIncrements; i++) { 284 const bool last = i == (kNumIncrements - 1); 285 io_buffers[i] = new net::IOBufferWithSize(last ? kSizeIncrement 286 : kSizeIncrement * (i + 1)); 287 total_bytes_requested += io_buffers[i]->size(); 288 289 // Invoke Socket::Write(). This will invoke |ssl_socket_|'s Write(), which 290 // this test mocks out. That mocked Write() is in an asynchronous waiting 291 // state until the passed callback (saved in the EXPECT_CALL for 292 // |ssl_socket_|'s Write()) is invoked. 293 socket_->Write( 294 io_buffers[i].get(), 295 io_buffers[i]->size(), 296 base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler))); 297 } 298 299 // Invoke callbacks for pending I/Os. These can synchronously invoke more of 300 // |ssl_socket_|'s Write() as needed. The callback checks how much is left 301 // in the request, and then starts issuing any queued Socket::Write() 302 // invocations. 303 while (!pending_callbacks.empty()) { 304 PendingCallback cb = pending_callbacks.front(); 305 pending_callbacks.pop_front(); 306 307 int amount_written_invocation = std::min(kSizeLimit, cb.second); 308 total_bytes_written += amount_written_invocation; 309 cb.first.Run(amount_written_invocation); 310 } 311 312 ASSERT_EQ(total_bytes_requested, total_bytes_written) 313 << "There should be exactly as many bytes written as originally " 314 << "requested to Write()."; 315} 316 317} // namespace extensions 318