1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "jingle/glue/fake_ssl_client_socket.h"
6
7#include <algorithm>
8#include <vector>
9
10#include "base/basictypes.h"
11#include "base/memory/ref_counted.h"
12#include "base/memory/scoped_ptr.h"
13#include "base/message_loop/message_loop.h"
14#include "net/base/io_buffer.h"
15#include "net/base/net_log.h"
16#include "net/base/test_completion_callback.h"
17#include "net/socket/socket_test_util.h"
18#include "net/socket/stream_socket.h"
19#include "testing/gmock/include/gmock/gmock.h"
20#include "testing/gtest/include/gtest/gtest.h"
21
22namespace jingle_glue {
23
24namespace {
25
26using ::testing::Return;
27using ::testing::ReturnRef;
28
29// Used by RunUnsuccessfulHandshakeTestHelper.  Represents where in
30// the handshake step an error should be inserted.
31enum HandshakeErrorLocation {
32  CONNECT_ERROR,
33  SEND_CLIENT_HELLO_ERROR,
34  VERIFY_SERVER_HELLO_ERROR,
35};
36
37// Private error codes appended to the net::Error set.
38enum {
39  // An error representing a server hello that has been corrupted in
40  // transit.
41  ERR_MALFORMED_SERVER_HELLO = -15000,
42};
43
44// Used by PassThroughMethods test.
45class MockClientSocket : public net::StreamSocket {
46 public:
47  virtual ~MockClientSocket() {}
48
49  MOCK_METHOD3(Read, int(net::IOBuffer*, int,
50                         const net::CompletionCallback&));
51  MOCK_METHOD3(Write, int(net::IOBuffer*, int,
52                          const net::CompletionCallback&));
53  MOCK_METHOD1(SetReceiveBufferSize, int(int32));
54  MOCK_METHOD1(SetSendBufferSize, int(int32));
55  MOCK_METHOD1(Connect, int(const net::CompletionCallback&));
56  MOCK_METHOD0(Disconnect, void());
57  MOCK_CONST_METHOD0(IsConnected, bool());
58  MOCK_CONST_METHOD0(IsConnectedAndIdle, bool());
59  MOCK_CONST_METHOD1(GetPeerAddress, int(net::IPEndPoint*));
60  MOCK_CONST_METHOD1(GetLocalAddress, int(net::IPEndPoint*));
61  MOCK_CONST_METHOD0(NetLog, const net::BoundNetLog&());
62  MOCK_METHOD0(SetSubresourceSpeculation, void());
63  MOCK_METHOD0(SetOmniboxSpeculation, void());
64  MOCK_CONST_METHOD0(WasEverUsed, bool());
65  MOCK_CONST_METHOD0(UsingTCPFastOpen, bool());
66  MOCK_CONST_METHOD0(NumBytesRead, int64());
67  MOCK_CONST_METHOD0(GetConnectTimeMicros, base::TimeDelta());
68  MOCK_CONST_METHOD0(WasNpnNegotiated, bool());
69  MOCK_CONST_METHOD0(GetNegotiatedProtocol, net::NextProto());
70  MOCK_METHOD1(GetSSLInfo, bool(net::SSLInfo*));
71};
72
73// Break up |data| into a bunch of chunked MockReads/Writes and push
74// them onto |ops|.
75template <net::MockReadWriteType type>
76void AddChunkedOps(base::StringPiece data, size_t chunk_size, net::IoMode mode,
77                   std::vector<net::MockReadWrite<type> >* ops) {
78  DCHECK_GT(chunk_size, 0U);
79  size_t offset = 0;
80  while (offset < data.size()) {
81    size_t bounded_chunk_size = std::min(data.size() - offset, chunk_size);
82    ops->push_back(net::MockReadWrite<type>(mode, data.data() + offset,
83                                            bounded_chunk_size));
84    offset += bounded_chunk_size;
85  }
86}
87
88class FakeSSLClientSocketTest : public testing::Test {
89 protected:
90  FakeSSLClientSocketTest() {}
91
92  virtual ~FakeSSLClientSocketTest() {}
93
94  scoped_ptr<net::StreamSocket> MakeClientSocket() {
95    return mock_client_socket_factory_.CreateTransportClientSocket(
96        net::AddressList(), NULL, net::NetLog::Source());
97  }
98
99  void SetData(const net::MockConnect& mock_connect,
100               std::vector<net::MockRead>* reads,
101               std::vector<net::MockWrite>* writes) {
102    static_socket_data_provider_.reset(
103        new net::StaticSocketDataProvider(
104            reads->empty() ? NULL : &*reads->begin(), reads->size(),
105            writes->empty() ? NULL : &*writes->begin(), writes->size()));
106    static_socket_data_provider_->set_connect_data(mock_connect);
107    mock_client_socket_factory_.AddSocketDataProvider(
108        static_socket_data_provider_.get());
109  }
110
111  void ExpectStatus(
112      net::IoMode mode, int expected_status, int immediate_status,
113      net::TestCompletionCallback* test_completion_callback) {
114    if (mode == net::ASYNC) {
115      EXPECT_EQ(net::ERR_IO_PENDING, immediate_status);
116      int status = test_completion_callback->WaitForResult();
117      EXPECT_EQ(expected_status, status);
118    } else {
119      EXPECT_EQ(expected_status, immediate_status);
120    }
121  }
122
123  // Sets up the mock socket to generate a successful handshake
124  // (sliced up according to the parameters) and makes sure the
125  // FakeSSLClientSocket behaves as expected.
126  void RunSuccessfulHandshakeTest(
127      net::IoMode mode, size_t read_chunk_size, size_t write_chunk_size,
128      int num_resets) {
129    base::StringPiece ssl_client_hello =
130        FakeSSLClientSocket::GetSslClientHello();
131    base::StringPiece ssl_server_hello =
132        FakeSSLClientSocket::GetSslServerHello();
133
134    net::MockConnect mock_connect(mode, net::OK);
135    std::vector<net::MockRead> reads;
136    std::vector<net::MockWrite> writes;
137    static const char kReadTestData[] = "read test data";
138    static const char kWriteTestData[] = "write test data";
139    for (int i = 0; i < num_resets + 1; ++i) {
140      SCOPED_TRACE(i);
141      AddChunkedOps(ssl_server_hello, read_chunk_size, mode, &reads);
142      AddChunkedOps(ssl_client_hello, write_chunk_size, mode, &writes);
143      reads.push_back(
144          net::MockRead(mode, kReadTestData, arraysize(kReadTestData)));
145      writes.push_back(
146          net::MockWrite(mode, kWriteTestData, arraysize(kWriteTestData)));
147    }
148    SetData(mock_connect, &reads, &writes);
149
150    FakeSSLClientSocket fake_ssl_client_socket(MakeClientSocket());
151
152    for (int i = 0; i < num_resets + 1; ++i) {
153      SCOPED_TRACE(i);
154      net::TestCompletionCallback test_completion_callback;
155      int status = fake_ssl_client_socket.Connect(
156          test_completion_callback.callback());
157      if (mode == net::ASYNC) {
158        EXPECT_FALSE(fake_ssl_client_socket.IsConnected());
159      }
160      ExpectStatus(mode, net::OK, status, &test_completion_callback);
161      if (fake_ssl_client_socket.IsConnected()) {
162        int read_len = arraysize(kReadTestData);
163        int read_buf_len = 2 * read_len;
164        scoped_refptr<net::IOBuffer> read_buf(
165            new net::IOBuffer(read_buf_len));
166        int read_status = fake_ssl_client_socket.Read(
167            read_buf.get(), read_buf_len, test_completion_callback.callback());
168        ExpectStatus(mode, read_len, read_status, &test_completion_callback);
169
170        scoped_refptr<net::IOBuffer> write_buf(
171            new net::StringIOBuffer(kWriteTestData));
172        int write_status =
173            fake_ssl_client_socket.Write(write_buf.get(),
174                                         arraysize(kWriteTestData),
175                                         test_completion_callback.callback());
176        ExpectStatus(mode, arraysize(kWriteTestData), write_status,
177                     &test_completion_callback);
178      } else {
179        ADD_FAILURE();
180      }
181      fake_ssl_client_socket.Disconnect();
182      EXPECT_FALSE(fake_ssl_client_socket.IsConnected());
183    }
184  }
185
186  // Sets up the mock socket to generate an unsuccessful handshake
187  // FakeSSLClientSocket fails as expected.
188  void RunUnsuccessfulHandshakeTestHelper(
189      net::IoMode mode, int error, HandshakeErrorLocation location) {
190    DCHECK_NE(error, net::OK);
191    base::StringPiece ssl_client_hello =
192        FakeSSLClientSocket::GetSslClientHello();
193    base::StringPiece ssl_server_hello =
194        FakeSSLClientSocket::GetSslServerHello();
195
196    net::MockConnect mock_connect(mode, net::OK);
197    std::vector<net::MockRead> reads;
198    std::vector<net::MockWrite> writes;
199    const size_t kChunkSize = 1;
200    AddChunkedOps(ssl_server_hello, kChunkSize, mode, &reads);
201    AddChunkedOps(ssl_client_hello, kChunkSize, mode, &writes);
202    switch (location) {
203      case CONNECT_ERROR:
204        mock_connect.result = error;
205        writes.clear();
206        reads.clear();
207        break;
208      case SEND_CLIENT_HELLO_ERROR: {
209        // Use a fixed index for repeatability.
210        size_t index = 100 % writes.size();
211        writes[index].result = error;
212        writes[index].data = NULL;
213        writes[index].data_len = 0;
214        writes.resize(index + 1);
215        reads.clear();
216        break;
217      }
218      case VERIFY_SERVER_HELLO_ERROR: {
219        // Use a fixed index for repeatability.
220        size_t index = 50 % reads.size();
221        if (error == ERR_MALFORMED_SERVER_HELLO) {
222          static const char kBadData[] = "BAD_DATA";
223          reads[index].data = kBadData;
224          reads[index].data_len = arraysize(kBadData);
225        } else {
226          reads[index].result = error;
227          reads[index].data = NULL;
228          reads[index].data_len = 0;
229        }
230        reads.resize(index + 1);
231        if (error ==
232            net::ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
233          static const char kDummyData[] = "DUMMY";
234          reads.push_back(net::MockRead(mode, kDummyData));
235        }
236        break;
237      }
238    }
239    SetData(mock_connect, &reads, &writes);
240
241    FakeSSLClientSocket fake_ssl_client_socket(MakeClientSocket());
242
243    // The two errors below are interpreted by FakeSSLClientSocket as
244    // an unexpected event.
245    int expected_status =
246        ((error == net::ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) ||
247         (error == ERR_MALFORMED_SERVER_HELLO)) ?
248        net::ERR_UNEXPECTED : error;
249
250    net::TestCompletionCallback test_completion_callback;
251    int status = fake_ssl_client_socket.Connect(
252        test_completion_callback.callback());
253    EXPECT_FALSE(fake_ssl_client_socket.IsConnected());
254    ExpectStatus(mode, expected_status, status, &test_completion_callback);
255    EXPECT_FALSE(fake_ssl_client_socket.IsConnected());
256  }
257
258  void RunUnsuccessfulHandshakeTest(
259      int error, HandshakeErrorLocation location) {
260    RunUnsuccessfulHandshakeTestHelper(net::SYNCHRONOUS, error, location);
261    RunUnsuccessfulHandshakeTestHelper(net::ASYNC, error, location);
262  }
263
264  // MockTCPClientSocket needs a message loop.
265  base::MessageLoop message_loop_;
266
267  net::MockClientSocketFactory mock_client_socket_factory_;
268  scoped_ptr<net::StaticSocketDataProvider> static_socket_data_provider_;
269};
270
271TEST_F(FakeSSLClientSocketTest, PassThroughMethods) {
272  scoped_ptr<MockClientSocket> mock_client_socket(new MockClientSocket());
273  const int kReceiveBufferSize = 10;
274  const int kSendBufferSize = 20;
275  net::IPEndPoint ip_endpoint(net::IPAddressNumber(net::kIPv4AddressSize), 80);
276  const int kPeerAddress = 30;
277  net::BoundNetLog net_log;
278  EXPECT_CALL(*mock_client_socket, SetReceiveBufferSize(kReceiveBufferSize));
279  EXPECT_CALL(*mock_client_socket, SetSendBufferSize(kSendBufferSize));
280  EXPECT_CALL(*mock_client_socket, GetPeerAddress(&ip_endpoint)).
281      WillOnce(Return(kPeerAddress));
282  EXPECT_CALL(*mock_client_socket, NetLog()).WillOnce(ReturnRef(net_log));
283  EXPECT_CALL(*mock_client_socket, SetSubresourceSpeculation());
284  EXPECT_CALL(*mock_client_socket, SetOmniboxSpeculation());
285
286  // Takes ownership of |mock_client_socket|.
287  FakeSSLClientSocket fake_ssl_client_socket(
288      mock_client_socket.PassAs<net::StreamSocket>());
289  fake_ssl_client_socket.SetReceiveBufferSize(kReceiveBufferSize);
290  fake_ssl_client_socket.SetSendBufferSize(kSendBufferSize);
291  EXPECT_EQ(kPeerAddress,
292            fake_ssl_client_socket.GetPeerAddress(&ip_endpoint));
293  EXPECT_EQ(&net_log, &fake_ssl_client_socket.NetLog());
294  fake_ssl_client_socket.SetSubresourceSpeculation();
295  fake_ssl_client_socket.SetOmniboxSpeculation();
296}
297
298TEST_F(FakeSSLClientSocketTest, SuccessfulHandshakeSync) {
299  for (size_t i = 1; i < 100; i += 3) {
300    SCOPED_TRACE(i);
301    for (size_t j = 1; j < 100; j += 5) {
302      SCOPED_TRACE(j);
303      RunSuccessfulHandshakeTest(net::SYNCHRONOUS, i, j, 0);
304    }
305  }
306}
307
308TEST_F(FakeSSLClientSocketTest, SuccessfulHandshakeAsync) {
309  for (size_t i = 1; i < 100; i += 7) {
310    SCOPED_TRACE(i);
311    for (size_t j = 1; j < 100; j += 9) {
312      SCOPED_TRACE(j);
313      RunSuccessfulHandshakeTest(net::ASYNC, i, j, 0);
314    }
315  }
316}
317
318TEST_F(FakeSSLClientSocketTest, ResetSocket) {
319  RunSuccessfulHandshakeTest(net::ASYNC, 1, 2, 3);
320}
321
322TEST_F(FakeSSLClientSocketTest, UnsuccessfulHandshakeConnectError) {
323  RunUnsuccessfulHandshakeTest(net::ERR_ACCESS_DENIED, CONNECT_ERROR);
324}
325
326TEST_F(FakeSSLClientSocketTest, UnsuccessfulHandshakeWriteError) {
327  RunUnsuccessfulHandshakeTest(net::ERR_OUT_OF_MEMORY,
328                               SEND_CLIENT_HELLO_ERROR);
329}
330
331TEST_F(FakeSSLClientSocketTest, UnsuccessfulHandshakeReadError) {
332  RunUnsuccessfulHandshakeTest(net::ERR_CONNECTION_CLOSED,
333                               VERIFY_SERVER_HELLO_ERROR);
334}
335
336TEST_F(FakeSSLClientSocketTest, PeerClosedDuringHandshake) {
337  RunUnsuccessfulHandshakeTest(
338      net::ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ,
339      VERIFY_SERVER_HELLO_ERROR);
340}
341
342TEST_F(FakeSSLClientSocketTest, MalformedServerHello) {
343  RunUnsuccessfulHandshakeTest(ERR_MALFORMED_SERVER_HELLO,
344                               VERIFY_SERVER_HELLO_ERROR);
345}
346
347}  // namespace
348
349}  // namespace jingle_glue
350