1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <string>
6#include <vector>
7
8#include "base/memory/ref_counted.h"
9#include "base/string_split.h"
10#include "base/string_util.h"
11#include "googleurl/src/gurl.h"
12#include "net/base/cookie_policy.h"
13#include "net/base/cookie_store.h"
14#include "net/base/net_errors.h"
15#include "net/base/sys_addrinfo.h"
16#include "net/base/transport_security_state.h"
17#include "net/socket_stream/socket_stream.h"
18#include "net/url_request/url_request_context.h"
19#include "net/websockets/websocket_job.h"
20#include "net/websockets/websocket_throttle.h"
21#include "testing/gtest/include/gtest/gtest.h"
22#include "testing/gmock/include/gmock/gmock.h"
23#include "testing/platform_test.h"
24
25namespace net {
26
27class MockSocketStream : public SocketStream {
28 public:
29  MockSocketStream(const GURL& url, SocketStream::Delegate* delegate)
30      : SocketStream(url, delegate) {}
31  virtual ~MockSocketStream() {}
32
33  virtual void Connect() {}
34  virtual bool SendData(const char* data, int len) {
35    sent_data_ += std::string(data, len);
36    return true;
37  }
38
39  virtual void Close() {}
40  virtual void RestartWithAuth(
41      const string16& username, const string16& password) {}
42  virtual void DetachDelegate() {
43    delegate_ = NULL;
44  }
45
46  const std::string& sent_data() const {
47    return sent_data_;
48  }
49
50 private:
51  std::string sent_data_;
52};
53
54class MockSocketStreamDelegate : public SocketStream::Delegate {
55 public:
56  MockSocketStreamDelegate()
57      : amount_sent_(0) {}
58  virtual ~MockSocketStreamDelegate() {}
59
60  virtual void OnConnected(SocketStream* socket, int max_pending_send_allowed) {
61  }
62  virtual void OnSentData(SocketStream* socket, int amount_sent) {
63    amount_sent_ += amount_sent;
64  }
65  virtual void OnReceivedData(SocketStream* socket,
66                              const char* data, int len) {
67    received_data_ += std::string(data, len);
68  }
69  virtual void OnClose(SocketStream* socket) {
70  }
71
72  size_t amount_sent() const { return amount_sent_; }
73  const std::string& received_data() const { return received_data_; }
74
75 private:
76  int amount_sent_;
77  std::string received_data_;
78};
79
80class MockCookieStore : public CookieStore {
81 public:
82  struct Entry {
83    GURL url;
84    std::string cookie_line;
85    CookieOptions options;
86  };
87  MockCookieStore() {}
88
89  virtual bool SetCookieWithOptions(const GURL& url,
90                                    const std::string& cookie_line,
91                                    const CookieOptions& options) {
92    Entry entry;
93    entry.url = url;
94    entry.cookie_line = cookie_line;
95    entry.options = options;
96    entries_.push_back(entry);
97    return true;
98  }
99  virtual std::string GetCookiesWithOptions(const GURL& url,
100                                            const CookieOptions& options) {
101    std::string result;
102    for (size_t i = 0; i < entries_.size(); i++) {
103      Entry &entry = entries_[i];
104      if (url == entry.url) {
105        if (!result.empty()) {
106          result += "; ";
107        }
108        result += entry.cookie_line;
109      }
110    }
111    return result;
112  }
113  virtual void DeleteCookie(const GURL& url,
114                            const std::string& cookie_name) {}
115  virtual CookieMonster* GetCookieMonster() { return NULL; }
116
117  const std::vector<Entry>& entries() const { return entries_; }
118
119 private:
120  friend class base::RefCountedThreadSafe<MockCookieStore>;
121  virtual ~MockCookieStore() {}
122
123  std::vector<Entry> entries_;
124};
125
126class MockCookiePolicy : public CookiePolicy {
127 public:
128  MockCookiePolicy() : allow_all_cookies_(true) {}
129  virtual ~MockCookiePolicy() {}
130
131  void set_allow_all_cookies(bool allow_all_cookies) {
132    allow_all_cookies_ = allow_all_cookies;
133  }
134
135  virtual int CanGetCookies(const GURL& url,
136                            const GURL& first_party_for_cookies) const {
137    if (allow_all_cookies_)
138      return OK;
139    return ERR_ACCESS_DENIED;
140  }
141
142  virtual int CanSetCookie(const GURL& url,
143                           const GURL& first_party_for_cookies,
144                           const std::string& cookie_line) const {
145    if (allow_all_cookies_)
146      return OK;
147    return ERR_ACCESS_DENIED;
148  }
149
150 private:
151  bool allow_all_cookies_;
152};
153
154class MockURLRequestContext : public URLRequestContext {
155 public:
156  MockURLRequestContext(CookieStore* cookie_store,
157                        CookiePolicy* cookie_policy) {
158    set_cookie_store(cookie_store);
159    set_cookie_policy(cookie_policy);
160    transport_security_state_ = new TransportSecurityState();
161    set_transport_security_state(transport_security_state_.get());
162    TransportSecurityState::DomainState state;
163    state.expiry = base::Time::Now() + base::TimeDelta::FromSeconds(1000);
164    transport_security_state_->EnableHost("upgrademe.com", state);
165  }
166
167 private:
168  friend class base::RefCountedThreadSafe<MockURLRequestContext>;
169  virtual ~MockURLRequestContext() {}
170
171  scoped_refptr<TransportSecurityState> transport_security_state_;
172};
173
174class WebSocketJobTest : public PlatformTest {
175 public:
176  virtual void SetUp() {
177    cookie_store_ = new MockCookieStore;
178    cookie_policy_.reset(new MockCookiePolicy);
179    context_ = new MockURLRequestContext(
180        cookie_store_.get(), cookie_policy_.get());
181  }
182  virtual void TearDown() {
183    cookie_store_ = NULL;
184    cookie_policy_.reset();
185    context_ = NULL;
186    websocket_ = NULL;
187    socket_ = NULL;
188  }
189 protected:
190  void InitWebSocketJob(const GURL& url, MockSocketStreamDelegate* delegate) {
191    websocket_ = new WebSocketJob(delegate);
192    socket_ = new MockSocketStream(url, websocket_.get());
193    websocket_->InitSocketStream(socket_.get());
194    websocket_->set_context(context_.get());
195    websocket_->state_ = WebSocketJob::CONNECTING;
196    struct addrinfo addr;
197    memset(&addr, 0, sizeof(struct addrinfo));
198    addr.ai_family = AF_INET;
199    addr.ai_addrlen = sizeof(struct sockaddr_in);
200    struct sockaddr_in sa_in;
201    memset(&sa_in, 0, sizeof(struct sockaddr_in));
202    memcpy(&sa_in.sin_addr, "\x7f\0\0\1", 4);
203    addr.ai_addr = reinterpret_cast<sockaddr*>(&sa_in);
204    addr.ai_next = NULL;
205    websocket_->addresses_.Copy(&addr, true);
206    WebSocketThrottle::GetInstance()->PutInQueue(websocket_);
207  }
208  WebSocketJob::State GetWebSocketJobState() {
209    return websocket_->state_;
210  }
211  void CloseWebSocketJob() {
212    if (websocket_->socket_) {
213      websocket_->socket_->DetachDelegate();
214      WebSocketThrottle::GetInstance()->RemoveFromQueue(websocket_);
215    }
216    websocket_->state_ = WebSocketJob::CLOSED;
217    websocket_->delegate_ = NULL;
218    websocket_->socket_ = NULL;
219  }
220  SocketStream* GetSocket(SocketStreamJob* job) {
221    return job->socket_.get();
222  }
223
224  scoped_refptr<MockCookieStore> cookie_store_;
225  scoped_ptr<MockCookiePolicy> cookie_policy_;
226  scoped_refptr<MockURLRequestContext> context_;
227  scoped_refptr<WebSocketJob> websocket_;
228  scoped_refptr<MockSocketStream> socket_;
229};
230
231TEST_F(WebSocketJobTest, SimpleHandshake) {
232  GURL url("ws://example.com/demo");
233  MockSocketStreamDelegate delegate;
234  InitWebSocketJob(url, &delegate);
235
236  static const char* kHandshakeRequestMessage =
237      "GET /demo HTTP/1.1\r\n"
238      "Host: example.com\r\n"
239      "Connection: Upgrade\r\n"
240      "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
241      "Sec-WebSocket-Protocol: sample\r\n"
242      "Upgrade: WebSocket\r\n"
243      "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
244      "Origin: http://example.com\r\n"
245      "\r\n"
246      "^n:ds[4U";
247
248  bool sent = websocket_->SendData(kHandshakeRequestMessage,
249                                   strlen(kHandshakeRequestMessage));
250  EXPECT_TRUE(sent);
251  MessageLoop::current()->RunAllPending();
252  EXPECT_EQ(kHandshakeRequestMessage, socket_->sent_data());
253  EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
254  websocket_->OnSentData(socket_.get(), strlen(kHandshakeRequestMessage));
255  EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent());
256
257  const char kHandshakeResponseMessage[] =
258      "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
259      "Upgrade: WebSocket\r\n"
260      "Connection: Upgrade\r\n"
261      "Sec-WebSocket-Origin: http://example.com\r\n"
262      "Sec-WebSocket-Location: ws://example.com/demo\r\n"
263      "Sec-WebSocket-Protocol: sample\r\n"
264      "\r\n"
265      "8jKS'y:G*Co,Wxa-";
266
267  websocket_->OnReceivedData(socket_.get(),
268                             kHandshakeResponseMessage,
269                             strlen(kHandshakeResponseMessage));
270  MessageLoop::current()->RunAllPending();
271  EXPECT_EQ(kHandshakeResponseMessage, delegate.received_data());
272  EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState());
273  CloseWebSocketJob();
274}
275
276TEST_F(WebSocketJobTest, SlowHandshake) {
277  GURL url("ws://example.com/demo");
278  MockSocketStreamDelegate delegate;
279  InitWebSocketJob(url, &delegate);
280
281  static const char* kHandshakeRequestMessage =
282      "GET /demo HTTP/1.1\r\n"
283      "Host: example.com\r\n"
284      "Connection: Upgrade\r\n"
285      "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
286      "Sec-WebSocket-Protocol: sample\r\n"
287      "Upgrade: WebSocket\r\n"
288      "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
289      "Origin: http://example.com\r\n"
290      "\r\n"
291      "^n:ds[4U";
292
293  bool sent = websocket_->SendData(kHandshakeRequestMessage,
294                                   strlen(kHandshakeRequestMessage));
295  EXPECT_TRUE(sent);
296  // We assume request is sent in one data chunk (from WebKit)
297  // We don't support streaming request.
298  MessageLoop::current()->RunAllPending();
299  EXPECT_EQ(kHandshakeRequestMessage, socket_->sent_data());
300  EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
301  websocket_->OnSentData(socket_.get(), strlen(kHandshakeRequestMessage));
302  EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent());
303
304  const char kHandshakeResponseMessage[] =
305      "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
306      "Upgrade: WebSocket\r\n"
307      "Connection: Upgrade\r\n"
308      "Sec-WebSocket-Origin: http://example.com\r\n"
309      "Sec-WebSocket-Location: ws://example.com/demo\r\n"
310      "Sec-WebSocket-Protocol: sample\r\n"
311      "\r\n"
312      "8jKS'y:G*Co,Wxa-";
313
314  std::vector<std::string> lines;
315  base::SplitString(kHandshakeResponseMessage, '\n', &lines);
316  for (size_t i = 0; i < lines.size() - 2; i++) {
317    std::string line = lines[i] + "\r\n";
318    SCOPED_TRACE("Line: " + line);
319    websocket_->OnReceivedData(socket_,
320                               line.c_str(),
321                               line.size());
322    MessageLoop::current()->RunAllPending();
323    EXPECT_TRUE(delegate.received_data().empty());
324    EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
325  }
326  websocket_->OnReceivedData(socket_.get(), "\r\n", 2);
327  MessageLoop::current()->RunAllPending();
328  EXPECT_TRUE(delegate.received_data().empty());
329  EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
330  websocket_->OnReceivedData(socket_.get(), "8jKS'y:G*Co,Wxa-", 16);
331  EXPECT_EQ(kHandshakeResponseMessage, delegate.received_data());
332  EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState());
333  CloseWebSocketJob();
334}
335
336TEST_F(WebSocketJobTest, HandshakeWithCookie) {
337  GURL url("ws://example.com/demo");
338  GURL cookieUrl("http://example.com/demo");
339  CookieOptions cookie_options;
340  cookie_store_->SetCookieWithOptions(
341      cookieUrl, "CR-test=1", cookie_options);
342  cookie_options.set_include_httponly();
343  cookie_store_->SetCookieWithOptions(
344      cookieUrl, "CR-test-httponly=1", cookie_options);
345
346  MockSocketStreamDelegate delegate;
347  InitWebSocketJob(url, &delegate);
348
349  static const char* kHandshakeRequestMessage =
350      "GET /demo HTTP/1.1\r\n"
351      "Host: example.com\r\n"
352      "Connection: Upgrade\r\n"
353      "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
354      "Sec-WebSocket-Protocol: sample\r\n"
355      "Upgrade: WebSocket\r\n"
356      "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
357      "Origin: http://example.com\r\n"
358      "Cookie: WK-test=1\r\n"
359      "\r\n"
360      "^n:ds[4U";
361
362  static const char* kHandshakeRequestExpected =
363      "GET /demo HTTP/1.1\r\n"
364      "Host: example.com\r\n"
365      "Connection: Upgrade\r\n"
366      "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
367      "Sec-WebSocket-Protocol: sample\r\n"
368      "Upgrade: WebSocket\r\n"
369      "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
370      "Origin: http://example.com\r\n"
371      "Cookie: CR-test=1; CR-test-httponly=1\r\n"
372      "\r\n"
373      "^n:ds[4U";
374
375  bool sent = websocket_->SendData(kHandshakeRequestMessage,
376                                   strlen(kHandshakeRequestMessage));
377  EXPECT_TRUE(sent);
378  MessageLoop::current()->RunAllPending();
379  EXPECT_EQ(kHandshakeRequestExpected, socket_->sent_data());
380  EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
381  websocket_->OnSentData(socket_, strlen(kHandshakeRequestExpected));
382  EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent());
383
384  const char kHandshakeResponseMessage[] =
385      "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
386      "Upgrade: WebSocket\r\n"
387      "Connection: Upgrade\r\n"
388      "Sec-WebSocket-Origin: http://example.com\r\n"
389      "Sec-WebSocket-Location: ws://example.com/demo\r\n"
390      "Sec-WebSocket-Protocol: sample\r\n"
391      "Set-Cookie: CR-set-test=1\r\n"
392      "\r\n"
393      "8jKS'y:G*Co,Wxa-";
394
395  static const char* kHandshakeResponseExpected =
396      "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
397      "Upgrade: WebSocket\r\n"
398      "Connection: Upgrade\r\n"
399      "Sec-WebSocket-Origin: http://example.com\r\n"
400      "Sec-WebSocket-Location: ws://example.com/demo\r\n"
401      "Sec-WebSocket-Protocol: sample\r\n"
402      "\r\n"
403      "8jKS'y:G*Co,Wxa-";
404
405  websocket_->OnReceivedData(socket_.get(),
406                             kHandshakeResponseMessage,
407                             strlen(kHandshakeResponseMessage));
408  MessageLoop::current()->RunAllPending();
409  EXPECT_EQ(kHandshakeResponseExpected, delegate.received_data());
410  EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState());
411
412  EXPECT_EQ(3U, cookie_store_->entries().size());
413  EXPECT_EQ(cookieUrl, cookie_store_->entries()[0].url);
414  EXPECT_EQ("CR-test=1", cookie_store_->entries()[0].cookie_line);
415  EXPECT_EQ(cookieUrl, cookie_store_->entries()[1].url);
416  EXPECT_EQ("CR-test-httponly=1", cookie_store_->entries()[1].cookie_line);
417  EXPECT_EQ(cookieUrl, cookie_store_->entries()[2].url);
418  EXPECT_EQ("CR-set-test=1", cookie_store_->entries()[2].cookie_line);
419
420  CloseWebSocketJob();
421}
422
423TEST_F(WebSocketJobTest, HandshakeWithCookieButNotAllowed) {
424  GURL url("ws://example.com/demo");
425  GURL cookieUrl("http://example.com/demo");
426  CookieOptions cookie_options;
427  cookie_store_->SetCookieWithOptions(
428      cookieUrl, "CR-test=1", cookie_options);
429  cookie_options.set_include_httponly();
430  cookie_store_->SetCookieWithOptions(
431      cookieUrl, "CR-test-httponly=1", cookie_options);
432  cookie_policy_->set_allow_all_cookies(false);
433
434  MockSocketStreamDelegate delegate;
435  InitWebSocketJob(url, &delegate);
436
437  static const char* kHandshakeRequestMessage =
438      "GET /demo HTTP/1.1\r\n"
439      "Host: example.com\r\n"
440      "Connection: Upgrade\r\n"
441      "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
442      "Sec-WebSocket-Protocol: sample\r\n"
443      "Upgrade: WebSocket\r\n"
444      "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
445      "Origin: http://example.com\r\n"
446      "Cookie: WK-test=1\r\n"
447      "\r\n"
448      "^n:ds[4U";
449
450  static const char* kHandshakeRequestExpected =
451      "GET /demo HTTP/1.1\r\n"
452      "Host: example.com\r\n"
453      "Connection: Upgrade\r\n"
454      "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
455      "Sec-WebSocket-Protocol: sample\r\n"
456      "Upgrade: WebSocket\r\n"
457      "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
458      "Origin: http://example.com\r\n"
459      "\r\n"
460      "^n:ds[4U";
461
462  bool sent = websocket_->SendData(kHandshakeRequestMessage,
463                                   strlen(kHandshakeRequestMessage));
464  EXPECT_TRUE(sent);
465  MessageLoop::current()->RunAllPending();
466  EXPECT_EQ(kHandshakeRequestExpected, socket_->sent_data());
467  EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
468  websocket_->OnSentData(socket_, strlen(kHandshakeRequestExpected));
469  EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent());
470
471  const char kHandshakeResponseMessage[] =
472      "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
473      "Upgrade: WebSocket\r\n"
474      "Connection: Upgrade\r\n"
475      "Sec-WebSocket-Origin: http://example.com\r\n"
476      "Sec-WebSocket-Location: ws://example.com/demo\r\n"
477      "Sec-WebSocket-Protocol: sample\r\n"
478      "Set-Cookie: CR-set-test=1\r\n"
479      "\r\n"
480      "8jKS'y:G*Co,Wxa-";
481
482  static const char* kHandshakeResponseExpected =
483      "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
484      "Upgrade: WebSocket\r\n"
485      "Connection: Upgrade\r\n"
486      "Sec-WebSocket-Origin: http://example.com\r\n"
487      "Sec-WebSocket-Location: ws://example.com/demo\r\n"
488      "Sec-WebSocket-Protocol: sample\r\n"
489      "\r\n"
490      "8jKS'y:G*Co,Wxa-";
491
492  websocket_->OnReceivedData(socket_.get(),
493                             kHandshakeResponseMessage,
494                             strlen(kHandshakeResponseMessage));
495  MessageLoop::current()->RunAllPending();
496  EXPECT_EQ(kHandshakeResponseExpected, delegate.received_data());
497  EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState());
498
499  EXPECT_EQ(2U, cookie_store_->entries().size());
500  EXPECT_EQ(cookieUrl, cookie_store_->entries()[0].url);
501  EXPECT_EQ("CR-test=1", cookie_store_->entries()[0].cookie_line);
502  EXPECT_EQ(cookieUrl, cookie_store_->entries()[1].url);
503  EXPECT_EQ("CR-test-httponly=1", cookie_store_->entries()[1].cookie_line);
504
505  CloseWebSocketJob();
506}
507
508TEST_F(WebSocketJobTest, HSTSUpgrade) {
509  GURL url("ws://upgrademe.com/");
510  MockSocketStreamDelegate delegate;
511  scoped_refptr<SocketStreamJob> job = SocketStreamJob::CreateSocketStreamJob(
512      url, &delegate, *context_.get());
513  EXPECT_TRUE(GetSocket(job.get())->is_secure());
514  job->DetachDelegate();
515
516  url = GURL("ws://donotupgrademe.com/");
517  job = SocketStreamJob::CreateSocketStreamJob(
518      url, &delegate, *context_.get());
519  EXPECT_FALSE(GetSocket(job.get())->is_secure());
520  job->DetachDelegate();
521}
522
523}  // namespace net
524