socket_test_util.h revision b2df76ea8fec9e32f6f3718986dba0d95315b29c
1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_
6#define NET_SOCKET_SOCKET_TEST_UTIL_H_
7
8#include <cstring>
9#include <deque>
10#include <string>
11#include <vector>
12
13#include "base/basictypes.h"
14#include "base/callback.h"
15#include "base/logging.h"
16#include "base/memory/scoped_ptr.h"
17#include "base/memory/scoped_vector.h"
18#include "base/memory/weak_ptr.h"
19#include "base/string16.h"
20#include "net/base/address_list.h"
21#include "net/base/io_buffer.h"
22#include "net/base/net_errors.h"
23#include "net/base/net_log.h"
24#include "net/base/test_completion_callback.h"
25#include "net/http/http_auth_controller.h"
26#include "net/http/http_proxy_client_socket_pool.h"
27#include "net/socket/client_socket_factory.h"
28#include "net/socket/client_socket_handle.h"
29#include "net/socket/socks_client_socket_pool.h"
30#include "net/socket/ssl_client_socket.h"
31#include "net/socket/ssl_client_socket_pool.h"
32#include "net/socket/transport_client_socket_pool.h"
33#include "net/ssl/ssl_config_service.h"
34#include "net/udp/datagram_client_socket.h"
35#include "testing/gtest/include/gtest/gtest.h"
36
37namespace net {
38
39enum {
40  // A private network error code used by the socket test utility classes.
41  // If the |result| member of a MockRead is
42  // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a
43  // marker that indicates the peer will close the connection after the next
44  // MockRead.  The other members of that MockRead are ignored.
45  ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000,
46};
47
48class AsyncSocket;
49class MockClientSocket;
50class ServerBoundCertService;
51class SSLClientSocket;
52class StreamSocket;
53
54enum IoMode {
55  ASYNC,
56  SYNCHRONOUS
57};
58
59struct MockConnect {
60  // Asynchronous connection success.
61  // Creates a MockConnect with |mode| ASYC, |result| OK, and
62  // |peer_addr| 192.0.2.33.
63  MockConnect();
64  // Creates a MockConnect with the specified mode and result, with
65  // |peer_addr| 192.0.2.33.
66  MockConnect(IoMode io_mode, int r);
67  MockConnect(IoMode io_mode, int r, IPEndPoint addr);
68  ~MockConnect();
69
70  IoMode mode;
71  int result;
72  IPEndPoint peer_addr;
73};
74
75// MockRead and MockWrite shares the same interface and members, but we'd like
76// to have distinct types because we don't want to have them used
77// interchangably. To do this, a struct template is defined, and MockRead and
78// MockWrite are instantiated by using this template. Template parameter |type|
79// is not used in the struct definition (it purely exists for creating a new
80// type).
81//
82// |data| in MockRead and MockWrite has different meanings: |data| in MockRead
83// is the data returned from the socket when MockTCPClientSocket::Read() is
84// attempted, while |data| in MockWrite is the expected data that should be
85// given in MockTCPClientSocket::Write().
86enum MockReadWriteType {
87  MOCK_READ,
88  MOCK_WRITE
89};
90
91template <MockReadWriteType type>
92struct MockReadWrite {
93  // Flag to indicate that the message loop should be terminated.
94  enum {
95    STOPLOOP = 1 << 31
96  };
97
98  // Default
99  MockReadWrite() : mode(SYNCHRONOUS), result(0), data(NULL), data_len(0),
100      sequence_number(0), time_stamp(base::Time::Now()) {}
101
102  // Read/write failure (no data).
103  MockReadWrite(IoMode io_mode, int result) : mode(io_mode), result(result),
104      data(NULL), data_len(0), sequence_number(0),
105      time_stamp(base::Time::Now()) { }
106
107  // Read/write failure (no data), with sequence information.
108  MockReadWrite(IoMode io_mode, int result, int seq) : mode(io_mode),
109      result(result), data(NULL), data_len(0), sequence_number(seq),
110      time_stamp(base::Time::Now()) { }
111
112  // Asynchronous read/write success (inferred data length).
113  explicit MockReadWrite(const char* data) : mode(ASYNC),  result(0),
114      data(data), data_len(strlen(data)), sequence_number(0),
115      time_stamp(base::Time::Now()) { }
116
117  // Read/write success (inferred data length).
118  MockReadWrite(IoMode io_mode, const char* data) : mode(io_mode), result(0),
119      data(data), data_len(strlen(data)), sequence_number(0),
120      time_stamp(base::Time::Now()) { }
121
122  // Read/write success.
123  MockReadWrite(IoMode io_mode, const char* data, int data_len) : mode(io_mode),
124      result(0), data(data), data_len(data_len), sequence_number(0),
125      time_stamp(base::Time::Now()) { }
126
127  // Read/write success (inferred data length) with sequence information.
128  MockReadWrite(IoMode io_mode, int seq, const char* data) : mode(io_mode),
129      result(0), data(data), data_len(strlen(data)), sequence_number(seq),
130      time_stamp(base::Time::Now()) { }
131
132  // Read/write success with sequence information.
133  MockReadWrite(IoMode io_mode, const char* data, int data_len, int seq) :
134      mode(io_mode), result(0), data(data), data_len(data_len),
135      sequence_number(seq), time_stamp(base::Time::Now()) { }
136
137  IoMode mode;
138  int result;
139  const char* data;
140  int data_len;
141
142  // For OrderedSocketData, which only allows reads to occur in a particular
143  // sequence.  If a read occurs before the given |sequence_number| is reached,
144  // an ERR_IO_PENDING is returned.
145  int sequence_number;      // The sequence number at which a read is allowed
146                            // to occur.
147  base::Time time_stamp;    // The time stamp at which the operation occurred.
148};
149
150typedef MockReadWrite<MOCK_READ> MockRead;
151typedef MockReadWrite<MOCK_WRITE> MockWrite;
152
153struct MockWriteResult {
154  MockWriteResult(IoMode io_mode, int result)
155      : mode(io_mode),
156        result(result) {}
157
158  IoMode mode;
159  int result;
160};
161
162// The SocketDataProvider is an interface used by the MockClientSocket
163// for getting data about individual reads and writes on the socket.
164class SocketDataProvider {
165 public:
166  SocketDataProvider() : socket_(NULL) {}
167
168  virtual ~SocketDataProvider() {}
169
170  // Returns the buffer and result code for the next simulated read.
171  // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller
172  // that it will be called via the AsyncSocket::OnReadComplete()
173  // function at a later time.
174  virtual MockRead GetNextRead() = 0;
175  virtual MockWriteResult OnWrite(const std::string& data) = 0;
176  virtual void Reset() = 0;
177
178  // Accessor for the socket which is using the SocketDataProvider.
179  AsyncSocket* socket() { return socket_; }
180  void set_socket(AsyncSocket* socket) { socket_ = socket; }
181
182  MockConnect connect_data() const { return connect_; }
183  void set_connect_data(const MockConnect& connect) { connect_ = connect; }
184
185 private:
186  MockConnect connect_;
187  AsyncSocket* socket_;
188
189  DISALLOW_COPY_AND_ASSIGN(SocketDataProvider);
190};
191
192// The AsyncSocket is an interface used by the SocketDataProvider to
193// complete the asynchronous read operation.
194class AsyncSocket {
195 public:
196  // If an async IO is pending because the SocketDataProvider returned
197  // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete
198  // is called to complete the asynchronous read operation.
199  // data.async is ignored, and this read is completed synchronously as
200  // part of this call.
201  virtual void OnReadComplete(const MockRead& data) = 0;
202};
203
204// SocketDataProvider which responds based on static tables of mock reads and
205// writes.
206class StaticSocketDataProvider : public SocketDataProvider {
207 public:
208  StaticSocketDataProvider();
209  StaticSocketDataProvider(MockRead* reads, size_t reads_count,
210                           MockWrite* writes, size_t writes_count);
211  virtual ~StaticSocketDataProvider();
212
213  // These functions get access to the next available read and write data.
214  const MockRead& PeekRead() const;
215  const MockWrite& PeekWrite() const;
216  // These functions get random access to the read and write data, for timing.
217  const MockRead& PeekRead(size_t index) const;
218  const MockWrite& PeekWrite(size_t index) const;
219  size_t read_index() const { return read_index_; }
220  size_t write_index() const { return write_index_; }
221  size_t read_count() const { return read_count_; }
222  size_t write_count() const { return write_count_; }
223
224  bool at_read_eof() const { return read_index_ >= read_count_; }
225  bool at_write_eof() const { return write_index_ >= write_count_; }
226
227  virtual void CompleteRead() {}
228
229  // SocketDataProvider implementation.
230  virtual MockRead GetNextRead() OVERRIDE;
231  virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
232  virtual void Reset() OVERRIDE;
233
234 private:
235  MockRead* reads_;
236  size_t read_index_;
237  size_t read_count_;
238  MockWrite* writes_;
239  size_t write_index_;
240  size_t write_count_;
241
242  DISALLOW_COPY_AND_ASSIGN(StaticSocketDataProvider);
243};
244
245// SocketDataProvider which can make decisions about next mock reads based on
246// received writes. It can also be used to enforce order of operations, for
247// example that tested code must send the "Hello!" message before receiving
248// response. This is useful for testing conversation-like protocols like FTP.
249class DynamicSocketDataProvider : public SocketDataProvider {
250 public:
251  DynamicSocketDataProvider();
252  virtual ~DynamicSocketDataProvider();
253
254  int short_read_limit() const { return short_read_limit_; }
255  void set_short_read_limit(int limit) { short_read_limit_ = limit; }
256
257  void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; }
258
259  // SocketDataProvider implementation.
260  virtual MockRead GetNextRead() OVERRIDE;
261  virtual MockWriteResult OnWrite(const std::string& data) = 0;
262  virtual void Reset() OVERRIDE;
263
264 protected:
265  // The next time there is a read from this socket, it will return |data|.
266  // Before calling SimulateRead next time, the previous data must be consumed.
267  void SimulateRead(const char* data, size_t length);
268  void SimulateRead(const char* data) {
269    SimulateRead(data, std::strlen(data));
270  }
271
272 private:
273  std::deque<MockRead> reads_;
274
275  // Max number of bytes we will read at a time. 0 means no limit.
276  int short_read_limit_;
277
278  // If true, we'll not require the client to consume all data before we
279  // mock the next read.
280  bool allow_unconsumed_reads_;
281
282  DISALLOW_COPY_AND_ASSIGN(DynamicSocketDataProvider);
283};
284
285// SSLSocketDataProviders only need to keep track of the return code from calls
286// to Connect().
287struct SSLSocketDataProvider {
288  SSLSocketDataProvider(IoMode mode, int result);
289  ~SSLSocketDataProvider();
290
291  void SetNextProto(NextProto proto);
292
293  MockConnect connect;
294  SSLClientSocket::NextProtoStatus next_proto_status;
295  std::string next_proto;
296  std::string server_protos;
297  bool was_npn_negotiated;
298  NextProto protocol_negotiated;
299  bool client_cert_sent;
300  SSLCertRequestInfo* cert_request_info;
301  scoped_refptr<X509Certificate> cert;
302  bool channel_id_sent;
303  ServerBoundCertService* server_bound_cert_service;
304};
305
306// A DataProvider where the client must write a request before the reads (e.g.
307// the response) will complete.
308class DelayedSocketData : public StaticSocketDataProvider {
309 public:
310  // |write_delay| the number of MockWrites to complete before allowing
311  //               a MockRead to complete.
312  // |reads| the list of MockRead completions.
313  // |writes| the list of MockWrite completions.
314  // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a
315  //       MockRead(true, 0, 0);
316  DelayedSocketData(int write_delay,
317                    MockRead* reads, size_t reads_count,
318                    MockWrite* writes, size_t writes_count);
319
320  // |connect| the result for the connect phase.
321  // |reads| the list of MockRead completions.
322  // |write_delay| the number of MockWrites to complete before allowing
323  //               a MockRead to complete.
324  // |writes| the list of MockWrite completions.
325  // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a
326  //       MockRead(true, 0, 0);
327  DelayedSocketData(const MockConnect& connect, int write_delay,
328                    MockRead* reads, size_t reads_count,
329                    MockWrite* writes, size_t writes_count);
330  virtual ~DelayedSocketData();
331
332  void ForceNextRead();
333
334  // StaticSocketDataProvider:
335  virtual MockRead GetNextRead() OVERRIDE;
336  virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
337  virtual void Reset() OVERRIDE;
338  virtual void CompleteRead() OVERRIDE;
339
340 private:
341  int write_delay_;
342  bool read_in_progress_;
343  base::WeakPtrFactory<DelayedSocketData> weak_factory_;
344};
345
346// A DataProvider where the reads are ordered.
347// If a read is requested before its sequence number is reached, we return an
348// ERR_IO_PENDING (that way we don't have to explicitly add a MockRead just to
349// wait).
350// The sequence number is incremented on every read and write operation.
351// The message loop may be interrupted by setting the high bit of the sequence
352// number in the MockRead's sequence number.  When that MockRead is reached,
353// we post a Quit message to the loop.  This allows us to interrupt the reading
354// of data before a complete message has arrived, and provides support for
355// testing server push when the request is issued while the response is in the
356// middle of being received.
357class OrderedSocketData : public StaticSocketDataProvider {
358 public:
359  // |reads| the list of MockRead completions.
360  // |writes| the list of MockWrite completions.
361  // Note: All MockReads and MockWrites must be async.
362  // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a
363  //       MockRead(true, 0, 0);
364  OrderedSocketData(MockRead* reads, size_t reads_count,
365                    MockWrite* writes, size_t writes_count);
366  virtual ~OrderedSocketData();
367
368  // |connect| the result for the connect phase.
369  // |reads| the list of MockRead completions.
370  // |writes| the list of MockWrite completions.
371  // Note: All MockReads and MockWrites must be async.
372  // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a
373  //       MockRead(true, 0, 0);
374  OrderedSocketData(const MockConnect& connect,
375                    MockRead* reads, size_t reads_count,
376                    MockWrite* writes, size_t writes_count);
377
378  // Posts a quit message to the current message loop, if one is running.
379  void EndLoop();
380
381  // StaticSocketDataProvider:
382  virtual MockRead GetNextRead() OVERRIDE;
383  virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
384  virtual void Reset() OVERRIDE;
385  virtual void CompleteRead() OVERRIDE;
386
387 private:
388  int sequence_number_;
389  int loop_stop_stage_;
390  bool blocked_;
391  base::WeakPtrFactory<OrderedSocketData> weak_factory_;
392};
393
394class DeterministicMockTCPClientSocket;
395
396// This class gives the user full control over the network activity,
397// specifically the timing of the COMPLETION of I/O operations.  Regardless of
398// the order in which I/O operations are initiated, this class ensures that they
399// complete in the correct order.
400//
401// Network activity is modeled as a sequence of numbered steps which is
402// incremented whenever an I/O operation completes.  This can happen under two
403// different circumstances:
404//
405// 1) Performing a synchronous I/O operation.  (Invoking Read() or Write()
406//    when the corresponding MockRead or MockWrite is marked !async).
407// 2) Running the Run() method of this class.  The run method will invoke
408//    the current MessageLoop, running all pending events, and will then
409//    invoke any pending IO callbacks.
410//
411// In addition, this class allows for I/O processing to "stop" at a specified
412// step, by calling SetStop(int) or StopAfter(int).  Initiating an I/O operation
413// by calling Read() or Write() while stopped is permitted if the operation is
414// asynchronous.  It is an error to perform synchronous I/O while stopped.
415//
416// When creating the MockReads and MockWrites, note that the sequence number
417// refers to the number of the step in which the I/O will complete.  In the
418// case of synchronous I/O, this will be the same step as the I/O is initiated.
419// However, in the case of asynchronous I/O, this I/O may be initiated in
420// a much earlier step. Furthermore, when the a Read() or Write() is separated
421// from its completion by other Read() or Writes()'s, it can not be marked
422// synchronous.  If it is, ERR_UNUEXPECTED will be returned indicating that a
423// synchronous Read() or Write() could not be completed synchronously because of
424// the specific ordering constraints.
425//
426// Sequence numbers are preserved across both reads and writes. There should be
427// no gaps in sequence numbers, and no repeated sequence numbers. i.e.
428//  MockRead reads[] = {
429//    MockRead(false, "first read", length, 0)   // sync
430//    MockRead(true, "second read", length, 2)   // async
431//  };
432//  MockWrite writes[] = {
433//    MockWrite(true, "first write", length, 1),    // async
434//    MockWrite(false, "second write", length, 3),  // sync
435//  };
436//
437// Example control flow:
438// Read() is called.  The current step is 0.  The first available read is
439// synchronous, so the call to Read() returns length.  The current step is
440// now 1.  Next, Read() is called again.  The next available read can
441// not be completed until step 2, so Read() returns ERR_IO_PENDING.  The current
442// step is still 1.  Write is called().  The first available write is able to
443// complete in this step, but is marked asynchronous.  Write() returns
444// ERR_IO_PENDING.  The current step is still 1.  At this point RunFor(1) is
445// called which will cause the write callback to be invoked, and will then
446// stop.  The current state is now 2.  RunFor(1) is called again, which
447// causes the read callback to be invoked, and will then stop.  Then current
448// step is 2.  Write() is called again.  Then next available write is
449// synchronous so the call to Write() returns length.
450//
451// For examples of how to use this class, see:
452//   deterministic_socket_data_unittests.cc
453class DeterministicSocketData
454    : public StaticSocketDataProvider {
455 public:
456  // The Delegate is an abstract interface which handles the communication from
457  // the DeterministicSocketData to the Deterministic MockSocket.  The
458  // MockSockets directly store a pointer to the DeterministicSocketData,
459  // whereas the DeterministicSocketData only stores a pointer to the
460  // abstract Delegate interface.
461  class Delegate {
462   public:
463    // Returns true if there is currently a write pending. That is to say, if
464    // an asynchronous write has been started but the callback has not been
465    // invoked.
466    virtual bool WritePending() const = 0;
467    // Returns true if there is currently a read pending. That is to say, if
468    // an asynchronous read has been started but the callback has not been
469    // invoked.
470    virtual bool ReadPending() const = 0;
471    // Called to complete an asynchronous write to execute the write callback.
472    virtual void CompleteWrite() = 0;
473    // Called to complete an asynchronous read to execute the read callback.
474    virtual int CompleteRead() = 0;
475
476   protected:
477    virtual ~Delegate() {}
478  };
479
480  // |reads| the list of MockRead completions.
481  // |writes| the list of MockWrite completions.
482  DeterministicSocketData(MockRead* reads, size_t reads_count,
483                          MockWrite* writes, size_t writes_count);
484  virtual ~DeterministicSocketData();
485
486  // Consume all the data up to the give stop point (via SetStop()).
487  void Run();
488
489  // Set the stop point to be |steps| from now, and then invoke Run().
490  void RunFor(int steps);
491
492  // Stop at step |seq|, which must be in the future.
493  virtual void SetStop(int seq);
494
495  // Stop |seq| steps after the current step.
496  virtual void StopAfter(int seq);
497  bool stopped() const { return stopped_; }
498  void SetStopped(bool val) { stopped_ = val; }
499  MockRead& current_read() { return current_read_; }
500  MockWrite& current_write() { return current_write_; }
501  int sequence_number() const { return sequence_number_; }
502  void set_delegate(base::WeakPtr<Delegate> delegate) {
503    delegate_ = delegate;
504  }
505
506  // StaticSocketDataProvider:
507
508  // When the socket calls Read(), that calls GetNextRead(), and expects either
509  // ERR_IO_PENDING or data.
510  virtual MockRead GetNextRead() OVERRIDE;
511
512  // When the socket calls Write(), it always completes synchronously. OnWrite()
513  // checks to make sure the written data matches the expected data. The
514  // callback will not be invoked until its sequence number is reached.
515  virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
516  virtual void Reset() OVERRIDE;
517  virtual void CompleteRead() OVERRIDE {}
518
519 private:
520  // Invoke the read and write callbacks, if the timing is appropriate.
521  void InvokeCallbacks();
522
523  void NextStep();
524
525  void VerifyCorrectSequenceNumbers(MockRead* reads, size_t reads_count,
526                                    MockWrite* writes, size_t writes_count);
527
528  int sequence_number_;
529  MockRead current_read_;
530  MockWrite current_write_;
531  int stopping_sequence_number_;
532  bool stopped_;
533  base::WeakPtr<Delegate> delegate_;
534  bool print_debug_;
535  bool is_running_;
536};
537
538// Holds an array of SocketDataProvider elements.  As Mock{TCP,SSL}StreamSocket
539// objects get instantiated, they take their data from the i'th element of this
540// array.
541template<typename T>
542class SocketDataProviderArray {
543 public:
544  SocketDataProviderArray() : next_index_(0) {}
545
546  T* GetNext() {
547    DCHECK_LT(next_index_, data_providers_.size());
548    return data_providers_[next_index_++];
549  }
550
551  void Add(T* data_provider) {
552    DCHECK(data_provider);
553    data_providers_.push_back(data_provider);
554  }
555
556  size_t next_index() { return next_index_; }
557
558  void ResetNextIndex() {
559    next_index_ = 0;
560  }
561
562 private:
563  // Index of the next |data_providers_| element to use. Not an iterator
564  // because those are invalidated on vector reallocation.
565  size_t next_index_;
566
567  // SocketDataProviders to be returned.
568  std::vector<T*> data_providers_;
569};
570
571class MockUDPClientSocket;
572class MockTCPClientSocket;
573class MockSSLClientSocket;
574
575// ClientSocketFactory which contains arrays of sockets of each type.
576// You should first fill the arrays using AddMock{SSL,}Socket. When the factory
577// is asked to create a socket, it takes next entry from appropriate array.
578// You can use ResetNextMockIndexes to reset that next entry index for all mock
579// socket types.
580class MockClientSocketFactory : public ClientSocketFactory {
581 public:
582  MockClientSocketFactory();
583  virtual ~MockClientSocketFactory();
584
585  void AddSocketDataProvider(SocketDataProvider* socket);
586  void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
587  void ResetNextMockIndexes();
588
589  SocketDataProviderArray<SocketDataProvider>& mock_data() {
590    return mock_data_;
591  }
592
593  // ClientSocketFactory
594  virtual DatagramClientSocket* CreateDatagramClientSocket(
595      DatagramSocket::BindType bind_type,
596      const RandIntCallback& rand_int_cb,
597      NetLog* net_log,
598      const NetLog::Source& source) OVERRIDE;
599  virtual StreamSocket* CreateTransportClientSocket(
600      const AddressList& addresses,
601      NetLog* net_log,
602      const NetLog::Source& source) OVERRIDE;
603  virtual SSLClientSocket* CreateSSLClientSocket(
604      ClientSocketHandle* transport_socket,
605      const HostPortPair& host_and_port,
606      const SSLConfig& ssl_config,
607      const SSLClientSocketContext& context) OVERRIDE;
608  virtual void ClearSSLSessionCache() OVERRIDE;
609
610 private:
611  SocketDataProviderArray<SocketDataProvider> mock_data_;
612  SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
613};
614
615class MockClientSocket : public SSLClientSocket {
616 public:
617  // Value returned by GetTLSUniqueChannelBinding().
618  static const char kTlsUnique[];
619
620  // The BoundNetLog is needed to test LoadTimingInfo, which uses NetLog IDs as
621  // unique socket IDs.
622  explicit MockClientSocket(const BoundNetLog& net_log);
623
624  // Socket implementation.
625  virtual int Read(IOBuffer* buf, int buf_len,
626                   const CompletionCallback& callback) = 0;
627  virtual int Write(IOBuffer* buf, int buf_len,
628                    const CompletionCallback& callback) = 0;
629  virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
630  virtual bool SetSendBufferSize(int32 size) OVERRIDE;
631
632  // StreamSocket implementation.
633  virtual int Connect(const CompletionCallback& callback) = 0;
634  virtual void Disconnect() OVERRIDE;
635  virtual bool IsConnected() const OVERRIDE;
636  virtual bool IsConnectedAndIdle() const OVERRIDE;
637  virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
638  virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
639  virtual const BoundNetLog& NetLog() const OVERRIDE;
640  virtual void SetSubresourceSpeculation() OVERRIDE {}
641  virtual void SetOmniboxSpeculation() OVERRIDE {}
642
643  // SSLClientSocket implementation.
644  virtual void GetSSLCertRequestInfo(
645      SSLCertRequestInfo* cert_request_info) OVERRIDE;
646  virtual int ExportKeyingMaterial(const base::StringPiece& label,
647                                   bool has_context,
648                                   const base::StringPiece& context,
649                                   unsigned char* out,
650                                   unsigned int outlen) OVERRIDE;
651  virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE;
652  virtual NextProtoStatus GetNextProto(std::string* proto,
653                                       std::string* server_protos) OVERRIDE;
654  virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE;
655
656 protected:
657  virtual ~MockClientSocket();
658  void RunCallbackAsync(const CompletionCallback& callback, int result);
659  void RunCallback(const CompletionCallback& callback, int result);
660
661  base::WeakPtrFactory<MockClientSocket> weak_factory_;
662
663  // True if Connect completed successfully and Disconnect hasn't been called.
664  bool connected_;
665
666  // Address of the "remote" peer we're connected to.
667  IPEndPoint peer_addr_;
668
669  BoundNetLog net_log_;
670};
671
672class MockTCPClientSocket : public MockClientSocket, public AsyncSocket {
673 public:
674  MockTCPClientSocket(const AddressList& addresses, net::NetLog* net_log,
675                      SocketDataProvider* socket);
676  virtual ~MockTCPClientSocket();
677
678  const AddressList& addresses() const { return addresses_; }
679
680  // Socket implementation.
681  virtual int Read(IOBuffer* buf, int buf_len,
682                   const CompletionCallback& callback) OVERRIDE;
683  virtual int Write(IOBuffer* buf, int buf_len,
684                    const CompletionCallback& callback) OVERRIDE;
685
686  // StreamSocket implementation.
687  virtual int Connect(const CompletionCallback& callback) OVERRIDE;
688  virtual void Disconnect() OVERRIDE;
689  virtual bool IsConnected() const OVERRIDE;
690  virtual bool IsConnectedAndIdle() const OVERRIDE;
691  virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
692  virtual bool WasEverUsed() const OVERRIDE;
693  virtual bool UsingTCPFastOpen() const OVERRIDE;
694  virtual bool WasNpnNegotiated() const OVERRIDE;
695  virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
696
697  // AsyncSocket:
698  virtual void OnReadComplete(const MockRead& data) OVERRIDE;
699
700 private:
701  int CompleteRead();
702
703  AddressList addresses_;
704
705  SocketDataProvider* data_;
706  int read_offset_;
707  MockRead read_data_;
708  bool need_read_data_;
709
710  // True if the peer has closed the connection.  This allows us to simulate
711  // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real
712  // TCPClientSocket.
713  bool peer_closed_connection_;
714
715  // While an asynchronous IO is pending, we save our user-buffer state.
716  IOBuffer* pending_buf_;
717  int pending_buf_len_;
718  CompletionCallback pending_callback_;
719  bool was_used_to_convey_data_;
720};
721
722// DeterministicSocketHelper is a helper class that can be used
723// to simulate net::Socket::Read() and net::Socket::Write()
724// using deterministic |data|.
725// Note: This is provided as a common helper class because
726// of the inheritance hierarchy of DeterministicMock[UDP,TCP]ClientSocket and a
727// desire not to introduce an additional common base class.
728class DeterministicSocketHelper {
729 public:
730  DeterministicSocketHelper(net::NetLog* net_log,
731                            DeterministicSocketData* data);
732  virtual ~DeterministicSocketHelper();
733
734  bool write_pending() const { return write_pending_; }
735  bool read_pending() const { return read_pending_; }
736
737  void CompleteWrite();
738  int CompleteRead();
739
740  int Write(IOBuffer* buf, int buf_len,
741            const CompletionCallback& callback);
742  int Read(IOBuffer* buf, int buf_len,
743           const CompletionCallback& callback);
744
745  const BoundNetLog& net_log() const { return net_log_; }
746
747  bool was_used_to_convey_data() const { return was_used_to_convey_data_; }
748
749  bool peer_closed_connection() const { return peer_closed_connection_; }
750
751  DeterministicSocketData* data() const { return data_; }
752
753 private:
754  bool write_pending_;
755  CompletionCallback write_callback_;
756  int write_result_;
757
758  MockRead read_data_;
759
760  IOBuffer* read_buf_;
761  int read_buf_len_;
762  bool read_pending_;
763  CompletionCallback read_callback_;
764  DeterministicSocketData* data_;
765  bool was_used_to_convey_data_;
766  bool peer_closed_connection_;
767  BoundNetLog net_log_;
768};
769
770// Mock UDP socket to be used in conjunction with DeterministicSocketData.
771class DeterministicMockUDPClientSocket
772    : public DatagramClientSocket,
773      public AsyncSocket,
774      public DeterministicSocketData::Delegate,
775      public base::SupportsWeakPtr<DeterministicMockUDPClientSocket> {
776 public:
777  DeterministicMockUDPClientSocket(net::NetLog* net_log,
778                                   DeterministicSocketData* data);
779  virtual ~DeterministicMockUDPClientSocket();
780
781  // DeterministicSocketData::Delegate:
782  virtual bool WritePending() const OVERRIDE;
783  virtual bool ReadPending() const OVERRIDE;
784  virtual void CompleteWrite() OVERRIDE;
785  virtual int CompleteRead() OVERRIDE;
786
787  // Socket implementation.
788  virtual int Read(IOBuffer* buf, int buf_len,
789                   const CompletionCallback& callback) OVERRIDE;
790  virtual int Write(IOBuffer* buf, int buf_len,
791                    const CompletionCallback& callback) OVERRIDE;
792  virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
793  virtual bool SetSendBufferSize(int32 size) OVERRIDE;
794
795  // DatagramSocket implementation.
796  virtual void Close() OVERRIDE;
797  virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
798  virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
799  virtual const BoundNetLog& NetLog() const OVERRIDE;
800
801  // DatagramClientSocket implementation.
802  virtual int Connect(const IPEndPoint& address) OVERRIDE;
803
804  // AsyncSocket implementation.
805  virtual void OnReadComplete(const MockRead& data) OVERRIDE;
806
807 private:
808  bool connected_;
809  IPEndPoint peer_address_;
810  DeterministicSocketHelper helper_;
811};
812
813// Mock TCP socket to be used in conjunction with DeterministicSocketData.
814class DeterministicMockTCPClientSocket
815    : public MockClientSocket,
816      public AsyncSocket,
817      public DeterministicSocketData::Delegate,
818      public base::SupportsWeakPtr<DeterministicMockTCPClientSocket> {
819 public:
820  DeterministicMockTCPClientSocket(net::NetLog* net_log,
821                                   DeterministicSocketData* data);
822  virtual ~DeterministicMockTCPClientSocket();
823
824  // DeterministicSocketData::Delegate:
825  virtual bool WritePending() const OVERRIDE;
826  virtual bool ReadPending() const OVERRIDE;
827  virtual void CompleteWrite() OVERRIDE;
828  virtual int CompleteRead() OVERRIDE;
829
830  // Socket:
831  virtual int Write(IOBuffer* buf, int buf_len,
832                    const CompletionCallback& callback) OVERRIDE;
833  virtual int Read(IOBuffer* buf, int buf_len,
834                   const CompletionCallback& callback) OVERRIDE;
835
836  // StreamSocket:
837  virtual int Connect(const CompletionCallback& callback) OVERRIDE;
838  virtual void Disconnect() OVERRIDE;
839  virtual bool IsConnected() const OVERRIDE;
840  virtual bool IsConnectedAndIdle() const OVERRIDE;
841  virtual bool WasEverUsed() const OVERRIDE;
842  virtual bool UsingTCPFastOpen() const OVERRIDE;
843  virtual bool WasNpnNegotiated() const OVERRIDE;
844  virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
845
846  // AsyncSocket:
847  virtual void OnReadComplete(const MockRead& data) OVERRIDE;
848
849 private:
850  DeterministicSocketHelper helper_;
851};
852
853class MockSSLClientSocket : public MockClientSocket, public AsyncSocket {
854 public:
855  MockSSLClientSocket(
856      ClientSocketHandle* transport_socket,
857      const HostPortPair& host_and_port,
858      const SSLConfig& ssl_config,
859      SSLSocketDataProvider* socket);
860  virtual ~MockSSLClientSocket();
861
862  // Socket implementation.
863  virtual int Read(IOBuffer* buf, int buf_len,
864                   const CompletionCallback& callback) OVERRIDE;
865  virtual int Write(IOBuffer* buf, int buf_len,
866                    const CompletionCallback& callback) OVERRIDE;
867
868  // StreamSocket implementation.
869  virtual int Connect(const CompletionCallback& callback) OVERRIDE;
870  virtual void Disconnect() OVERRIDE;
871  virtual bool IsConnected() const OVERRIDE;
872  virtual bool WasEverUsed() const OVERRIDE;
873  virtual bool UsingTCPFastOpen() const OVERRIDE;
874  virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
875  virtual bool WasNpnNegotiated() const OVERRIDE;
876  virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
877
878  // SSLClientSocket implementation.
879  virtual void GetSSLCertRequestInfo(
880      SSLCertRequestInfo* cert_request_info) OVERRIDE;
881  virtual NextProtoStatus GetNextProto(std::string* proto,
882                                       std::string* server_protos) OVERRIDE;
883  virtual bool set_was_npn_negotiated(bool negotiated) OVERRIDE;
884  virtual void set_protocol_negotiated(
885      NextProto protocol_negotiated) OVERRIDE;
886  virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
887
888  // This MockSocket does not implement the manual async IO feature.
889  virtual void OnReadComplete(const MockRead& data) OVERRIDE;
890
891  virtual bool WasChannelIDSent() const OVERRIDE;
892  virtual void set_channel_id_sent(bool channel_id_sent) OVERRIDE;
893  virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE;
894
895 private:
896  static void ConnectCallback(MockSSLClientSocket *ssl_client_socket,
897                              const CompletionCallback& callback,
898                              int rv);
899
900  scoped_ptr<ClientSocketHandle> transport_;
901  SSLSocketDataProvider* data_;
902  bool is_npn_state_set_;
903  bool new_npn_value_;
904  bool is_protocol_negotiated_set_;
905  NextProto protocol_negotiated_;
906};
907
908class MockUDPClientSocket
909    : public DatagramClientSocket,
910      public AsyncSocket {
911 public:
912  MockUDPClientSocket(SocketDataProvider* data, net::NetLog* net_log);
913  virtual ~MockUDPClientSocket();
914
915  // Socket implementation.
916  virtual int Read(IOBuffer* buf, int buf_len,
917                   const CompletionCallback& callback) OVERRIDE;
918  virtual int Write(IOBuffer* buf, int buf_len,
919                    const CompletionCallback& callback) OVERRIDE;
920  virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
921  virtual bool SetSendBufferSize(int32 size) OVERRIDE;
922
923  // DatagramSocket implementation.
924  virtual void Close() OVERRIDE;
925  virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
926  virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
927  virtual const BoundNetLog& NetLog() const OVERRIDE;
928
929  // DatagramClientSocket implementation.
930  virtual int Connect(const IPEndPoint& address) OVERRIDE;
931
932  // AsyncSocket implementation.
933  virtual void OnReadComplete(const MockRead& data) OVERRIDE;
934
935 private:
936  int CompleteRead();
937
938  void RunCallbackAsync(const CompletionCallback& callback, int result);
939  void RunCallback(const CompletionCallback& callback, int result);
940
941  bool connected_;
942  SocketDataProvider* data_;
943  int read_offset_;
944  MockRead read_data_;
945  bool need_read_data_;
946
947  // Address of the "remote" peer we're connected to.
948  IPEndPoint peer_addr_;
949
950  // While an asynchronous IO is pending, we save our user-buffer state.
951  IOBuffer* pending_buf_;
952  int pending_buf_len_;
953  CompletionCallback pending_callback_;
954
955  BoundNetLog net_log_;
956
957  base::WeakPtrFactory<MockUDPClientSocket> weak_factory_;
958
959  DISALLOW_COPY_AND_ASSIGN(MockUDPClientSocket);
960};
961
962class TestSocketRequest : public TestCompletionCallbackBase {
963 public:
964  TestSocketRequest(std::vector<TestSocketRequest*>* request_order,
965                    size_t* completion_count);
966  virtual ~TestSocketRequest();
967
968  ClientSocketHandle* handle() { return &handle_; }
969
970  const net::CompletionCallback& callback() const { return callback_; }
971
972 private:
973  void OnComplete(int result);
974
975  ClientSocketHandle handle_;
976  std::vector<TestSocketRequest*>* request_order_;
977  size_t* completion_count_;
978  CompletionCallback callback_;
979
980  DISALLOW_COPY_AND_ASSIGN(TestSocketRequest);
981};
982
983class ClientSocketPoolTest {
984 public:
985  enum KeepAlive {
986    KEEP_ALIVE,
987
988    // A socket will be disconnected in addition to handle being reset.
989    NO_KEEP_ALIVE,
990  };
991
992  static const int kIndexOutOfBounds;
993  static const int kRequestNotFound;
994
995  ClientSocketPoolTest();
996  ~ClientSocketPoolTest();
997
998  template <typename PoolType, typename SocketParams>
999  int StartRequestUsingPool(PoolType* socket_pool,
1000                            const std::string& group_name,
1001                            RequestPriority priority,
1002                            const scoped_refptr<SocketParams>& socket_params) {
1003    DCHECK(socket_pool);
1004    TestSocketRequest* request = new TestSocketRequest(&request_order_,
1005                                                       &completion_count_);
1006    requests_.push_back(request);
1007    int rv = request->handle()->Init(
1008        group_name, socket_params, priority, request->callback(),
1009        socket_pool, BoundNetLog());
1010    if (rv != ERR_IO_PENDING)
1011      request_order_.push_back(request);
1012    return rv;
1013  }
1014
1015  // Provided there were n requests started, takes |index| in range 1..n
1016  // and returns order in which that request completed, in range 1..n,
1017  // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound
1018  // if that request did not complete (for example was canceled).
1019  int GetOrderOfRequest(size_t index) const;
1020
1021  // Resets first initialized socket handle from |requests_|. If found such
1022  // a handle, returns true.
1023  bool ReleaseOneConnection(KeepAlive keep_alive);
1024
1025  // Releases connections until there is nothing to release.
1026  void ReleaseAllConnections(KeepAlive keep_alive);
1027
1028  TestSocketRequest* request(int i) { return requests_[i]; }
1029  size_t requests_size() const { return requests_.size(); }
1030  ScopedVector<TestSocketRequest>* requests() { return &requests_; }
1031  size_t completion_count() const { return completion_count_; }
1032
1033 private:
1034  ScopedVector<TestSocketRequest> requests_;
1035  std::vector<TestSocketRequest*> request_order_;
1036  size_t completion_count_;
1037};
1038
1039class MockTransportClientSocketPool : public TransportClientSocketPool {
1040 public:
1041  class MockConnectJob {
1042   public:
1043    MockConnectJob(StreamSocket* socket, ClientSocketHandle* handle,
1044                   const CompletionCallback& callback);
1045    ~MockConnectJob();
1046
1047    int Connect();
1048    bool CancelHandle(const ClientSocketHandle* handle);
1049
1050   private:
1051    void OnConnect(int rv);
1052
1053    scoped_ptr<StreamSocket> socket_;
1054    ClientSocketHandle* handle_;
1055    CompletionCallback user_callback_;
1056
1057    DISALLOW_COPY_AND_ASSIGN(MockConnectJob);
1058  };
1059
1060  MockTransportClientSocketPool(
1061      int max_sockets,
1062      int max_sockets_per_group,
1063      ClientSocketPoolHistograms* histograms,
1064      ClientSocketFactory* socket_factory);
1065
1066  virtual ~MockTransportClientSocketPool();
1067
1068  int release_count() const { return release_count_; }
1069  int cancel_count() const { return cancel_count_; }
1070
1071  // TransportClientSocketPool implementation.
1072  virtual int RequestSocket(const std::string& group_name,
1073                            const void* socket_params,
1074                            RequestPriority priority,
1075                            ClientSocketHandle* handle,
1076                            const CompletionCallback& callback,
1077                            const BoundNetLog& net_log) OVERRIDE;
1078
1079  virtual void CancelRequest(const std::string& group_name,
1080                             ClientSocketHandle* handle) OVERRIDE;
1081  virtual void ReleaseSocket(const std::string& group_name,
1082                             StreamSocket* socket, int id) OVERRIDE;
1083
1084 private:
1085  ClientSocketFactory* client_socket_factory_;
1086  ScopedVector<MockConnectJob> job_list_;
1087  int release_count_;
1088  int cancel_count_;
1089
1090  DISALLOW_COPY_AND_ASSIGN(MockTransportClientSocketPool);
1091};
1092
1093class DeterministicMockClientSocketFactory : public ClientSocketFactory {
1094 public:
1095  DeterministicMockClientSocketFactory();
1096  virtual ~DeterministicMockClientSocketFactory();
1097
1098  void AddSocketDataProvider(DeterministicSocketData* socket);
1099  void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
1100  void ResetNextMockIndexes();
1101
1102  // Return |index|-th MockSSLClientSocket (starting from 0) that the factory
1103  // created.
1104  MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const;
1105
1106  SocketDataProviderArray<DeterministicSocketData>& mock_data() {
1107    return mock_data_;
1108  }
1109  std::vector<DeterministicMockTCPClientSocket*>& tcp_client_sockets() {
1110    return tcp_client_sockets_;
1111  }
1112  std::vector<DeterministicMockUDPClientSocket*>& udp_client_sockets() {
1113    return udp_client_sockets_;
1114  }
1115
1116  // ClientSocketFactory
1117  virtual DatagramClientSocket* CreateDatagramClientSocket(
1118      DatagramSocket::BindType bind_type,
1119      const RandIntCallback& rand_int_cb,
1120      NetLog* net_log,
1121      const NetLog::Source& source) OVERRIDE;
1122  virtual StreamSocket* CreateTransportClientSocket(
1123      const AddressList& addresses,
1124      NetLog* net_log,
1125      const NetLog::Source& source) OVERRIDE;
1126  virtual SSLClientSocket* CreateSSLClientSocket(
1127      ClientSocketHandle* transport_socket,
1128      const HostPortPair& host_and_port,
1129      const SSLConfig& ssl_config,
1130      const SSLClientSocketContext& context) OVERRIDE;
1131  virtual void ClearSSLSessionCache() OVERRIDE;
1132
1133 private:
1134  SocketDataProviderArray<DeterministicSocketData> mock_data_;
1135  SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
1136
1137  // Store pointers to handed out sockets in case the test wants to get them.
1138  std::vector<DeterministicMockTCPClientSocket*> tcp_client_sockets_;
1139  std::vector<DeterministicMockUDPClientSocket*> udp_client_sockets_;
1140  std::vector<MockSSLClientSocket*> ssl_client_sockets_;
1141};
1142
1143class MockSOCKSClientSocketPool : public SOCKSClientSocketPool {
1144 public:
1145  MockSOCKSClientSocketPool(
1146      int max_sockets,
1147      int max_sockets_per_group,
1148      ClientSocketPoolHistograms* histograms,
1149      TransportClientSocketPool* transport_pool);
1150
1151  virtual ~MockSOCKSClientSocketPool();
1152
1153  // SOCKSClientSocketPool implementation.
1154  virtual int RequestSocket(const std::string& group_name,
1155                            const void* socket_params,
1156                            RequestPriority priority,
1157                            ClientSocketHandle* handle,
1158                            const CompletionCallback& callback,
1159                            const BoundNetLog& net_log) OVERRIDE;
1160
1161  virtual void CancelRequest(const std::string& group_name,
1162                             ClientSocketHandle* handle) OVERRIDE;
1163  virtual void ReleaseSocket(const std::string& group_name,
1164                             StreamSocket* socket, int id) OVERRIDE;
1165
1166 private:
1167  TransportClientSocketPool* const transport_pool_;
1168
1169  DISALLOW_COPY_AND_ASSIGN(MockSOCKSClientSocketPool);
1170};
1171
1172// Constants for a successful SOCKS v5 handshake.
1173extern const char kSOCKS5GreetRequest[];
1174extern const int kSOCKS5GreetRequestLength;
1175
1176extern const char kSOCKS5GreetResponse[];
1177extern const int kSOCKS5GreetResponseLength;
1178
1179extern const char kSOCKS5OkRequest[];
1180extern const int kSOCKS5OkRequestLength;
1181
1182extern const char kSOCKS5OkResponse[];
1183extern const int kSOCKS5OkResponseLength;
1184
1185}  // namespace net
1186
1187#endif  // NET_SOCKET_SOCKET_TEST_UTIL_H_
1188