socket_test_util.h revision 21d179b334e59e9a3bfcaed4c4430bef1bc5759d
1// Copyright (c) 2010 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_
6#define NET_SOCKET_SOCKET_TEST_UTIL_H_
7#pragma once
8
9#include <cstring>
10#include <deque>
11#include <string>
12#include <vector>
13
14#include "base/basictypes.h"
15#include "base/callback.h"
16#include "base/logging.h"
17#include "base/scoped_ptr.h"
18#include "base/scoped_vector.h"
19#include "base/string16.h"
20#include "base/weak_ptr.h"
21#include "net/base/address_list.h"
22#include "net/base/io_buffer.h"
23#include "net/base/net_errors.h"
24#include "net/base/net_log.h"
25#include "net/base/ssl_config_service.h"
26#include "net/base/test_completion_callback.h"
27#include "net/http/http_auth_controller.h"
28#include "net/http/http_proxy_client_socket_pool.h"
29#include "net/socket/client_socket_factory.h"
30#include "net/socket/client_socket_handle.h"
31#include "net/socket/socks_client_socket_pool.h"
32#include "net/socket/ssl_client_socket.h"
33#include "net/socket/ssl_client_socket_pool.h"
34#include "net/socket/tcp_client_socket_pool.h"
35#include "testing/gtest/include/gtest/gtest.h"
36
37namespace net {
38
39enum {
40  // A private network error code used by the socket test utility classes.
41  // If the |result| member of a MockRead is
42  // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a
43  // marker that indicates the peer will close the connection after the next
44  // MockRead.  The other members of that MockRead are ignored.
45  ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000,
46};
47
48class ClientSocket;
49class MockClientSocket;
50class SSLClientSocket;
51class SSLHostInfo;
52
53struct MockConnect {
54  // Asynchronous connection success.
55  MockConnect() : async(true), result(OK) { }
56  MockConnect(bool a, int r) : async(a), result(r) { }
57
58  bool async;
59  int result;
60};
61
62struct MockRead {
63  // Flag to indicate that the message loop should be terminated.
64  enum {
65    STOPLOOP = 1 << 31
66  };
67
68  // Default
69  MockRead() : async(false), result(0), data(NULL), data_len(0),
70      sequence_number(0), time_stamp(base::Time::Now()) {}
71
72  // Read failure (no data).
73  MockRead(bool async, int result) : async(async) , result(result), data(NULL),
74      data_len(0), sequence_number(0), time_stamp(base::Time::Now()) { }
75
76  // Read failure (no data), with sequence information.
77  MockRead(bool async, int result, int seq) : async(async) , result(result),
78      data(NULL), data_len(0), sequence_number(seq),
79      time_stamp(base::Time::Now()) { }
80
81  // Asynchronous read success (inferred data length).
82  explicit MockRead(const char* data) : async(true),  result(0), data(data),
83      data_len(strlen(data)), sequence_number(0),
84      time_stamp(base::Time::Now()) { }
85
86  // Read success (inferred data length).
87  MockRead(bool async, const char* data) : async(async), result(0), data(data),
88      data_len(strlen(data)), sequence_number(0),
89      time_stamp(base::Time::Now()) { }
90
91  // Read success.
92  MockRead(bool async, const char* data, int data_len) : async(async),
93      result(0), data(data), data_len(data_len), sequence_number(0),
94      time_stamp(base::Time::Now()) { }
95
96  // Read success (inferred data length) with sequence information.
97  MockRead(bool async, int seq, const char* data) : async(async),
98      result(0), data(data), data_len(strlen(data)), sequence_number(seq),
99      time_stamp(base::Time::Now()) { }
100
101  // Read success with sequence information.
102  MockRead(bool async, const char* data, int data_len, int seq) : async(async),
103      result(0), data(data), data_len(data_len), sequence_number(seq),
104      time_stamp(base::Time::Now()) { }
105
106  bool async;
107  int result;
108  const char* data;
109  int data_len;
110
111  // For OrderedSocketData, which only allows reads to occur in a particular
112  // sequence.  If a read occurs before the given |sequence_number| is reached,
113  // an ERR_IO_PENDING is returned.
114  int sequence_number;      // The sequence number at which a read is allowed
115                            // to occur.
116  base::Time time_stamp;    // The time stamp at which the operation occurred.
117};
118
119// MockWrite uses the same member fields as MockRead, but with different
120// meanings. The expected input to MockTCPClientSocket::Write() is given
121// by {data, data_len}, and the return value of Write() is controlled by
122// {async, result}.
123typedef MockRead MockWrite;
124
125struct MockWriteResult {
126  MockWriteResult(bool async, int result) : async(async), result(result) {}
127
128  bool async;
129  int result;
130};
131
132// The SocketDataProvider is an interface used by the MockClientSocket
133// for getting data about individual reads and writes on the socket.
134class SocketDataProvider {
135 public:
136  SocketDataProvider() : socket_(NULL) {}
137
138  virtual ~SocketDataProvider() {}
139
140  // Returns the buffer and result code for the next simulated read.
141  // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller
142  // that it will be called via the MockClientSocket::OnReadComplete()
143  // function at a later time.
144  virtual MockRead GetNextRead() = 0;
145  virtual MockWriteResult OnWrite(const std::string& data) = 0;
146  virtual void Reset() = 0;
147
148  // Accessor for the socket which is using the SocketDataProvider.
149  MockClientSocket* socket() { return socket_; }
150  void set_socket(MockClientSocket* socket) { socket_ = socket; }
151
152  MockConnect connect_data() const { return connect_; }
153  void set_connect_data(const MockConnect& connect) { connect_ = connect; }
154
155 private:
156  MockConnect connect_;
157  MockClientSocket* socket_;
158
159  DISALLOW_COPY_AND_ASSIGN(SocketDataProvider);
160};
161
162// SocketDataProvider which responds based on static tables of mock reads and
163// writes.
164class StaticSocketDataProvider : public SocketDataProvider {
165 public:
166  StaticSocketDataProvider();
167  StaticSocketDataProvider(MockRead* reads, size_t reads_count,
168                           MockWrite* writes, size_t writes_count);
169  virtual ~StaticSocketDataProvider();
170
171  // SocketDataProvider methods:
172  virtual MockRead GetNextRead();
173  virtual MockWriteResult OnWrite(const std::string& data);
174  virtual void Reset();
175  virtual void CompleteRead() {}
176
177  // These functions get access to the next available read and write data.
178  const MockRead& PeekRead() const;
179  const MockWrite& PeekWrite() const;
180  // These functions get random access to the read and write data, for timing.
181  const MockRead& PeekRead(size_t index) const;
182  const MockWrite& PeekWrite(size_t index) const;
183  size_t read_index() const { return read_index_; }
184  size_t write_index() const { return write_index_; }
185  size_t read_count() const { return read_count_; }
186  size_t write_count() const { return write_count_; }
187
188  bool at_read_eof() const { return read_index_ >= read_count_; }
189  bool at_write_eof() const { return write_index_ >= write_count_; }
190
191 private:
192  MockRead* reads_;
193  size_t read_index_;
194  size_t read_count_;
195  MockWrite* writes_;
196  size_t write_index_;
197  size_t write_count_;
198
199  DISALLOW_COPY_AND_ASSIGN(StaticSocketDataProvider);
200};
201
202// SocketDataProvider which can make decisions about next mock reads based on
203// received writes. It can also be used to enforce order of operations, for
204// example that tested code must send the "Hello!" message before receiving
205// response. This is useful for testing conversation-like protocols like FTP.
206class DynamicSocketDataProvider : public SocketDataProvider {
207 public:
208  DynamicSocketDataProvider();
209  virtual ~DynamicSocketDataProvider();
210
211  // SocketDataProvider methods:
212  virtual MockRead GetNextRead();
213  virtual MockWriteResult OnWrite(const std::string& data) = 0;
214  virtual void Reset();
215
216  int short_read_limit() const { return short_read_limit_; }
217  void set_short_read_limit(int limit) { short_read_limit_ = limit; }
218
219  void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; }
220
221 protected:
222  // The next time there is a read from this socket, it will return |data|.
223  // Before calling SimulateRead next time, the previous data must be consumed.
224  void SimulateRead(const char* data, size_t length);
225  void SimulateRead(const char* data) {
226    SimulateRead(data, std::strlen(data));
227  }
228
229 private:
230  std::deque<MockRead> reads_;
231
232  // Max number of bytes we will read at a time. 0 means no limit.
233  int short_read_limit_;
234
235  // If true, we'll not require the client to consume all data before we
236  // mock the next read.
237  bool allow_unconsumed_reads_;
238
239  DISALLOW_COPY_AND_ASSIGN(DynamicSocketDataProvider);
240};
241
242// SSLSocketDataProviders only need to keep track of the return code from calls
243// to Connect().
244struct SSLSocketDataProvider {
245  SSLSocketDataProvider(bool async, int result)
246      : connect(async, result),
247        next_proto_status(SSLClientSocket::kNextProtoUnsupported),
248        was_npn_negotiated(false) { }
249
250  MockConnect connect;
251  SSLClientSocket::NextProtoStatus next_proto_status;
252  std::string next_proto;
253  bool was_npn_negotiated;
254};
255
256// A DataProvider where the client must write a request before the reads (e.g.
257// the response) will complete.
258class DelayedSocketData : public StaticSocketDataProvider,
259                          public base::RefCounted<DelayedSocketData> {
260 public:
261  // |write_delay| the number of MockWrites to complete before allowing
262  //               a MockRead to complete.
263  // |reads| the list of MockRead completions.
264  // |writes| the list of MockWrite completions.
265  // Note: All MockReads and MockWrites must be async.
266  // Note: The MockRead and MockWrite lists musts end with a EOF
267  //       e.g. a MockRead(true, 0, 0);
268  DelayedSocketData(int write_delay,
269                    MockRead* reads, size_t reads_count,
270                    MockWrite* writes, size_t writes_count);
271
272  // |connect| the result for the connect phase.
273  // |reads| the list of MockRead completions.
274  // |write_delay| the number of MockWrites to complete before allowing
275  //               a MockRead to complete.
276  // |writes| the list of MockWrite completions.
277  // Note: All MockReads and MockWrites must be async.
278  // Note: The MockRead and MockWrite lists musts end with a EOF
279  //       e.g. a MockRead(true, 0, 0);
280  DelayedSocketData(const MockConnect& connect, int write_delay,
281                    MockRead* reads, size_t reads_count,
282                    MockWrite* writes, size_t writes_count);
283  ~DelayedSocketData();
284
285  virtual MockRead GetNextRead();
286  virtual MockWriteResult OnWrite(const std::string& data);
287  virtual void Reset();
288  virtual void CompleteRead();
289  void ForceNextRead();
290
291 private:
292  int write_delay_;
293  ScopedRunnableMethodFactory<DelayedSocketData> factory_;
294};
295
296// A DataProvider where the reads are ordered.
297// If a read is requested before its sequence number is reached, we return an
298// ERR_IO_PENDING (that way we don't have to explicitly add a MockRead just to
299// wait).
300// The sequence number is incremented on every read and write operation.
301// The message loop may be interrupted by setting the high bit of the sequence
302// number in the MockRead's sequence number.  When that MockRead is reached,
303// we post a Quit message to the loop.  This allows us to interrupt the reading
304// of data before a complete message has arrived, and provides support for
305// testing server push when the request is issued while the response is in the
306// middle of being received.
307class OrderedSocketData : public StaticSocketDataProvider,
308                          public base::RefCounted<OrderedSocketData> {
309 public:
310  // |reads| the list of MockRead completions.
311  // |writes| the list of MockWrite completions.
312  // Note: All MockReads and MockWrites must be async.
313  // Note: The MockRead and MockWrite lists musts end with a EOF
314  //       e.g. a MockRead(true, 0, 0);
315  OrderedSocketData(MockRead* reads, size_t reads_count,
316                    MockWrite* writes, size_t writes_count);
317
318  // |connect| the result for the connect phase.
319  // |reads| the list of MockRead completions.
320  // |writes| the list of MockWrite completions.
321  // Note: All MockReads and MockWrites must be async.
322  // Note: The MockRead and MockWrite lists musts end with a EOF
323  //       e.g. a MockRead(true, 0, 0);
324  OrderedSocketData(const MockConnect& connect,
325                    MockRead* reads, size_t reads_count,
326                    MockWrite* writes, size_t writes_count);
327
328  virtual MockRead GetNextRead();
329  virtual MockWriteResult OnWrite(const std::string& data);
330  virtual void Reset();
331  virtual void CompleteRead();
332
333  void SetCompletionCallback(CompletionCallback* callback) {
334    callback_ = callback;
335  }
336
337  // Posts a quit message to the current message loop, if one is running.
338  void EndLoop();
339
340 private:
341  friend class base::RefCounted<OrderedSocketData>;
342  virtual ~OrderedSocketData();
343
344  int sequence_number_;
345  int loop_stop_stage_;
346  CompletionCallback* callback_;
347  bool blocked_;
348  ScopedRunnableMethodFactory<OrderedSocketData> factory_;
349};
350
351class DeterministicMockTCPClientSocket;
352
353// This class gives the user full control over the network activity,
354// specifically the timing of the COMPLETION of I/O operations.  Regardless of
355// the order in which I/O operations are initiated, this class ensures that they
356// complete in the correct order.
357//
358// Network activity is modeled as a sequence of numbered steps which is
359// incremented whenever an I/O operation completes.  This can happen under two
360// different circumstances:
361//
362// 1) Performing a synchronous I/O operation.  (Invoking Read() or Write()
363//    when the corresponding MockRead or MockWrite is marked !async).
364// 2) Running the Run() method of this class.  The run method will invoke
365//    the current MessageLoop, running all pending events, and will then
366//    invoke any pending IO callbacks.
367//
368// In addition, this class allows for I/O processing to "stop" at a specified
369// step, by calling SetStop(int) or StopAfter(int).  Initiating an I/O operation
370// by calling Read() or Write() while stopped is permitted if the operation is
371// asynchronous.  It is an error to perform synchronous I/O while stopped.
372//
373// When creating the MockReads and MockWrites, note that the sequence number
374// refers to the number of the step in which the I/O will complete.  In the
375// case of synchronous I/O, this will be the same step as the I/O is initiated.
376// However, in the case of asynchronous I/O, this I/O may be initiated in
377// a much earlier step. Furthermore, when the a Read() or Write() is separated
378// from its completion by other Read() or Writes()'s, it can not be marked
379// synchronous.  If it is, ERR_UNUEXPECTED will be returned indicating that a
380// synchronous Read() or Write() could not be completed synchronously because of
381// the specific ordering constraints.
382//
383// Sequence numbers are preserved across both reads and writes. There should be
384// no gaps in sequence numbers, and no repeated sequence numbers. i.e.
385//  MockRead reads[] = {
386//    MockRead(false, "first read", length, 0)   // sync
387//    MockRead(true, "second read", length, 2)   // async
388//  };
389//  MockWrite writes[] = {
390//    MockWrite(true, "first write", length, 1),    // async
391//    MockWrite(false, "second write", length, 3),  // sync
392//  };
393//
394// Example control flow:
395// Read() is called.  The current step is 0.  The first available read is
396// synchronous, so the call to Read() returns length.  The current step is
397// now 1.  Next, Read() is called again.  The next available read can
398// not be completed until step 2, so Read() returns ERR_IO_PENDING.  The current
399// step is still 1.  Write is called().  The first available write is able to
400// complete in this step, but is marked asynchronous.  Write() returns
401// ERR_IO_PENDING.  The current step is still 1.  At this point RunFor(1) is
402// called which will cause the write callback to be invoked, and will then
403// stop.  The current state is now 2.  RunFor(1) is called again, which
404// causes the read callback to be invoked, and will then stop.  Then current
405// step is 2.  Write() is called again.  Then next available write is
406// synchronous so the call to Write() returns length.
407//
408// For examples of how to use this class, see:
409//   deterministic_socket_data_unittests.cc
410class DeterministicSocketData : public StaticSocketDataProvider,
411    public base::RefCounted<DeterministicSocketData> {
412 public:
413  // |reads| the list of MockRead completions.
414  // |writes| the list of MockWrite completions.
415  DeterministicSocketData(MockRead* reads, size_t reads_count,
416                          MockWrite* writes, size_t writes_count);
417
418  // When the socket calls Read(), that calls GetNextRead(), and expects either
419  // ERR_IO_PENDING or data.
420  virtual MockRead GetNextRead();
421
422  // When the socket calls Write(), it always completes synchronously. OnWrite()
423  // checks to make sure the written data matches the expected data. The
424  // callback will not be invoked until its sequence number is reached.
425  virtual MockWriteResult OnWrite(const std::string& data);
426
427  virtual void Reset();
428
429  virtual void CompleteRead() {}
430
431  // Consume all the data up to the give stop point (via SetStop()).
432  void Run();
433
434  // Set the stop point to be |steps| from now, and then invoke Run().
435  void RunFor(int steps);
436
437  // Stop at step |seq|, which must be in the future.
438  virtual void SetStop(int seq) {
439    DCHECK_LT(sequence_number_, seq);
440    stopping_sequence_number_ = seq;
441    stopped_ = false;
442  }
443
444  // Stop |seq| steps after the current step.
445  virtual void StopAfter(int seq) {
446    SetStop(sequence_number_ + seq);
447  }
448  bool stopped() const { return stopped_; }
449  void SetStopped(bool val) { stopped_ = val; }
450  MockRead& current_read() { return current_read_; }
451  MockRead& current_write() { return current_write_; }
452  int sequence_number() const { return sequence_number_; }
453  void set_socket(base::WeakPtr<DeterministicMockTCPClientSocket> socket) {
454    socket_ = socket;
455  }
456
457 private:
458  // Invoke the read and write callbacks, if the timing is appropriate.
459  void InvokeCallbacks();
460
461  void NextStep();
462
463  int sequence_number_;
464  MockRead current_read_;
465  MockWrite current_write_;
466  int stopping_sequence_number_;
467  bool stopped_;
468  base::WeakPtr<DeterministicMockTCPClientSocket> socket_;
469  bool print_debug_;
470};
471
472
473// Holds an array of SocketDataProvider elements.  As Mock{TCP,SSL}ClientSocket
474// objects get instantiated, they take their data from the i'th element of this
475// array.
476template<typename T>
477class SocketDataProviderArray {
478 public:
479  SocketDataProviderArray() : next_index_(0) {
480  }
481
482  T* GetNext() {
483    DCHECK_LT(next_index_, data_providers_.size());
484    return data_providers_[next_index_++];
485  }
486
487  void Add(T* data_provider) {
488    DCHECK(data_provider);
489    data_providers_.push_back(data_provider);
490  }
491
492  void ResetNextIndex() {
493    next_index_ = 0;
494  }
495
496 private:
497  // Index of the next |data_providers_| element to use. Not an iterator
498  // because those are invalidated on vector reallocation.
499  size_t next_index_;
500
501  // SocketDataProviders to be returned.
502  std::vector<T*> data_providers_;
503};
504
505class MockTCPClientSocket;
506class MockSSLClientSocket;
507
508// ClientSocketFactory which contains arrays of sockets of each type.
509// You should first fill the arrays using AddMock{SSL,}Socket. When the factory
510// is asked to create a socket, it takes next entry from appropriate array.
511// You can use ResetNextMockIndexes to reset that next entry index for all mock
512// socket types.
513class MockClientSocketFactory : public ClientSocketFactory {
514 public:
515  MockClientSocketFactory();
516  virtual ~MockClientSocketFactory();
517
518  void AddSocketDataProvider(SocketDataProvider* socket);
519  void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
520  void ResetNextMockIndexes();
521
522  // Return |index|-th MockTCPClientSocket (starting from 0) that the factory
523  // created.
524  MockTCPClientSocket* GetMockTCPClientSocket(size_t index) const;
525
526  // Return |index|-th MockSSLClientSocket (starting from 0) that the factory
527  // created.
528  MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const;
529
530  // ClientSocketFactory
531  virtual ClientSocket* CreateTCPClientSocket(
532      const AddressList& addresses,
533      NetLog* net_log,
534      const NetLog::Source& source);
535  virtual SSLClientSocket* CreateSSLClientSocket(
536      ClientSocketHandle* transport_socket,
537      const HostPortPair& host_and_port,
538      const SSLConfig& ssl_config,
539      SSLHostInfo* ssl_host_info,
540      CertVerifier* cert_verifier,
541      DnsCertProvenanceChecker* dns_cert_checker);
542  SocketDataProviderArray<SocketDataProvider>& mock_data() {
543    return mock_data_;
544  }
545  std::vector<MockTCPClientSocket*>& tcp_client_sockets() {
546    return tcp_client_sockets_;
547  }
548
549 private:
550  SocketDataProviderArray<SocketDataProvider> mock_data_;
551  SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
552
553  // Store pointers to handed out sockets in case the test wants to get them.
554  std::vector<MockTCPClientSocket*> tcp_client_sockets_;
555  std::vector<MockSSLClientSocket*> ssl_client_sockets_;
556};
557
558class MockClientSocket : public net::SSLClientSocket {
559 public:
560  explicit MockClientSocket(net::NetLog* net_log);
561  // ClientSocket methods:
562  virtual int Connect(net::CompletionCallback* callback) = 0;
563  virtual void Disconnect();
564  virtual bool IsConnected() const;
565  virtual bool IsConnectedAndIdle() const;
566  virtual int GetPeerAddress(AddressList* address) const;
567  virtual const BoundNetLog& NetLog() const { return net_log_;}
568  virtual void SetSubresourceSpeculation() {}
569  virtual void SetOmniboxSpeculation() {}
570
571  // SSLClientSocket methods:
572  virtual void GetSSLInfo(net::SSLInfo* ssl_info);
573  virtual void GetSSLCertRequestInfo(
574      net::SSLCertRequestInfo* cert_request_info);
575  virtual NextProtoStatus GetNextProto(std::string* proto);
576
577  // Socket methods:
578  virtual int Read(net::IOBuffer* buf, int buf_len,
579                   net::CompletionCallback* callback) = 0;
580  virtual int Write(net::IOBuffer* buf, int buf_len,
581                    net::CompletionCallback* callback) = 0;
582  virtual bool SetReceiveBufferSize(int32 size) { return true; }
583  virtual bool SetSendBufferSize(int32 size) { return true; }
584
585  // If an async IO is pending because the SocketDataProvider returned
586  // ERR_IO_PENDING, then the MockClientSocket waits until this OnReadComplete
587  // is called to complete the asynchronous read operation.
588  // data.async is ignored, and this read is completed synchronously as
589  // part of this call.
590  virtual void OnReadComplete(const MockRead& data) = 0;
591
592 protected:
593  virtual ~MockClientSocket() {}
594  void RunCallbackAsync(net::CompletionCallback* callback, int result);
595  void RunCallback(net::CompletionCallback*, int result);
596
597  ScopedRunnableMethodFactory<MockClientSocket> method_factory_;
598
599  // True if Connect completed successfully and Disconnect hasn't been called.
600  bool connected_;
601
602  net::BoundNetLog net_log_;
603};
604
605class MockTCPClientSocket : public MockClientSocket {
606 public:
607  MockTCPClientSocket(const net::AddressList& addresses, net::NetLog* net_log,
608                      net::SocketDataProvider* socket);
609
610  // ClientSocket methods:
611  virtual int Connect(net::CompletionCallback* callback);
612  virtual void Disconnect();
613  virtual bool IsConnected() const;
614  virtual bool IsConnectedAndIdle() const { return IsConnected(); }
615  virtual bool WasEverUsed() const { return was_used_to_convey_data_; }
616  virtual bool UsingTCPFastOpen() const { return false; }
617
618  // Socket methods:
619  virtual int Read(net::IOBuffer* buf, int buf_len,
620                   net::CompletionCallback* callback);
621  virtual int Write(net::IOBuffer* buf, int buf_len,
622                    net::CompletionCallback* callback);
623
624  virtual void OnReadComplete(const MockRead& data);
625
626  net::AddressList addresses() const { return addresses_; }
627
628 private:
629  int CompleteRead();
630
631  net::AddressList addresses_;
632
633  net::SocketDataProvider* data_;
634  int read_offset_;
635  net::MockRead read_data_;
636  bool need_read_data_;
637
638  // True if the peer has closed the connection.  This allows us to simulate
639  // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real
640  // TCPClientSocket.
641  bool peer_closed_connection_;
642
643  // While an asynchronous IO is pending, we save our user-buffer state.
644  net::IOBuffer* pending_buf_;
645  int pending_buf_len_;
646  net::CompletionCallback* pending_callback_;
647  bool was_used_to_convey_data_;
648};
649
650class DeterministicMockTCPClientSocket : public MockClientSocket,
651    public base::SupportsWeakPtr<DeterministicMockTCPClientSocket> {
652 public:
653  DeterministicMockTCPClientSocket(net::NetLog* net_log,
654      net::DeterministicSocketData* data);
655
656  // ClientSocket methods:
657  virtual int Connect(net::CompletionCallback* callback);
658  virtual void Disconnect();
659  virtual bool IsConnected() const;
660  virtual bool IsConnectedAndIdle() const { return IsConnected(); }
661  virtual bool WasEverUsed() const { return was_used_to_convey_data_; }
662  virtual bool UsingTCPFastOpen() const { return false; }
663
664  // Socket methods:
665  virtual int Write(net::IOBuffer* buf, int buf_len,
666                    net::CompletionCallback* callback);
667  virtual int Read(net::IOBuffer* buf, int buf_len,
668                   net::CompletionCallback* callback);
669
670  bool write_pending() const { return write_pending_; }
671  bool read_pending() const { return read_pending_; }
672
673  void CompleteWrite();
674  int CompleteRead();
675  void OnReadComplete(const MockRead& data);
676
677 private:
678  bool write_pending_;
679  net::CompletionCallback* write_callback_;
680  int write_result_;
681
682  net::MockRead read_data_;
683
684  net::IOBuffer* read_buf_;
685  int read_buf_len_;
686  bool read_pending_;
687  net::CompletionCallback* read_callback_;
688  net::DeterministicSocketData* data_;
689  bool was_used_to_convey_data_;
690};
691
692class MockSSLClientSocket : public MockClientSocket {
693 public:
694  MockSSLClientSocket(
695      net::ClientSocketHandle* transport_socket,
696      const HostPortPair& host_and_port,
697      const net::SSLConfig& ssl_config,
698      SSLHostInfo* ssl_host_info,
699      net::SSLSocketDataProvider* socket);
700  ~MockSSLClientSocket();
701
702  // ClientSocket methods:
703  virtual int Connect(net::CompletionCallback* callback);
704  virtual void Disconnect();
705  virtual bool IsConnected() const;
706  virtual bool WasEverUsed() const;
707  virtual bool UsingTCPFastOpen() const;
708
709  // Socket methods:
710  virtual int Read(net::IOBuffer* buf, int buf_len,
711                   net::CompletionCallback* callback);
712  virtual int Write(net::IOBuffer* buf, int buf_len,
713                    net::CompletionCallback* callback);
714
715  // SSLClientSocket methods:
716  virtual void GetSSLInfo(net::SSLInfo* ssl_info);
717  virtual NextProtoStatus GetNextProto(std::string* proto);
718  virtual bool was_npn_negotiated() const;
719  virtual bool set_was_npn_negotiated(bool negotiated);
720
721  // This MockSocket does not implement the manual async IO feature.
722  virtual void OnReadComplete(const MockRead& data) { NOTIMPLEMENTED(); }
723
724 private:
725  class ConnectCallback;
726
727  scoped_ptr<ClientSocketHandle> transport_;
728  net::SSLSocketDataProvider* data_;
729  bool is_npn_state_set_;
730  bool new_npn_value_;
731  bool was_used_to_convey_data_;
732};
733
734class TestSocketRequest : public CallbackRunner< Tuple1<int> > {
735 public:
736  TestSocketRequest(
737      std::vector<TestSocketRequest*>* request_order,
738      size_t* completion_count);
739  virtual ~TestSocketRequest();
740
741  ClientSocketHandle* handle() { return &handle_; }
742
743  int WaitForResult();
744  virtual void RunWithParams(const Tuple1<int>& params);
745
746 private:
747  ClientSocketHandle handle_;
748  std::vector<TestSocketRequest*>* request_order_;
749  size_t* completion_count_;
750  TestCompletionCallback callback_;
751};
752
753class ClientSocketPoolTest {
754 public:
755  enum KeepAlive {
756    KEEP_ALIVE,
757
758    // A socket will be disconnected in addition to handle being reset.
759    NO_KEEP_ALIVE,
760  };
761
762  static const int kIndexOutOfBounds;
763  static const int kRequestNotFound;
764
765  ClientSocketPoolTest();
766  ~ClientSocketPoolTest();
767
768  template <typename PoolType, typename SocketParams>
769  int StartRequestUsingPool(PoolType* socket_pool,
770                            const std::string& group_name,
771                            RequestPriority priority,
772                            const scoped_refptr<SocketParams>& socket_params) {
773    DCHECK(socket_pool);
774    TestSocketRequest* request = new TestSocketRequest(&request_order_,
775                                                       &completion_count_);
776    requests_.push_back(request);
777    int rv = request->handle()->Init(
778        group_name, socket_params, priority, request,
779        socket_pool, BoundNetLog());
780    if (rv != ERR_IO_PENDING)
781      request_order_.push_back(request);
782    return rv;
783  }
784
785  // Provided there were n requests started, takes |index| in range 1..n
786  // and returns order in which that request completed, in range 1..n,
787  // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound
788  // if that request did not complete (for example was canceled).
789  int GetOrderOfRequest(size_t index) const;
790
791  // Resets first initialized socket handle from |requests_|. If found such
792  // a handle, returns true.
793  bool ReleaseOneConnection(KeepAlive keep_alive);
794
795  // Releases connections until there is nothing to release.
796  void ReleaseAllConnections(KeepAlive keep_alive);
797
798  TestSocketRequest* request(int i) { return requests_[i]; }
799  size_t requests_size() const { return requests_.size(); }
800  ScopedVector<TestSocketRequest>* requests() { return &requests_; }
801  size_t completion_count() const { return completion_count_; }
802
803 private:
804  ScopedVector<TestSocketRequest> requests_;
805  std::vector<TestSocketRequest*> request_order_;
806  size_t completion_count_;
807};
808
809class MockTCPClientSocketPool : public TCPClientSocketPool {
810 public:
811  class MockConnectJob {
812   public:
813    MockConnectJob(ClientSocket* socket, ClientSocketHandle* handle,
814                   CompletionCallback* callback);
815    ~MockConnectJob();
816
817    int Connect();
818    bool CancelHandle(const ClientSocketHandle* handle);
819
820   private:
821    void OnConnect(int rv);
822
823    scoped_ptr<ClientSocket> socket_;
824    ClientSocketHandle* handle_;
825    CompletionCallback* user_callback_;
826    CompletionCallbackImpl<MockConnectJob> connect_callback_;
827
828    DISALLOW_COPY_AND_ASSIGN(MockConnectJob);
829  };
830
831  MockTCPClientSocketPool(
832      int max_sockets,
833      int max_sockets_per_group,
834      ClientSocketPoolHistograms* histograms,
835      ClientSocketFactory* socket_factory);
836
837  virtual ~MockTCPClientSocketPool();
838
839  int release_count() const { return release_count_; }
840  int cancel_count() const { return cancel_count_; }
841
842  // TCPClientSocketPool methods.
843  virtual int RequestSocket(const std::string& group_name,
844                            const void* socket_params,
845                            RequestPriority priority,
846                            ClientSocketHandle* handle,
847                            CompletionCallback* callback,
848                            const BoundNetLog& net_log);
849
850  virtual void CancelRequest(const std::string& group_name,
851                             ClientSocketHandle* handle);
852  virtual void ReleaseSocket(const std::string& group_name,
853                             ClientSocket* socket, int id);
854
855 private:
856  ClientSocketFactory* client_socket_factory_;
857  ScopedVector<MockConnectJob> job_list_;
858  int release_count_;
859  int cancel_count_;
860
861  DISALLOW_COPY_AND_ASSIGN(MockTCPClientSocketPool);
862};
863
864class DeterministicMockClientSocketFactory : public ClientSocketFactory {
865 public:
866  DeterministicMockClientSocketFactory();
867  virtual ~DeterministicMockClientSocketFactory();
868
869  void AddSocketDataProvider(DeterministicSocketData* socket);
870  void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
871  void ResetNextMockIndexes();
872
873  // Return |index|-th MockSSLClientSocket (starting from 0) that the factory
874  // created.
875  MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const;
876
877  // ClientSocketFactory
878  virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses,
879                                              NetLog* net_log,
880                                              const NetLog::Source& source);
881  virtual SSLClientSocket* CreateSSLClientSocket(
882      ClientSocketHandle* transport_socket,
883      const HostPortPair& host_and_port,
884      const SSLConfig& ssl_config,
885      SSLHostInfo* ssl_host_info,
886      CertVerifier* cert_verifier,
887      DnsCertProvenanceChecker* dns_cert_checker);
888
889  SocketDataProviderArray<DeterministicSocketData>& mock_data() {
890    return mock_data_;
891  }
892  std::vector<DeterministicMockTCPClientSocket*>& tcp_client_sockets() {
893    return tcp_client_sockets_;
894  }
895
896 private:
897  SocketDataProviderArray<DeterministicSocketData> mock_data_;
898  SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
899
900  // Store pointers to handed out sockets in case the test wants to get them.
901  std::vector<DeterministicMockTCPClientSocket*> tcp_client_sockets_;
902  std::vector<MockSSLClientSocket*> ssl_client_sockets_;
903};
904
905class MockSOCKSClientSocketPool : public SOCKSClientSocketPool {
906 public:
907  MockSOCKSClientSocketPool(
908      int max_sockets,
909      int max_sockets_per_group,
910      ClientSocketPoolHistograms* histograms,
911      TCPClientSocketPool* tcp_pool);
912
913  virtual ~MockSOCKSClientSocketPool();
914
915  // SOCKSClientSocketPool methods.
916  virtual int RequestSocket(const std::string& group_name,
917                            const void* socket_params,
918                            RequestPriority priority,
919                            ClientSocketHandle* handle,
920                            CompletionCallback* callback,
921                            const BoundNetLog& net_log);
922
923  virtual void CancelRequest(const std::string& group_name,
924                             ClientSocketHandle* handle);
925  virtual void ReleaseSocket(const std::string& group_name,
926                             ClientSocket* socket, int id);
927
928 private:
929  TCPClientSocketPool* const tcp_pool_;
930
931  DISALLOW_COPY_AND_ASSIGN(MockSOCKSClientSocketPool);
932};
933
934// Constants for a successful SOCKS v5 handshake.
935extern const char kSOCKS5GreetRequest[];
936extern const int kSOCKS5GreetRequestLength;
937
938extern const char kSOCKS5GreetResponse[];
939extern const int kSOCKS5GreetResponseLength;
940
941extern const char kSOCKS5OkRequest[];
942extern const int kSOCKS5OkRequestLength;
943
944extern const char kSOCKS5OkResponse[];
945extern const int kSOCKS5OkResponseLength;
946
947}  // namespace net
948
949#endif  // NET_SOCKET_SOCKET_TEST_UTIL_H_
950