1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "extensions/browser/api/socket/tls_socket.h"
6
7#include <deque>
8#include <utility>
9
10#include "base/memory/scoped_ptr.h"
11#include "base/strings/string_piece.h"
12#include "net/base/address_list.h"
13#include "net/base/completion_callback.h"
14#include "net/base/io_buffer.h"
15#include "net/base/net_errors.h"
16#include "net/base/rand_callback.h"
17#include "net/socket/ssl_client_socket.h"
18#include "net/socket/tcp_client_socket.h"
19#include "testing/gmock/include/gmock/gmock.h"
20
21using testing::_;
22using testing::DoAll;
23using testing::Invoke;
24using testing::Gt;
25using testing::Return;
26using testing::SaveArg;
27using testing::WithArgs;
28using base::StringPiece;
29
30namespace extensions {
31
32class MockSSLClientSocket : public net::SSLClientSocket {
33 public:
34  MockSSLClientSocket() {}
35  MOCK_METHOD0(Disconnect, void());
36  MOCK_METHOD3(Read,
37               int(net::IOBuffer* buf,
38                   int buf_len,
39                   const net::CompletionCallback& callback));
40  MOCK_METHOD3(Write,
41               int(net::IOBuffer* buf,
42                   int buf_len,
43                   const net::CompletionCallback& callback));
44  MOCK_METHOD1(SetReceiveBufferSize, int(int32));
45  MOCK_METHOD1(SetSendBufferSize, int(int32));
46  MOCK_METHOD1(Connect, int(const CompletionCallback&));
47  MOCK_CONST_METHOD0(IsConnectedAndIdle, bool());
48  MOCK_CONST_METHOD1(GetPeerAddress, int(net::IPEndPoint*));
49  MOCK_CONST_METHOD1(GetLocalAddress, int(net::IPEndPoint*));
50  MOCK_CONST_METHOD0(NetLog, const net::BoundNetLog&());
51  MOCK_METHOD0(SetSubresourceSpeculation, void());
52  MOCK_METHOD0(SetOmniboxSpeculation, void());
53  MOCK_CONST_METHOD0(WasEverUsed, bool());
54  MOCK_CONST_METHOD0(UsingTCPFastOpen, bool());
55  MOCK_METHOD1(GetSSLInfo, bool(net::SSLInfo*));
56  MOCK_METHOD5(ExportKeyingMaterial,
57               int(const StringPiece&,
58                   bool,
59                   const StringPiece&,
60                   unsigned char*,
61                   unsigned int));
62  MOCK_METHOD1(GetTLSUniqueChannelBinding, int(std::string*));
63  MOCK_CONST_METHOD0(GetSessionCacheKey, std::string());
64  MOCK_CONST_METHOD0(InSessionCache, bool());
65  MOCK_METHOD1(SetHandshakeCompletionCallback, void(const base::Closure&));
66  MOCK_METHOD1(GetSSLCertRequestInfo, void(net::SSLCertRequestInfo*));
67  MOCK_METHOD1(GetNextProto,
68               net::SSLClientSocket::NextProtoStatus(std::string*));
69  MOCK_CONST_METHOD0(GetUnverifiedServerCertificateChain,
70                     scoped_refptr<net::X509Certificate>());
71  MOCK_CONST_METHOD0(GetChannelIDService, net::ChannelIDService*());
72  virtual bool IsConnected() const OVERRIDE { return true; }
73
74 private:
75  DISALLOW_COPY_AND_ASSIGN(MockSSLClientSocket);
76};
77
78class MockTCPSocket : public net::TCPClientSocket {
79 public:
80  explicit MockTCPSocket(const net::AddressList& address_list)
81      : net::TCPClientSocket(address_list, NULL, net::NetLog::Source()) {}
82
83  MOCK_METHOD3(Read,
84               int(net::IOBuffer* buf,
85                   int buf_len,
86                   const net::CompletionCallback& callback));
87  MOCK_METHOD3(Write,
88               int(net::IOBuffer* buf,
89                   int buf_len,
90                   const net::CompletionCallback& callback));
91  MOCK_METHOD2(SetKeepAlive, bool(bool enable, int delay));
92  MOCK_METHOD1(SetNoDelay, bool(bool no_delay));
93
94  virtual bool IsConnected() const OVERRIDE { return true; }
95
96 private:
97  DISALLOW_COPY_AND_ASSIGN(MockTCPSocket);
98};
99
100class CompleteHandler {
101 public:
102  CompleteHandler() {}
103  MOCK_METHOD1(OnComplete, void(int result_code));
104  MOCK_METHOD2(OnReadComplete,
105               void(int result_code, scoped_refptr<net::IOBuffer> io_buffer));
106  MOCK_METHOD2(OnAccept, void(int, net::TCPClientSocket*));
107
108 private:
109  DISALLOW_COPY_AND_ASSIGN(CompleteHandler);
110};
111
112class TLSSocketTest : public ::testing::Test {
113 public:
114  TLSSocketTest() {}
115
116  virtual void SetUp() {
117    net::AddressList address_list;
118    // |ssl_socket_| is owned by |socket_|. TLSSocketTest keeps a pointer to
119    // it to expect invocations from TLSSocket to |ssl_socket_|.
120    scoped_ptr<MockSSLClientSocket> ssl_sock(new MockSSLClientSocket);
121    ssl_socket_ = ssl_sock.get();
122    socket_.reset(new TLSSocket(ssl_sock.PassAs<net::StreamSocket>(),
123                                "test_extension_id"));
124    EXPECT_CALL(*ssl_socket_, Disconnect()).Times(1);
125  };
126
127  virtual void TearDown() {
128    ssl_socket_ = NULL;
129    socket_.reset();
130  };
131
132 protected:
133  MockSSLClientSocket* ssl_socket_;
134  scoped_ptr<TLSSocket> socket_;
135};
136
137// Verify that a Read() on TLSSocket will pass through into a Read() on
138// |ssl_socket_| and invoke its completion callback.
139TEST_F(TLSSocketTest, TestTLSSocketRead) {
140  CompleteHandler handler;
141
142  EXPECT_CALL(*ssl_socket_, Read(_, _, _)).Times(1);
143  EXPECT_CALL(handler, OnReadComplete(_, _)).Times(1);
144
145  const int count = 512;
146  socket_->Read(
147      count,
148      base::Bind(&CompleteHandler::OnReadComplete, base::Unretained(&handler)));
149}
150
151// Verify that a Write() on a TLSSocket will pass through to Write()
152// invocations on |ssl_socket_|, handling partial writes correctly, and calls
153// the completion callback correctly.
154TEST_F(TLSSocketTest, TestTLSSocketWrite) {
155  CompleteHandler handler;
156  net::CompletionCallback callback;
157
158  EXPECT_CALL(*ssl_socket_, Write(_, _, _)).Times(2).WillRepeatedly(
159      DoAll(SaveArg<2>(&callback), Return(128)));
160  EXPECT_CALL(handler, OnComplete(_)).Times(1);
161
162  scoped_refptr<net::IOBufferWithSize> io_buffer(
163      new net::IOBufferWithSize(256));
164  socket_->Write(
165      io_buffer.get(),
166      io_buffer->size(),
167      base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler)));
168}
169
170// Simulate a blocked Write, and verify that, when simulating the Write going
171// through, the callback gets invoked.
172TEST_F(TLSSocketTest, TestTLSSocketBlockedWrite) {
173  CompleteHandler handler;
174  net::CompletionCallback callback;
175
176  // Return ERR_IO_PENDING to say the Write()'s blocked. Save the |callback|
177  // Write()'s passed.
178  EXPECT_CALL(*ssl_socket_, Write(_, _, _)).Times(2).WillRepeatedly(
179      DoAll(SaveArg<2>(&callback), Return(net::ERR_IO_PENDING)));
180
181  scoped_refptr<net::IOBufferWithSize> io_buffer(new net::IOBufferWithSize(42));
182  socket_->Write(
183      io_buffer.get(),
184      io_buffer->size(),
185      base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler)));
186
187  // After the simulated asynchronous writes come back (via calls to
188  // callback.Run()), hander's OnComplete() should get invoked with the total
189  // amount written.
190  EXPECT_CALL(handler, OnComplete(42)).Times(1);
191  callback.Run(40);
192  callback.Run(2);
193}
194
195// Simulate multiple blocked Write()s.
196TEST_F(TLSSocketTest, TestTLSSocketBlockedWriteReentry) {
197  const int kNumIOs = 5;
198  CompleteHandler handlers[kNumIOs];
199  net::CompletionCallback callback;
200  scoped_refptr<net::IOBufferWithSize> io_buffers[kNumIOs];
201
202  // The implementation of TLSSocket::Write() is inherited from
203  // Socket::Write(), which implements an internal write queue that wraps
204  // TLSSocket::WriteImpl(). Each call from TLSSocket::WriteImpl() will invoke
205  // |ssl_socket_|'s Write() (mocked here). Save the |callback| (assume they
206  // will all be equivalent), and return ERR_IO_PENDING, to indicate a blocked
207  // request. The mocked SSLClientSocket::Write() will get one request per
208  // TLSSocket::Write() request invoked on |socket_| below.
209  EXPECT_CALL(*ssl_socket_, Write(_, _, _)).Times(kNumIOs).WillRepeatedly(
210      DoAll(SaveArg<2>(&callback), Return(net::ERR_IO_PENDING)));
211
212  // Send out |kNuMIOs| requests, each with a different size.
213  for (int i = 0; i < kNumIOs; i++) {
214    io_buffers[i] = new net::IOBufferWithSize(128 + i * 50);
215    socket_->Write(io_buffers[i].get(),
216                   io_buffers[i]->size(),
217                   base::Bind(&CompleteHandler::OnComplete,
218                              base::Unretained(&handlers[i])));
219
220    // Set up expectations on all |kNumIOs| handlers.
221    EXPECT_CALL(handlers[i], OnComplete(io_buffers[i]->size())).Times(1);
222  }
223
224  // Finish each pending I/O. This should satisfy the expectations on the
225  // handlers.
226  for (int i = 0; i < kNumIOs; i++) {
227    callback.Run(128 + i * 50);
228  }
229}
230
231typedef std::pair<net::CompletionCallback, int> PendingCallback;
232
233class CallbackList : public std::deque<PendingCallback> {
234 public:
235  void append(const net::CompletionCallback& cb, int arg) {
236    push_back(std::make_pair(cb, arg));
237  }
238};
239
240// Simulate Write()s above and below a SSLClientSocket size limit.
241TEST_F(TLSSocketTest, TestTLSSocketLargeWrites) {
242  const int kSizeIncrement = 4096;
243  const int kNumIncrements = 10;
244  const int kFragmentIncrement = 4;
245  const int kSizeLimit = kSizeIncrement * kFragmentIncrement;
246  net::CompletionCallback callback;
247  CompleteHandler handler;
248  scoped_refptr<net::IOBufferWithSize> io_buffers[kNumIncrements];
249  CallbackList pending_callbacks;
250  size_t total_bytes_requested = 0;
251  size_t total_bytes_written = 0;
252
253  // Some implementations of SSLClientSocket may have write-size limits (e.g,
254  // max 1 TLS record, which is 16k). This test mocks a size limit at
255  // |kSizeIncrement| and calls Write() above and below that limit. It
256  // simulates SSLClientSocket::Write() behavior in only writing up to the size
257  // limit, requiring additional calls for the remaining data to be sent.
258  // Socket::Write() (and supporting methods) execute the additional calls as
259  // needed. This test verifies that this inherited implementation does
260  // properly issue additional calls, and that the total amount returned from
261  // all mocked SSLClientSocket::Write() calls is the same as originally
262  // requested.
263
264  // |ssl_socket_|'s Write() will write at most |kSizeLimit| bytes. The
265  // inherited Socket::Write() will repeatedly call |ssl_socket_|'s Write()
266  // until the entire original request is sent. Socket::Write() will queue any
267  // additional write requests until the current request is complete. A
268  // request is complete when the callback passed to Socket::WriteImpl() is
269  // invoked with an argument equal to the original number of bytes requested
270  // from Socket::Write(). If the callback is invoked with a smaller number,
271  // Socket::WriteImpl() will get repeatedly invoked until the sum of the
272  // callbacks' arguments is equal to the original requested amount.
273  EXPECT_CALL(*ssl_socket_, Write(_, _, _)).WillRepeatedly(
274      DoAll(WithArgs<2, 1>(Invoke(&pending_callbacks, &CallbackList::append)),
275            Return(net::ERR_IO_PENDING)));
276
277  // Observe what comes back from Socket::Write() here.
278  EXPECT_CALL(handler, OnComplete(Gt(0))).Times(kNumIncrements);
279
280  // Send out |kNumIncrements| requests, each with a different size. The
281  // last request is the same size as the first, and the ones in the middle
282  // are monotonically increasing from the first.
283  for (int i = 0; i < kNumIncrements; i++) {
284    const bool last = i == (kNumIncrements - 1);
285    io_buffers[i] = new net::IOBufferWithSize(last ? kSizeIncrement
286                                                   : kSizeIncrement * (i + 1));
287    total_bytes_requested += io_buffers[i]->size();
288
289    // Invoke Socket::Write(). This will invoke |ssl_socket_|'s Write(), which
290    // this test mocks out. That mocked Write() is in an asynchronous waiting
291    // state until the passed callback (saved in the EXPECT_CALL for
292    // |ssl_socket_|'s Write()) is invoked.
293    socket_->Write(
294        io_buffers[i].get(),
295        io_buffers[i]->size(),
296        base::Bind(&CompleteHandler::OnComplete, base::Unretained(&handler)));
297  }
298
299  // Invoke callbacks for pending I/Os. These can synchronously invoke more of
300  // |ssl_socket_|'s Write() as needed. The callback checks how much is left
301  // in the request, and then starts issuing any queued Socket::Write()
302  // invocations.
303  while (!pending_callbacks.empty()) {
304    PendingCallback cb = pending_callbacks.front();
305    pending_callbacks.pop_front();
306
307    int amount_written_invocation = std::min(kSizeLimit, cb.second);
308    total_bytes_written += amount_written_invocation;
309    cb.first.Run(amount_written_invocation);
310  }
311
312  ASSERT_EQ(total_bytes_requested, total_bytes_written)
313      << "There should be exactly as many bytes written as originally "
314      << "requested to Write().";
315}
316
317}  // namespace extensions
318