1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_
6#define NET_SOCKET_SOCKET_TEST_UTIL_H_
7#pragma once
8
9#include <cstring>
10#include <deque>
11#include <string>
12#include <vector>
13
14#include "base/basictypes.h"
15#include "base/callback.h"
16#include "base/logging.h"
17#include "base/memory/scoped_ptr.h"
18#include "base/memory/scoped_vector.h"
19#include "base/memory/weak_ptr.h"
20#include "base/string16.h"
21#include "net/base/address_list.h"
22#include "net/base/io_buffer.h"
23#include "net/base/net_errors.h"
24#include "net/base/net_log.h"
25#include "net/base/ssl_config_service.h"
26#include "net/base/test_completion_callback.h"
27#include "net/http/http_auth_controller.h"
28#include "net/http/http_proxy_client_socket_pool.h"
29#include "net/socket/client_socket_factory.h"
30#include "net/socket/client_socket_handle.h"
31#include "net/socket/socks_client_socket_pool.h"
32#include "net/socket/ssl_client_socket.h"
33#include "net/socket/ssl_client_socket_pool.h"
34#include "net/socket/transport_client_socket_pool.h"
35#include "testing/gtest/include/gtest/gtest.h"
36
37namespace net {
38
39enum {
40  // A private network error code used by the socket test utility classes.
41  // If the |result| member of a MockRead is
42  // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a
43  // marker that indicates the peer will close the connection after the next
44  // MockRead.  The other members of that MockRead are ignored.
45  ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000,
46};
47
48class ClientSocket;
49class MockClientSocket;
50class SSLClientSocket;
51class SSLHostInfo;
52
53struct MockConnect {
54  // Asynchronous connection success.
55  MockConnect() : async(true), result(OK) { }
56  MockConnect(bool a, int r) : async(a), result(r) { }
57
58  bool async;
59  int result;
60};
61
62struct MockRead {
63  // Flag to indicate that the message loop should be terminated.
64  enum {
65    STOPLOOP = 1 << 31
66  };
67
68  // Default
69  MockRead() : async(false), result(0), data(NULL), data_len(0),
70      sequence_number(0), time_stamp(base::Time::Now()) {}
71
72  // Read failure (no data).
73  MockRead(bool async, int result) : async(async) , result(result), data(NULL),
74      data_len(0), sequence_number(0), time_stamp(base::Time::Now()) { }
75
76  // Read failure (no data), with sequence information.
77  MockRead(bool async, int result, int seq) : async(async) , result(result),
78      data(NULL), data_len(0), sequence_number(seq),
79      time_stamp(base::Time::Now()) { }
80
81  // Asynchronous read success (inferred data length).
82  explicit MockRead(const char* data) : async(true),  result(0), data(data),
83      data_len(strlen(data)), sequence_number(0),
84      time_stamp(base::Time::Now()) { }
85
86  // Read success (inferred data length).
87  MockRead(bool async, const char* data) : async(async), result(0), data(data),
88      data_len(strlen(data)), sequence_number(0),
89      time_stamp(base::Time::Now()) { }
90
91  // Read success.
92  MockRead(bool async, const char* data, int data_len) : async(async),
93      result(0), data(data), data_len(data_len), sequence_number(0),
94      time_stamp(base::Time::Now()) { }
95
96  // Read success (inferred data length) with sequence information.
97  MockRead(bool async, int seq, const char* data) : async(async),
98      result(0), data(data), data_len(strlen(data)), sequence_number(seq),
99      time_stamp(base::Time::Now()) { }
100
101  // Read success with sequence information.
102  MockRead(bool async, const char* data, int data_len, int seq) : async(async),
103      result(0), data(data), data_len(data_len), sequence_number(seq),
104      time_stamp(base::Time::Now()) { }
105
106  bool async;
107  int result;
108  const char* data;
109  int data_len;
110
111  // For OrderedSocketData, which only allows reads to occur in a particular
112  // sequence.  If a read occurs before the given |sequence_number| is reached,
113  // an ERR_IO_PENDING is returned.
114  int sequence_number;      // The sequence number at which a read is allowed
115                            // to occur.
116  base::Time time_stamp;    // The time stamp at which the operation occurred.
117};
118
119// MockWrite uses the same member fields as MockRead, but with different
120// meanings. The expected input to MockTCPClientSocket::Write() is given
121// by {data, data_len}, and the return value of Write() is controlled by
122// {async, result}.
123typedef MockRead MockWrite;
124
125struct MockWriteResult {
126  MockWriteResult(bool async, int result) : async(async), result(result) {}
127
128  bool async;
129  int result;
130};
131
132// The SocketDataProvider is an interface used by the MockClientSocket
133// for getting data about individual reads and writes on the socket.
134class SocketDataProvider {
135 public:
136  SocketDataProvider() : socket_(NULL) {}
137
138  virtual ~SocketDataProvider() {}
139
140  // Returns the buffer and result code for the next simulated read.
141  // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller
142  // that it will be called via the MockClientSocket::OnReadComplete()
143  // function at a later time.
144  virtual MockRead GetNextRead() = 0;
145  virtual MockWriteResult OnWrite(const std::string& data) = 0;
146  virtual void Reset() = 0;
147
148  // Accessor for the socket which is using the SocketDataProvider.
149  MockClientSocket* socket() { return socket_; }
150  void set_socket(MockClientSocket* socket) { socket_ = socket; }
151
152  MockConnect connect_data() const { return connect_; }
153  void set_connect_data(const MockConnect& connect) { connect_ = connect; }
154
155 private:
156  MockConnect connect_;
157  MockClientSocket* socket_;
158
159  DISALLOW_COPY_AND_ASSIGN(SocketDataProvider);
160};
161
162// SocketDataProvider which responds based on static tables of mock reads and
163// writes.
164class StaticSocketDataProvider : public SocketDataProvider {
165 public:
166  StaticSocketDataProvider();
167  StaticSocketDataProvider(MockRead* reads, size_t reads_count,
168                           MockWrite* writes, size_t writes_count);
169  virtual ~StaticSocketDataProvider();
170
171  // These functions get access to the next available read and write data.
172  const MockRead& PeekRead() const;
173  const MockWrite& PeekWrite() const;
174  // These functions get random access to the read and write data, for timing.
175  const MockRead& PeekRead(size_t index) const;
176  const MockWrite& PeekWrite(size_t index) const;
177  size_t read_index() const { return read_index_; }
178  size_t write_index() const { return write_index_; }
179  size_t read_count() const { return read_count_; }
180  size_t write_count() const { return write_count_; }
181
182  bool at_read_eof() const { return read_index_ >= read_count_; }
183  bool at_write_eof() const { return write_index_ >= write_count_; }
184
185  virtual void CompleteRead() {}
186
187  // SocketDataProvider methods:
188  virtual MockRead GetNextRead();
189  virtual MockWriteResult OnWrite(const std::string& data);
190  virtual void Reset();
191
192 private:
193  MockRead* reads_;
194  size_t read_index_;
195  size_t read_count_;
196  MockWrite* writes_;
197  size_t write_index_;
198  size_t write_count_;
199
200  DISALLOW_COPY_AND_ASSIGN(StaticSocketDataProvider);
201};
202
203// SocketDataProvider which can make decisions about next mock reads based on
204// received writes. It can also be used to enforce order of operations, for
205// example that tested code must send the "Hello!" message before receiving
206// response. This is useful for testing conversation-like protocols like FTP.
207class DynamicSocketDataProvider : public SocketDataProvider {
208 public:
209  DynamicSocketDataProvider();
210  virtual ~DynamicSocketDataProvider();
211
212  int short_read_limit() const { return short_read_limit_; }
213  void set_short_read_limit(int limit) { short_read_limit_ = limit; }
214
215  void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; }
216
217  // SocketDataProvider methods:
218  virtual MockRead GetNextRead();
219  virtual MockWriteResult OnWrite(const std::string& data) = 0;
220  virtual void Reset();
221
222 protected:
223  // The next time there is a read from this socket, it will return |data|.
224  // Before calling SimulateRead next time, the previous data must be consumed.
225  void SimulateRead(const char* data, size_t length);
226  void SimulateRead(const char* data) {
227    SimulateRead(data, std::strlen(data));
228  }
229
230 private:
231  std::deque<MockRead> reads_;
232
233  // Max number of bytes we will read at a time. 0 means no limit.
234  int short_read_limit_;
235
236  // If true, we'll not require the client to consume all data before we
237  // mock the next read.
238  bool allow_unconsumed_reads_;
239
240  DISALLOW_COPY_AND_ASSIGN(DynamicSocketDataProvider);
241};
242
243// SSLSocketDataProviders only need to keep track of the return code from calls
244// to Connect().
245struct SSLSocketDataProvider {
246  SSLSocketDataProvider(bool async, int result);
247  ~SSLSocketDataProvider();
248
249  MockConnect connect;
250  SSLClientSocket::NextProtoStatus next_proto_status;
251  std::string next_proto;
252  bool was_npn_negotiated;
253  net::SSLCertRequestInfo* cert_request_info;
254  scoped_refptr<X509Certificate> cert_;
255};
256
257// A DataProvider where the client must write a request before the reads (e.g.
258// the response) will complete.
259class DelayedSocketData : public StaticSocketDataProvider,
260                          public base::RefCounted<DelayedSocketData> {
261 public:
262  // |write_delay| the number of MockWrites to complete before allowing
263  //               a MockRead to complete.
264  // |reads| the list of MockRead completions.
265  // |writes| the list of MockWrite completions.
266  // Note: All MockReads and MockWrites must be async.
267  // Note: The MockRead and MockWrite lists musts end with a EOF
268  //       e.g. a MockRead(true, 0, 0);
269  DelayedSocketData(int write_delay,
270                    MockRead* reads, size_t reads_count,
271                    MockWrite* writes, size_t writes_count);
272
273  // |connect| the result for the connect phase.
274  // |reads| the list of MockRead completions.
275  // |write_delay| the number of MockWrites to complete before allowing
276  //               a MockRead to complete.
277  // |writes| the list of MockWrite completions.
278  // Note: All MockReads and MockWrites must be async.
279  // Note: The MockRead and MockWrite lists musts end with a EOF
280  //       e.g. a MockRead(true, 0, 0);
281  DelayedSocketData(const MockConnect& connect, int write_delay,
282                    MockRead* reads, size_t reads_count,
283                    MockWrite* writes, size_t writes_count);
284  ~DelayedSocketData();
285
286  void ForceNextRead();
287
288  // StaticSocketDataProvider:
289  virtual MockRead GetNextRead();
290  virtual MockWriteResult OnWrite(const std::string& data);
291  virtual void Reset();
292  virtual void CompleteRead();
293
294 private:
295  int write_delay_;
296  ScopedRunnableMethodFactory<DelayedSocketData> factory_;
297};
298
299// A DataProvider where the reads are ordered.
300// If a read is requested before its sequence number is reached, we return an
301// ERR_IO_PENDING (that way we don't have to explicitly add a MockRead just to
302// wait).
303// The sequence number is incremented on every read and write operation.
304// The message loop may be interrupted by setting the high bit of the sequence
305// number in the MockRead's sequence number.  When that MockRead is reached,
306// we post a Quit message to the loop.  This allows us to interrupt the reading
307// of data before a complete message has arrived, and provides support for
308// testing server push when the request is issued while the response is in the
309// middle of being received.
310class OrderedSocketData : public StaticSocketDataProvider,
311                          public base::RefCounted<OrderedSocketData> {
312 public:
313  // |reads| the list of MockRead completions.
314  // |writes| the list of MockWrite completions.
315  // Note: All MockReads and MockWrites must be async.
316  // Note: The MockRead and MockWrite lists musts end with a EOF
317  //       e.g. a MockRead(true, 0, 0);
318  OrderedSocketData(MockRead* reads, size_t reads_count,
319                    MockWrite* writes, size_t writes_count);
320
321  // |connect| the result for the connect phase.
322  // |reads| the list of MockRead completions.
323  // |writes| the list of MockWrite completions.
324  // Note: All MockReads and MockWrites must be async.
325  // Note: The MockRead and MockWrite lists musts end with a EOF
326  //       e.g. a MockRead(true, 0, 0);
327  OrderedSocketData(const MockConnect& connect,
328                    MockRead* reads, size_t reads_count,
329                    MockWrite* writes, size_t writes_count);
330
331  void SetCompletionCallback(CompletionCallback* callback) {
332    callback_ = callback;
333  }
334
335  // Posts a quit message to the current message loop, if one is running.
336  void EndLoop();
337
338  // StaticSocketDataProvider:
339  virtual MockRead GetNextRead();
340  virtual MockWriteResult OnWrite(const std::string& data);
341  virtual void Reset();
342  virtual void CompleteRead();
343
344 private:
345  friend class base::RefCounted<OrderedSocketData>;
346  virtual ~OrderedSocketData();
347
348  int sequence_number_;
349  int loop_stop_stage_;
350  CompletionCallback* callback_;
351  bool blocked_;
352  ScopedRunnableMethodFactory<OrderedSocketData> factory_;
353};
354
355class DeterministicMockTCPClientSocket;
356
357// This class gives the user full control over the network activity,
358// specifically the timing of the COMPLETION of I/O operations.  Regardless of
359// the order in which I/O operations are initiated, this class ensures that they
360// complete in the correct order.
361//
362// Network activity is modeled as a sequence of numbered steps which is
363// incremented whenever an I/O operation completes.  This can happen under two
364// different circumstances:
365//
366// 1) Performing a synchronous I/O operation.  (Invoking Read() or Write()
367//    when the corresponding MockRead or MockWrite is marked !async).
368// 2) Running the Run() method of this class.  The run method will invoke
369//    the current MessageLoop, running all pending events, and will then
370//    invoke any pending IO callbacks.
371//
372// In addition, this class allows for I/O processing to "stop" at a specified
373// step, by calling SetStop(int) or StopAfter(int).  Initiating an I/O operation
374// by calling Read() or Write() while stopped is permitted if the operation is
375// asynchronous.  It is an error to perform synchronous I/O while stopped.
376//
377// When creating the MockReads and MockWrites, note that the sequence number
378// refers to the number of the step in which the I/O will complete.  In the
379// case of synchronous I/O, this will be the same step as the I/O is initiated.
380// However, in the case of asynchronous I/O, this I/O may be initiated in
381// a much earlier step. Furthermore, when the a Read() or Write() is separated
382// from its completion by other Read() or Writes()'s, it can not be marked
383// synchronous.  If it is, ERR_UNUEXPECTED will be returned indicating that a
384// synchronous Read() or Write() could not be completed synchronously because of
385// the specific ordering constraints.
386//
387// Sequence numbers are preserved across both reads and writes. There should be
388// no gaps in sequence numbers, and no repeated sequence numbers. i.e.
389//  MockRead reads[] = {
390//    MockRead(false, "first read", length, 0)   // sync
391//    MockRead(true, "second read", length, 2)   // async
392//  };
393//  MockWrite writes[] = {
394//    MockWrite(true, "first write", length, 1),    // async
395//    MockWrite(false, "second write", length, 3),  // sync
396//  };
397//
398// Example control flow:
399// Read() is called.  The current step is 0.  The first available read is
400// synchronous, so the call to Read() returns length.  The current step is
401// now 1.  Next, Read() is called again.  The next available read can
402// not be completed until step 2, so Read() returns ERR_IO_PENDING.  The current
403// step is still 1.  Write is called().  The first available write is able to
404// complete in this step, but is marked asynchronous.  Write() returns
405// ERR_IO_PENDING.  The current step is still 1.  At this point RunFor(1) is
406// called which will cause the write callback to be invoked, and will then
407// stop.  The current state is now 2.  RunFor(1) is called again, which
408// causes the read callback to be invoked, and will then stop.  Then current
409// step is 2.  Write() is called again.  Then next available write is
410// synchronous so the call to Write() returns length.
411//
412// For examples of how to use this class, see:
413//   deterministic_socket_data_unittests.cc
414class DeterministicSocketData : public StaticSocketDataProvider,
415    public base::RefCounted<DeterministicSocketData> {
416 public:
417  // |reads| the list of MockRead completions.
418  // |writes| the list of MockWrite completions.
419  DeterministicSocketData(MockRead* reads, size_t reads_count,
420                          MockWrite* writes, size_t writes_count);
421  virtual ~DeterministicSocketData();
422
423  // Consume all the data up to the give stop point (via SetStop()).
424  void Run();
425
426  // Set the stop point to be |steps| from now, and then invoke Run().
427  void RunFor(int steps);
428
429  // Stop at step |seq|, which must be in the future.
430  virtual void SetStop(int seq);
431
432  // Stop |seq| steps after the current step.
433  virtual void StopAfter(int seq);
434  bool stopped() const { return stopped_; }
435  void SetStopped(bool val) { stopped_ = val; }
436  MockRead& current_read() { return current_read_; }
437  MockRead& current_write() { return current_write_; }
438  int sequence_number() const { return sequence_number_; }
439  void set_socket(base::WeakPtr<DeterministicMockTCPClientSocket> socket) {
440    socket_ = socket;
441  }
442
443  // StaticSocketDataProvider:
444
445  // When the socket calls Read(), that calls GetNextRead(), and expects either
446  // ERR_IO_PENDING or data.
447  virtual MockRead GetNextRead();
448
449  // When the socket calls Write(), it always completes synchronously. OnWrite()
450  // checks to make sure the written data matches the expected data. The
451  // callback will not be invoked until its sequence number is reached.
452  virtual MockWriteResult OnWrite(const std::string& data);
453  virtual void Reset();
454  virtual void CompleteRead() {}
455
456 private:
457  // Invoke the read and write callbacks, if the timing is appropriate.
458  void InvokeCallbacks();
459
460  void NextStep();
461
462  int sequence_number_;
463  MockRead current_read_;
464  MockWrite current_write_;
465  int stopping_sequence_number_;
466  bool stopped_;
467  base::WeakPtr<DeterministicMockTCPClientSocket> socket_;
468  bool print_debug_;
469};
470
471// Holds an array of SocketDataProvider elements.  As Mock{TCP,SSL}ClientSocket
472// objects get instantiated, they take their data from the i'th element of this
473// array.
474template<typename T>
475class SocketDataProviderArray {
476 public:
477  SocketDataProviderArray() : next_index_(0) {
478  }
479
480  T* GetNext() {
481    DCHECK_LT(next_index_, data_providers_.size());
482    return data_providers_[next_index_++];
483  }
484
485  void Add(T* data_provider) {
486    DCHECK(data_provider);
487    data_providers_.push_back(data_provider);
488  }
489
490  void ResetNextIndex() {
491    next_index_ = 0;
492  }
493
494 private:
495  // Index of the next |data_providers_| element to use. Not an iterator
496  // because those are invalidated on vector reallocation.
497  size_t next_index_;
498
499  // SocketDataProviders to be returned.
500  std::vector<T*> data_providers_;
501};
502
503class MockTCPClientSocket;
504class MockSSLClientSocket;
505
506// ClientSocketFactory which contains arrays of sockets of each type.
507// You should first fill the arrays using AddMock{SSL,}Socket. When the factory
508// is asked to create a socket, it takes next entry from appropriate array.
509// You can use ResetNextMockIndexes to reset that next entry index for all mock
510// socket types.
511class MockClientSocketFactory : public ClientSocketFactory {
512 public:
513  MockClientSocketFactory();
514  virtual ~MockClientSocketFactory();
515
516  void AddSocketDataProvider(SocketDataProvider* socket);
517  void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
518  void ResetNextMockIndexes();
519
520  // Return |index|-th MockTCPClientSocket (starting from 0) that the factory
521  // created.
522  MockTCPClientSocket* GetMockTCPClientSocket(size_t index) const;
523
524  // Return |index|-th MockSSLClientSocket (starting from 0) that the factory
525  // created.
526  MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const;
527
528  SocketDataProviderArray<SocketDataProvider>& mock_data() {
529    return mock_data_;
530  }
531  std::vector<MockTCPClientSocket*>& tcp_client_sockets() {
532    return tcp_client_sockets_;
533  }
534
535  // ClientSocketFactory
536  virtual ClientSocket* CreateTransportClientSocket(
537      const AddressList& addresses,
538      NetLog* net_log,
539      const NetLog::Source& source);
540  virtual SSLClientSocket* CreateSSLClientSocket(
541      ClientSocketHandle* transport_socket,
542      const HostPortPair& host_and_port,
543      const SSLConfig& ssl_config,
544      SSLHostInfo* ssl_host_info,
545      CertVerifier* cert_verifier,
546      DnsCertProvenanceChecker* dns_cert_checker);
547  virtual void ClearSSLSessionCache();
548
549 private:
550  SocketDataProviderArray<SocketDataProvider> mock_data_;
551  SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
552
553  // Store pointers to handed out sockets in case the test wants to get them.
554  std::vector<MockTCPClientSocket*> tcp_client_sockets_;
555  std::vector<MockSSLClientSocket*> ssl_client_sockets_;
556};
557
558class MockClientSocket : public net::SSLClientSocket {
559 public:
560  explicit MockClientSocket(net::NetLog* net_log);
561
562  // If an async IO is pending because the SocketDataProvider returned
563  // ERR_IO_PENDING, then the MockClientSocket waits until this OnReadComplete
564  // is called to complete the asynchronous read operation.
565  // data.async is ignored, and this read is completed synchronously as
566  // part of this call.
567  virtual void OnReadComplete(const MockRead& data) = 0;
568
569  // Socket methods:
570  virtual int Read(net::IOBuffer* buf, int buf_len,
571                   net::CompletionCallback* callback) = 0;
572  virtual int Write(net::IOBuffer* buf, int buf_len,
573                    net::CompletionCallback* callback) = 0;
574  virtual bool SetReceiveBufferSize(int32 size);
575  virtual bool SetSendBufferSize(int32 size);
576
577  // ClientSocket methods:
578  virtual int Connect(net::CompletionCallback* callback) = 0;
579  virtual void Disconnect();
580  virtual bool IsConnected() const;
581  virtual bool IsConnectedAndIdle() const;
582  virtual int GetPeerAddress(AddressList* address) const;
583  virtual int GetLocalAddress(IPEndPoint* address) const;
584  virtual const BoundNetLog& NetLog() const;
585  virtual void SetSubresourceSpeculation() {}
586  virtual void SetOmniboxSpeculation() {}
587
588  // SSLClientSocket methods:
589  virtual void GetSSLInfo(net::SSLInfo* ssl_info);
590  virtual void GetSSLCertRequestInfo(
591      net::SSLCertRequestInfo* cert_request_info);
592  virtual NextProtoStatus GetNextProto(std::string* proto);
593
594 protected:
595  virtual ~MockClientSocket();
596  void RunCallbackAsync(net::CompletionCallback* callback, int result);
597  void RunCallback(net::CompletionCallback*, int result);
598
599  ScopedRunnableMethodFactory<MockClientSocket> method_factory_;
600
601  // True if Connect completed successfully and Disconnect hasn't been called.
602  bool connected_;
603
604  net::BoundNetLog net_log_;
605};
606
607class MockTCPClientSocket : public MockClientSocket {
608 public:
609  MockTCPClientSocket(const net::AddressList& addresses, net::NetLog* net_log,
610                      net::SocketDataProvider* socket);
611
612  net::AddressList addresses() const { return addresses_; }
613
614  // Socket methods:
615  virtual int Read(net::IOBuffer* buf, int buf_len,
616                   net::CompletionCallback* callback);
617  virtual int Write(net::IOBuffer* buf, int buf_len,
618                    net::CompletionCallback* callback);
619
620  // ClientSocket methods:
621  virtual int Connect(net::CompletionCallback* callback);
622  virtual void Disconnect();
623  virtual bool IsConnected() const;
624  virtual bool IsConnectedAndIdle() const;
625  virtual int GetPeerAddress(AddressList* address) const;
626  virtual bool WasEverUsed() const;
627  virtual bool UsingTCPFastOpen() const;
628
629  // MockClientSocket:
630  virtual void OnReadComplete(const MockRead& data);
631
632 private:
633  int CompleteRead();
634
635  net::AddressList addresses_;
636
637  net::SocketDataProvider* data_;
638  int read_offset_;
639  net::MockRead read_data_;
640  bool need_read_data_;
641
642  // True if the peer has closed the connection.  This allows us to simulate
643  // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real
644  // TCPClientSocket.
645  bool peer_closed_connection_;
646
647  // While an asynchronous IO is pending, we save our user-buffer state.
648  net::IOBuffer* pending_buf_;
649  int pending_buf_len_;
650  net::CompletionCallback* pending_callback_;
651  bool was_used_to_convey_data_;
652};
653
654class DeterministicMockTCPClientSocket : public MockClientSocket,
655    public base::SupportsWeakPtr<DeterministicMockTCPClientSocket> {
656 public:
657  DeterministicMockTCPClientSocket(net::NetLog* net_log,
658      net::DeterministicSocketData* data);
659  virtual ~DeterministicMockTCPClientSocket();
660
661  bool write_pending() const { return write_pending_; }
662  bool read_pending() const { return read_pending_; }
663
664  void CompleteWrite();
665  int CompleteRead();
666
667  // Socket:
668  virtual int Write(net::IOBuffer* buf, int buf_len,
669                    net::CompletionCallback* callback);
670  virtual int Read(net::IOBuffer* buf, int buf_len,
671                   net::CompletionCallback* callback);
672
673  // ClientSocket:
674  virtual int Connect(net::CompletionCallback* callback);
675  virtual void Disconnect();
676  virtual bool IsConnected() const;
677  virtual bool IsConnectedAndIdle() const;
678  virtual bool WasEverUsed() const;
679  virtual bool UsingTCPFastOpen() const;
680
681  // MockClientSocket:
682  virtual void OnReadComplete(const MockRead& data);
683
684 private:
685  bool write_pending_;
686  net::CompletionCallback* write_callback_;
687  int write_result_;
688
689  net::MockRead read_data_;
690
691  net::IOBuffer* read_buf_;
692  int read_buf_len_;
693  bool read_pending_;
694  net::CompletionCallback* read_callback_;
695  net::DeterministicSocketData* data_;
696  bool was_used_to_convey_data_;
697};
698
699class MockSSLClientSocket : public MockClientSocket {
700 public:
701  MockSSLClientSocket(
702      net::ClientSocketHandle* transport_socket,
703      const HostPortPair& host_and_port,
704      const net::SSLConfig& ssl_config,
705      SSLHostInfo* ssl_host_info,
706      net::SSLSocketDataProvider* socket);
707  virtual ~MockSSLClientSocket();
708
709  // Socket methods:
710  virtual int Read(net::IOBuffer* buf, int buf_len,
711                   net::CompletionCallback* callback);
712  virtual int Write(net::IOBuffer* buf, int buf_len,
713                    net::CompletionCallback* callback);
714
715  // ClientSocket methods:
716  virtual int Connect(net::CompletionCallback* callback);
717  virtual void Disconnect();
718  virtual bool IsConnected() const;
719  virtual bool WasEverUsed() const;
720  virtual bool UsingTCPFastOpen() const;
721
722  // SSLClientSocket methods:
723  virtual void GetSSLInfo(net::SSLInfo* ssl_info);
724  virtual void GetSSLCertRequestInfo(
725      net::SSLCertRequestInfo* cert_request_info);
726  virtual NextProtoStatus GetNextProto(std::string* proto);
727  virtual bool was_npn_negotiated() const;
728  virtual bool set_was_npn_negotiated(bool negotiated);
729
730  // This MockSocket does not implement the manual async IO feature.
731  virtual void OnReadComplete(const MockRead& data);
732
733 private:
734  class ConnectCallback;
735
736  scoped_ptr<ClientSocketHandle> transport_;
737  net::SSLSocketDataProvider* data_;
738  bool is_npn_state_set_;
739  bool new_npn_value_;
740  bool was_used_to_convey_data_;
741};
742
743class TestSocketRequest : public CallbackRunner< Tuple1<int> > {
744 public:
745  TestSocketRequest(
746      std::vector<TestSocketRequest*>* request_order,
747      size_t* completion_count);
748  virtual ~TestSocketRequest();
749
750  ClientSocketHandle* handle() { return &handle_; }
751
752  int WaitForResult();
753  virtual void RunWithParams(const Tuple1<int>& params);
754
755 private:
756  ClientSocketHandle handle_;
757  std::vector<TestSocketRequest*>* request_order_;
758  size_t* completion_count_;
759  TestCompletionCallback callback_;
760};
761
762class ClientSocketPoolTest {
763 public:
764  enum KeepAlive {
765    KEEP_ALIVE,
766
767    // A socket will be disconnected in addition to handle being reset.
768    NO_KEEP_ALIVE,
769  };
770
771  static const int kIndexOutOfBounds;
772  static const int kRequestNotFound;
773
774  ClientSocketPoolTest();
775  ~ClientSocketPoolTest();
776
777  template <typename PoolType, typename SocketParams>
778  int StartRequestUsingPool(PoolType* socket_pool,
779                            const std::string& group_name,
780                            RequestPriority priority,
781                            const scoped_refptr<SocketParams>& socket_params) {
782    DCHECK(socket_pool);
783    TestSocketRequest* request = new TestSocketRequest(&request_order_,
784                                                       &completion_count_);
785    requests_.push_back(request);
786    int rv = request->handle()->Init(
787        group_name, socket_params, priority, request,
788        socket_pool, BoundNetLog());
789    if (rv != ERR_IO_PENDING)
790      request_order_.push_back(request);
791    return rv;
792  }
793
794  // Provided there were n requests started, takes |index| in range 1..n
795  // and returns order in which that request completed, in range 1..n,
796  // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound
797  // if that request did not complete (for example was canceled).
798  int GetOrderOfRequest(size_t index) const;
799
800  // Resets first initialized socket handle from |requests_|. If found such
801  // a handle, returns true.
802  bool ReleaseOneConnection(KeepAlive keep_alive);
803
804  // Releases connections until there is nothing to release.
805  void ReleaseAllConnections(KeepAlive keep_alive);
806
807  TestSocketRequest* request(int i) { return requests_[i]; }
808  size_t requests_size() const { return requests_.size(); }
809  ScopedVector<TestSocketRequest>* requests() { return &requests_; }
810  size_t completion_count() const { return completion_count_; }
811
812 private:
813  ScopedVector<TestSocketRequest> requests_;
814  std::vector<TestSocketRequest*> request_order_;
815  size_t completion_count_;
816};
817
818class MockTransportClientSocketPool : public TransportClientSocketPool {
819 public:
820  class MockConnectJob {
821   public:
822    MockConnectJob(ClientSocket* socket, ClientSocketHandle* handle,
823                   CompletionCallback* callback);
824    ~MockConnectJob();
825
826    int Connect();
827    bool CancelHandle(const ClientSocketHandle* handle);
828
829   private:
830    void OnConnect(int rv);
831
832    scoped_ptr<ClientSocket> socket_;
833    ClientSocketHandle* handle_;
834    CompletionCallback* user_callback_;
835    CompletionCallbackImpl<MockConnectJob> connect_callback_;
836
837    DISALLOW_COPY_AND_ASSIGN(MockConnectJob);
838  };
839
840  MockTransportClientSocketPool(
841      int max_sockets,
842      int max_sockets_per_group,
843      ClientSocketPoolHistograms* histograms,
844      ClientSocketFactory* socket_factory);
845
846  virtual ~MockTransportClientSocketPool();
847
848  int release_count() const { return release_count_; }
849  int cancel_count() const { return cancel_count_; }
850
851  // TransportClientSocketPool methods.
852  virtual int RequestSocket(const std::string& group_name,
853                            const void* socket_params,
854                            RequestPriority priority,
855                            ClientSocketHandle* handle,
856                            CompletionCallback* callback,
857                            const BoundNetLog& net_log);
858
859  virtual void CancelRequest(const std::string& group_name,
860                             ClientSocketHandle* handle);
861  virtual void ReleaseSocket(const std::string& group_name,
862                             ClientSocket* socket, int id);
863
864 private:
865  ClientSocketFactory* client_socket_factory_;
866  ScopedVector<MockConnectJob> job_list_;
867  int release_count_;
868  int cancel_count_;
869
870  DISALLOW_COPY_AND_ASSIGN(MockTransportClientSocketPool);
871};
872
873class DeterministicMockClientSocketFactory : public ClientSocketFactory {
874 public:
875  DeterministicMockClientSocketFactory();
876  virtual ~DeterministicMockClientSocketFactory();
877
878  void AddSocketDataProvider(DeterministicSocketData* socket);
879  void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
880  void ResetNextMockIndexes();
881
882  // Return |index|-th MockSSLClientSocket (starting from 0) that the factory
883  // created.
884  MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const;
885
886  SocketDataProviderArray<DeterministicSocketData>& mock_data() {
887    return mock_data_;
888  }
889  std::vector<DeterministicMockTCPClientSocket*>& tcp_client_sockets() {
890    return tcp_client_sockets_;
891  }
892
893  // ClientSocketFactory
894  virtual ClientSocket* CreateTransportClientSocket(
895      const AddressList& addresses,
896      NetLog* net_log,
897      const NetLog::Source& source);
898  virtual SSLClientSocket* CreateSSLClientSocket(
899      ClientSocketHandle* transport_socket,
900      const HostPortPair& host_and_port,
901      const SSLConfig& ssl_config,
902      SSLHostInfo* ssl_host_info,
903      CertVerifier* cert_verifier,
904      DnsCertProvenanceChecker* dns_cert_checker);
905  virtual void ClearSSLSessionCache();
906
907 private:
908  SocketDataProviderArray<DeterministicSocketData> mock_data_;
909  SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
910
911  // Store pointers to handed out sockets in case the test wants to get them.
912  std::vector<DeterministicMockTCPClientSocket*> tcp_client_sockets_;
913  std::vector<MockSSLClientSocket*> ssl_client_sockets_;
914};
915
916class MockSOCKSClientSocketPool : public SOCKSClientSocketPool {
917 public:
918  MockSOCKSClientSocketPool(
919      int max_sockets,
920      int max_sockets_per_group,
921      ClientSocketPoolHistograms* histograms,
922      TransportClientSocketPool* transport_pool);
923
924  virtual ~MockSOCKSClientSocketPool();
925
926  // SOCKSClientSocketPool methods.
927  virtual int RequestSocket(const std::string& group_name,
928                            const void* socket_params,
929                            RequestPriority priority,
930                            ClientSocketHandle* handle,
931                            CompletionCallback* callback,
932                            const BoundNetLog& net_log);
933
934  virtual void CancelRequest(const std::string& group_name,
935                             ClientSocketHandle* handle);
936  virtual void ReleaseSocket(const std::string& group_name,
937                             ClientSocket* socket, int id);
938
939 private:
940  TransportClientSocketPool* const transport_pool_;
941
942  DISALLOW_COPY_AND_ASSIGN(MockSOCKSClientSocketPool);
943};
944
945// Constants for a successful SOCKS v5 handshake.
946extern const char kSOCKS5GreetRequest[];
947extern const int kSOCKS5GreetRequestLength;
948
949extern const char kSOCKS5GreetResponse[];
950extern const int kSOCKS5GreetResponseLength;
951
952extern const char kSOCKS5OkRequest[];
953extern const int kSOCKS5OkRequestLength;
954
955extern const char kSOCKS5OkResponse[];
956extern const int kSOCKS5OkResponseLength;
957
958}  // namespace net
959
960#endif  // NET_SOCKET_SOCKET_TEST_UTIL_H_
961