1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_
6#define NET_SOCKET_SOCKET_TEST_UTIL_H_
7
8#include <cstring>
9#include <deque>
10#include <string>
11#include <vector>
12
13#include "base/basictypes.h"
14#include "base/callback.h"
15#include "base/logging.h"
16#include "base/memory/scoped_ptr.h"
17#include "base/memory/scoped_vector.h"
18#include "base/memory/weak_ptr.h"
19#include "base/strings/string16.h"
20#include "net/base/address_list.h"
21#include "net/base/io_buffer.h"
22#include "net/base/net_errors.h"
23#include "net/base/net_log.h"
24#include "net/base/test_completion_callback.h"
25#include "net/http/http_auth_controller.h"
26#include "net/http/http_proxy_client_socket_pool.h"
27#include "net/socket/client_socket_factory.h"
28#include "net/socket/client_socket_handle.h"
29#include "net/socket/socks_client_socket_pool.h"
30#include "net/socket/ssl_client_socket.h"
31#include "net/socket/ssl_client_socket_pool.h"
32#include "net/socket/transport_client_socket_pool.h"
33#include "net/ssl/ssl_config_service.h"
34#include "net/udp/datagram_client_socket.h"
35#include "testing/gtest/include/gtest/gtest.h"
36
37namespace net {
38
39enum {
40  // A private network error code used by the socket test utility classes.
41  // If the |result| member of a MockRead is
42  // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a
43  // marker that indicates the peer will close the connection after the next
44  // MockRead.  The other members of that MockRead are ignored.
45  ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000,
46};
47
48class AsyncSocket;
49class MockClientSocket;
50class ServerBoundCertService;
51class SSLClientSocket;
52class StreamSocket;
53
54enum IoMode {
55  ASYNC,
56  SYNCHRONOUS
57};
58
59struct MockConnect {
60  // Asynchronous connection success.
61  // Creates a MockConnect with |mode| ASYC, |result| OK, and
62  // |peer_addr| 192.0.2.33.
63  MockConnect();
64  // Creates a MockConnect with the specified mode and result, with
65  // |peer_addr| 192.0.2.33.
66  MockConnect(IoMode io_mode, int r);
67  MockConnect(IoMode io_mode, int r, IPEndPoint addr);
68  ~MockConnect();
69
70  IoMode mode;
71  int result;
72  IPEndPoint peer_addr;
73};
74
75// MockRead and MockWrite shares the same interface and members, but we'd like
76// to have distinct types because we don't want to have them used
77// interchangably. To do this, a struct template is defined, and MockRead and
78// MockWrite are instantiated by using this template. Template parameter |type|
79// is not used in the struct definition (it purely exists for creating a new
80// type).
81//
82// |data| in MockRead and MockWrite has different meanings: |data| in MockRead
83// is the data returned from the socket when MockTCPClientSocket::Read() is
84// attempted, while |data| in MockWrite is the expected data that should be
85// given in MockTCPClientSocket::Write().
86enum MockReadWriteType {
87  MOCK_READ,
88  MOCK_WRITE
89};
90
91template <MockReadWriteType type>
92struct MockReadWrite {
93  // Flag to indicate that the message loop should be terminated.
94  enum {
95    STOPLOOP = 1 << 31
96  };
97
98  // Default
99  MockReadWrite() : mode(SYNCHRONOUS), result(0), data(NULL), data_len(0),
100      sequence_number(0), time_stamp(base::Time::Now()) {}
101
102  // Read/write failure (no data).
103  MockReadWrite(IoMode io_mode, int result) : mode(io_mode), result(result),
104      data(NULL), data_len(0), sequence_number(0),
105      time_stamp(base::Time::Now()) { }
106
107  // Read/write failure (no data), with sequence information.
108  MockReadWrite(IoMode io_mode, int result, int seq) : mode(io_mode),
109      result(result), data(NULL), data_len(0), sequence_number(seq),
110      time_stamp(base::Time::Now()) { }
111
112  // Asynchronous read/write success (inferred data length).
113  explicit MockReadWrite(const char* data) : mode(ASYNC),  result(0),
114      data(data), data_len(strlen(data)), sequence_number(0),
115      time_stamp(base::Time::Now()) { }
116
117  // Read/write success (inferred data length).
118  MockReadWrite(IoMode io_mode, const char* data) : mode(io_mode), result(0),
119      data(data), data_len(strlen(data)), sequence_number(0),
120      time_stamp(base::Time::Now()) { }
121
122  // Read/write success.
123  MockReadWrite(IoMode io_mode, const char* data, int data_len) : mode(io_mode),
124      result(0), data(data), data_len(data_len), sequence_number(0),
125      time_stamp(base::Time::Now()) { }
126
127  // Read/write success (inferred data length) with sequence information.
128  MockReadWrite(IoMode io_mode, int seq, const char* data) : mode(io_mode),
129      result(0), data(data), data_len(strlen(data)), sequence_number(seq),
130      time_stamp(base::Time::Now()) { }
131
132  // Read/write success with sequence information.
133  MockReadWrite(IoMode io_mode, const char* data, int data_len, int seq) :
134      mode(io_mode), result(0), data(data), data_len(data_len),
135      sequence_number(seq), time_stamp(base::Time::Now()) { }
136
137  IoMode mode;
138  int result;
139  const char* data;
140  int data_len;
141
142  // For OrderedSocketData, which only allows reads to occur in a particular
143  // sequence.  If a read occurs before the given |sequence_number| is reached,
144  // an ERR_IO_PENDING is returned.
145  int sequence_number;      // The sequence number at which a read is allowed
146                            // to occur.
147  base::Time time_stamp;    // The time stamp at which the operation occurred.
148};
149
150typedef MockReadWrite<MOCK_READ> MockRead;
151typedef MockReadWrite<MOCK_WRITE> MockWrite;
152
153struct MockWriteResult {
154  MockWriteResult(IoMode io_mode, int result)
155      : mode(io_mode),
156        result(result) {}
157
158  IoMode mode;
159  int result;
160};
161
162// The SocketDataProvider is an interface used by the MockClientSocket
163// for getting data about individual reads and writes on the socket.
164class SocketDataProvider {
165 public:
166  SocketDataProvider() : socket_(NULL) {}
167
168  virtual ~SocketDataProvider() {}
169
170  // Returns the buffer and result code for the next simulated read.
171  // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller
172  // that it will be called via the AsyncSocket::OnReadComplete()
173  // function at a later time.
174  virtual MockRead GetNextRead() = 0;
175  virtual MockWriteResult OnWrite(const std::string& data) = 0;
176  virtual void Reset() = 0;
177
178  // Accessor for the socket which is using the SocketDataProvider.
179  AsyncSocket* socket() { return socket_; }
180  void set_socket(AsyncSocket* socket) { socket_ = socket; }
181
182  MockConnect connect_data() const { return connect_; }
183  void set_connect_data(const MockConnect& connect) { connect_ = connect; }
184
185 private:
186  MockConnect connect_;
187  AsyncSocket* socket_;
188
189  DISALLOW_COPY_AND_ASSIGN(SocketDataProvider);
190};
191
192// The AsyncSocket is an interface used by the SocketDataProvider to
193// complete the asynchronous read operation.
194class AsyncSocket {
195 public:
196  // If an async IO is pending because the SocketDataProvider returned
197  // ERR_IO_PENDING, then the AsyncSocket waits until this OnReadComplete
198  // is called to complete the asynchronous read operation.
199  // data.async is ignored, and this read is completed synchronously as
200  // part of this call.
201  virtual void OnReadComplete(const MockRead& data) = 0;
202  virtual void OnConnectComplete(const MockConnect& data) = 0;
203};
204
205// SocketDataProvider which responds based on static tables of mock reads and
206// writes.
207class StaticSocketDataProvider : public SocketDataProvider {
208 public:
209  StaticSocketDataProvider();
210  StaticSocketDataProvider(MockRead* reads, size_t reads_count,
211                           MockWrite* writes, size_t writes_count);
212  virtual ~StaticSocketDataProvider();
213
214  // These functions get access to the next available read and write data.
215  const MockRead& PeekRead() const;
216  const MockWrite& PeekWrite() const;
217  // These functions get random access to the read and write data, for timing.
218  const MockRead& PeekRead(size_t index) const;
219  const MockWrite& PeekWrite(size_t index) const;
220  size_t read_index() const { return read_index_; }
221  size_t write_index() const { return write_index_; }
222  size_t read_count() const { return read_count_; }
223  size_t write_count() const { return write_count_; }
224
225  bool at_read_eof() const { return read_index_ >= read_count_; }
226  bool at_write_eof() const { return write_index_ >= write_count_; }
227
228  virtual void CompleteRead() {}
229
230  // SocketDataProvider implementation.
231  virtual MockRead GetNextRead() OVERRIDE;
232  virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
233  ;  virtual void Reset() OVERRIDE;
234
235 private:
236  MockRead* reads_;
237  size_t read_index_;
238  size_t read_count_;
239  MockWrite* writes_;
240  size_t write_index_;
241  size_t write_count_;
242
243  DISALLOW_COPY_AND_ASSIGN(StaticSocketDataProvider);
244};
245
246// SocketDataProvider which can make decisions about next mock reads based on
247// received writes. It can also be used to enforce order of operations, for
248// example that tested code must send the "Hello!" message before receiving
249// response. This is useful for testing conversation-like protocols like FTP.
250class DynamicSocketDataProvider : public SocketDataProvider {
251 public:
252  DynamicSocketDataProvider();
253  virtual ~DynamicSocketDataProvider();
254
255  int short_read_limit() const { return short_read_limit_; }
256  void set_short_read_limit(int limit) { short_read_limit_ = limit; }
257
258  void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; }
259
260  // SocketDataProvider implementation.
261  virtual MockRead GetNextRead() OVERRIDE;
262  virtual MockWriteResult OnWrite(const std::string& data) = 0;
263  virtual void Reset() OVERRIDE;
264
265 protected:
266  // The next time there is a read from this socket, it will return |data|.
267  // Before calling SimulateRead next time, the previous data must be consumed.
268  void SimulateRead(const char* data, size_t length);
269  void SimulateRead(const char* data) {
270    SimulateRead(data, std::strlen(data));
271  }
272
273 private:
274  std::deque<MockRead> reads_;
275
276  // Max number of bytes we will read at a time. 0 means no limit.
277  int short_read_limit_;
278
279  // If true, we'll not require the client to consume all data before we
280  // mock the next read.
281  bool allow_unconsumed_reads_;
282
283  DISALLOW_COPY_AND_ASSIGN(DynamicSocketDataProvider);
284};
285
286// SSLSocketDataProviders only need to keep track of the return code from calls
287// to Connect().
288struct SSLSocketDataProvider {
289  SSLSocketDataProvider(IoMode mode, int result);
290  ~SSLSocketDataProvider();
291
292  void SetNextProto(NextProto proto);
293
294  MockConnect connect;
295  SSLClientSocket::NextProtoStatus next_proto_status;
296  std::string next_proto;
297  std::string server_protos;
298  bool was_npn_negotiated;
299  NextProto protocol_negotiated;
300  bool client_cert_sent;
301  SSLCertRequestInfo* cert_request_info;
302  scoped_refptr<X509Certificate> cert;
303  bool channel_id_sent;
304  ServerBoundCertService* server_bound_cert_service;
305};
306
307// A DataProvider where the client must write a request before the reads (e.g.
308// the response) will complete.
309class DelayedSocketData : public StaticSocketDataProvider {
310 public:
311  // |write_delay| the number of MockWrites to complete before allowing
312  //               a MockRead to complete.
313  // |reads| the list of MockRead completions.
314  // |writes| the list of MockWrite completions.
315  // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a
316  //       MockRead(true, 0, 0);
317  DelayedSocketData(int write_delay,
318                    MockRead* reads, size_t reads_count,
319                    MockWrite* writes, size_t writes_count);
320
321  // |connect| the result for the connect phase.
322  // |reads| the list of MockRead completions.
323  // |write_delay| the number of MockWrites to complete before allowing
324  //               a MockRead to complete.
325  // |writes| the list of MockWrite completions.
326  // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a
327  //       MockRead(true, 0, 0);
328  DelayedSocketData(const MockConnect& connect, int write_delay,
329                    MockRead* reads, size_t reads_count,
330                    MockWrite* writes, size_t writes_count);
331  virtual ~DelayedSocketData();
332
333  void ForceNextRead();
334
335  // StaticSocketDataProvider:
336  virtual MockRead GetNextRead() OVERRIDE;
337  virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
338  virtual void Reset() OVERRIDE;
339  virtual void CompleteRead() OVERRIDE;
340
341 private:
342  int write_delay_;
343  bool read_in_progress_;
344  base::WeakPtrFactory<DelayedSocketData> weak_factory_;
345};
346
347// A DataProvider where the reads are ordered.
348// If a read is requested before its sequence number is reached, we return an
349// ERR_IO_PENDING (that way we don't have to explicitly add a MockRead just to
350// wait).
351// The sequence number is incremented on every read and write operation.
352// The message loop may be interrupted by setting the high bit of the sequence
353// number in the MockRead's sequence number.  When that MockRead is reached,
354// we post a Quit message to the loop.  This allows us to interrupt the reading
355// of data before a complete message has arrived, and provides support for
356// testing server push when the request is issued while the response is in the
357// middle of being received.
358class OrderedSocketData : public StaticSocketDataProvider {
359 public:
360  // |reads| the list of MockRead completions.
361  // |writes| the list of MockWrite completions.
362  // Note: All MockReads and MockWrites must be async.
363  // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a
364  //       MockRead(true, 0, 0);
365  OrderedSocketData(MockRead* reads, size_t reads_count,
366                    MockWrite* writes, size_t writes_count);
367  virtual ~OrderedSocketData();
368
369  // |connect| the result for the connect phase.
370  // |reads| the list of MockRead completions.
371  // |writes| the list of MockWrite completions.
372  // Note: All MockReads and MockWrites must be async.
373  // Note: For stream sockets, the MockRead list must end with a EOF, e.g., a
374  //       MockRead(true, 0, 0);
375  OrderedSocketData(const MockConnect& connect,
376                    MockRead* reads, size_t reads_count,
377                    MockWrite* writes, size_t writes_count);
378
379  // Posts a quit message to the current message loop, if one is running.
380  void EndLoop();
381
382  // StaticSocketDataProvider:
383  virtual MockRead GetNextRead() OVERRIDE;
384  virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
385  virtual void Reset() OVERRIDE;
386  virtual void CompleteRead() OVERRIDE;
387
388 private:
389  int sequence_number_;
390  int loop_stop_stage_;
391  bool blocked_;
392  base::WeakPtrFactory<OrderedSocketData> weak_factory_;
393};
394
395class DeterministicMockTCPClientSocket;
396
397// This class gives the user full control over the network activity,
398// specifically the timing of the COMPLETION of I/O operations.  Regardless of
399// the order in which I/O operations are initiated, this class ensures that they
400// complete in the correct order.
401//
402// Network activity is modeled as a sequence of numbered steps which is
403// incremented whenever an I/O operation completes.  This can happen under two
404// different circumstances:
405//
406// 1) Performing a synchronous I/O operation.  (Invoking Read() or Write()
407//    when the corresponding MockRead or MockWrite is marked !async).
408// 2) Running the Run() method of this class.  The run method will invoke
409//    the current MessageLoop, running all pending events, and will then
410//    invoke any pending IO callbacks.
411//
412// In addition, this class allows for I/O processing to "stop" at a specified
413// step, by calling SetStop(int) or StopAfter(int).  Initiating an I/O operation
414// by calling Read() or Write() while stopped is permitted if the operation is
415// asynchronous.  It is an error to perform synchronous I/O while stopped.
416//
417// When creating the MockReads and MockWrites, note that the sequence number
418// refers to the number of the step in which the I/O will complete.  In the
419// case of synchronous I/O, this will be the same step as the I/O is initiated.
420// However, in the case of asynchronous I/O, this I/O may be initiated in
421// a much earlier step. Furthermore, when the a Read() or Write() is separated
422// from its completion by other Read() or Writes()'s, it can not be marked
423// synchronous.  If it is, ERR_UNUEXPECTED will be returned indicating that a
424// synchronous Read() or Write() could not be completed synchronously because of
425// the specific ordering constraints.
426//
427// Sequence numbers are preserved across both reads and writes. There should be
428// no gaps in sequence numbers, and no repeated sequence numbers. i.e.
429//  MockRead reads[] = {
430//    MockRead(false, "first read", length, 0)   // sync
431//    MockRead(true, "second read", length, 2)   // async
432//  };
433//  MockWrite writes[] = {
434//    MockWrite(true, "first write", length, 1),    // async
435//    MockWrite(false, "second write", length, 3),  // sync
436//  };
437//
438// Example control flow:
439// Read() is called.  The current step is 0.  The first available read is
440// synchronous, so the call to Read() returns length.  The current step is
441// now 1.  Next, Read() is called again.  The next available read can
442// not be completed until step 2, so Read() returns ERR_IO_PENDING.  The current
443// step is still 1.  Write is called().  The first available write is able to
444// complete in this step, but is marked asynchronous.  Write() returns
445// ERR_IO_PENDING.  The current step is still 1.  At this point RunFor(1) is
446// called which will cause the write callback to be invoked, and will then
447// stop.  The current state is now 2.  RunFor(1) is called again, which
448// causes the read callback to be invoked, and will then stop.  Then current
449// step is 2.  Write() is called again.  Then next available write is
450// synchronous so the call to Write() returns length.
451//
452// For examples of how to use this class, see:
453//   deterministic_socket_data_unittests.cc
454class DeterministicSocketData
455    : public StaticSocketDataProvider {
456 public:
457  // The Delegate is an abstract interface which handles the communication from
458  // the DeterministicSocketData to the Deterministic MockSocket.  The
459  // MockSockets directly store a pointer to the DeterministicSocketData,
460  // whereas the DeterministicSocketData only stores a pointer to the
461  // abstract Delegate interface.
462  class Delegate {
463   public:
464    // Returns true if there is currently a write pending. That is to say, if
465    // an asynchronous write has been started but the callback has not been
466    // invoked.
467    virtual bool WritePending() const = 0;
468    // Returns true if there is currently a read pending. That is to say, if
469    // an asynchronous read has been started but the callback has not been
470    // invoked.
471    virtual bool ReadPending() const = 0;
472    // Called to complete an asynchronous write to execute the write callback.
473    virtual void CompleteWrite() = 0;
474    // Called to complete an asynchronous read to execute the read callback.
475    virtual int CompleteRead() = 0;
476
477   protected:
478    virtual ~Delegate() {}
479  };
480
481  // |reads| the list of MockRead completions.
482  // |writes| the list of MockWrite completions.
483  DeterministicSocketData(MockRead* reads, size_t reads_count,
484                          MockWrite* writes, size_t writes_count);
485  virtual ~DeterministicSocketData();
486
487  // Consume all the data up to the give stop point (via SetStop()).
488  void Run();
489
490  // Set the stop point to be |steps| from now, and then invoke Run().
491  void RunFor(int steps);
492
493  // Stop at step |seq|, which must be in the future.
494  virtual void SetStop(int seq);
495
496  // Stop |seq| steps after the current step.
497  virtual void StopAfter(int seq);
498  bool stopped() const { return stopped_; }
499  void SetStopped(bool val) { stopped_ = val; }
500  MockRead& current_read() { return current_read_; }
501  MockWrite& current_write() { return current_write_; }
502  int sequence_number() const { return sequence_number_; }
503  void set_delegate(base::WeakPtr<Delegate> delegate) {
504    delegate_ = delegate;
505  }
506
507  // StaticSocketDataProvider:
508
509  // When the socket calls Read(), that calls GetNextRead(), and expects either
510  // ERR_IO_PENDING or data.
511  virtual MockRead GetNextRead() OVERRIDE;
512
513  // When the socket calls Write(), it always completes synchronously. OnWrite()
514  // checks to make sure the written data matches the expected data. The
515  // callback will not be invoked until its sequence number is reached.
516  virtual MockWriteResult OnWrite(const std::string& data) OVERRIDE;
517  virtual void Reset() OVERRIDE;
518  virtual void CompleteRead() OVERRIDE {}
519
520 private:
521  // Invoke the read and write callbacks, if the timing is appropriate.
522  void InvokeCallbacks();
523
524  void NextStep();
525
526  void VerifyCorrectSequenceNumbers(MockRead* reads, size_t reads_count,
527                                    MockWrite* writes, size_t writes_count);
528
529  int sequence_number_;
530  MockRead current_read_;
531  MockWrite current_write_;
532  int stopping_sequence_number_;
533  bool stopped_;
534  base::WeakPtr<Delegate> delegate_;
535  bool print_debug_;
536  bool is_running_;
537};
538
539// Holds an array of SocketDataProvider elements.  As Mock{TCP,SSL}StreamSocket
540// objects get instantiated, they take their data from the i'th element of this
541// array.
542template<typename T>
543class SocketDataProviderArray {
544 public:
545  SocketDataProviderArray() : next_index_(0) {}
546
547  T* GetNext() {
548    DCHECK_LT(next_index_, data_providers_.size());
549    return data_providers_[next_index_++];
550  }
551
552  void Add(T* data_provider) {
553    DCHECK(data_provider);
554    data_providers_.push_back(data_provider);
555  }
556
557  size_t next_index() { return next_index_; }
558
559  void ResetNextIndex() {
560    next_index_ = 0;
561  }
562
563 private:
564  // Index of the next |data_providers_| element to use. Not an iterator
565  // because those are invalidated on vector reallocation.
566  size_t next_index_;
567
568  // SocketDataProviders to be returned.
569  std::vector<T*> data_providers_;
570};
571
572class MockUDPClientSocket;
573class MockTCPClientSocket;
574class MockSSLClientSocket;
575
576// ClientSocketFactory which contains arrays of sockets of each type.
577// You should first fill the arrays using AddMock{SSL,}Socket. When the factory
578// is asked to create a socket, it takes next entry from appropriate array.
579// You can use ResetNextMockIndexes to reset that next entry index for all mock
580// socket types.
581class MockClientSocketFactory : public ClientSocketFactory {
582 public:
583  MockClientSocketFactory();
584  virtual ~MockClientSocketFactory();
585
586  void AddSocketDataProvider(SocketDataProvider* socket);
587  void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
588  void ResetNextMockIndexes();
589
590  SocketDataProviderArray<SocketDataProvider>& mock_data() {
591    return mock_data_;
592  }
593
594  // ClientSocketFactory
595  virtual DatagramClientSocket* CreateDatagramClientSocket(
596      DatagramSocket::BindType bind_type,
597      const RandIntCallback& rand_int_cb,
598      NetLog* net_log,
599      const NetLog::Source& source) OVERRIDE;
600  virtual StreamSocket* CreateTransportClientSocket(
601      const AddressList& addresses,
602      NetLog* net_log,
603      const NetLog::Source& source) OVERRIDE;
604  virtual SSLClientSocket* CreateSSLClientSocket(
605      ClientSocketHandle* transport_socket,
606      const HostPortPair& host_and_port,
607      const SSLConfig& ssl_config,
608      const SSLClientSocketContext& context) OVERRIDE;
609  virtual void ClearSSLSessionCache() OVERRIDE;
610
611 private:
612  SocketDataProviderArray<SocketDataProvider> mock_data_;
613  SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
614};
615
616class MockClientSocket : public SSLClientSocket {
617 public:
618  // Value returned by GetTLSUniqueChannelBinding().
619  static const char kTlsUnique[];
620
621  // The BoundNetLog is needed to test LoadTimingInfo, which uses NetLog IDs as
622  // unique socket IDs.
623  explicit MockClientSocket(const BoundNetLog& net_log);
624
625  // Socket implementation.
626  virtual int Read(IOBuffer* buf, int buf_len,
627                   const CompletionCallback& callback) = 0;
628  virtual int Write(IOBuffer* buf, int buf_len,
629                    const CompletionCallback& callback) = 0;
630  virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
631  virtual bool SetSendBufferSize(int32 size) OVERRIDE;
632
633  // StreamSocket implementation.
634  virtual int Connect(const CompletionCallback& callback) = 0;
635  virtual void Disconnect() OVERRIDE;
636  virtual bool IsConnected() const OVERRIDE;
637  virtual bool IsConnectedAndIdle() const OVERRIDE;
638  virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
639  virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
640  virtual const BoundNetLog& NetLog() const OVERRIDE;
641  virtual void SetSubresourceSpeculation() OVERRIDE {}
642  virtual void SetOmniboxSpeculation() OVERRIDE {}
643
644  // SSLClientSocket implementation.
645  virtual void GetSSLCertRequestInfo(
646      SSLCertRequestInfo* cert_request_info) OVERRIDE;
647  virtual int ExportKeyingMaterial(const base::StringPiece& label,
648                                   bool has_context,
649                                   const base::StringPiece& context,
650                                   unsigned char* out,
651                                   unsigned int outlen) OVERRIDE;
652  virtual int GetTLSUniqueChannelBinding(std::string* out) OVERRIDE;
653  virtual NextProtoStatus GetNextProto(std::string* proto,
654                                       std::string* server_protos) OVERRIDE;
655  virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE;
656
657 protected:
658  virtual ~MockClientSocket();
659  void RunCallbackAsync(const CompletionCallback& callback, int result);
660  void RunCallback(const CompletionCallback& callback, int result);
661
662  base::WeakPtrFactory<MockClientSocket> weak_factory_;
663
664  // True if Connect completed successfully and Disconnect hasn't been called.
665  bool connected_;
666
667  // Address of the "remote" peer we're connected to.
668  IPEndPoint peer_addr_;
669
670  BoundNetLog net_log_;
671};
672
673class MockTCPClientSocket : public MockClientSocket, public AsyncSocket {
674 public:
675  MockTCPClientSocket(const AddressList& addresses, net::NetLog* net_log,
676                      SocketDataProvider* socket);
677  virtual ~MockTCPClientSocket();
678
679  const AddressList& addresses() const { return addresses_; }
680
681  // Socket implementation.
682  virtual int Read(IOBuffer* buf, int buf_len,
683                   const CompletionCallback& callback) OVERRIDE;
684  virtual int Write(IOBuffer* buf, int buf_len,
685                    const CompletionCallback& callback) OVERRIDE;
686
687  // StreamSocket implementation.
688  virtual int Connect(const CompletionCallback& callback) OVERRIDE;
689  virtual void Disconnect() OVERRIDE;
690  virtual bool IsConnected() const OVERRIDE;
691  virtual bool IsConnectedAndIdle() const OVERRIDE;
692  virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
693  virtual bool WasEverUsed() const OVERRIDE;
694  virtual bool UsingTCPFastOpen() const OVERRIDE;
695  virtual bool WasNpnNegotiated() const OVERRIDE;
696  virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
697
698  // AsyncSocket:
699  virtual void OnReadComplete(const MockRead& data) OVERRIDE;
700  virtual void OnConnectComplete(const MockConnect& data) OVERRIDE;
701
702 private:
703  int CompleteRead();
704
705  AddressList addresses_;
706
707  SocketDataProvider* data_;
708  int read_offset_;
709  MockRead read_data_;
710  bool need_read_data_;
711
712  // True if the peer has closed the connection.  This allows us to simulate
713  // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real
714  // TCPClientSocket.
715  bool peer_closed_connection_;
716
717  // While an asynchronous IO is pending, we save our user-buffer state.
718  IOBuffer* pending_buf_;
719  int pending_buf_len_;
720  CompletionCallback pending_callback_;
721  bool was_used_to_convey_data_;
722};
723
724// DeterministicSocketHelper is a helper class that can be used
725// to simulate net::Socket::Read() and net::Socket::Write()
726// using deterministic |data|.
727// Note: This is provided as a common helper class because
728// of the inheritance hierarchy of DeterministicMock[UDP,TCP]ClientSocket and a
729// desire not to introduce an additional common base class.
730class DeterministicSocketHelper {
731 public:
732  DeterministicSocketHelper(net::NetLog* net_log,
733                            DeterministicSocketData* data);
734  virtual ~DeterministicSocketHelper();
735
736  bool write_pending() const { return write_pending_; }
737  bool read_pending() const { return read_pending_; }
738
739  void CompleteWrite();
740  int CompleteRead();
741
742  int Write(IOBuffer* buf, int buf_len,
743            const CompletionCallback& callback);
744  int Read(IOBuffer* buf, int buf_len,
745           const CompletionCallback& callback);
746
747  const BoundNetLog& net_log() const { return net_log_; }
748
749  bool was_used_to_convey_data() const { return was_used_to_convey_data_; }
750
751  bool peer_closed_connection() const { return peer_closed_connection_; }
752
753  DeterministicSocketData* data() const { return data_; }
754
755 private:
756  bool write_pending_;
757  CompletionCallback write_callback_;
758  int write_result_;
759
760  MockRead read_data_;
761
762  IOBuffer* read_buf_;
763  int read_buf_len_;
764  bool read_pending_;
765  CompletionCallback read_callback_;
766  DeterministicSocketData* data_;
767  bool was_used_to_convey_data_;
768  bool peer_closed_connection_;
769  BoundNetLog net_log_;
770};
771
772// Mock UDP socket to be used in conjunction with DeterministicSocketData.
773class DeterministicMockUDPClientSocket
774    : public DatagramClientSocket,
775      public AsyncSocket,
776      public DeterministicSocketData::Delegate,
777      public base::SupportsWeakPtr<DeterministicMockUDPClientSocket> {
778 public:
779  DeterministicMockUDPClientSocket(net::NetLog* net_log,
780                                   DeterministicSocketData* data);
781  virtual ~DeterministicMockUDPClientSocket();
782
783  // DeterministicSocketData::Delegate:
784  virtual bool WritePending() const OVERRIDE;
785  virtual bool ReadPending() const OVERRIDE;
786  virtual void CompleteWrite() OVERRIDE;
787  virtual int CompleteRead() OVERRIDE;
788
789  // Socket implementation.
790  virtual int Read(IOBuffer* buf, int buf_len,
791                   const CompletionCallback& callback) OVERRIDE;
792  virtual int Write(IOBuffer* buf, int buf_len,
793                    const CompletionCallback& callback) OVERRIDE;
794  virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
795  virtual bool SetSendBufferSize(int32 size) OVERRIDE;
796
797  // DatagramSocket implementation.
798  virtual void Close() OVERRIDE;
799  virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
800  virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
801  virtual const BoundNetLog& NetLog() const OVERRIDE;
802
803  // DatagramClientSocket implementation.
804  virtual int Connect(const IPEndPoint& address) OVERRIDE;
805
806  // AsyncSocket implementation.
807  virtual void OnReadComplete(const MockRead& data) OVERRIDE;
808  virtual void OnConnectComplete(const MockConnect& data) OVERRIDE;
809
810 private:
811  bool connected_;
812  IPEndPoint peer_address_;
813  DeterministicSocketHelper helper_;
814};
815
816// Mock TCP socket to be used in conjunction with DeterministicSocketData.
817class DeterministicMockTCPClientSocket
818    : public MockClientSocket,
819      public AsyncSocket,
820      public DeterministicSocketData::Delegate,
821      public base::SupportsWeakPtr<DeterministicMockTCPClientSocket> {
822 public:
823  DeterministicMockTCPClientSocket(net::NetLog* net_log,
824                                   DeterministicSocketData* data);
825  virtual ~DeterministicMockTCPClientSocket();
826
827  // DeterministicSocketData::Delegate:
828  virtual bool WritePending() const OVERRIDE;
829  virtual bool ReadPending() const OVERRIDE;
830  virtual void CompleteWrite() OVERRIDE;
831  virtual int CompleteRead() OVERRIDE;
832
833  // Socket:
834  virtual int Write(IOBuffer* buf, int buf_len,
835                    const CompletionCallback& callback) OVERRIDE;
836  virtual int Read(IOBuffer* buf, int buf_len,
837                   const CompletionCallback& callback) OVERRIDE;
838
839  // StreamSocket:
840  virtual int Connect(const CompletionCallback& callback) OVERRIDE;
841  virtual void Disconnect() OVERRIDE;
842  virtual bool IsConnected() const OVERRIDE;
843  virtual bool IsConnectedAndIdle() const OVERRIDE;
844  virtual bool WasEverUsed() const OVERRIDE;
845  virtual bool UsingTCPFastOpen() const OVERRIDE;
846  virtual bool WasNpnNegotiated() const OVERRIDE;
847  virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
848
849  // AsyncSocket:
850  virtual void OnReadComplete(const MockRead& data) OVERRIDE;
851  virtual void OnConnectComplete(const MockConnect& data) OVERRIDE;
852
853 private:
854  DeterministicSocketHelper helper_;
855};
856
857class MockSSLClientSocket : public MockClientSocket, public AsyncSocket {
858 public:
859  MockSSLClientSocket(
860      ClientSocketHandle* transport_socket,
861      const HostPortPair& host_and_port,
862      const SSLConfig& ssl_config,
863      SSLSocketDataProvider* socket);
864  virtual ~MockSSLClientSocket();
865
866  // Socket implementation.
867  virtual int Read(IOBuffer* buf, int buf_len,
868                   const CompletionCallback& callback) OVERRIDE;
869  virtual int Write(IOBuffer* buf, int buf_len,
870                    const CompletionCallback& callback) OVERRIDE;
871
872  // StreamSocket implementation.
873  virtual int Connect(const CompletionCallback& callback) OVERRIDE;
874  virtual void Disconnect() OVERRIDE;
875  virtual bool IsConnected() const OVERRIDE;
876  virtual bool WasEverUsed() const OVERRIDE;
877  virtual bool UsingTCPFastOpen() const OVERRIDE;
878  virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
879  virtual bool WasNpnNegotiated() const OVERRIDE;
880  virtual bool GetSSLInfo(SSLInfo* ssl_info) OVERRIDE;
881
882  // SSLClientSocket implementation.
883  virtual void GetSSLCertRequestInfo(
884      SSLCertRequestInfo* cert_request_info) OVERRIDE;
885  virtual NextProtoStatus GetNextProto(std::string* proto,
886                                       std::string* server_protos) OVERRIDE;
887  virtual bool set_was_npn_negotiated(bool negotiated) OVERRIDE;
888  virtual void set_protocol_negotiated(
889      NextProto protocol_negotiated) OVERRIDE;
890  virtual NextProto GetNegotiatedProtocol() const OVERRIDE;
891
892  // This MockSocket does not implement the manual async IO feature.
893  virtual void OnReadComplete(const MockRead& data) OVERRIDE;
894  virtual void OnConnectComplete(const MockConnect& data) OVERRIDE;
895
896  virtual bool WasChannelIDSent() const OVERRIDE;
897  virtual void set_channel_id_sent(bool channel_id_sent) OVERRIDE;
898  virtual ServerBoundCertService* GetServerBoundCertService() const OVERRIDE;
899
900 private:
901  static void ConnectCallback(MockSSLClientSocket *ssl_client_socket,
902                              const CompletionCallback& callback,
903                              int rv);
904
905  scoped_ptr<ClientSocketHandle> transport_;
906  SSLSocketDataProvider* data_;
907  bool is_npn_state_set_;
908  bool new_npn_value_;
909  bool is_protocol_negotiated_set_;
910  NextProto protocol_negotiated_;
911};
912
913class MockUDPClientSocket
914    : public DatagramClientSocket,
915      public AsyncSocket {
916 public:
917  MockUDPClientSocket(SocketDataProvider* data, net::NetLog* net_log);
918  virtual ~MockUDPClientSocket();
919
920  // Socket implementation.
921  virtual int Read(IOBuffer* buf, int buf_len,
922                   const CompletionCallback& callback) OVERRIDE;
923  virtual int Write(IOBuffer* buf, int buf_len,
924                    const CompletionCallback& callback) OVERRIDE;
925  virtual bool SetReceiveBufferSize(int32 size) OVERRIDE;
926  virtual bool SetSendBufferSize(int32 size) OVERRIDE;
927
928  // DatagramSocket implementation.
929  virtual void Close() OVERRIDE;
930  virtual int GetPeerAddress(IPEndPoint* address) const OVERRIDE;
931  virtual int GetLocalAddress(IPEndPoint* address) const OVERRIDE;
932  virtual const BoundNetLog& NetLog() const OVERRIDE;
933
934  // DatagramClientSocket implementation.
935  virtual int Connect(const IPEndPoint& address) OVERRIDE;
936
937  // AsyncSocket implementation.
938  virtual void OnReadComplete(const MockRead& data) OVERRIDE;
939  virtual void OnConnectComplete(const MockConnect& data) OVERRIDE;
940
941 private:
942  int CompleteRead();
943
944  void RunCallbackAsync(const CompletionCallback& callback, int result);
945  void RunCallback(const CompletionCallback& callback, int result);
946
947  bool connected_;
948  SocketDataProvider* data_;
949  int read_offset_;
950  MockRead read_data_;
951  bool need_read_data_;
952
953  // Address of the "remote" peer we're connected to.
954  IPEndPoint peer_addr_;
955
956  // While an asynchronous IO is pending, we save our user-buffer state.
957  IOBuffer* pending_buf_;
958  int pending_buf_len_;
959  CompletionCallback pending_callback_;
960
961  BoundNetLog net_log_;
962
963  base::WeakPtrFactory<MockUDPClientSocket> weak_factory_;
964
965  DISALLOW_COPY_AND_ASSIGN(MockUDPClientSocket);
966};
967
968class TestSocketRequest : public TestCompletionCallbackBase {
969 public:
970  TestSocketRequest(std::vector<TestSocketRequest*>* request_order,
971                    size_t* completion_count);
972  virtual ~TestSocketRequest();
973
974  ClientSocketHandle* handle() { return &handle_; }
975
976  const net::CompletionCallback& callback() const { return callback_; }
977
978 private:
979  void OnComplete(int result);
980
981  ClientSocketHandle handle_;
982  std::vector<TestSocketRequest*>* request_order_;
983  size_t* completion_count_;
984  CompletionCallback callback_;
985
986  DISALLOW_COPY_AND_ASSIGN(TestSocketRequest);
987};
988
989class ClientSocketPoolTest {
990 public:
991  enum KeepAlive {
992    KEEP_ALIVE,
993
994    // A socket will be disconnected in addition to handle being reset.
995    NO_KEEP_ALIVE,
996  };
997
998  static const int kIndexOutOfBounds;
999  static const int kRequestNotFound;
1000
1001  ClientSocketPoolTest();
1002  ~ClientSocketPoolTest();
1003
1004  template <typename PoolType, typename SocketParams>
1005  int StartRequestUsingPool(PoolType* socket_pool,
1006                            const std::string& group_name,
1007                            RequestPriority priority,
1008                            const scoped_refptr<SocketParams>& socket_params) {
1009    DCHECK(socket_pool);
1010    TestSocketRequest* request = new TestSocketRequest(&request_order_,
1011                                                       &completion_count_);
1012    requests_.push_back(request);
1013    int rv = request->handle()->Init(
1014        group_name, socket_params, priority, request->callback(),
1015        socket_pool, BoundNetLog());
1016    if (rv != ERR_IO_PENDING)
1017      request_order_.push_back(request);
1018    return rv;
1019  }
1020
1021  // Provided there were n requests started, takes |index| in range 1..n
1022  // and returns order in which that request completed, in range 1..n,
1023  // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound
1024  // if that request did not complete (for example was canceled).
1025  int GetOrderOfRequest(size_t index) const;
1026
1027  // Resets first initialized socket handle from |requests_|. If found such
1028  // a handle, returns true.
1029  bool ReleaseOneConnection(KeepAlive keep_alive);
1030
1031  // Releases connections until there is nothing to release.
1032  void ReleaseAllConnections(KeepAlive keep_alive);
1033
1034  // Note that this uses 0-based indices, while GetOrderOfRequest takes and
1035  // returns 0-based indices.
1036  TestSocketRequest* request(int i) { return requests_[i]; }
1037
1038  size_t requests_size() const { return requests_.size(); }
1039  ScopedVector<TestSocketRequest>* requests() { return &requests_; }
1040  size_t completion_count() const { return completion_count_; }
1041
1042 private:
1043  ScopedVector<TestSocketRequest> requests_;
1044  std::vector<TestSocketRequest*> request_order_;
1045  size_t completion_count_;
1046};
1047
1048class MockTransportClientSocketPool : public TransportClientSocketPool {
1049 public:
1050  class MockConnectJob {
1051   public:
1052    MockConnectJob(StreamSocket* socket, ClientSocketHandle* handle,
1053                   const CompletionCallback& callback);
1054    ~MockConnectJob();
1055
1056    int Connect();
1057    bool CancelHandle(const ClientSocketHandle* handle);
1058
1059   private:
1060    void OnConnect(int rv);
1061
1062    scoped_ptr<StreamSocket> socket_;
1063    ClientSocketHandle* handle_;
1064    CompletionCallback user_callback_;
1065
1066    DISALLOW_COPY_AND_ASSIGN(MockConnectJob);
1067  };
1068
1069  MockTransportClientSocketPool(
1070      int max_sockets,
1071      int max_sockets_per_group,
1072      ClientSocketPoolHistograms* histograms,
1073      ClientSocketFactory* socket_factory);
1074
1075  virtual ~MockTransportClientSocketPool();
1076
1077  int release_count() const { return release_count_; }
1078  int cancel_count() const { return cancel_count_; }
1079
1080  // TransportClientSocketPool implementation.
1081  virtual int RequestSocket(const std::string& group_name,
1082                            const void* socket_params,
1083                            RequestPriority priority,
1084                            ClientSocketHandle* handle,
1085                            const CompletionCallback& callback,
1086                            const BoundNetLog& net_log) OVERRIDE;
1087
1088  virtual void CancelRequest(const std::string& group_name,
1089                             ClientSocketHandle* handle) OVERRIDE;
1090  virtual void ReleaseSocket(const std::string& group_name,
1091                             StreamSocket* socket, int id) OVERRIDE;
1092
1093 private:
1094  ClientSocketFactory* client_socket_factory_;
1095  ScopedVector<MockConnectJob> job_list_;
1096  int release_count_;
1097  int cancel_count_;
1098
1099  DISALLOW_COPY_AND_ASSIGN(MockTransportClientSocketPool);
1100};
1101
1102class DeterministicMockClientSocketFactory : public ClientSocketFactory {
1103 public:
1104  DeterministicMockClientSocketFactory();
1105  virtual ~DeterministicMockClientSocketFactory();
1106
1107  void AddSocketDataProvider(DeterministicSocketData* socket);
1108  void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
1109  void ResetNextMockIndexes();
1110
1111  // Return |index|-th MockSSLClientSocket (starting from 0) that the factory
1112  // created.
1113  MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const;
1114
1115  SocketDataProviderArray<DeterministicSocketData>& mock_data() {
1116    return mock_data_;
1117  }
1118  std::vector<DeterministicMockTCPClientSocket*>& tcp_client_sockets() {
1119    return tcp_client_sockets_;
1120  }
1121  std::vector<DeterministicMockUDPClientSocket*>& udp_client_sockets() {
1122    return udp_client_sockets_;
1123  }
1124
1125  // ClientSocketFactory
1126  virtual DatagramClientSocket* CreateDatagramClientSocket(
1127      DatagramSocket::BindType bind_type,
1128      const RandIntCallback& rand_int_cb,
1129      NetLog* net_log,
1130      const NetLog::Source& source) OVERRIDE;
1131  virtual StreamSocket* CreateTransportClientSocket(
1132      const AddressList& addresses,
1133      NetLog* net_log,
1134      const NetLog::Source& source) OVERRIDE;
1135  virtual SSLClientSocket* CreateSSLClientSocket(
1136      ClientSocketHandle* transport_socket,
1137      const HostPortPair& host_and_port,
1138      const SSLConfig& ssl_config,
1139      const SSLClientSocketContext& context) OVERRIDE;
1140  virtual void ClearSSLSessionCache() OVERRIDE;
1141
1142 private:
1143  SocketDataProviderArray<DeterministicSocketData> mock_data_;
1144  SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
1145
1146  // Store pointers to handed out sockets in case the test wants to get them.
1147  std::vector<DeterministicMockTCPClientSocket*> tcp_client_sockets_;
1148  std::vector<DeterministicMockUDPClientSocket*> udp_client_sockets_;
1149  std::vector<MockSSLClientSocket*> ssl_client_sockets_;
1150};
1151
1152class MockSOCKSClientSocketPool : public SOCKSClientSocketPool {
1153 public:
1154  MockSOCKSClientSocketPool(
1155      int max_sockets,
1156      int max_sockets_per_group,
1157      ClientSocketPoolHistograms* histograms,
1158      TransportClientSocketPool* transport_pool);
1159
1160  virtual ~MockSOCKSClientSocketPool();
1161
1162  // SOCKSClientSocketPool implementation.
1163  virtual int RequestSocket(const std::string& group_name,
1164                            const void* socket_params,
1165                            RequestPriority priority,
1166                            ClientSocketHandle* handle,
1167                            const CompletionCallback& callback,
1168                            const BoundNetLog& net_log) OVERRIDE;
1169
1170  virtual void CancelRequest(const std::string& group_name,
1171                             ClientSocketHandle* handle) OVERRIDE;
1172  virtual void ReleaseSocket(const std::string& group_name,
1173                             StreamSocket* socket, int id) OVERRIDE;
1174
1175 private:
1176  TransportClientSocketPool* const transport_pool_;
1177
1178  DISALLOW_COPY_AND_ASSIGN(MockSOCKSClientSocketPool);
1179};
1180
1181// Constants for a successful SOCKS v5 handshake.
1182extern const char kSOCKS5GreetRequest[];
1183extern const int kSOCKS5GreetRequestLength;
1184
1185extern const char kSOCKS5GreetResponse[];
1186extern const int kSOCKS5GreetResponseLength;
1187
1188extern const char kSOCKS5OkRequest[];
1189extern const int kSOCKS5OkRequestLength;
1190
1191extern const char kSOCKS5OkResponse[];
1192extern const int kSOCKS5OkResponseLength;
1193
1194}  // namespace net
1195
1196#endif  // NET_SOCKET_SOCKET_TEST_UTIL_H_
1197