1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/socket_stream/socket_stream.h"
6
7#include <string>
8#include <vector>
9
10#include "base/bind.h"
11#include "base/bind_helpers.h"
12#include "base/callback.h"
13#include "base/strings/utf_string_conversions.h"
14#include "net/base/auth.h"
15#include "net/base/net_log.h"
16#include "net/base/net_log_unittest.h"
17#include "net/base/test_completion_callback.h"
18#include "net/dns/mock_host_resolver.h"
19#include "net/http/http_network_session.h"
20#include "net/proxy/proxy_service.h"
21#include "net/socket/socket_test_util.h"
22#include "net/url_request/url_request_test_util.h"
23#include "testing/gtest/include/gtest/gtest.h"
24#include "testing/platform_test.h"
25
26namespace net {
27
28namespace {
29
30struct SocketStreamEvent {
31  enum EventType {
32    EVENT_START_OPEN_CONNECTION, EVENT_CONNECTED, EVENT_SENT_DATA,
33    EVENT_RECEIVED_DATA, EVENT_CLOSE, EVENT_AUTH_REQUIRED, EVENT_ERROR,
34  };
35
36  SocketStreamEvent(EventType type,
37                    SocketStream* socket_stream,
38                    int num,
39                    const std::string& str,
40                    AuthChallengeInfo* auth_challenge_info,
41                    int error)
42      : event_type(type), socket(socket_stream), number(num), data(str),
43        auth_info(auth_challenge_info), error_code(error) {}
44
45  EventType event_type;
46  SocketStream* socket;
47  int number;
48  std::string data;
49  scoped_refptr<AuthChallengeInfo> auth_info;
50  int error_code;
51};
52
53class SocketStreamEventRecorder : public SocketStream::Delegate {
54 public:
55  // |callback| will be run when the OnClose() or OnError() method is called.
56  // For OnClose(), |callback| is called with OK. For OnError(), it's called
57  // with the error code.
58  explicit SocketStreamEventRecorder(const CompletionCallback& callback)
59      : callback_(callback) {}
60  virtual ~SocketStreamEventRecorder() {}
61
62  void SetOnStartOpenConnection(
63      const base::Callback<int(SocketStreamEvent*)>& callback) {
64    on_start_open_connection_ = callback;
65  }
66  void SetOnConnected(
67      const base::Callback<void(SocketStreamEvent*)>& callback) {
68    on_connected_ = callback;
69  }
70  void SetOnSentData(
71      const base::Callback<void(SocketStreamEvent*)>& callback) {
72    on_sent_data_ = callback;
73  }
74  void SetOnReceivedData(
75      const base::Callback<void(SocketStreamEvent*)>& callback) {
76    on_received_data_ = callback;
77  }
78  void SetOnClose(const base::Callback<void(SocketStreamEvent*)>& callback) {
79    on_close_ = callback;
80  }
81  void SetOnAuthRequired(
82      const base::Callback<void(SocketStreamEvent*)>& callback) {
83    on_auth_required_ = callback;
84  }
85  void SetOnError(const base::Callback<void(SocketStreamEvent*)>& callback) {
86    on_error_ = callback;
87  }
88
89  virtual int OnStartOpenConnection(
90      SocketStream* socket,
91      const CompletionCallback& callback) OVERRIDE {
92    connection_callback_ = callback;
93    events_.push_back(
94        SocketStreamEvent(SocketStreamEvent::EVENT_START_OPEN_CONNECTION,
95                          socket, 0, std::string(), NULL, OK));
96    if (!on_start_open_connection_.is_null())
97      return on_start_open_connection_.Run(&events_.back());
98    return OK;
99  }
100  virtual void OnConnected(SocketStream* socket,
101                           int num_pending_send_allowed) OVERRIDE {
102    events_.push_back(
103        SocketStreamEvent(SocketStreamEvent::EVENT_CONNECTED,
104                          socket, num_pending_send_allowed, std::string(),
105                          NULL, OK));
106    if (!on_connected_.is_null())
107      on_connected_.Run(&events_.back());
108  }
109  virtual void OnSentData(SocketStream* socket,
110                          int amount_sent) OVERRIDE {
111    events_.push_back(
112        SocketStreamEvent(SocketStreamEvent::EVENT_SENT_DATA, socket,
113                          amount_sent, std::string(), NULL, OK));
114    if (!on_sent_data_.is_null())
115      on_sent_data_.Run(&events_.back());
116  }
117  virtual void OnReceivedData(SocketStream* socket,
118                              const char* data, int len) OVERRIDE {
119    events_.push_back(
120        SocketStreamEvent(SocketStreamEvent::EVENT_RECEIVED_DATA, socket, len,
121                          std::string(data, len), NULL, OK));
122    if (!on_received_data_.is_null())
123      on_received_data_.Run(&events_.back());
124  }
125  virtual void OnClose(SocketStream* socket) OVERRIDE {
126    events_.push_back(
127        SocketStreamEvent(SocketStreamEvent::EVENT_CLOSE, socket, 0,
128                          std::string(), NULL, OK));
129    if (!on_close_.is_null())
130      on_close_.Run(&events_.back());
131    if (!callback_.is_null())
132      callback_.Run(OK);
133  }
134  virtual void OnAuthRequired(SocketStream* socket,
135                              AuthChallengeInfo* auth_info) OVERRIDE {
136    events_.push_back(
137        SocketStreamEvent(SocketStreamEvent::EVENT_AUTH_REQUIRED, socket, 0,
138                          std::string(), auth_info, OK));
139    if (!on_auth_required_.is_null())
140      on_auth_required_.Run(&events_.back());
141  }
142  virtual void OnError(const SocketStream* socket, int error) OVERRIDE {
143    events_.push_back(
144        SocketStreamEvent(SocketStreamEvent::EVENT_ERROR, NULL, 0,
145                          std::string(), NULL, error));
146    if (!on_error_.is_null())
147      on_error_.Run(&events_.back());
148    if (!callback_.is_null())
149      callback_.Run(error);
150  }
151
152  void DoClose(SocketStreamEvent* event) {
153    event->socket->Close();
154  }
155  void DoRestartWithAuth(SocketStreamEvent* event) {
156    VLOG(1) << "RestartWithAuth username=" << credentials_.username()
157            << " password=" << credentials_.password();
158    event->socket->RestartWithAuth(credentials_);
159  }
160  void SetAuthInfo(const AuthCredentials& credentials) {
161    credentials_ = credentials;
162  }
163  // Wakes up the SocketStream waiting for completion of OnStartOpenConnection()
164  // of its delegate.
165  void CompleteConnection(int result) {
166    connection_callback_.Run(result);
167  }
168
169  const std::vector<SocketStreamEvent>& GetSeenEvents() const {
170    return events_;
171  }
172
173 private:
174  std::vector<SocketStreamEvent> events_;
175  base::Callback<int(SocketStreamEvent*)> on_start_open_connection_;
176  base::Callback<void(SocketStreamEvent*)> on_connected_;
177  base::Callback<void(SocketStreamEvent*)> on_sent_data_;
178  base::Callback<void(SocketStreamEvent*)> on_received_data_;
179  base::Callback<void(SocketStreamEvent*)> on_close_;
180  base::Callback<void(SocketStreamEvent*)> on_auth_required_;
181  base::Callback<void(SocketStreamEvent*)> on_error_;
182  const CompletionCallback callback_;
183  CompletionCallback connection_callback_;
184  AuthCredentials credentials_;
185
186  DISALLOW_COPY_AND_ASSIGN(SocketStreamEventRecorder);
187};
188
189// This is used for the test OnErrorDetachDelegate.
190class SelfDeletingDelegate : public SocketStream::Delegate {
191 public:
192  // |callback| must cause the test message loop to exit when called.
193  explicit SelfDeletingDelegate(const CompletionCallback& callback)
194      : socket_stream_(), callback_(callback) {}
195
196  virtual ~SelfDeletingDelegate() {}
197
198  // Call DetachDelegate(), delete |this|, then run the callback.
199  virtual void OnError(const SocketStream* socket, int error) OVERRIDE {
200    // callback_ will be deleted when we delete |this|, so copy it to call it
201    // afterwards.
202    CompletionCallback callback = callback_;
203    socket_stream_->DetachDelegate();
204    delete this;
205    callback.Run(OK);
206  }
207
208  // This can't be passed in the constructor because this object needs to be
209  // created before SocketStream.
210  void set_socket_stream(const scoped_refptr<SocketStream>& socket_stream) {
211    socket_stream_ = socket_stream;
212    EXPECT_EQ(socket_stream_->delegate(), this);
213  }
214
215  virtual void OnConnected(SocketStream* socket, int max_pending_send_allowed)
216      OVERRIDE {
217    ADD_FAILURE() << "OnConnected() should not be called";
218  }
219  virtual void OnSentData(SocketStream* socket, int amount_sent) OVERRIDE {
220    ADD_FAILURE() << "OnSentData() should not be called";
221  }
222  virtual void OnReceivedData(SocketStream* socket, const char* data, int len)
223      OVERRIDE {
224    ADD_FAILURE() << "OnReceivedData() should not be called";
225  }
226  virtual void OnClose(SocketStream* socket) OVERRIDE {
227    ADD_FAILURE() << "OnClose() should not be called";
228  }
229
230 private:
231  scoped_refptr<SocketStream> socket_stream_;
232  const CompletionCallback callback_;
233
234  DISALLOW_COPY_AND_ASSIGN(SelfDeletingDelegate);
235};
236
237class TestURLRequestContextWithProxy : public TestURLRequestContext {
238 public:
239  explicit TestURLRequestContextWithProxy(const std::string& proxy)
240      : TestURLRequestContext(true) {
241    context_storage_.set_proxy_service(ProxyService::CreateFixed(proxy));
242    Init();
243  }
244  virtual ~TestURLRequestContextWithProxy() {}
245};
246
247class TestSocketStreamNetworkDelegate : public TestNetworkDelegate {
248 public:
249  TestSocketStreamNetworkDelegate()
250      : before_connect_result_(OK) {}
251  virtual ~TestSocketStreamNetworkDelegate() {}
252
253  virtual int OnBeforeSocketStreamConnect(
254      SocketStream* stream,
255      const CompletionCallback& callback) OVERRIDE {
256    return before_connect_result_;
257  }
258
259  void SetBeforeConnectResult(int result) {
260    before_connect_result_ = result;
261  }
262
263 private:
264  int before_connect_result_;
265};
266
267}  // namespace
268
269class SocketStreamTest : public PlatformTest {
270 public:
271  virtual ~SocketStreamTest() {}
272  virtual void SetUp() {
273    mock_socket_factory_.reset();
274    handshake_request_ = kWebSocketHandshakeRequest;
275    handshake_response_ = kWebSocketHandshakeResponse;
276  }
277  virtual void TearDown() {
278    mock_socket_factory_.reset();
279  }
280
281  virtual void SetWebSocketHandshakeMessage(
282      const char* request, const char* response) {
283    handshake_request_ = request;
284    handshake_response_ = response;
285  }
286  virtual void AddWebSocketMessage(const std::string& message) {
287    messages_.push_back(message);
288  }
289
290  virtual MockClientSocketFactory* GetMockClientSocketFactory() {
291    mock_socket_factory_.reset(new MockClientSocketFactory);
292    return mock_socket_factory_.get();
293  }
294
295  // Functions for SocketStreamEventRecorder to handle calls to the
296  // SocketStream::Delegate methods from the SocketStream.
297
298  virtual void DoSendWebSocketHandshake(SocketStreamEvent* event) {
299    event->socket->SendData(
300        handshake_request_.data(), handshake_request_.size());
301  }
302
303  virtual void DoCloseFlushPendingWriteTest(SocketStreamEvent* event) {
304    // handshake response received.
305    for (size_t i = 0; i < messages_.size(); i++) {
306      std::vector<char> frame;
307      frame.push_back('\0');
308      frame.insert(frame.end(), messages_[i].begin(), messages_[i].end());
309      frame.push_back('\xff');
310      EXPECT_TRUE(event->socket->SendData(&frame[0], frame.size()));
311    }
312    // Actual StreamSocket close must happen after all frames queued by
313    // SendData above are sent out.
314    event->socket->Close();
315  }
316
317  virtual void DoCloseFlushPendingWriteTestWithSetContextNull(
318      SocketStreamEvent* event) {
319    event->socket->set_context(NULL);
320    // handshake response received.
321    for (size_t i = 0; i < messages_.size(); i++) {
322      std::vector<char> frame;
323      frame.push_back('\0');
324      frame.insert(frame.end(), messages_[i].begin(), messages_[i].end());
325      frame.push_back('\xff');
326      EXPECT_TRUE(event->socket->SendData(&frame[0], frame.size()));
327    }
328    // Actual StreamSocket close must happen after all frames queued by
329    // SendData above are sent out.
330    event->socket->Close();
331  }
332
333  virtual void DoFailByTooBigDataAndClose(SocketStreamEvent* event) {
334    std::string frame(event->number + 1, 0x00);
335    VLOG(1) << event->number;
336    EXPECT_FALSE(event->socket->SendData(&frame[0], frame.size()));
337    event->socket->Close();
338  }
339
340  virtual int DoSwitchToSpdyTest(SocketStreamEvent* event) {
341    return ERR_PROTOCOL_SWITCHED;
342  }
343
344  // Notifies |io_test_callback_| of that this method is called, and keeps the
345  // SocketStream waiting.
346  virtual int DoIOPending(SocketStreamEvent* event) {
347    io_test_callback_.callback().Run(OK);
348    return ERR_IO_PENDING;
349  }
350
351  static const char kWebSocketHandshakeRequest[];
352  static const char kWebSocketHandshakeResponse[];
353
354 protected:
355  TestCompletionCallback io_test_callback_;
356
357 private:
358  std::string handshake_request_;
359  std::string handshake_response_;
360  std::vector<std::string> messages_;
361
362  scoped_ptr<MockClientSocketFactory> mock_socket_factory_;
363};
364
365const char SocketStreamTest::kWebSocketHandshakeRequest[] =
366    "GET /demo HTTP/1.1\r\n"
367    "Host: example.com\r\n"
368    "Connection: Upgrade\r\n"
369    "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
370    "Sec-WebSocket-Protocol: sample\r\n"
371    "Upgrade: WebSocket\r\n"
372    "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
373    "Origin: http://example.com\r\n"
374    "\r\n"
375    "^n:ds[4U";
376
377const char SocketStreamTest::kWebSocketHandshakeResponse[] =
378    "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
379    "Upgrade: WebSocket\r\n"
380    "Connection: Upgrade\r\n"
381    "Sec-WebSocket-Origin: http://example.com\r\n"
382    "Sec-WebSocket-Location: ws://example.com/demo\r\n"
383    "Sec-WebSocket-Protocol: sample\r\n"
384    "\r\n"
385    "8jKS'y:G*Co,Wxa-";
386
387TEST_F(SocketStreamTest, CloseFlushPendingWrite) {
388  TestCompletionCallback test_callback;
389
390  scoped_ptr<SocketStreamEventRecorder> delegate(
391      new SocketStreamEventRecorder(test_callback.callback()));
392  delegate->SetOnConnected(base::Bind(
393      &SocketStreamTest::DoSendWebSocketHandshake, base::Unretained(this)));
394  delegate->SetOnReceivedData(base::Bind(
395      &SocketStreamTest::DoCloseFlushPendingWriteTest,
396      base::Unretained(this)));
397
398  TestURLRequestContext context;
399
400  scoped_refptr<SocketStream> socket_stream(
401      new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
402
403  socket_stream->set_context(&context);
404
405  MockWrite data_writes[] = {
406    MockWrite(SocketStreamTest::kWebSocketHandshakeRequest),
407    MockWrite(ASYNC, "\0message1\xff", 10),
408    MockWrite(ASYNC, "\0message2\xff", 10)
409  };
410  MockRead data_reads[] = {
411    MockRead(SocketStreamTest::kWebSocketHandshakeResponse),
412    // Server doesn't close the connection after handshake.
413    MockRead(ASYNC, ERR_IO_PENDING)
414  };
415  AddWebSocketMessage("message1");
416  AddWebSocketMessage("message2");
417
418  DelayedSocketData data_provider(
419      1, data_reads, arraysize(data_reads),
420      data_writes, arraysize(data_writes));
421
422  MockClientSocketFactory* mock_socket_factory =
423      GetMockClientSocketFactory();
424  mock_socket_factory->AddSocketDataProvider(&data_provider);
425
426  socket_stream->SetClientSocketFactory(mock_socket_factory);
427
428  socket_stream->Connect();
429
430  test_callback.WaitForResult();
431
432  EXPECT_TRUE(data_provider.at_read_eof());
433  EXPECT_TRUE(data_provider.at_write_eof());
434
435  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
436  ASSERT_EQ(7U, events.size());
437
438  EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION,
439            events[0].event_type);
440  EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type);
441  EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[2].event_type);
442  EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[3].event_type);
443  EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[4].event_type);
444  EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[5].event_type);
445  EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[6].event_type);
446}
447
448TEST_F(SocketStreamTest, ResolveFailure) {
449  TestCompletionCallback test_callback;
450
451  scoped_ptr<SocketStreamEventRecorder> delegate(
452      new SocketStreamEventRecorder(test_callback.callback()));
453
454  scoped_refptr<SocketStream> socket_stream(
455      new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
456
457  // Make resolver fail.
458  TestURLRequestContext context;
459  scoped_ptr<MockHostResolver> mock_host_resolver(
460      new MockHostResolver());
461  mock_host_resolver->rules()->AddSimulatedFailure("example.com");
462  context.set_host_resolver(mock_host_resolver.get());
463  socket_stream->set_context(&context);
464
465  // No read/write on socket is expected.
466  StaticSocketDataProvider data_provider(NULL, 0, NULL, 0);
467  MockClientSocketFactory* mock_socket_factory =
468      GetMockClientSocketFactory();
469  mock_socket_factory->AddSocketDataProvider(&data_provider);
470  socket_stream->SetClientSocketFactory(mock_socket_factory);
471
472  socket_stream->Connect();
473
474  test_callback.WaitForResult();
475
476  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
477  ASSERT_EQ(2U, events.size());
478
479  EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[0].event_type);
480  EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[1].event_type);
481}
482
483TEST_F(SocketStreamTest, ExceedMaxPendingSendAllowed) {
484  TestCompletionCallback test_callback;
485
486  scoped_ptr<SocketStreamEventRecorder> delegate(
487      new SocketStreamEventRecorder(test_callback.callback()));
488  delegate->SetOnConnected(base::Bind(
489      &SocketStreamTest::DoFailByTooBigDataAndClose, base::Unretained(this)));
490
491  TestURLRequestContext context;
492
493  scoped_refptr<SocketStream> socket_stream(
494      new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
495
496  socket_stream->set_context(&context);
497
498  DelayedSocketData data_provider(1, NULL, 0, NULL, 0);
499
500  MockClientSocketFactory* mock_socket_factory =
501      GetMockClientSocketFactory();
502  mock_socket_factory->AddSocketDataProvider(&data_provider);
503
504  socket_stream->SetClientSocketFactory(mock_socket_factory);
505
506  socket_stream->Connect();
507
508  test_callback.WaitForResult();
509
510  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
511  ASSERT_EQ(4U, events.size());
512
513  EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION,
514            events[0].event_type);
515  EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type);
516  EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[2].event_type);
517  EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[3].event_type);
518}
519
520TEST_F(SocketStreamTest, BasicAuthProxy) {
521  MockClientSocketFactory mock_socket_factory;
522  MockWrite data_writes1[] = {
523    MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
524              "Host: example.com\r\n"
525              "Proxy-Connection: keep-alive\r\n\r\n"),
526  };
527  MockRead data_reads1[] = {
528    MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"),
529    MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"),
530    MockRead("\r\n"),
531  };
532  StaticSocketDataProvider data1(data_reads1, arraysize(data_reads1),
533                                 data_writes1, arraysize(data_writes1));
534  mock_socket_factory.AddSocketDataProvider(&data1);
535
536  MockWrite data_writes2[] = {
537    MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
538              "Host: example.com\r\n"
539              "Proxy-Connection: keep-alive\r\n"
540              "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"),
541  };
542  MockRead data_reads2[] = {
543    MockRead("HTTP/1.1 200 Connection Established\r\n"),
544    MockRead("Proxy-agent: Apache/2.2.8\r\n"),
545    MockRead("\r\n"),
546    // SocketStream::DoClose is run asynchronously.  Socket can be read after
547    // "\r\n".  We have to give ERR_IO_PENDING to SocketStream then to indicate
548    // server doesn't close the connection.
549    MockRead(ASYNC, ERR_IO_PENDING)
550  };
551  StaticSocketDataProvider data2(data_reads2, arraysize(data_reads2),
552                                 data_writes2, arraysize(data_writes2));
553  mock_socket_factory.AddSocketDataProvider(&data2);
554
555  TestCompletionCallback test_callback;
556
557  scoped_ptr<SocketStreamEventRecorder> delegate(
558      new SocketStreamEventRecorder(test_callback.callback()));
559  delegate->SetOnConnected(base::Bind(&SocketStreamEventRecorder::DoClose,
560                                      base::Unretained(delegate.get())));
561  delegate->SetAuthInfo(AuthCredentials(ASCIIToUTF16("foo"),
562                                        ASCIIToUTF16("bar")));
563  delegate->SetOnAuthRequired(base::Bind(
564      &SocketStreamEventRecorder::DoRestartWithAuth,
565      base::Unretained(delegate.get())));
566
567  scoped_refptr<SocketStream> socket_stream(
568      new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
569
570  TestURLRequestContextWithProxy context("myproxy:70");
571
572  socket_stream->set_context(&context);
573  socket_stream->SetClientSocketFactory(&mock_socket_factory);
574
575  socket_stream->Connect();
576
577  test_callback.WaitForResult();
578
579  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
580  ASSERT_EQ(5U, events.size());
581
582  EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION,
583            events[0].event_type);
584  EXPECT_EQ(SocketStreamEvent::EVENT_AUTH_REQUIRED, events[1].event_type);
585  EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[2].event_type);
586  EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[3].event_type);
587  EXPECT_EQ(ERR_ABORTED, events[3].error_code);
588  EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[4].event_type);
589
590  // TODO(eroman): Add back NetLogTest here...
591}
592
593TEST_F(SocketStreamTest, BasicAuthProxyWithAuthCache) {
594  MockClientSocketFactory mock_socket_factory;
595  MockWrite data_writes[] = {
596    // WebSocket(SocketStream) always uses CONNECT when it is configured to use
597    // proxy so the port may not be 443.
598    MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
599              "Host: example.com\r\n"
600              "Proxy-Connection: keep-alive\r\n"
601              "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"),
602  };
603  MockRead data_reads[] = {
604    MockRead("HTTP/1.1 200 Connection Established\r\n"),
605    MockRead("Proxy-agent: Apache/2.2.8\r\n"),
606    MockRead("\r\n"),
607    MockRead(ASYNC, ERR_IO_PENDING)
608  };
609  StaticSocketDataProvider data(data_reads, arraysize(data_reads),
610                                 data_writes, arraysize(data_writes));
611  mock_socket_factory.AddSocketDataProvider(&data);
612
613  TestCompletionCallback test_callback;
614  scoped_ptr<SocketStreamEventRecorder> delegate(
615      new SocketStreamEventRecorder(test_callback.callback()));
616  delegate->SetOnConnected(base::Bind(&SocketStreamEventRecorder::DoClose,
617                                      base::Unretained(delegate.get())));
618
619  scoped_refptr<SocketStream> socket_stream(
620      new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
621
622  TestURLRequestContextWithProxy context("myproxy:70");
623  HttpAuthCache* auth_cache =
624      context.http_transaction_factory()->GetSession()->http_auth_cache();
625  auth_cache->Add(GURL("http://myproxy:70"),
626                  "MyRealm1",
627                  HttpAuth::AUTH_SCHEME_BASIC,
628                  "Basic realm=MyRealm1",
629                  AuthCredentials(ASCIIToUTF16("foo"),
630                                  ASCIIToUTF16("bar")),
631                  "/");
632
633  socket_stream->set_context(&context);
634  socket_stream->SetClientSocketFactory(&mock_socket_factory);
635
636  socket_stream->Connect();
637
638  test_callback.WaitForResult();
639
640  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
641  ASSERT_EQ(4U, events.size());
642  EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION,
643            events[0].event_type);
644  EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type);
645  EXPECT_EQ(ERR_ABORTED, events[2].error_code);
646  EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[3].event_type);
647}
648
649TEST_F(SocketStreamTest, WSSBasicAuthProxyWithAuthCache) {
650  MockClientSocketFactory mock_socket_factory;
651  MockWrite data_writes1[] = {
652    MockWrite("CONNECT example.com:443 HTTP/1.1\r\n"
653              "Host: example.com\r\n"
654              "Proxy-Connection: keep-alive\r\n"
655              "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"),
656  };
657  MockRead data_reads1[] = {
658    MockRead("HTTP/1.1 200 Connection Established\r\n"),
659    MockRead("Proxy-agent: Apache/2.2.8\r\n"),
660    MockRead("\r\n"),
661    MockRead(ASYNC, ERR_IO_PENDING)
662  };
663  StaticSocketDataProvider data1(data_reads1, arraysize(data_reads1),
664                                 data_writes1, arraysize(data_writes1));
665  mock_socket_factory.AddSocketDataProvider(&data1);
666
667  SSLSocketDataProvider data2(ASYNC, OK);
668  mock_socket_factory.AddSSLSocketDataProvider(&data2);
669
670  TestCompletionCallback test_callback;
671  scoped_ptr<SocketStreamEventRecorder> delegate(
672      new SocketStreamEventRecorder(test_callback.callback()));
673  delegate->SetOnConnected(base::Bind(&SocketStreamEventRecorder::DoClose,
674                                      base::Unretained(delegate.get())));
675
676  scoped_refptr<SocketStream> socket_stream(
677      new SocketStream(GURL("wss://example.com/demo"), delegate.get()));
678
679  TestURLRequestContextWithProxy context("myproxy:70");
680  HttpAuthCache* auth_cache =
681      context.http_transaction_factory()->GetSession()->http_auth_cache();
682  auth_cache->Add(GURL("http://myproxy:70"),
683                  "MyRealm1",
684                  HttpAuth::AUTH_SCHEME_BASIC,
685                  "Basic realm=MyRealm1",
686                  AuthCredentials(ASCIIToUTF16("foo"),
687                                  ASCIIToUTF16("bar")),
688                  "/");
689
690  socket_stream->set_context(&context);
691  socket_stream->SetClientSocketFactory(&mock_socket_factory);
692
693  socket_stream->Connect();
694
695  test_callback.WaitForResult();
696
697  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
698  ASSERT_EQ(4U, events.size());
699  EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION,
700            events[0].event_type);
701  EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type);
702  EXPECT_EQ(ERR_ABORTED, events[2].error_code);
703  EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[3].event_type);
704}
705
706TEST_F(SocketStreamTest, IOPending) {
707  TestCompletionCallback test_callback;
708
709  scoped_ptr<SocketStreamEventRecorder> delegate(
710      new SocketStreamEventRecorder(test_callback.callback()));
711  delegate->SetOnStartOpenConnection(base::Bind(
712      &SocketStreamTest::DoIOPending, base::Unretained(this)));
713  delegate->SetOnConnected(base::Bind(
714      &SocketStreamTest::DoSendWebSocketHandshake, base::Unretained(this)));
715  delegate->SetOnReceivedData(base::Bind(
716      &SocketStreamTest::DoCloseFlushPendingWriteTest,
717      base::Unretained(this)));
718
719  TestURLRequestContext context;
720
721  scoped_refptr<SocketStream> socket_stream(
722      new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
723
724  socket_stream->set_context(&context);
725
726  MockWrite data_writes[] = {
727    MockWrite(SocketStreamTest::kWebSocketHandshakeRequest),
728    MockWrite(ASYNC, "\0message1\xff", 10),
729    MockWrite(ASYNC, "\0message2\xff", 10)
730  };
731  MockRead data_reads[] = {
732    MockRead(SocketStreamTest::kWebSocketHandshakeResponse),
733    // Server doesn't close the connection after handshake.
734    MockRead(ASYNC, ERR_IO_PENDING)
735  };
736  AddWebSocketMessage("message1");
737  AddWebSocketMessage("message2");
738
739  DelayedSocketData data_provider(
740      1, data_reads, arraysize(data_reads),
741      data_writes, arraysize(data_writes));
742
743  MockClientSocketFactory* mock_socket_factory =
744      GetMockClientSocketFactory();
745  mock_socket_factory->AddSocketDataProvider(&data_provider);
746
747  socket_stream->SetClientSocketFactory(mock_socket_factory);
748
749  socket_stream->Connect();
750  io_test_callback_.WaitForResult();
751  EXPECT_EQ(SocketStream::STATE_RESOLVE_PROTOCOL_COMPLETE,
752            socket_stream->next_state_);
753  delegate->CompleteConnection(OK);
754
755  EXPECT_EQ(OK, test_callback.WaitForResult());
756
757  EXPECT_TRUE(data_provider.at_read_eof());
758  EXPECT_TRUE(data_provider.at_write_eof());
759
760  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
761  ASSERT_EQ(7U, events.size());
762
763  EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION,
764            events[0].event_type);
765  EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type);
766  EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[2].event_type);
767  EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[3].event_type);
768  EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[4].event_type);
769  EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[5].event_type);
770  EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[6].event_type);
771}
772
773TEST_F(SocketStreamTest, SwitchToSpdy) {
774  TestCompletionCallback test_callback;
775
776  scoped_ptr<SocketStreamEventRecorder> delegate(
777      new SocketStreamEventRecorder(test_callback.callback()));
778  delegate->SetOnStartOpenConnection(base::Bind(
779      &SocketStreamTest::DoSwitchToSpdyTest, base::Unretained(this)));
780
781  TestURLRequestContext context;
782
783  scoped_refptr<SocketStream> socket_stream(
784      new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
785
786  socket_stream->set_context(&context);
787
788  socket_stream->Connect();
789
790  EXPECT_EQ(ERR_PROTOCOL_SWITCHED, test_callback.WaitForResult());
791
792  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
793  ASSERT_EQ(2U, events.size());
794
795  EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION,
796            events[0].event_type);
797  EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[1].event_type);
798  EXPECT_EQ(ERR_PROTOCOL_SWITCHED, events[1].error_code);
799}
800
801TEST_F(SocketStreamTest, SwitchAfterPending) {
802  TestCompletionCallback test_callback;
803
804  scoped_ptr<SocketStreamEventRecorder> delegate(
805      new SocketStreamEventRecorder(test_callback.callback()));
806  delegate->SetOnStartOpenConnection(base::Bind(
807      &SocketStreamTest::DoIOPending, base::Unretained(this)));
808
809  TestURLRequestContext context;
810
811  scoped_refptr<SocketStream> socket_stream(
812      new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
813
814  socket_stream->set_context(&context);
815
816  socket_stream->Connect();
817  io_test_callback_.WaitForResult();
818
819  EXPECT_EQ(SocketStream::STATE_RESOLVE_PROTOCOL_COMPLETE,
820            socket_stream->next_state_);
821  delegate->CompleteConnection(ERR_PROTOCOL_SWITCHED);
822
823  EXPECT_EQ(ERR_PROTOCOL_SWITCHED, test_callback.WaitForResult());
824
825  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
826  ASSERT_EQ(2U, events.size());
827
828  EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION,
829            events[0].event_type);
830  EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[1].event_type);
831  EXPECT_EQ(ERR_PROTOCOL_SWITCHED, events[1].error_code);
832}
833
834// Test a connection though a secure proxy.
835TEST_F(SocketStreamTest, SecureProxyConnectError) {
836  MockClientSocketFactory mock_socket_factory;
837  MockWrite data_writes[] = {
838    MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
839              "Host: example.com\r\n"
840              "Proxy-Connection: keep-alive\r\n\r\n")
841  };
842  MockRead data_reads[] = {
843    MockRead("HTTP/1.1 200 Connection Established\r\n"),
844    MockRead("Proxy-agent: Apache/2.2.8\r\n"),
845    MockRead("\r\n"),
846    // SocketStream::DoClose is run asynchronously.  Socket can be read after
847    // "\r\n".  We have to give ERR_IO_PENDING to SocketStream then to indicate
848    // server doesn't close the connection.
849    MockRead(ASYNC, ERR_IO_PENDING)
850  };
851  StaticSocketDataProvider data(data_reads, arraysize(data_reads),
852                                data_writes, arraysize(data_writes));
853  mock_socket_factory.AddSocketDataProvider(&data);
854  SSLSocketDataProvider ssl(SYNCHRONOUS, ERR_SSL_PROTOCOL_ERROR);
855  mock_socket_factory.AddSSLSocketDataProvider(&ssl);
856
857  TestCompletionCallback test_callback;
858  TestURLRequestContextWithProxy context("https://myproxy:70");
859
860  scoped_ptr<SocketStreamEventRecorder> delegate(
861      new SocketStreamEventRecorder(test_callback.callback()));
862  delegate->SetOnConnected(base::Bind(&SocketStreamEventRecorder::DoClose,
863                                      base::Unretained(delegate.get())));
864
865  scoped_refptr<SocketStream> socket_stream(
866      new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
867
868  socket_stream->set_context(&context);
869  socket_stream->SetClientSocketFactory(&mock_socket_factory);
870
871  socket_stream->Connect();
872
873  test_callback.WaitForResult();
874
875  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
876  ASSERT_EQ(3U, events.size());
877
878  EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION,
879            events[0].event_type);
880  EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[1].event_type);
881  EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, events[1].error_code);
882  EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[2].event_type);
883}
884
885// Test a connection though a secure proxy.
886TEST_F(SocketStreamTest, SecureProxyConnect) {
887  MockClientSocketFactory mock_socket_factory;
888  MockWrite data_writes[] = {
889    MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
890              "Host: example.com\r\n"
891              "Proxy-Connection: keep-alive\r\n\r\n")
892  };
893  MockRead data_reads[] = {
894    MockRead("HTTP/1.1 200 Connection Established\r\n"),
895    MockRead("Proxy-agent: Apache/2.2.8\r\n"),
896    MockRead("\r\n"),
897    // SocketStream::DoClose is run asynchronously.  Socket can be read after
898    // "\r\n".  We have to give ERR_IO_PENDING to SocketStream then to indicate
899    // server doesn't close the connection.
900    MockRead(ASYNC, ERR_IO_PENDING)
901  };
902  StaticSocketDataProvider data(data_reads, arraysize(data_reads),
903                                data_writes, arraysize(data_writes));
904  mock_socket_factory.AddSocketDataProvider(&data);
905  SSLSocketDataProvider ssl(SYNCHRONOUS, OK);
906  mock_socket_factory.AddSSLSocketDataProvider(&ssl);
907
908  TestCompletionCallback test_callback;
909  TestURLRequestContextWithProxy context("https://myproxy:70");
910
911  scoped_ptr<SocketStreamEventRecorder> delegate(
912      new SocketStreamEventRecorder(test_callback.callback()));
913  delegate->SetOnConnected(base::Bind(&SocketStreamEventRecorder::DoClose,
914                                      base::Unretained(delegate.get())));
915
916  scoped_refptr<SocketStream> socket_stream(
917      new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
918
919  socket_stream->set_context(&context);
920  socket_stream->SetClientSocketFactory(&mock_socket_factory);
921
922  socket_stream->Connect();
923
924  test_callback.WaitForResult();
925
926  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
927  ASSERT_EQ(4U, events.size());
928
929  EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION,
930            events[0].event_type);
931  EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type);
932  EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[2].event_type);
933  EXPECT_EQ(ERR_ABORTED, events[2].error_code);
934  EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[3].event_type);
935}
936
937TEST_F(SocketStreamTest, BeforeConnectFailed) {
938  TestCompletionCallback test_callback;
939
940  scoped_ptr<SocketStreamEventRecorder> delegate(
941      new SocketStreamEventRecorder(test_callback.callback()));
942
943  TestURLRequestContext context;
944  TestSocketStreamNetworkDelegate network_delegate;
945  network_delegate.SetBeforeConnectResult(ERR_ACCESS_DENIED);
946  context.set_network_delegate(&network_delegate);
947
948  scoped_refptr<SocketStream> socket_stream(
949      new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
950
951  socket_stream->set_context(&context);
952
953  socket_stream->Connect();
954
955  test_callback.WaitForResult();
956
957  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
958  ASSERT_EQ(2U, events.size());
959
960  EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[0].event_type);
961  EXPECT_EQ(ERR_ACCESS_DENIED, events[0].error_code);
962  EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[1].event_type);
963}
964
965// Check that a connect failure, followed by the delegate calling DetachDelegate
966// and deleting itself in the OnError callback, is handled correctly.
967TEST_F(SocketStreamTest, OnErrorDetachDelegate) {
968  MockClientSocketFactory mock_socket_factory;
969  TestCompletionCallback test_callback;
970
971  // SelfDeletingDelegate is self-owning; we just need a pointer to it to
972  // connect it and the SocketStream.
973  SelfDeletingDelegate* delegate =
974      new SelfDeletingDelegate(test_callback.callback());
975  MockConnect mock_connect(ASYNC, ERR_CONNECTION_REFUSED);
976  StaticSocketDataProvider data;
977  data.set_connect_data(mock_connect);
978  mock_socket_factory.AddSocketDataProvider(&data);
979
980  TestURLRequestContext context;
981  scoped_refptr<SocketStream> socket_stream(
982      new SocketStream(GURL("ws://localhost:9998/echo"), delegate));
983  socket_stream->set_context(&context);
984  socket_stream->SetClientSocketFactory(&mock_socket_factory);
985  delegate->set_socket_stream(socket_stream);
986  // The delegate pointer will become invalid during the test. Set it to NULL to
987  // avoid holding a dangling pointer.
988  delegate = NULL;
989
990  socket_stream->Connect();
991
992  EXPECT_EQ(OK, test_callback.WaitForResult());
993}
994
995TEST_F(SocketStreamTest, NullContextSocketStreamShouldNotCrash) {
996  TestCompletionCallback test_callback;
997
998  scoped_ptr<SocketStreamEventRecorder> delegate(
999      new SocketStreamEventRecorder(test_callback.callback()));
1000  TestURLRequestContext context;
1001  scoped_refptr<SocketStream> socket_stream(
1002      new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
1003  delegate->SetOnStartOpenConnection(base::Bind(
1004      &SocketStreamTest::DoIOPending, base::Unretained(this)));
1005  delegate->SetOnConnected(base::Bind(
1006      &SocketStreamTest::DoSendWebSocketHandshake, base::Unretained(this)));
1007  delegate->SetOnReceivedData(base::Bind(
1008      &SocketStreamTest::DoCloseFlushPendingWriteTestWithSetContextNull,
1009      base::Unretained(this)));
1010
1011  socket_stream->set_context(&context);
1012
1013  MockWrite data_writes[] = {
1014    MockWrite(SocketStreamTest::kWebSocketHandshakeRequest),
1015  };
1016  MockRead data_reads[] = {
1017    MockRead(SocketStreamTest::kWebSocketHandshakeResponse),
1018  };
1019  AddWebSocketMessage("message1");
1020  AddWebSocketMessage("message2");
1021
1022  DelayedSocketData data_provider(
1023      1, data_reads, arraysize(data_reads),
1024      data_writes, arraysize(data_writes));
1025
1026  MockClientSocketFactory* mock_socket_factory = GetMockClientSocketFactory();
1027  mock_socket_factory->AddSocketDataProvider(&data_provider);
1028  socket_stream->SetClientSocketFactory(mock_socket_factory);
1029
1030  socket_stream->Connect();
1031  io_test_callback_.WaitForResult();
1032  delegate->CompleteConnection(OK);
1033  EXPECT_EQ(OK, test_callback.WaitForResult());
1034
1035  EXPECT_TRUE(data_provider.at_read_eof());
1036  EXPECT_TRUE(data_provider.at_write_eof());
1037
1038  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
1039  ASSERT_EQ(5U, events.size());
1040
1041  EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION,
1042            events[0].event_type);
1043  EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type);
1044  EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[2].event_type);
1045  EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[3].event_type);
1046  EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[4].event_type);
1047}
1048
1049}  // namespace net
1050