1// Copyright (c) 2009 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <string>
6#include <vector>
7
8#include "net/base/load_log.h"
9#include "net/base/load_log_unittest.h"
10#include "net/base/mock_host_resolver.h"
11#include "net/base/test_completion_callback.h"
12#include "net/socket/socket_test_util.h"
13#include "net/socket_stream/socket_stream.h"
14#include "net/url_request/url_request_unittest.h"
15#include "testing/gtest/include/gtest/gtest.h"
16#include "testing/platform_test.h"
17
18struct SocketStreamEvent {
19  enum EventType {
20    EVENT_CONNECTED, EVENT_SENT_DATA, EVENT_RECEIVED_DATA, EVENT_CLOSE,
21    EVENT_AUTH_REQUIRED,
22  };
23
24  SocketStreamEvent(EventType type, net::SocketStream* socket_stream,
25                    int num, const std::string& str,
26                    net::AuthChallengeInfo* auth_challenge_info)
27      : event_type(type), socket(socket_stream), number(num), data(str),
28        auth_info(auth_challenge_info) {}
29
30  EventType event_type;
31  net::SocketStream* socket;
32  int number;
33  std::string data;
34  scoped_refptr<net::AuthChallengeInfo> auth_info;
35};
36
37class SocketStreamEventRecorder : public net::SocketStream::Delegate {
38 public:
39  explicit SocketStreamEventRecorder(net::CompletionCallback* callback)
40      : on_connected_(NULL),
41        on_sent_data_(NULL),
42        on_received_data_(NULL),
43        on_close_(NULL),
44        on_auth_required_(NULL),
45        callback_(callback) {}
46  virtual ~SocketStreamEventRecorder() {
47    delete on_connected_;
48    delete on_sent_data_;
49    delete on_received_data_;
50    delete on_close_;
51    delete on_auth_required_;
52  }
53
54  void SetOnConnected(Callback1<SocketStreamEvent*>::Type* callback) {
55    on_connected_ = callback;
56  }
57  void SetOnSentData(Callback1<SocketStreamEvent*>::Type* callback) {
58    on_sent_data_ = callback;
59  }
60  void SetOnReceivedData(Callback1<SocketStreamEvent*>::Type* callback) {
61    on_received_data_ = callback;
62  }
63  void SetOnClose(Callback1<SocketStreamEvent*>::Type* callback) {
64    on_close_ = callback;
65  }
66  void SetOnAuthRequired(Callback1<SocketStreamEvent*>::Type* callback) {
67    on_auth_required_ = callback;
68  }
69
70  virtual void OnConnected(net::SocketStream* socket,
71                           int num_pending_send_allowed) {
72    events_.push_back(
73        SocketStreamEvent(SocketStreamEvent::EVENT_CONNECTED,
74                          socket, num_pending_send_allowed, std::string(),
75                          NULL));
76    if (on_connected_)
77      on_connected_->Run(&events_.back());
78  }
79  virtual void OnSentData(net::SocketStream* socket,
80                          int amount_sent) {
81    events_.push_back(
82        SocketStreamEvent(SocketStreamEvent::EVENT_SENT_DATA,
83                          socket, amount_sent, std::string(), NULL));
84    if (on_sent_data_)
85      on_sent_data_->Run(&events_.back());
86  }
87  virtual void OnReceivedData(net::SocketStream* socket,
88                              const char* data, int len) {
89    events_.push_back(
90        SocketStreamEvent(SocketStreamEvent::EVENT_RECEIVED_DATA,
91                          socket, len, std::string(data, len), NULL));
92    if (on_received_data_)
93      on_received_data_->Run(&events_.back());
94  }
95  virtual void OnClose(net::SocketStream* socket) {
96    events_.push_back(
97        SocketStreamEvent(SocketStreamEvent::EVENT_CLOSE,
98                          socket, 0, std::string(), NULL));
99    if (on_close_)
100      on_close_->Run(&events_.back());
101    if (callback_)
102      callback_->Run(net::OK);
103  }
104  virtual void OnAuthRequired(net::SocketStream* socket,
105                              net::AuthChallengeInfo* auth_info) {
106    events_.push_back(
107        SocketStreamEvent(SocketStreamEvent::EVENT_AUTH_REQUIRED,
108                          socket, 0, std::string(), auth_info));
109    if (on_auth_required_)
110      on_auth_required_->Run(&events_.back());
111  }
112
113  void DoClose(SocketStreamEvent* event) {
114    event->socket->Close();
115  }
116  void DoRestartWithAuth(SocketStreamEvent* event) {
117    LOG(INFO) << "RestartWithAuth username=" << username_
118              << " password=" << password_;
119    event->socket->RestartWithAuth(username_, password_);
120  }
121  void SetAuthInfo(const std::wstring& username,
122                   const std::wstring& password) {
123    username_ = username;
124    password_ = password;
125  }
126
127  const std::vector<SocketStreamEvent>& GetSeenEvents() const {
128    return events_;
129  }
130
131 private:
132  std::vector<SocketStreamEvent> events_;
133  Callback1<SocketStreamEvent*>::Type* on_connected_;
134  Callback1<SocketStreamEvent*>::Type* on_sent_data_;
135  Callback1<SocketStreamEvent*>::Type* on_received_data_;
136  Callback1<SocketStreamEvent*>::Type* on_close_;
137  Callback1<SocketStreamEvent*>::Type* on_auth_required_;
138  net::CompletionCallback* callback_;
139
140  std::wstring username_;
141  std::wstring password_;
142
143  DISALLOW_COPY_AND_ASSIGN(SocketStreamEventRecorder);
144};
145
146namespace net {
147
148class SocketStreamTest : public PlatformTest {
149};
150
151TEST_F(SocketStreamTest, BasicAuthProxy) {
152  MockClientSocketFactory mock_socket_factory;
153  MockWrite data_writes1[] = {
154    MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
155              "Host: example.com\r\n"
156              "Proxy-Connection: keep-alive\r\n\r\n"),
157  };
158  MockRead data_reads1[] = {
159    MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"),
160    MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"),
161    MockRead("\r\n"),
162  };
163  StaticSocketDataProvider data1(data_reads1, data_writes1);
164  mock_socket_factory.AddSocketDataProvider(&data1);
165
166  MockWrite data_writes2[] = {
167    MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
168              "Host: example.com\r\n"
169              "Proxy-Connection: keep-alive\r\n"
170              "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"),
171  };
172  MockRead data_reads2[] = {
173    MockRead("HTTP/1.1 200 Connection Established\r\n"),
174    MockRead("Proxy-agent: Apache/2.2.8\r\n"),
175    MockRead("\r\n"),
176  };
177  StaticSocketDataProvider data2(data_reads2, data_writes2);
178  mock_socket_factory.AddSocketDataProvider(&data2);
179
180  TestCompletionCallback callback;
181
182  scoped_ptr<SocketStreamEventRecorder> delegate(
183      new SocketStreamEventRecorder(&callback));
184  delegate->SetOnConnected(NewCallback(delegate.get(),
185                                       &SocketStreamEventRecorder::DoClose));
186  const std::wstring kUsername = L"foo";
187  const std::wstring kPassword = L"bar";
188  delegate->SetAuthInfo(kUsername, kPassword);
189  delegate->SetOnAuthRequired(
190      NewCallback(delegate.get(),
191                  &SocketStreamEventRecorder::DoRestartWithAuth));
192
193  scoped_refptr<SocketStream> socket_stream =
194      new SocketStream(GURL("ws://example.com/demo"), delegate.get());
195
196  socket_stream->set_context(new TestURLRequestContext("myproxy:70"));
197  socket_stream->SetHostResolver(new MockHostResolver());
198  socket_stream->SetClientSocketFactory(&mock_socket_factory);
199
200  socket_stream->Connect();
201
202  callback.WaitForResult();
203
204  const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
205  EXPECT_EQ(3U, events.size());
206
207  EXPECT_EQ(SocketStreamEvent::EVENT_AUTH_REQUIRED, events[0].event_type);
208  EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type);
209  EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[2].event_type);
210
211  // The first and last entries of the LoadLog should be for
212  // SOCKET_STREAM_CONNECT.
213  EXPECT_TRUE(LogContainsBeginEvent(
214      *socket_stream->load_log(), 0, LoadLog::TYPE_SOCKET_STREAM_CONNECT));
215  EXPECT_TRUE(LogContainsEndEvent(
216      *socket_stream->load_log(), -1, LoadLog::TYPE_SOCKET_STREAM_CONNECT));
217}
218
219}  // namespace net
220