1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket/socket_test_util.h"
6
7#include <algorithm>
8#include <vector>
9
10#include "base/basictypes.h"
11#include "base/bind.h"
12#include "base/bind_helpers.h"
13#include "base/callback_helpers.h"
14#include "base/compiler_specific.h"
15#include "base/message_loop/message_loop.h"
16#include "base/run_loop.h"
17#include "base/time/time.h"
18#include "net/base/address_family.h"
19#include "net/base/address_list.h"
20#include "net/base/auth.h"
21#include "net/base/load_timing_info.h"
22#include "net/http/http_network_session.h"
23#include "net/http/http_request_headers.h"
24#include "net/http/http_response_headers.h"
25#include "net/socket/client_socket_pool_histograms.h"
26#include "net/socket/socket.h"
27#include "net/ssl/ssl_cert_request_info.h"
28#include "net/ssl/ssl_connection_status_flags.h"
29#include "net/ssl/ssl_info.h"
30#include "testing/gtest/include/gtest/gtest.h"
31
32// Socket events are easier to debug if you log individual reads and writes.
33// Enable these if locally debugging, but they are too noisy for the waterfall.
34#if 0
35#define NET_TRACE(level, s) DLOG(level) << s << __FUNCTION__ << "() "
36#else
37#define NET_TRACE(level, s) EAT_STREAM_PARAMETERS
38#endif
39
40namespace net {
41
42namespace {
43
44inline char AsciifyHigh(char x) {
45  char nybble = static_cast<char>((x >> 4) & 0x0F);
46  return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
47}
48
49inline char AsciifyLow(char x) {
50  char nybble = static_cast<char>((x >> 0) & 0x0F);
51  return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10);
52}
53
54inline char Asciify(char x) {
55  if ((x < 0) || !isprint(x))
56    return '.';
57  return x;
58}
59
60void DumpData(const char* data, int data_len) {
61  if (logging::LOG_INFO < logging::GetMinLogLevel())
62    return;
63  DVLOG(1) << "Length:  " << data_len;
64  const char* pfx = "Data:    ";
65  if (!data || (data_len <= 0)) {
66    DVLOG(1) << pfx << "<None>";
67  } else {
68    int i;
69    for (i = 0; i <= (data_len - 4); i += 4) {
70      DVLOG(1) << pfx
71               << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
72               << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
73               << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
74               << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3])
75               << "  '"
76               << Asciify(data[i + 0])
77               << Asciify(data[i + 1])
78               << Asciify(data[i + 2])
79               << Asciify(data[i + 3])
80               << "'";
81      pfx = "         ";
82    }
83    // Take care of any 'trailing' bytes, if data_len was not a multiple of 4.
84    switch (data_len - i) {
85      case 3:
86        DVLOG(1) << pfx
87                 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
88                 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
89                 << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2])
90                 << "    '"
91                 << Asciify(data[i + 0])
92                 << Asciify(data[i + 1])
93                 << Asciify(data[i + 2])
94                 << " '";
95        break;
96      case 2:
97        DVLOG(1) << pfx
98                 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
99                 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1])
100                 << "      '"
101                 << Asciify(data[i + 0])
102                 << Asciify(data[i + 1])
103                 << "  '";
104        break;
105      case 1:
106        DVLOG(1) << pfx
107                 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0])
108                 << "        '"
109                 << Asciify(data[i + 0])
110                 << "   '";
111        break;
112    }
113  }
114}
115
116template <MockReadWriteType type>
117void DumpMockReadWrite(const MockReadWrite<type>& r) {
118  if (logging::LOG_INFO < logging::GetMinLogLevel())
119    return;
120  DVLOG(1) << "Async:   " << (r.mode == ASYNC)
121           << "\nResult:  " << r.result;
122  DumpData(r.data, r.data_len);
123  const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : "";
124  DVLOG(1) << "Stage:   " << (r.sequence_number & ~MockRead::STOPLOOP) << stop
125           << "\nTime:    " << r.time_stamp.ToInternalValue();
126}
127
128}  // namespace
129
130MockConnect::MockConnect() : mode(ASYNC), result(OK) {
131  IPAddressNumber ip;
132  CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip));
133  peer_addr = IPEndPoint(ip, 0);
134}
135
136MockConnect::MockConnect(IoMode io_mode, int r) : mode(io_mode), result(r) {
137  IPAddressNumber ip;
138  CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip));
139  peer_addr = IPEndPoint(ip, 0);
140}
141
142MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr) :
143    mode(io_mode),
144    result(r),
145    peer_addr(addr) {
146}
147
148MockConnect::~MockConnect() {}
149
150StaticSocketDataProvider::StaticSocketDataProvider()
151    : reads_(NULL),
152      read_index_(0),
153      read_count_(0),
154      writes_(NULL),
155      write_index_(0),
156      write_count_(0) {
157}
158
159StaticSocketDataProvider::StaticSocketDataProvider(MockRead* reads,
160                                                   size_t reads_count,
161                                                   MockWrite* writes,
162                                                   size_t writes_count)
163    : reads_(reads),
164      read_index_(0),
165      read_count_(reads_count),
166      writes_(writes),
167      write_index_(0),
168      write_count_(writes_count) {
169}
170
171StaticSocketDataProvider::~StaticSocketDataProvider() {}
172
173const MockRead& StaticSocketDataProvider::PeekRead() const {
174  CHECK(!at_read_eof());
175  return reads_[read_index_];
176}
177
178const MockWrite& StaticSocketDataProvider::PeekWrite() const {
179  CHECK(!at_write_eof());
180  return writes_[write_index_];
181}
182
183const MockRead& StaticSocketDataProvider::PeekRead(size_t index) const {
184  CHECK_LT(index, read_count_);
185  return reads_[index];
186}
187
188const MockWrite& StaticSocketDataProvider::PeekWrite(size_t index) const {
189  CHECK_LT(index, write_count_);
190  return writes_[index];
191}
192
193MockRead StaticSocketDataProvider::GetNextRead() {
194  CHECK(!at_read_eof());
195  reads_[read_index_].time_stamp = base::Time::Now();
196  return reads_[read_index_++];
197}
198
199MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
200  if (!writes_) {
201    // Not using mock writes; succeed synchronously.
202    return MockWriteResult(SYNCHRONOUS, data.length());
203  }
204  EXPECT_FALSE(at_write_eof());
205  if (at_write_eof()) {
206    // Show what the extra write actually consists of.
207    EXPECT_EQ("<unexpected write>", data);
208    return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
209  }
210
211  // Check that what we are writing matches the expectation.
212  // Then give the mocked return value.
213  MockWrite* w = &writes_[write_index_++];
214  w->time_stamp = base::Time::Now();
215  int result = w->result;
216  if (w->data) {
217    // Note - we can simulate a partial write here.  If the expected data
218    // is a match, but shorter than the write actually written, that is legal.
219    // Example:
220    //   Application writes "foobarbaz" (9 bytes)
221    //   Expected write was "foo" (3 bytes)
222    //   This is a success, and we return 3 to the application.
223    std::string expected_data(w->data, w->data_len);
224    EXPECT_GE(data.length(), expected_data.length());
225    std::string actual_data(data.substr(0, w->data_len));
226    EXPECT_EQ(expected_data, actual_data);
227    if (expected_data != actual_data)
228      return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
229    if (result == OK)
230      result = w->data_len;
231  }
232  return MockWriteResult(w->mode, result);
233}
234
235void StaticSocketDataProvider::Reset() {
236  read_index_ = 0;
237  write_index_ = 0;
238}
239
240DynamicSocketDataProvider::DynamicSocketDataProvider()
241    : short_read_limit_(0),
242      allow_unconsumed_reads_(false) {
243}
244
245DynamicSocketDataProvider::~DynamicSocketDataProvider() {}
246
247MockRead DynamicSocketDataProvider::GetNextRead() {
248  if (reads_.empty())
249    return MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
250  MockRead result = reads_.front();
251  if (short_read_limit_ == 0 || result.data_len <= short_read_limit_) {
252    reads_.pop_front();
253  } else {
254    result.data_len = short_read_limit_;
255    reads_.front().data += result.data_len;
256    reads_.front().data_len -= result.data_len;
257  }
258  return result;
259}
260
261void DynamicSocketDataProvider::Reset() {
262  reads_.clear();
263}
264
265void DynamicSocketDataProvider::SimulateRead(const char* data,
266                                             const size_t length) {
267  if (!allow_unconsumed_reads_) {
268    EXPECT_TRUE(reads_.empty()) << "Unconsumed read: " << reads_.front().data;
269  }
270  reads_.push_back(MockRead(ASYNC, data, length));
271}
272
273SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result)
274    : connect(mode, result),
275      next_proto_status(SSLClientSocket::kNextProtoUnsupported),
276      was_npn_negotiated(false),
277      protocol_negotiated(kProtoUnknown),
278      client_cert_sent(false),
279      cert_request_info(NULL),
280      channel_id_sent(false),
281      connection_status(0),
282      should_pause_on_connect(false),
283      is_in_session_cache(false) {
284  SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_2,
285                                &connection_status);
286  // Set to TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305
287  SSLConnectionStatusSetCipherSuite(0xcc14, &connection_status);
288}
289
290SSLSocketDataProvider::~SSLSocketDataProvider() {
291}
292
293void SSLSocketDataProvider::SetNextProto(NextProto proto) {
294  was_npn_negotiated = true;
295  next_proto_status = SSLClientSocket::kNextProtoNegotiated;
296  protocol_negotiated = proto;
297  next_proto = SSLClientSocket::NextProtoToString(proto);
298}
299
300DelayedSocketData::DelayedSocketData(
301    int write_delay, MockRead* reads, size_t reads_count,
302    MockWrite* writes, size_t writes_count)
303    : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
304      write_delay_(write_delay),
305      read_in_progress_(false),
306      weak_factory_(this) {
307  DCHECK_GE(write_delay_, 0);
308}
309
310DelayedSocketData::DelayedSocketData(
311    const MockConnect& connect, int write_delay, MockRead* reads,
312    size_t reads_count, MockWrite* writes, size_t writes_count)
313    : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
314      write_delay_(write_delay),
315      read_in_progress_(false),
316      weak_factory_(this) {
317  DCHECK_GE(write_delay_, 0);
318  set_connect_data(connect);
319}
320
321DelayedSocketData::~DelayedSocketData() {
322}
323
324void DelayedSocketData::ForceNextRead() {
325  DCHECK(read_in_progress_);
326  write_delay_ = 0;
327  CompleteRead();
328}
329
330MockRead DelayedSocketData::GetNextRead() {
331  MockRead out = MockRead(ASYNC, ERR_IO_PENDING);
332  if (write_delay_ <= 0)
333    out = StaticSocketDataProvider::GetNextRead();
334  read_in_progress_ = (out.result == ERR_IO_PENDING);
335  return out;
336}
337
338MockWriteResult DelayedSocketData::OnWrite(const std::string& data) {
339  MockWriteResult rv = StaticSocketDataProvider::OnWrite(data);
340  // Now that our write has completed, we can allow reads to continue.
341  if (!--write_delay_ && read_in_progress_)
342    base::MessageLoop::current()->PostDelayedTask(
343        FROM_HERE,
344        base::Bind(&DelayedSocketData::CompleteRead,
345                   weak_factory_.GetWeakPtr()),
346        base::TimeDelta::FromMilliseconds(100));
347  return rv;
348}
349
350void DelayedSocketData::Reset() {
351  set_socket(NULL);
352  read_in_progress_ = false;
353  weak_factory_.InvalidateWeakPtrs();
354  StaticSocketDataProvider::Reset();
355}
356
357void DelayedSocketData::CompleteRead() {
358  if (socket() && read_in_progress_)
359    socket()->OnReadComplete(GetNextRead());
360}
361
362OrderedSocketData::OrderedSocketData(
363    MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count)
364    : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
365      sequence_number_(0), loop_stop_stage_(0),
366      blocked_(false), weak_factory_(this) {
367}
368
369OrderedSocketData::OrderedSocketData(
370    const MockConnect& connect,
371    MockRead* reads, size_t reads_count,
372    MockWrite* writes, size_t writes_count)
373    : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
374      sequence_number_(0), loop_stop_stage_(0),
375      blocked_(false), weak_factory_(this) {
376  set_connect_data(connect);
377}
378
379void OrderedSocketData::EndLoop() {
380  // If we've already stopped the loop, don't do it again until we've advanced
381  // to the next sequence_number.
382  NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_ << ": EndLoop()";
383  if (loop_stop_stage_ > 0) {
384    const MockRead& next_read = StaticSocketDataProvider::PeekRead();
385    if ((next_read.sequence_number & ~MockRead::STOPLOOP) >
386        loop_stop_stage_) {
387      NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_
388                                << ": Clearing stop index";
389      loop_stop_stage_ = 0;
390    } else {
391      return;
392    }
393  }
394  // Record the sequence_number at which we stopped the loop.
395  NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_
396                            << ": Posting Quit at read " << read_index();
397  loop_stop_stage_ = sequence_number_;
398}
399
400MockRead OrderedSocketData::GetNextRead() {
401  weak_factory_.InvalidateWeakPtrs();
402  blocked_ = false;
403  const MockRead& next_read = StaticSocketDataProvider::PeekRead();
404  if (next_read.sequence_number & MockRead::STOPLOOP)
405    EndLoop();
406  if ((next_read.sequence_number & ~MockRead::STOPLOOP) <=
407      sequence_number_++) {
408    NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_ - 1
409                              << ": Read " << read_index();
410    DumpMockReadWrite(next_read);
411    blocked_ = (next_read.result == ERR_IO_PENDING);
412    return StaticSocketDataProvider::GetNextRead();
413  }
414  NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_ - 1
415                            << ": I/O Pending";
416  MockRead result = MockRead(ASYNC, ERR_IO_PENDING);
417  DumpMockReadWrite(result);
418  blocked_ = true;
419  return result;
420}
421
422MockWriteResult OrderedSocketData::OnWrite(const std::string& data) {
423  NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_
424                            << ": Write " << write_index();
425  DumpMockReadWrite(PeekWrite());
426  ++sequence_number_;
427  if (blocked_) {
428    // TODO(willchan): This 100ms delay seems to work around some weirdness.  We
429    // should probably fix the weirdness.  One example is in SpdyStream,
430    // DoSendRequest() will return ERR_IO_PENDING, and there's a race.  If the
431    // SYN_REPLY causes OnResponseReceived() to get called before
432    // SpdyStream::ReadResponseHeaders() is called, we hit a NOTREACHED().
433    base::MessageLoop::current()->PostDelayedTask(
434        FROM_HERE,
435        base::Bind(&OrderedSocketData::CompleteRead,
436                   weak_factory_.GetWeakPtr()),
437        base::TimeDelta::FromMilliseconds(100));
438  }
439  return StaticSocketDataProvider::OnWrite(data);
440}
441
442void OrderedSocketData::Reset() {
443  NET_TRACE(INFO, "  *** ") << "Stage "
444                            << sequence_number_ << ": Reset()";
445  sequence_number_ = 0;
446  loop_stop_stage_ = 0;
447  set_socket(NULL);
448  weak_factory_.InvalidateWeakPtrs();
449  StaticSocketDataProvider::Reset();
450}
451
452void OrderedSocketData::CompleteRead() {
453  if (socket() && blocked_) {
454    NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_;
455    socket()->OnReadComplete(GetNextRead());
456  }
457}
458
459OrderedSocketData::~OrderedSocketData() {}
460
461DeterministicSocketData::DeterministicSocketData(MockRead* reads,
462    size_t reads_count, MockWrite* writes, size_t writes_count)
463    : StaticSocketDataProvider(reads, reads_count, writes, writes_count),
464      sequence_number_(0),
465      current_read_(),
466      current_write_(),
467      stopping_sequence_number_(0),
468      stopped_(false),
469      print_debug_(false),
470      is_running_(false) {
471  VerifyCorrectSequenceNumbers(reads, reads_count, writes, writes_count);
472}
473
474DeterministicSocketData::~DeterministicSocketData() {}
475
476void DeterministicSocketData::Run() {
477  DCHECK(!is_running_);
478  is_running_ = true;
479
480  SetStopped(false);
481  int counter = 0;
482  // Continue to consume data until all data has run out, or the stopped_ flag
483  // has been set. Consuming data requires two separate operations -- running
484  // the tasks in the message loop, and explicitly invoking the read/write
485  // callbacks (simulating network I/O). We check our conditions between each,
486  // since they can change in either.
487  while ((!at_write_eof() || !at_read_eof()) && !stopped()) {
488    if (counter % 2 == 0)
489      base::RunLoop().RunUntilIdle();
490    if (counter % 2 == 1) {
491      InvokeCallbacks();
492    }
493    counter++;
494  }
495  // We're done consuming new data, but it is possible there are still some
496  // pending callbacks which we expect to complete before returning.
497  while (delegate_.get() &&
498         (delegate_->WritePending() || delegate_->ReadPending()) &&
499         !stopped()) {
500    InvokeCallbacks();
501    base::RunLoop().RunUntilIdle();
502  }
503  SetStopped(false);
504  is_running_ = false;
505}
506
507void DeterministicSocketData::RunFor(int steps) {
508  StopAfter(steps);
509  Run();
510}
511
512void DeterministicSocketData::SetStop(int seq) {
513  DCHECK_LT(sequence_number_, seq);
514  stopping_sequence_number_ = seq;
515  stopped_ = false;
516}
517
518void DeterministicSocketData::StopAfter(int seq) {
519  SetStop(sequence_number_ + seq);
520}
521
522MockRead DeterministicSocketData::GetNextRead() {
523  current_read_ = StaticSocketDataProvider::PeekRead();
524
525  // Synchronous read while stopped is an error
526  if (stopped() && current_read_.mode == SYNCHRONOUS) {
527    LOG(ERROR) << "Unable to perform synchronous IO while stopped";
528    return MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
529  }
530
531  // Async read which will be called back in a future step.
532  if (sequence_number_ < current_read_.sequence_number) {
533    NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_
534                              << ": I/O Pending";
535    MockRead result = MockRead(SYNCHRONOUS, ERR_IO_PENDING);
536    if (current_read_.mode == SYNCHRONOUS) {
537      LOG(ERROR) << "Unable to perform synchronous read: "
538          << current_read_.sequence_number
539          << " at stage: " << sequence_number_;
540      result = MockRead(SYNCHRONOUS, ERR_UNEXPECTED);
541    }
542    if (print_debug_)
543      DumpMockReadWrite(result);
544    return result;
545  }
546
547  NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_
548                            << ": Read " << read_index();
549  if (print_debug_)
550    DumpMockReadWrite(current_read_);
551
552  // Increment the sequence number if IO is complete
553  if (current_read_.mode == SYNCHRONOUS)
554    NextStep();
555
556  DCHECK_NE(ERR_IO_PENDING, current_read_.result);
557  StaticSocketDataProvider::GetNextRead();
558
559  return current_read_;
560}
561
562MockWriteResult DeterministicSocketData::OnWrite(const std::string& data) {
563  const MockWrite& next_write = StaticSocketDataProvider::PeekWrite();
564  current_write_ = next_write;
565
566  // Synchronous write while stopped is an error
567  if (stopped() && next_write.mode == SYNCHRONOUS) {
568    LOG(ERROR) << "Unable to perform synchronous IO while stopped";
569    return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
570  }
571
572  // Async write which will be called back in a future step.
573  if (sequence_number_ < next_write.sequence_number) {
574    NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_
575                              << ": I/O Pending";
576    if (next_write.mode == SYNCHRONOUS) {
577      LOG(ERROR) << "Unable to perform synchronous write: "
578          << next_write.sequence_number << " at stage: " << sequence_number_;
579      return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED);
580    }
581  } else {
582    NET_TRACE(INFO, "  *** ") << "Stage " << sequence_number_
583                              << ": Write " << write_index();
584  }
585
586  if (print_debug_)
587    DumpMockReadWrite(next_write);
588
589  // Move to the next step if I/O is synchronous, since the operation will
590  // complete when this method returns.
591  if (next_write.mode == SYNCHRONOUS)
592    NextStep();
593
594  // This is either a sync write for this step, or an async write.
595  return StaticSocketDataProvider::OnWrite(data);
596}
597
598void DeterministicSocketData::Reset() {
599  NET_TRACE(INFO, "  *** ") << "Stage "
600                            << sequence_number_ << ": Reset()";
601  sequence_number_ = 0;
602  StaticSocketDataProvider::Reset();
603  NOTREACHED();
604}
605
606void DeterministicSocketData::InvokeCallbacks() {
607  if (delegate_.get() && delegate_->WritePending() &&
608      (current_write().sequence_number == sequence_number())) {
609    NextStep();
610    delegate_->CompleteWrite();
611    return;
612  }
613  if (delegate_.get() && delegate_->ReadPending() &&
614      (current_read().sequence_number == sequence_number())) {
615    NextStep();
616    delegate_->CompleteRead();
617    return;
618  }
619}
620
621void DeterministicSocketData::NextStep() {
622  // Invariant: Can never move *past* the stopping step.
623  DCHECK_LT(sequence_number_, stopping_sequence_number_);
624  sequence_number_++;
625  if (sequence_number_ == stopping_sequence_number_)
626    SetStopped(true);
627}
628
629void DeterministicSocketData::VerifyCorrectSequenceNumbers(
630    MockRead* reads, size_t reads_count,
631    MockWrite* writes, size_t writes_count) {
632  size_t read = 0;
633  size_t write = 0;
634  int expected = 0;
635  while (read < reads_count || write < writes_count) {
636    // Check to see that we have a read or write at the expected
637    // state.
638    if (read < reads_count  && reads[read].sequence_number == expected) {
639      ++read;
640      ++expected;
641      continue;
642    }
643    if (write < writes_count && writes[write].sequence_number == expected) {
644      ++write;
645      ++expected;
646      continue;
647    }
648    NOTREACHED() << "Missing sequence number: " << expected;
649    return;
650  }
651  DCHECK_EQ(read, reads_count);
652  DCHECK_EQ(write, writes_count);
653}
654
655MockClientSocketFactory::MockClientSocketFactory() {}
656
657MockClientSocketFactory::~MockClientSocketFactory() {}
658
659void MockClientSocketFactory::AddSocketDataProvider(
660    SocketDataProvider* data) {
661  mock_data_.Add(data);
662}
663
664void MockClientSocketFactory::AddSSLSocketDataProvider(
665    SSLSocketDataProvider* data) {
666  mock_ssl_data_.Add(data);
667}
668
669void MockClientSocketFactory::ResetNextMockIndexes() {
670  mock_data_.ResetNextIndex();
671  mock_ssl_data_.ResetNextIndex();
672}
673
674scoped_ptr<DatagramClientSocket>
675MockClientSocketFactory::CreateDatagramClientSocket(
676    DatagramSocket::BindType bind_type,
677    const RandIntCallback& rand_int_cb,
678    net::NetLog* net_log,
679    const net::NetLog::Source& source) {
680  SocketDataProvider* data_provider = mock_data_.GetNext();
681  scoped_ptr<MockUDPClientSocket> socket(
682      new MockUDPClientSocket(data_provider, net_log));
683  data_provider->set_socket(socket.get());
684  if (bind_type == DatagramSocket::RANDOM_BIND)
685    socket->set_source_port(rand_int_cb.Run(1025, 65535));
686  return socket.PassAs<DatagramClientSocket>();
687}
688
689scoped_ptr<StreamSocket> MockClientSocketFactory::CreateTransportClientSocket(
690    const AddressList& addresses,
691    net::NetLog* net_log,
692    const net::NetLog::Source& source) {
693  SocketDataProvider* data_provider = mock_data_.GetNext();
694  scoped_ptr<MockTCPClientSocket> socket(
695      new MockTCPClientSocket(addresses, net_log, data_provider));
696  data_provider->set_socket(socket.get());
697  return socket.PassAs<StreamSocket>();
698}
699
700scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket(
701    scoped_ptr<ClientSocketHandle> transport_socket,
702    const HostPortPair& host_and_port,
703    const SSLConfig& ssl_config,
704    const SSLClientSocketContext& context) {
705  scoped_ptr<MockSSLClientSocket> socket(
706      new MockSSLClientSocket(transport_socket.Pass(),
707                              host_and_port,
708                              ssl_config,
709                              mock_ssl_data_.GetNext()));
710  ssl_client_sockets_.push_back(socket.get());
711  return socket.PassAs<SSLClientSocket>();
712}
713
714void MockClientSocketFactory::ClearSSLSessionCache() {
715}
716
717const char MockClientSocket::kTlsUnique[] = "MOCK_TLSUNIQ";
718
719MockClientSocket::MockClientSocket(const BoundNetLog& net_log)
720    : connected_(false),
721      net_log_(net_log),
722      weak_factory_(this) {
723  IPAddressNumber ip;
724  CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip));
725  peer_addr_ = IPEndPoint(ip, 0);
726}
727
728int MockClientSocket::SetReceiveBufferSize(int32 size) {
729  return OK;
730}
731
732int MockClientSocket::SetSendBufferSize(int32 size) {
733  return OK;
734}
735
736void MockClientSocket::Disconnect() {
737  connected_ = false;
738}
739
740bool MockClientSocket::IsConnected() const {
741  return connected_;
742}
743
744bool MockClientSocket::IsConnectedAndIdle() const {
745  return connected_;
746}
747
748int MockClientSocket::GetPeerAddress(IPEndPoint* address) const {
749  if (!IsConnected())
750    return ERR_SOCKET_NOT_CONNECTED;
751  *address = peer_addr_;
752  return OK;
753}
754
755int MockClientSocket::GetLocalAddress(IPEndPoint* address) const {
756  IPAddressNumber ip;
757  bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip);
758  CHECK(rv);
759  *address = IPEndPoint(ip, 123);
760  return OK;
761}
762
763const BoundNetLog& MockClientSocket::NetLog() const {
764  return net_log_;
765}
766
767std::string MockClientSocket::GetSessionCacheKey() const {
768  NOTIMPLEMENTED();
769  return std::string();
770}
771
772bool MockClientSocket::InSessionCache() const {
773  NOTIMPLEMENTED();
774  return false;
775}
776
777void MockClientSocket::SetHandshakeCompletionCallback(const base::Closure& cb) {
778  NOTIMPLEMENTED();
779}
780
781void MockClientSocket::GetSSLCertRequestInfo(
782  SSLCertRequestInfo* cert_request_info) {
783}
784
785int MockClientSocket::ExportKeyingMaterial(const base::StringPiece& label,
786                                           bool has_context,
787                                           const base::StringPiece& context,
788                                           unsigned char* out,
789                                           unsigned int outlen) {
790  memset(out, 'A', outlen);
791  return OK;
792}
793
794int MockClientSocket::GetTLSUniqueChannelBinding(std::string* out) {
795  out->assign(MockClientSocket::kTlsUnique);
796  return OK;
797}
798
799ChannelIDService* MockClientSocket::GetChannelIDService() const {
800  NOTREACHED();
801  return NULL;
802}
803
804SSLClientSocket::NextProtoStatus
805MockClientSocket::GetNextProto(std::string* proto) {
806  proto->clear();
807  return SSLClientSocket::kNextProtoUnsupported;
808}
809
810scoped_refptr<X509Certificate>
811MockClientSocket::GetUnverifiedServerCertificateChain() const {
812  NOTREACHED();
813  return NULL;
814}
815
816MockClientSocket::~MockClientSocket() {}
817
818void MockClientSocket::RunCallbackAsync(const CompletionCallback& callback,
819                                        int result) {
820  base::MessageLoop::current()->PostTask(
821      FROM_HERE,
822      base::Bind(&MockClientSocket::RunCallback,
823                 weak_factory_.GetWeakPtr(),
824                 callback,
825                 result));
826}
827
828void MockClientSocket::RunCallback(const net::CompletionCallback& callback,
829                                   int result) {
830  if (!callback.is_null())
831    callback.Run(result);
832}
833
834MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses,
835                                         net::NetLog* net_log,
836                                         SocketDataProvider* data)
837    : MockClientSocket(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)),
838      addresses_(addresses),
839      data_(data),
840      read_offset_(0),
841      read_data_(SYNCHRONOUS, ERR_UNEXPECTED),
842      need_read_data_(true),
843      peer_closed_connection_(false),
844      pending_buf_(NULL),
845      pending_buf_len_(0),
846      was_used_to_convey_data_(false) {
847  DCHECK(data_);
848  peer_addr_ = data->connect_data().peer_addr;
849  data_->Reset();
850}
851
852MockTCPClientSocket::~MockTCPClientSocket() {}
853
854int MockTCPClientSocket::Read(IOBuffer* buf, int buf_len,
855                              const CompletionCallback& callback) {
856  if (!connected_)
857    return ERR_UNEXPECTED;
858
859  // If the buffer is already in use, a read is already in progress!
860  DCHECK(pending_buf_.get() == NULL);
861
862  // Store our async IO data.
863  pending_buf_ = buf;
864  pending_buf_len_ = buf_len;
865  pending_callback_ = callback;
866
867  if (need_read_data_) {
868    read_data_ = data_->GetNextRead();
869    if (read_data_.result == ERR_CONNECTION_CLOSED) {
870      // This MockRead is just a marker to instruct us to set
871      // peer_closed_connection_.
872      peer_closed_connection_ = true;
873    }
874    if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
875      // This MockRead is just a marker to instruct us to set
876      // peer_closed_connection_.  Skip it and get the next one.
877      read_data_ = data_->GetNextRead();
878      peer_closed_connection_ = true;
879    }
880    // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
881    // to complete the async IO manually later (via OnReadComplete).
882    if (read_data_.result == ERR_IO_PENDING) {
883      // We need to be using async IO in this case.
884      DCHECK(!callback.is_null());
885      return ERR_IO_PENDING;
886    }
887    need_read_data_ = false;
888  }
889
890  return CompleteRead();
891}
892
893int MockTCPClientSocket::Write(IOBuffer* buf, int buf_len,
894                               const CompletionCallback& callback) {
895  DCHECK(buf);
896  DCHECK_GT(buf_len, 0);
897
898  if (!connected_)
899    return ERR_UNEXPECTED;
900
901  std::string data(buf->data(), buf_len);
902  MockWriteResult write_result = data_->OnWrite(data);
903
904  was_used_to_convey_data_ = true;
905
906  if (write_result.mode == ASYNC) {
907    RunCallbackAsync(callback, write_result.result);
908    return ERR_IO_PENDING;
909  }
910
911  return write_result.result;
912}
913
914int MockTCPClientSocket::Connect(const CompletionCallback& callback) {
915  if (connected_)
916    return OK;
917  connected_ = true;
918  peer_closed_connection_ = false;
919  if (data_->connect_data().mode == ASYNC) {
920    if (data_->connect_data().result == ERR_IO_PENDING)
921      pending_callback_ = callback;
922    else
923      RunCallbackAsync(callback, data_->connect_data().result);
924    return ERR_IO_PENDING;
925  }
926  return data_->connect_data().result;
927}
928
929void MockTCPClientSocket::Disconnect() {
930  MockClientSocket::Disconnect();
931  pending_callback_.Reset();
932}
933
934bool MockTCPClientSocket::IsConnected() const {
935  return connected_ && !peer_closed_connection_;
936}
937
938bool MockTCPClientSocket::IsConnectedAndIdle() const {
939  return IsConnected();
940}
941
942int MockTCPClientSocket::GetPeerAddress(IPEndPoint* address) const {
943  if (addresses_.empty())
944    return MockClientSocket::GetPeerAddress(address);
945
946  *address = addresses_[0];
947  return OK;
948}
949
950bool MockTCPClientSocket::WasEverUsed() const {
951  return was_used_to_convey_data_;
952}
953
954bool MockTCPClientSocket::UsingTCPFastOpen() const {
955  return false;
956}
957
958bool MockTCPClientSocket::WasNpnNegotiated() const {
959  return false;
960}
961
962bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
963  return false;
964}
965
966void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
967  // There must be a read pending.
968  DCHECK(pending_buf_.get());
969  // You can't complete a read with another ERR_IO_PENDING status code.
970  DCHECK_NE(ERR_IO_PENDING, data.result);
971  // Since we've been waiting for data, need_read_data_ should be true.
972  DCHECK(need_read_data_);
973
974  read_data_ = data;
975  need_read_data_ = false;
976
977  // The caller is simulating that this IO completes right now.  Don't
978  // let CompleteRead() schedule a callback.
979  read_data_.mode = SYNCHRONOUS;
980
981  CompletionCallback callback = pending_callback_;
982  int rv = CompleteRead();
983  RunCallback(callback, rv);
984}
985
986void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) {
987  CompletionCallback callback = pending_callback_;
988  RunCallback(callback, data.result);
989}
990
991int MockTCPClientSocket::CompleteRead() {
992  DCHECK(pending_buf_.get());
993  DCHECK(pending_buf_len_ > 0);
994
995  was_used_to_convey_data_ = true;
996
997  // Save the pending async IO data and reset our |pending_| state.
998  scoped_refptr<IOBuffer> buf = pending_buf_;
999  int buf_len = pending_buf_len_;
1000  CompletionCallback callback = pending_callback_;
1001  pending_buf_ = NULL;
1002  pending_buf_len_ = 0;
1003  pending_callback_.Reset();
1004
1005  int result = read_data_.result;
1006  DCHECK(result != ERR_IO_PENDING);
1007
1008  if (read_data_.data) {
1009    if (read_data_.data_len - read_offset_ > 0) {
1010      result = std::min(buf_len, read_data_.data_len - read_offset_);
1011      memcpy(buf->data(), read_data_.data + read_offset_, result);
1012      read_offset_ += result;
1013      if (read_offset_ == read_data_.data_len) {
1014        need_read_data_ = true;
1015        read_offset_ = 0;
1016      }
1017    } else {
1018      result = 0;  // EOF
1019    }
1020  }
1021
1022  if (read_data_.mode == ASYNC) {
1023    DCHECK(!callback.is_null());
1024    RunCallbackAsync(callback, result);
1025    return ERR_IO_PENDING;
1026  }
1027  return result;
1028}
1029
1030DeterministicSocketHelper::DeterministicSocketHelper(
1031    net::NetLog* net_log,
1032    DeterministicSocketData* data)
1033    : write_pending_(false),
1034      write_result_(0),
1035      read_data_(),
1036      read_buf_(NULL),
1037      read_buf_len_(0),
1038      read_pending_(false),
1039      data_(data),
1040      was_used_to_convey_data_(false),
1041      peer_closed_connection_(false),
1042      net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)) {
1043}
1044
1045DeterministicSocketHelper::~DeterministicSocketHelper() {}
1046
1047void DeterministicSocketHelper::CompleteWrite() {
1048  was_used_to_convey_data_ = true;
1049  write_pending_ = false;
1050  write_callback_.Run(write_result_);
1051}
1052
1053int DeterministicSocketHelper::CompleteRead() {
1054  DCHECK_GT(read_buf_len_, 0);
1055  DCHECK_LE(read_data_.data_len, read_buf_len_);
1056  DCHECK(read_buf_);
1057
1058  was_used_to_convey_data_ = true;
1059
1060  if (read_data_.result == ERR_IO_PENDING)
1061    read_data_ = data_->GetNextRead();
1062  DCHECK_NE(ERR_IO_PENDING, read_data_.result);
1063  // If read_data_.mode is ASYNC, we do not need to wait, since this is already
1064  // the callback. Therefore we don't even bother to check it.
1065  int result = read_data_.result;
1066
1067  if (read_data_.data_len > 0) {
1068    DCHECK(read_data_.data);
1069    result = std::min(read_buf_len_, read_data_.data_len);
1070    memcpy(read_buf_->data(), read_data_.data, result);
1071  }
1072
1073  if (read_pending_) {
1074    read_pending_ = false;
1075    read_callback_.Run(result);
1076  }
1077
1078  return result;
1079}
1080
1081int DeterministicSocketHelper::Write(
1082    IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
1083  DCHECK(buf);
1084  DCHECK_GT(buf_len, 0);
1085
1086  std::string data(buf->data(), buf_len);
1087  MockWriteResult write_result = data_->OnWrite(data);
1088
1089  if (write_result.mode == ASYNC) {
1090    write_callback_ = callback;
1091    write_result_ = write_result.result;
1092    DCHECK(!write_callback_.is_null());
1093    write_pending_ = true;
1094    return ERR_IO_PENDING;
1095  }
1096
1097  was_used_to_convey_data_ = true;
1098  write_pending_ = false;
1099  return write_result.result;
1100}
1101
1102int DeterministicSocketHelper::Read(
1103    IOBuffer* buf, int buf_len, const CompletionCallback& callback) {
1104
1105  read_data_ = data_->GetNextRead();
1106  // The buffer should always be big enough to contain all the MockRead data. To
1107  // use small buffers, split the data into multiple MockReads.
1108  DCHECK_LE(read_data_.data_len, buf_len);
1109
1110  if (read_data_.result == ERR_CONNECTION_CLOSED) {
1111    // This MockRead is just a marker to instruct us to set
1112    // peer_closed_connection_.
1113    peer_closed_connection_ = true;
1114  }
1115  if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
1116    // This MockRead is just a marker to instruct us to set
1117    // peer_closed_connection_.  Skip it and get the next one.
1118    read_data_ = data_->GetNextRead();
1119    peer_closed_connection_ = true;
1120  }
1121
1122  read_buf_ = buf;
1123  read_buf_len_ = buf_len;
1124  read_callback_ = callback;
1125
1126  if (read_data_.mode == ASYNC || (read_data_.result == ERR_IO_PENDING)) {
1127    read_pending_ = true;
1128    DCHECK(!read_callback_.is_null());
1129    return ERR_IO_PENDING;
1130  }
1131
1132  was_used_to_convey_data_ = true;
1133  return CompleteRead();
1134}
1135
1136DeterministicMockUDPClientSocket::DeterministicMockUDPClientSocket(
1137    net::NetLog* net_log,
1138    DeterministicSocketData* data)
1139    : connected_(false),
1140      helper_(net_log, data),
1141      source_port_(123) {
1142}
1143
1144DeterministicMockUDPClientSocket::~DeterministicMockUDPClientSocket() {}
1145
1146bool DeterministicMockUDPClientSocket::WritePending() const {
1147  return helper_.write_pending();
1148}
1149
1150bool DeterministicMockUDPClientSocket::ReadPending() const {
1151  return helper_.read_pending();
1152}
1153
1154void DeterministicMockUDPClientSocket::CompleteWrite() {
1155  helper_.CompleteWrite();
1156}
1157
1158int DeterministicMockUDPClientSocket::CompleteRead() {
1159  return helper_.CompleteRead();
1160}
1161
1162int DeterministicMockUDPClientSocket::Connect(const IPEndPoint& address) {
1163  if (connected_)
1164    return OK;
1165  connected_ = true;
1166  peer_address_ = address;
1167  return helper_.data()->connect_data().result;
1168};
1169
1170int DeterministicMockUDPClientSocket::Write(
1171    IOBuffer* buf,
1172    int buf_len,
1173    const CompletionCallback& callback) {
1174  if (!connected_)
1175    return ERR_UNEXPECTED;
1176
1177  return helper_.Write(buf, buf_len, callback);
1178}
1179
1180int DeterministicMockUDPClientSocket::Read(
1181    IOBuffer* buf,
1182    int buf_len,
1183    const CompletionCallback& callback) {
1184  if (!connected_)
1185    return ERR_UNEXPECTED;
1186
1187  return helper_.Read(buf, buf_len, callback);
1188}
1189
1190int DeterministicMockUDPClientSocket::SetReceiveBufferSize(int32 size) {
1191  return OK;
1192}
1193
1194int DeterministicMockUDPClientSocket::SetSendBufferSize(int32 size) {
1195  return OK;
1196}
1197
1198void DeterministicMockUDPClientSocket::Close() {
1199  connected_ = false;
1200}
1201
1202int DeterministicMockUDPClientSocket::GetPeerAddress(
1203    IPEndPoint* address) const {
1204  *address = peer_address_;
1205  return OK;
1206}
1207
1208int DeterministicMockUDPClientSocket::GetLocalAddress(
1209    IPEndPoint* address) const {
1210  IPAddressNumber ip;
1211  bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip);
1212  CHECK(rv);
1213  *address = IPEndPoint(ip, source_port_);
1214  return OK;
1215}
1216
1217const BoundNetLog& DeterministicMockUDPClientSocket::NetLog() const {
1218  return helper_.net_log();
1219}
1220
1221void DeterministicMockUDPClientSocket::OnReadComplete(const MockRead& data) {}
1222
1223void DeterministicMockUDPClientSocket::OnConnectComplete(
1224    const MockConnect& data) {
1225  NOTIMPLEMENTED();
1226}
1227
1228DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket(
1229    net::NetLog* net_log,
1230    DeterministicSocketData* data)
1231    : MockClientSocket(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)),
1232      helper_(net_log, data) {
1233  peer_addr_ = data->connect_data().peer_addr;
1234}
1235
1236DeterministicMockTCPClientSocket::~DeterministicMockTCPClientSocket() {}
1237
1238bool DeterministicMockTCPClientSocket::WritePending() const {
1239  return helper_.write_pending();
1240}
1241
1242bool DeterministicMockTCPClientSocket::ReadPending() const {
1243  return helper_.read_pending();
1244}
1245
1246void DeterministicMockTCPClientSocket::CompleteWrite() {
1247  helper_.CompleteWrite();
1248}
1249
1250int DeterministicMockTCPClientSocket::CompleteRead() {
1251  return helper_.CompleteRead();
1252}
1253
1254int DeterministicMockTCPClientSocket::Write(
1255    IOBuffer* buf,
1256    int buf_len,
1257    const CompletionCallback& callback) {
1258  if (!connected_)
1259    return ERR_UNEXPECTED;
1260
1261  return helper_.Write(buf, buf_len, callback);
1262}
1263
1264int DeterministicMockTCPClientSocket::Read(
1265    IOBuffer* buf,
1266    int buf_len,
1267    const CompletionCallback& callback) {
1268  if (!connected_)
1269    return ERR_UNEXPECTED;
1270
1271  return helper_.Read(buf, buf_len, callback);
1272}
1273
1274// TODO(erikchen): Support connect sequencing.
1275int DeterministicMockTCPClientSocket::Connect(
1276    const CompletionCallback& callback) {
1277  if (connected_)
1278    return OK;
1279  connected_ = true;
1280  if (helper_.data()->connect_data().mode == ASYNC) {
1281    RunCallbackAsync(callback, helper_.data()->connect_data().result);
1282    return ERR_IO_PENDING;
1283  }
1284  return helper_.data()->connect_data().result;
1285}
1286
1287void DeterministicMockTCPClientSocket::Disconnect() {
1288  MockClientSocket::Disconnect();
1289}
1290
1291bool DeterministicMockTCPClientSocket::IsConnected() const {
1292  return connected_ && !helper_.peer_closed_connection();
1293}
1294
1295bool DeterministicMockTCPClientSocket::IsConnectedAndIdle() const {
1296  return IsConnected();
1297}
1298
1299bool DeterministicMockTCPClientSocket::WasEverUsed() const {
1300  return helper_.was_used_to_convey_data();
1301}
1302
1303bool DeterministicMockTCPClientSocket::UsingTCPFastOpen() const {
1304  return false;
1305}
1306
1307bool DeterministicMockTCPClientSocket::WasNpnNegotiated() const {
1308  return false;
1309}
1310
1311bool DeterministicMockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
1312  return false;
1313}
1314
1315void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {}
1316
1317void DeterministicMockTCPClientSocket::OnConnectComplete(
1318    const MockConnect& data) {}
1319
1320MockSSLClientSocket::MockSSLClientSocket(
1321    scoped_ptr<ClientSocketHandle> transport_socket,
1322    const HostPortPair& host_port_pair,
1323    const SSLConfig& ssl_config,
1324    SSLSocketDataProvider* data)
1325    : MockClientSocket(
1326          // Have to use the right BoundNetLog for LoadTimingInfo regression
1327          // tests.
1328          transport_socket->socket()->NetLog()),
1329      transport_(transport_socket.Pass()),
1330      host_port_pair_(host_port_pair),
1331      data_(data),
1332      is_npn_state_set_(false),
1333      new_npn_value_(false),
1334      is_protocol_negotiated_set_(false),
1335      protocol_negotiated_(kProtoUnknown),
1336      next_connect_state_(STATE_NONE),
1337      reached_connect_(false),
1338      weak_factory_(this) {
1339  DCHECK(data_);
1340  peer_addr_ = data->connect.peer_addr;
1341}
1342
1343MockSSLClientSocket::~MockSSLClientSocket() {
1344  Disconnect();
1345}
1346
1347int MockSSLClientSocket::Read(IOBuffer* buf, int buf_len,
1348                              const CompletionCallback& callback) {
1349  return transport_->socket()->Read(buf, buf_len, callback);
1350}
1351
1352int MockSSLClientSocket::Write(IOBuffer* buf, int buf_len,
1353                               const CompletionCallback& callback) {
1354  return transport_->socket()->Write(buf, buf_len, callback);
1355}
1356
1357int MockSSLClientSocket::Connect(const CompletionCallback& callback) {
1358  next_connect_state_ = STATE_SSL_CONNECT;
1359  reached_connect_ = true;
1360  int rv = DoConnectLoop(OK);
1361  if (rv == ERR_IO_PENDING)
1362    connect_callback_ = callback;
1363  return rv;
1364}
1365
1366void MockSSLClientSocket::Disconnect() {
1367  weak_factory_.InvalidateWeakPtrs();
1368  MockClientSocket::Disconnect();
1369  if (transport_->socket() != NULL)
1370    transport_->socket()->Disconnect();
1371}
1372
1373bool MockSSLClientSocket::IsConnected() const {
1374  return transport_->socket()->IsConnected() && connected_;
1375}
1376
1377bool MockSSLClientSocket::WasEverUsed() const {
1378  return transport_->socket()->WasEverUsed();
1379}
1380
1381bool MockSSLClientSocket::UsingTCPFastOpen() const {
1382  return transport_->socket()->UsingTCPFastOpen();
1383}
1384
1385int MockSSLClientSocket::GetPeerAddress(IPEndPoint* address) const {
1386  return transport_->socket()->GetPeerAddress(address);
1387}
1388
1389bool MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) {
1390  ssl_info->Reset();
1391  ssl_info->cert = data_->cert;
1392  ssl_info->client_cert_sent = data_->client_cert_sent;
1393  ssl_info->channel_id_sent = data_->channel_id_sent;
1394  ssl_info->connection_status = data_->connection_status;
1395  return true;
1396}
1397
1398std::string MockSSLClientSocket::GetSessionCacheKey() const {
1399  // For the purposes of these tests, |host_and_port| will serve as the
1400  // cache key.
1401  return host_port_pair_.ToString();
1402}
1403
1404bool MockSSLClientSocket::InSessionCache() const {
1405  return data_->is_in_session_cache;
1406}
1407
1408void MockSSLClientSocket::SetHandshakeCompletionCallback(
1409    const base::Closure& cb) {
1410  handshake_completion_callback_ = cb;
1411}
1412
1413void MockSSLClientSocket::GetSSLCertRequestInfo(
1414    SSLCertRequestInfo* cert_request_info) {
1415  DCHECK(cert_request_info);
1416  if (data_->cert_request_info) {
1417    cert_request_info->host_and_port =
1418        data_->cert_request_info->host_and_port;
1419    cert_request_info->client_certs = data_->cert_request_info->client_certs;
1420  } else {
1421    cert_request_info->Reset();
1422  }
1423}
1424
1425SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto(
1426    std::string* proto) {
1427  *proto = data_->next_proto;
1428  return data_->next_proto_status;
1429}
1430
1431bool MockSSLClientSocket::set_was_npn_negotiated(bool negotiated) {
1432  is_npn_state_set_ = true;
1433  return new_npn_value_ = negotiated;
1434}
1435
1436bool MockSSLClientSocket::WasNpnNegotiated() const {
1437  if (is_npn_state_set_)
1438    return new_npn_value_;
1439  return data_->was_npn_negotiated;
1440}
1441
1442NextProto MockSSLClientSocket::GetNegotiatedProtocol() const {
1443  if (is_protocol_negotiated_set_)
1444    return protocol_negotiated_;
1445  return data_->protocol_negotiated;
1446}
1447
1448void MockSSLClientSocket::set_protocol_negotiated(
1449    NextProto protocol_negotiated) {
1450  is_protocol_negotiated_set_ = true;
1451  protocol_negotiated_ = protocol_negotiated;
1452}
1453
1454bool MockSSLClientSocket::WasChannelIDSent() const {
1455  return data_->channel_id_sent;
1456}
1457
1458void MockSSLClientSocket::set_channel_id_sent(bool channel_id_sent) {
1459  data_->channel_id_sent = channel_id_sent;
1460}
1461
1462ChannelIDService* MockSSLClientSocket::GetChannelIDService() const {
1463  return data_->channel_id_service;
1464}
1465
1466void MockSSLClientSocket::OnReadComplete(const MockRead& data) {
1467  NOTIMPLEMENTED();
1468}
1469
1470void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) {
1471  NOTIMPLEMENTED();
1472}
1473
1474void MockSSLClientSocket::RestartPausedConnect() {
1475  DCHECK(data_->should_pause_on_connect);
1476  DCHECK_EQ(next_connect_state_, STATE_SSL_CONNECT_COMPLETE);
1477  OnIOComplete(data_->connect.result);
1478}
1479
1480void MockSSLClientSocket::OnIOComplete(int result) {
1481  int rv = DoConnectLoop(result);
1482  if (rv != ERR_IO_PENDING)
1483    base::ResetAndReturn(&connect_callback_).Run(rv);
1484}
1485
1486int MockSSLClientSocket::DoConnectLoop(int result) {
1487  DCHECK_NE(next_connect_state_, STATE_NONE);
1488
1489  int rv = result;
1490  do {
1491    ConnectState state = next_connect_state_;
1492    next_connect_state_ = STATE_NONE;
1493    switch (state) {
1494      case STATE_SSL_CONNECT:
1495        rv = DoSSLConnect();
1496        break;
1497      case STATE_SSL_CONNECT_COMPLETE:
1498        rv = DoSSLConnectComplete(rv);
1499        break;
1500      default:
1501        NOTREACHED() << "bad state";
1502        rv = ERR_UNEXPECTED;
1503        break;
1504    }
1505  } while (rv != ERR_IO_PENDING && next_connect_state_ != STATE_NONE);
1506
1507  return rv;
1508}
1509
1510int MockSSLClientSocket::DoSSLConnect() {
1511  next_connect_state_ = STATE_SSL_CONNECT_COMPLETE;
1512
1513  if (data_->should_pause_on_connect)
1514    return ERR_IO_PENDING;
1515
1516  if (data_->connect.mode == ASYNC) {
1517    base::MessageLoop::current()->PostTask(
1518        FROM_HERE,
1519        base::Bind(&MockSSLClientSocket::OnIOComplete,
1520                   weak_factory_.GetWeakPtr(),
1521                   data_->connect.result));
1522    return ERR_IO_PENDING;
1523  }
1524
1525  return data_->connect.result;
1526}
1527
1528int MockSSLClientSocket::DoSSLConnectComplete(int result) {
1529  if (result == OK)
1530    connected_ = true;
1531
1532  if (!handshake_completion_callback_.is_null())
1533    base::ResetAndReturn(&handshake_completion_callback_).Run();
1534  return result;
1535}
1536
1537MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data,
1538                                         net::NetLog* net_log)
1539    : connected_(false),
1540      data_(data),
1541      read_offset_(0),
1542      read_data_(SYNCHRONOUS, ERR_UNEXPECTED),
1543      need_read_data_(true),
1544      source_port_(123),
1545      pending_buf_(NULL),
1546      pending_buf_len_(0),
1547      net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)),
1548      weak_factory_(this) {
1549  DCHECK(data_);
1550  data_->Reset();
1551  peer_addr_ = data->connect_data().peer_addr;
1552}
1553
1554MockUDPClientSocket::~MockUDPClientSocket() {}
1555
1556int MockUDPClientSocket::Read(IOBuffer* buf,
1557                              int buf_len,
1558                              const CompletionCallback& callback) {
1559  if (!connected_)
1560    return ERR_UNEXPECTED;
1561
1562  // If the buffer is already in use, a read is already in progress!
1563  DCHECK(pending_buf_.get() == NULL);
1564
1565  // Store our async IO data.
1566  pending_buf_ = buf;
1567  pending_buf_len_ = buf_len;
1568  pending_callback_ = callback;
1569
1570  if (need_read_data_) {
1571    read_data_ = data_->GetNextRead();
1572    // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
1573    // to complete the async IO manually later (via OnReadComplete).
1574    if (read_data_.result == ERR_IO_PENDING) {
1575      // We need to be using async IO in this case.
1576      DCHECK(!callback.is_null());
1577      return ERR_IO_PENDING;
1578    }
1579    need_read_data_ = false;
1580  }
1581
1582  return CompleteRead();
1583}
1584
1585int MockUDPClientSocket::Write(IOBuffer* buf, int buf_len,
1586                               const CompletionCallback& callback) {
1587  DCHECK(buf);
1588  DCHECK_GT(buf_len, 0);
1589
1590  if (!connected_)
1591    return ERR_UNEXPECTED;
1592
1593  std::string data(buf->data(), buf_len);
1594  MockWriteResult write_result = data_->OnWrite(data);
1595
1596  if (write_result.mode == ASYNC) {
1597    RunCallbackAsync(callback, write_result.result);
1598    return ERR_IO_PENDING;
1599  }
1600  return write_result.result;
1601}
1602
1603int MockUDPClientSocket::SetReceiveBufferSize(int32 size) {
1604  return OK;
1605}
1606
1607int MockUDPClientSocket::SetSendBufferSize(int32 size) {
1608  return OK;
1609}
1610
1611void MockUDPClientSocket::Close() {
1612  connected_ = false;
1613}
1614
1615int MockUDPClientSocket::GetPeerAddress(IPEndPoint* address) const {
1616  *address = peer_addr_;
1617  return OK;
1618}
1619
1620int MockUDPClientSocket::GetLocalAddress(IPEndPoint* address) const {
1621  IPAddressNumber ip;
1622  bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip);
1623  CHECK(rv);
1624  *address = IPEndPoint(ip, source_port_);
1625  return OK;
1626}
1627
1628const BoundNetLog& MockUDPClientSocket::NetLog() const {
1629  return net_log_;
1630}
1631
1632int MockUDPClientSocket::Connect(const IPEndPoint& address) {
1633  connected_ = true;
1634  peer_addr_ = address;
1635  return data_->connect_data().result;
1636}
1637
1638void MockUDPClientSocket::OnReadComplete(const MockRead& data) {
1639  // There must be a read pending.
1640  DCHECK(pending_buf_.get());
1641  // You can't complete a read with another ERR_IO_PENDING status code.
1642  DCHECK_NE(ERR_IO_PENDING, data.result);
1643  // Since we've been waiting for data, need_read_data_ should be true.
1644  DCHECK(need_read_data_);
1645
1646  read_data_ = data;
1647  need_read_data_ = false;
1648
1649  // The caller is simulating that this IO completes right now.  Don't
1650  // let CompleteRead() schedule a callback.
1651  read_data_.mode = SYNCHRONOUS;
1652
1653  net::CompletionCallback callback = pending_callback_;
1654  int rv = CompleteRead();
1655  RunCallback(callback, rv);
1656}
1657
1658void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) {
1659  NOTIMPLEMENTED();
1660}
1661
1662int MockUDPClientSocket::CompleteRead() {
1663  DCHECK(pending_buf_.get());
1664  DCHECK(pending_buf_len_ > 0);
1665
1666  // Save the pending async IO data and reset our |pending_| state.
1667  scoped_refptr<IOBuffer> buf = pending_buf_;
1668  int buf_len = pending_buf_len_;
1669  CompletionCallback callback = pending_callback_;
1670  pending_buf_ = NULL;
1671  pending_buf_len_ = 0;
1672  pending_callback_.Reset();
1673
1674  int result = read_data_.result;
1675  DCHECK(result != ERR_IO_PENDING);
1676
1677  if (read_data_.data) {
1678    if (read_data_.data_len - read_offset_ > 0) {
1679      result = std::min(buf_len, read_data_.data_len - read_offset_);
1680      memcpy(buf->data(), read_data_.data + read_offset_, result);
1681      read_offset_ += result;
1682      if (read_offset_ == read_data_.data_len) {
1683        need_read_data_ = true;
1684        read_offset_ = 0;
1685      }
1686    } else {
1687      result = 0;  // EOF
1688    }
1689  }
1690
1691  if (read_data_.mode == ASYNC) {
1692    DCHECK(!callback.is_null());
1693    RunCallbackAsync(callback, result);
1694    return ERR_IO_PENDING;
1695  }
1696  return result;
1697}
1698
1699void MockUDPClientSocket::RunCallbackAsync(const CompletionCallback& callback,
1700                                           int result) {
1701  base::MessageLoop::current()->PostTask(
1702      FROM_HERE,
1703      base::Bind(&MockUDPClientSocket::RunCallback,
1704                 weak_factory_.GetWeakPtr(),
1705                 callback,
1706                 result));
1707}
1708
1709void MockUDPClientSocket::RunCallback(const CompletionCallback& callback,
1710                                      int result) {
1711  if (!callback.is_null())
1712    callback.Run(result);
1713}
1714
1715TestSocketRequest::TestSocketRequest(
1716    std::vector<TestSocketRequest*>* request_order, size_t* completion_count)
1717    : request_order_(request_order),
1718      completion_count_(completion_count),
1719      callback_(base::Bind(&TestSocketRequest::OnComplete,
1720                           base::Unretained(this))) {
1721  DCHECK(request_order);
1722  DCHECK(completion_count);
1723}
1724
1725TestSocketRequest::~TestSocketRequest() {
1726}
1727
1728void TestSocketRequest::OnComplete(int result) {
1729  SetResult(result);
1730  (*completion_count_)++;
1731  request_order_->push_back(this);
1732}
1733
1734// static
1735const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
1736
1737// static
1738const int ClientSocketPoolTest::kRequestNotFound = -2;
1739
1740ClientSocketPoolTest::ClientSocketPoolTest() : completion_count_(0) {}
1741ClientSocketPoolTest::~ClientSocketPoolTest() {}
1742
1743int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const {
1744  index--;
1745  if (index >= requests_.size())
1746    return kIndexOutOfBounds;
1747
1748  for (size_t i = 0; i < request_order_.size(); i++)
1749    if (requests_[index] == request_order_[i])
1750      return i + 1;
1751
1752  return kRequestNotFound;
1753}
1754
1755bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
1756  ScopedVector<TestSocketRequest>::iterator i;
1757  for (i = requests_.begin(); i != requests_.end(); ++i) {
1758    if ((*i)->handle()->is_initialized()) {
1759      if (keep_alive == NO_KEEP_ALIVE)
1760        (*i)->handle()->socket()->Disconnect();
1761      (*i)->handle()->Reset();
1762      base::RunLoop().RunUntilIdle();
1763      return true;
1764    }
1765  }
1766  return false;
1767}
1768
1769void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
1770  bool released_one;
1771  do {
1772    released_one = ReleaseOneConnection(keep_alive);
1773  } while (released_one);
1774}
1775
1776MockTransportClientSocketPool::MockConnectJob::MockConnectJob(
1777    scoped_ptr<StreamSocket> socket,
1778    ClientSocketHandle* handle,
1779    const CompletionCallback& callback)
1780    : socket_(socket.Pass()),
1781      handle_(handle),
1782      user_callback_(callback) {
1783}
1784
1785MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() {}
1786
1787int MockTransportClientSocketPool::MockConnectJob::Connect() {
1788  int rv = socket_->Connect(base::Bind(&MockConnectJob::OnConnect,
1789                                       base::Unretained(this)));
1790  if (rv == OK) {
1791    user_callback_.Reset();
1792    OnConnect(OK);
1793  }
1794  return rv;
1795}
1796
1797bool MockTransportClientSocketPool::MockConnectJob::CancelHandle(
1798    const ClientSocketHandle* handle) {
1799  if (handle != handle_)
1800    return false;
1801  socket_.reset();
1802  handle_ = NULL;
1803  user_callback_.Reset();
1804  return true;
1805}
1806
1807void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) {
1808  if (!socket_.get())
1809    return;
1810  if (rv == OK) {
1811    handle_->SetSocket(socket_.Pass());
1812
1813    // Needed for socket pool tests that layer other sockets on top of mock
1814    // sockets.
1815    LoadTimingInfo::ConnectTiming connect_timing;
1816    base::TimeTicks now = base::TimeTicks::Now();
1817    connect_timing.dns_start = now;
1818    connect_timing.dns_end = now;
1819    connect_timing.connect_start = now;
1820    connect_timing.connect_end = now;
1821    handle_->set_connect_timing(connect_timing);
1822  } else {
1823    socket_.reset();
1824  }
1825
1826  handle_ = NULL;
1827
1828  if (!user_callback_.is_null()) {
1829    CompletionCallback callback = user_callback_;
1830    user_callback_.Reset();
1831    callback.Run(rv);
1832  }
1833}
1834
1835MockTransportClientSocketPool::MockTransportClientSocketPool(
1836    int max_sockets,
1837    int max_sockets_per_group,
1838    ClientSocketPoolHistograms* histograms,
1839    ClientSocketFactory* socket_factory)
1840    : TransportClientSocketPool(max_sockets, max_sockets_per_group, histograms,
1841                                NULL, NULL, NULL),
1842      client_socket_factory_(socket_factory),
1843      last_request_priority_(DEFAULT_PRIORITY),
1844      release_count_(0),
1845      cancel_count_(0) {
1846}
1847
1848MockTransportClientSocketPool::~MockTransportClientSocketPool() {}
1849
1850int MockTransportClientSocketPool::RequestSocket(
1851    const std::string& group_name, const void* socket_params,
1852    RequestPriority priority, ClientSocketHandle* handle,
1853    const CompletionCallback& callback, const BoundNetLog& net_log) {
1854  last_request_priority_ = priority;
1855  scoped_ptr<StreamSocket> socket =
1856      client_socket_factory_->CreateTransportClientSocket(
1857          AddressList(), net_log.net_log(), net::NetLog::Source());
1858  MockConnectJob* job = new MockConnectJob(socket.Pass(), handle, callback);
1859  job_list_.push_back(job);
1860  handle->set_pool_id(1);
1861  return job->Connect();
1862}
1863
1864void MockTransportClientSocketPool::CancelRequest(const std::string& group_name,
1865                                                  ClientSocketHandle* handle) {
1866  std::vector<MockConnectJob*>::iterator i;
1867  for (i = job_list_.begin(); i != job_list_.end(); ++i) {
1868    if ((*i)->CancelHandle(handle)) {
1869      cancel_count_++;
1870      break;
1871    }
1872  }
1873}
1874
1875void MockTransportClientSocketPool::ReleaseSocket(
1876    const std::string& group_name,
1877    scoped_ptr<StreamSocket> socket,
1878    int id) {
1879  EXPECT_EQ(1, id);
1880  release_count_++;
1881}
1882
1883DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {}
1884
1885DeterministicMockClientSocketFactory::~DeterministicMockClientSocketFactory() {}
1886
1887void DeterministicMockClientSocketFactory::AddSocketDataProvider(
1888    DeterministicSocketData* data) {
1889  mock_data_.Add(data);
1890}
1891
1892void DeterministicMockClientSocketFactory::AddSSLSocketDataProvider(
1893    SSLSocketDataProvider* data) {
1894  mock_ssl_data_.Add(data);
1895}
1896
1897void DeterministicMockClientSocketFactory::ResetNextMockIndexes() {
1898  mock_data_.ResetNextIndex();
1899  mock_ssl_data_.ResetNextIndex();
1900}
1901
1902MockSSLClientSocket* DeterministicMockClientSocketFactory::
1903    GetMockSSLClientSocket(size_t index) const {
1904  DCHECK_LT(index, ssl_client_sockets_.size());
1905  return ssl_client_sockets_[index];
1906}
1907
1908scoped_ptr<DatagramClientSocket>
1909DeterministicMockClientSocketFactory::CreateDatagramClientSocket(
1910    DatagramSocket::BindType bind_type,
1911    const RandIntCallback& rand_int_cb,
1912    net::NetLog* net_log,
1913    const NetLog::Source& source) {
1914  DeterministicSocketData* data_provider = mock_data().GetNext();
1915  scoped_ptr<DeterministicMockUDPClientSocket> socket(
1916      new DeterministicMockUDPClientSocket(net_log, data_provider));
1917  data_provider->set_delegate(socket->AsWeakPtr());
1918  udp_client_sockets().push_back(socket.get());
1919  if (bind_type == DatagramSocket::RANDOM_BIND)
1920    socket->set_source_port(rand_int_cb.Run(1025, 65535));
1921  return socket.PassAs<DatagramClientSocket>();
1922}
1923
1924scoped_ptr<StreamSocket>
1925DeterministicMockClientSocketFactory::CreateTransportClientSocket(
1926    const AddressList& addresses,
1927    net::NetLog* net_log,
1928    const net::NetLog::Source& source) {
1929  DeterministicSocketData* data_provider = mock_data().GetNext();
1930  scoped_ptr<DeterministicMockTCPClientSocket> socket(
1931      new DeterministicMockTCPClientSocket(net_log, data_provider));
1932  data_provider->set_delegate(socket->AsWeakPtr());
1933  tcp_client_sockets().push_back(socket.get());
1934  return socket.PassAs<StreamSocket>();
1935}
1936
1937scoped_ptr<SSLClientSocket>
1938DeterministicMockClientSocketFactory::CreateSSLClientSocket(
1939    scoped_ptr<ClientSocketHandle> transport_socket,
1940    const HostPortPair& host_and_port,
1941    const SSLConfig& ssl_config,
1942    const SSLClientSocketContext& context) {
1943  scoped_ptr<MockSSLClientSocket> socket(
1944      new MockSSLClientSocket(transport_socket.Pass(),
1945                              host_and_port, ssl_config,
1946                              mock_ssl_data_.GetNext()));
1947  ssl_client_sockets_.push_back(socket.get());
1948  return socket.PassAs<SSLClientSocket>();
1949}
1950
1951void DeterministicMockClientSocketFactory::ClearSSLSessionCache() {
1952}
1953
1954MockSOCKSClientSocketPool::MockSOCKSClientSocketPool(
1955    int max_sockets,
1956    int max_sockets_per_group,
1957    ClientSocketPoolHistograms* histograms,
1958    TransportClientSocketPool* transport_pool)
1959    : SOCKSClientSocketPool(max_sockets, max_sockets_per_group, histograms,
1960                            NULL, transport_pool, NULL),
1961      transport_pool_(transport_pool) {
1962}
1963
1964MockSOCKSClientSocketPool::~MockSOCKSClientSocketPool() {}
1965
1966int MockSOCKSClientSocketPool::RequestSocket(
1967    const std::string& group_name, const void* socket_params,
1968    RequestPriority priority, ClientSocketHandle* handle,
1969    const CompletionCallback& callback, const BoundNetLog& net_log) {
1970  return transport_pool_->RequestSocket(
1971      group_name, socket_params, priority, handle, callback, net_log);
1972}
1973
1974void MockSOCKSClientSocketPool::CancelRequest(
1975    const std::string& group_name,
1976    ClientSocketHandle* handle) {
1977  return transport_pool_->CancelRequest(group_name, handle);
1978}
1979
1980void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name,
1981                                              scoped_ptr<StreamSocket> socket,
1982                                              int id) {
1983  return transport_pool_->ReleaseSocket(group_name, socket.Pass(), id);
1984}
1985
1986const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 };
1987const int kSOCKS5GreetRequestLength = arraysize(kSOCKS5GreetRequest);
1988
1989const char kSOCKS5GreetResponse[] = { 0x05, 0x00 };
1990const int kSOCKS5GreetResponseLength = arraysize(kSOCKS5GreetResponse);
1991
1992const char kSOCKS5OkRequest[] =
1993    { 0x05, 0x01, 0x00, 0x03, 0x04, 'h', 'o', 's', 't', 0x00, 0x50 };
1994const int kSOCKS5OkRequestLength = arraysize(kSOCKS5OkRequest);
1995
1996const char kSOCKS5OkResponse[] =
1997    { 0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0x00, 0x50 };
1998const int kSOCKS5OkResponseLength = arraysize(kSOCKS5OkResponse);
1999
2000}  // namespace net
2001