1// Copyright (c) 2010 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <string>
6#include <vector>
7
8#include "base/callback.h"
9#include "base/utf_string_conversions.h"
10#include "net/base/auth.h"
11#include "net/base/mock_host_resolver.h"
12#include "net/base/net_log.h"
13#include "net/base/net_log_unittest.h"
14#include "net/base/test_completion_callback.h"
15#include "net/socket/socket_test_util.h"
16#include "net/socket_stream/socket_stream.h"
17#include "net/url_request/url_request_test_util.h"
18#include "testing/gtest/include/gtest/gtest.h"
19#include "testing/platform_test.h"
20
21struct SocketStreamEvent {
22  enum EventType {
23    EVENT_CONNECTED, EVENT_SENT_DATA, EVENT_RECEIVED_DATA, EVENT_CLOSE,
24    EVENT_AUTH_REQUIRED,
25  };
26
27  SocketStreamEvent(EventType type, net::SocketStream* socket_stream,
28                    int num, const std::string& str,
29                    net::AuthChallengeInfo* auth_challenge_info)
30      : event_type(type), socket(socket_stream), number(num), data(str),
31        auth_info(auth_challenge_info) {}
32
33  EventType event_type;
34  net::SocketStream* socket;
35  int number;
36  std::string data;
37  scoped_refptr<net::AuthChallengeInfo> auth_info;
38};
39
40class SocketStreamEventRecorder : public net::SocketStream::Delegate {
41 public:
42  explicit SocketStreamEventRecorder(net::CompletionCallback* callback)
43      : on_connected_(NULL),
44        on_sent_data_(NULL),
45        on_received_data_(NULL),
46        on_close_(NULL),
47        on_auth_required_(NULL),
48        callback_(callback) {}
49  virtual ~SocketStreamEventRecorder() {
50    delete on_connected_;
51    delete on_sent_data_;
52    delete on_received_data_;
53    delete on_close_;
54    delete on_auth_required_;
55  }
56
57  void SetOnConnected(Callback1<SocketStreamEvent*>::Type* callback) {
58    on_connected_ = callback;
59  }
60  void SetOnSentData(Callback1<SocketStreamEvent*>::Type* callback) {
61    on_sent_data_ = callback;
62  }
63  void SetOnReceivedData(Callback1<SocketStreamEvent*>::Type* callback) {
64    on_received_data_ = callback;
65  }
66  void SetOnClose(Callback1<SocketStreamEvent*>::Type* callback) {
67    on_close_ = callback;
68  }
69  void SetOnAuthRequired(Callback1<SocketStreamEvent*>::Type* callback) {
70    on_auth_required_ = callback;
71  }
72
73  virtual void OnConnected(net::SocketStream* socket,
74                           int num_pending_send_allowed) {
75    events_.push_back(
76        SocketStreamEvent(SocketStreamEvent::EVENT_CONNECTED,
77                          socket, num_pending_send_allowed, std::string(),
78                          NULL));
79    if (on_connected_)
80      on_connected_->Run(&events_.back());
81  }
82  virtual void OnSentData(net::SocketStream* socket,
83                          int amount_sent) {
84    events_.push_back(
85        SocketStreamEvent(SocketStreamEvent::EVENT_SENT_DATA,
86                          socket, amount_sent, std::string(), NULL));
87    if (on_sent_data_)
88      on_sent_data_->Run(&events_.back());
89  }
90  virtual void OnReceivedData(net::SocketStream* socket,
91                              const char* data, int len) {
92    events_.push_back(
93        SocketStreamEvent(SocketStreamEvent::EVENT_RECEIVED_DATA,
94                          socket, len, std::string(data, len), NULL));
95    if (on_received_data_)
96      on_received_data_->Run(&events_.back());
97  }
98  virtual void OnClose(net::SocketStream* socket) {
99    events_.push_back(
100        SocketStreamEvent(SocketStreamEvent::EVENT_CLOSE,
101                          socket, 0, std::string(), NULL));
102    if (on_close_)
103      on_close_->Run(&events_.back());
104    if (callback_)
105      callback_->Run(net::OK);
106  }
107  virtual void OnAuthRequired(net::SocketStream* socket,
108                              net::AuthChallengeInfo* auth_info) {
109    events_.push_back(
110        SocketStreamEvent(SocketStreamEvent::EVENT_AUTH_REQUIRED,
111                          socket, 0, std::string(), auth_info));
112    if (on_auth_required_)
113      on_auth_required_->Run(&events_.back());
114  }
115
116  void DoClose(SocketStreamEvent* event) {
117    event->socket->Close();
118  }
119  void DoRestartWithAuth(SocketStreamEvent* event) {
120    VLOG(1) << "RestartWithAuth username=" << username_
121            << " password=" << password_;
122    event->socket->RestartWithAuth(username_, password_);
123  }
124  void SetAuthInfo(const string16& username,
125                   const string16& password) {
126    username_ = username;
127    password_ = password;
128  }
129
130  const std::vector<SocketStreamEvent>& GetSeenEvents() const {
131    return events_;
132  }
133
134 private:
135  std::vector<SocketStreamEvent> events_;
136  Callback1<SocketStreamEvent*>::Type* on_connected_;
137  Callback1<SocketStreamEvent*>::Type* on_sent_data_;
138  Callback1<SocketStreamEvent*>::Type* on_received_data_;
139  Callback1<SocketStreamEvent*>::Type* on_close_;
140  Callback1<SocketStreamEvent*>::Type* on_auth_required_;
141  net::CompletionCallback* callback_;
142
143  string16 username_;
144  string16 password_;
145
146  DISALLOW_COPY_AND_ASSIGN(SocketStreamEventRecorder);
147};
148
149namespace net {
150
151class SocketStreamTest : public PlatformTest {
152 public:
153  virtual ~SocketStreamTest() {}
154  virtual void SetUp() {
155    mock_socket_factory_.reset();
156    handshake_request_ = kWebSocketHandshakeRequest;
157    handshake_response_ = kWebSocketHandshakeResponse;
158  }
159  virtual void TearDown() {
160    mock_socket_factory_.reset();
161  }
162
163  virtual void SetWebSocketHandshakeMessage(
164      const char* request, const char* response) {
165    handshake_request_ = request;
166    handshake_response_ = response;
167  }
168  virtual void AddWebSocketMessage(const std::string& message) {
169    messages_.push_back(message);
170  }
171
172  virtual MockClientSocketFactory* GetMockClientSocketFactory() {
173    mock_socket_factory_.reset(new MockClientSocketFactory);
174    return mock_socket_factory_.get();
175  }
176
177  virtual void DoSendWebSocketHandshake(SocketStreamEvent* event) {
178    event->socket->SendData(
179        handshake_request_.data(), handshake_request_.size());
180  }
181
182  virtual void DoCloseFlushPendingWriteTest(SocketStreamEvent* event) {
183    // handshake response received.
184    for (size_t i = 0; i < messages_.size(); i++) {
185      std::vector<char> frame;
186      frame.push_back('\0');
187      frame.insert(frame.end(), messages_[i].begin(), messages_[i].end());
188      frame.push_back('\xff');
189      EXPECT_TRUE(event->socket->SendData(&frame[0], frame.size()));
190    }
191    // Actual ClientSocket close must happen after all frames queued by
192    // SendData above are sent out.
193    event->socket->Close();
194  }
195
196  static const char* kWebSocketHandshakeRequest;
197  static const char* kWebSocketHandshakeResponse;
198
199 private:
200  std::string handshake_request_;
201  std::string handshake_response_;
202  std::vector<std::string> messages_;
203
204  scoped_ptr<MockClientSocketFactory> mock_socket_factory_;
205};
206
207const char* SocketStreamTest::kWebSocketHandshakeRequest =
208    "GET /demo HTTP/1.1\r\n"
209    "Host: example.com\r\n"
210    "Connection: Upgrade\r\n"
211    "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
212    "Sec-WebSocket-Protocol: sample\r\n"
213    "Upgrade: WebSocket\r\n"
214    "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
215    "Origin: http://example.com\r\n"
216    "\r\n"
217    "^n:ds[4U";
218
219const char* SocketStreamTest::kWebSocketHandshakeResponse =
220    "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
221    "Upgrade: WebSocket\r\n"
222    "Connection: Upgrade\r\n"
223    "Sec-WebSocket-Origin: http://example.com\r\n"
224    "Sec-WebSocket-Location: ws://example.com/demo\r\n"
225    "Sec-WebSocket-Protocol: sample\r\n"
226    "\r\n"
227    "8jKS'y:G*Co,Wxa-";
228
229TEST_F(SocketStreamTest, CloseFlushPendingWrite) {
230  TestCompletionCallback callback;
231
232  scoped_ptr<SocketStreamEventRecorder> delegate(
233      new SocketStreamEventRecorder(&callback));
234  // Necessary for NewCallback.
235  SocketStreamTest* test = this;
236  delegate->SetOnConnected(NewCallback(
237      test, &SocketStreamTest::DoSendWebSocketHandshake));
238  delegate->SetOnReceivedData(NewCallback(
239      test, &SocketStreamTest::DoCloseFlushPendingWriteTest));
240
241  MockHostResolver host_resolver;
242
243  scoped_refptr<SocketStream> socket_stream(
244      new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
245
246  socket_stream->set_context(new TestURLRequestContext());
247  socket_stream->SetHostResolver(&host_resolver);
248
249  MockWrite data_writes[] = {
250    MockWrite(SocketStreamTest::kWebSocketHandshakeRequest),
251    MockWrite(true, "\0message1\xff", 10),
252    MockWrite(true, "\0message2\xff", 10)
253  };
254  MockRead data_reads[] = {
255    MockRead(SocketStreamTest::kWebSocketHandshakeResponse),
256    // Server doesn't close the connection after handshake.
257    MockRead(true, ERR_IO_PENDING)
258  };
259  AddWebSocketMessage("message1");
260  AddWebSocketMessage("message2");
261
262  scoped_refptr<DelayedSocketData> data_provider(
263      new DelayedSocketData(1,
264                            data_reads, arraysize(data_reads),
265                            data_writes, arraysize(data_writes)));
266
267  MockClientSocketFactory* mock_socket_factory =
268      GetMockClientSocketFactory();
269  mock_socket_factory->AddSocketDataProvider(data_provider.get());
270
271  socket_stream->SetClientSocketFactory(mock_socket_factory);
272
273  socket_stream->Connect();
274
275  callback.WaitForResult();
276
277  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
278  EXPECT_EQ(6U, events.size());
279
280  EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[0].event_type);
281  EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[1].event_type);
282  EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[2].event_type);
283  EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[3].event_type);
284  EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[4].event_type);
285  EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[5].event_type);
286}
287
288TEST_F(SocketStreamTest, BasicAuthProxy) {
289  MockClientSocketFactory mock_socket_factory;
290  MockWrite data_writes1[] = {
291    MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
292              "Host: example.com\r\n"
293              "Proxy-Connection: keep-alive\r\n\r\n"),
294  };
295  MockRead data_reads1[] = {
296    MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"),
297    MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"),
298    MockRead("\r\n"),
299  };
300  StaticSocketDataProvider data1(data_reads1, arraysize(data_reads1),
301                                 data_writes1, arraysize(data_writes1));
302  mock_socket_factory.AddSocketDataProvider(&data1);
303
304  MockWrite data_writes2[] = {
305    MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
306              "Host: example.com\r\n"
307              "Proxy-Connection: keep-alive\r\n"
308              "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"),
309  };
310  MockRead data_reads2[] = {
311    MockRead("HTTP/1.1 200 Connection Established\r\n"),
312    MockRead("Proxy-agent: Apache/2.2.8\r\n"),
313    MockRead("\r\n"),
314    // SocketStream::DoClose is run asynchronously.  Socket can be read after
315    // "\r\n".  We have to give ERR_IO_PENDING to SocketStream then to indicate
316    // server doesn't close the connection.
317    MockRead(true, ERR_IO_PENDING)
318  };
319  StaticSocketDataProvider data2(data_reads2, arraysize(data_reads2),
320                                 data_writes2, arraysize(data_writes2));
321  mock_socket_factory.AddSocketDataProvider(&data2);
322
323  TestCompletionCallback callback;
324
325  scoped_ptr<SocketStreamEventRecorder> delegate(
326      new SocketStreamEventRecorder(&callback));
327  delegate->SetOnConnected(NewCallback(delegate.get(),
328                                       &SocketStreamEventRecorder::DoClose));
329  delegate->SetAuthInfo(ASCIIToUTF16("foo"), ASCIIToUTF16("bar"));
330  delegate->SetOnAuthRequired(
331      NewCallback(delegate.get(),
332                  &SocketStreamEventRecorder::DoRestartWithAuth));
333
334  scoped_refptr<SocketStream> socket_stream(
335      new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
336
337  socket_stream->set_context(new TestURLRequestContext("myproxy:70"));
338  MockHostResolver host_resolver;
339  socket_stream->SetHostResolver(&host_resolver);
340  socket_stream->SetClientSocketFactory(&mock_socket_factory);
341
342  socket_stream->Connect();
343
344  callback.WaitForResult();
345
346  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
347  EXPECT_EQ(3U, events.size());
348
349  EXPECT_EQ(SocketStreamEvent::EVENT_AUTH_REQUIRED, events[0].event_type);
350  EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type);
351  EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[2].event_type);
352
353  // TODO(eroman): Add back NetLogTest here...
354}
355
356}  // namespace net
357