1// Copyright (c) 2009 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include <string> 6#include <vector> 7 8#include "net/base/load_log.h" 9#include "net/base/load_log_unittest.h" 10#include "net/base/mock_host_resolver.h" 11#include "net/base/test_completion_callback.h" 12#include "net/socket/socket_test_util.h" 13#include "net/socket_stream/socket_stream.h" 14#include "net/url_request/url_request_unittest.h" 15#include "testing/gtest/include/gtest/gtest.h" 16#include "testing/platform_test.h" 17 18struct SocketStreamEvent { 19 enum EventType { 20 EVENT_CONNECTED, EVENT_SENT_DATA, EVENT_RECEIVED_DATA, EVENT_CLOSE, 21 EVENT_AUTH_REQUIRED, 22 }; 23 24 SocketStreamEvent(EventType type, net::SocketStream* socket_stream, 25 int num, const std::string& str, 26 net::AuthChallengeInfo* auth_challenge_info) 27 : event_type(type), socket(socket_stream), number(num), data(str), 28 auth_info(auth_challenge_info) {} 29 30 EventType event_type; 31 net::SocketStream* socket; 32 int number; 33 std::string data; 34 scoped_refptr<net::AuthChallengeInfo> auth_info; 35}; 36 37class SocketStreamEventRecorder : public net::SocketStream::Delegate { 38 public: 39 explicit SocketStreamEventRecorder(net::CompletionCallback* callback) 40 : on_connected_(NULL), 41 on_sent_data_(NULL), 42 on_received_data_(NULL), 43 on_close_(NULL), 44 on_auth_required_(NULL), 45 callback_(callback) {} 46 virtual ~SocketStreamEventRecorder() { 47 delete on_connected_; 48 delete on_sent_data_; 49 delete on_received_data_; 50 delete on_close_; 51 delete on_auth_required_; 52 } 53 54 void SetOnConnected(Callback1<SocketStreamEvent*>::Type* callback) { 55 on_connected_ = callback; 56 } 57 void SetOnSentData(Callback1<SocketStreamEvent*>::Type* callback) { 58 on_sent_data_ = callback; 59 } 60 void SetOnReceivedData(Callback1<SocketStreamEvent*>::Type* callback) { 61 on_received_data_ = callback; 62 } 63 void SetOnClose(Callback1<SocketStreamEvent*>::Type* callback) { 64 on_close_ = callback; 65 } 66 void SetOnAuthRequired(Callback1<SocketStreamEvent*>::Type* callback) { 67 on_auth_required_ = callback; 68 } 69 70 virtual void OnConnected(net::SocketStream* socket, 71 int num_pending_send_allowed) { 72 events_.push_back( 73 SocketStreamEvent(SocketStreamEvent::EVENT_CONNECTED, 74 socket, num_pending_send_allowed, std::string(), 75 NULL)); 76 if (on_connected_) 77 on_connected_->Run(&events_.back()); 78 } 79 virtual void OnSentData(net::SocketStream* socket, 80 int amount_sent) { 81 events_.push_back( 82 SocketStreamEvent(SocketStreamEvent::EVENT_SENT_DATA, 83 socket, amount_sent, std::string(), NULL)); 84 if (on_sent_data_) 85 on_sent_data_->Run(&events_.back()); 86 } 87 virtual void OnReceivedData(net::SocketStream* socket, 88 const char* data, int len) { 89 events_.push_back( 90 SocketStreamEvent(SocketStreamEvent::EVENT_RECEIVED_DATA, 91 socket, len, std::string(data, len), NULL)); 92 if (on_received_data_) 93 on_received_data_->Run(&events_.back()); 94 } 95 virtual void OnClose(net::SocketStream* socket) { 96 events_.push_back( 97 SocketStreamEvent(SocketStreamEvent::EVENT_CLOSE, 98 socket, 0, std::string(), NULL)); 99 if (on_close_) 100 on_close_->Run(&events_.back()); 101 if (callback_) 102 callback_->Run(net::OK); 103 } 104 virtual void OnAuthRequired(net::SocketStream* socket, 105 net::AuthChallengeInfo* auth_info) { 106 events_.push_back( 107 SocketStreamEvent(SocketStreamEvent::EVENT_AUTH_REQUIRED, 108 socket, 0, std::string(), auth_info)); 109 if (on_auth_required_) 110 on_auth_required_->Run(&events_.back()); 111 } 112 113 void DoClose(SocketStreamEvent* event) { 114 event->socket->Close(); 115 } 116 void DoRestartWithAuth(SocketStreamEvent* event) { 117 LOG(INFO) << "RestartWithAuth username=" << username_ 118 << " password=" << password_; 119 event->socket->RestartWithAuth(username_, password_); 120 } 121 void SetAuthInfo(const std::wstring& username, 122 const std::wstring& password) { 123 username_ = username; 124 password_ = password; 125 } 126 127 const std::vector<SocketStreamEvent>& GetSeenEvents() const { 128 return events_; 129 } 130 131 private: 132 std::vector<SocketStreamEvent> events_; 133 Callback1<SocketStreamEvent*>::Type* on_connected_; 134 Callback1<SocketStreamEvent*>::Type* on_sent_data_; 135 Callback1<SocketStreamEvent*>::Type* on_received_data_; 136 Callback1<SocketStreamEvent*>::Type* on_close_; 137 Callback1<SocketStreamEvent*>::Type* on_auth_required_; 138 net::CompletionCallback* callback_; 139 140 std::wstring username_; 141 std::wstring password_; 142 143 DISALLOW_COPY_AND_ASSIGN(SocketStreamEventRecorder); 144}; 145 146namespace net { 147 148class SocketStreamTest : public PlatformTest { 149}; 150 151TEST_F(SocketStreamTest, BasicAuthProxy) { 152 MockClientSocketFactory mock_socket_factory; 153 MockWrite data_writes1[] = { 154 MockWrite("CONNECT example.com:80 HTTP/1.1\r\n" 155 "Host: example.com\r\n" 156 "Proxy-Connection: keep-alive\r\n\r\n"), 157 }; 158 MockRead data_reads1[] = { 159 MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"), 160 MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"), 161 MockRead("\r\n"), 162 }; 163 StaticSocketDataProvider data1(data_reads1, data_writes1); 164 mock_socket_factory.AddSocketDataProvider(&data1); 165 166 MockWrite data_writes2[] = { 167 MockWrite("CONNECT example.com:80 HTTP/1.1\r\n" 168 "Host: example.com\r\n" 169 "Proxy-Connection: keep-alive\r\n" 170 "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), 171 }; 172 MockRead data_reads2[] = { 173 MockRead("HTTP/1.1 200 Connection Established\r\n"), 174 MockRead("Proxy-agent: Apache/2.2.8\r\n"), 175 MockRead("\r\n"), 176 }; 177 StaticSocketDataProvider data2(data_reads2, data_writes2); 178 mock_socket_factory.AddSocketDataProvider(&data2); 179 180 TestCompletionCallback callback; 181 182 scoped_ptr<SocketStreamEventRecorder> delegate( 183 new SocketStreamEventRecorder(&callback)); 184 delegate->SetOnConnected(NewCallback(delegate.get(), 185 &SocketStreamEventRecorder::DoClose)); 186 const std::wstring kUsername = L"foo"; 187 const std::wstring kPassword = L"bar"; 188 delegate->SetAuthInfo(kUsername, kPassword); 189 delegate->SetOnAuthRequired( 190 NewCallback(delegate.get(), 191 &SocketStreamEventRecorder::DoRestartWithAuth)); 192 193 scoped_refptr<SocketStream> socket_stream = 194 new SocketStream(GURL("ws://example.com/demo"), delegate.get()); 195 196 socket_stream->set_context(new TestURLRequestContext("myproxy:70")); 197 socket_stream->SetHostResolver(new MockHostResolver()); 198 socket_stream->SetClientSocketFactory(&mock_socket_factory); 199 200 socket_stream->Connect(); 201 202 callback.WaitForResult(); 203 204 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 205 EXPECT_EQ(3U, events.size()); 206 207 EXPECT_EQ(SocketStreamEvent::EVENT_AUTH_REQUIRED, events[0].event_type); 208 EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type); 209 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[2].event_type); 210 211 // The first and last entries of the LoadLog should be for 212 // SOCKET_STREAM_CONNECT. 213 EXPECT_TRUE(LogContainsBeginEvent( 214 *socket_stream->load_log(), 0, LoadLog::TYPE_SOCKET_STREAM_CONNECT)); 215 EXPECT_TRUE(LogContainsEndEvent( 216 *socket_stream->load_log(), -1, LoadLog::TYPE_SOCKET_STREAM_CONNECT)); 217} 218 219} // namespace net 220