1// Copyright 2013 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "content/browser/renderer_host/websocket_host.h" 6 7#include "base/basictypes.h" 8#include "base/memory/weak_ptr.h" 9#include "base/strings/string_util.h" 10#include "content/browser/renderer_host/websocket_dispatcher_host.h" 11#include "content/browser/ssl/ssl_error_handler.h" 12#include "content/browser/ssl/ssl_manager.h" 13#include "content/common/websocket_messages.h" 14#include "ipc/ipc_message_macros.h" 15#include "net/http/http_request_headers.h" 16#include "net/http/http_response_headers.h" 17#include "net/http/http_util.h" 18#include "net/ssl/ssl_info.h" 19#include "net/websockets/websocket_channel.h" 20#include "net/websockets/websocket_errors.h" 21#include "net/websockets/websocket_event_interface.h" 22#include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode 23#include "net/websockets/websocket_handshake_request_info.h" 24#include "net/websockets/websocket_handshake_response_info.h" 25#include "url/origin.h" 26 27namespace content { 28 29namespace { 30 31typedef net::WebSocketEventInterface::ChannelState ChannelState; 32 33// Convert a content::WebSocketMessageType to a 34// net::WebSocketFrameHeader::OpCode 35net::WebSocketFrameHeader::OpCode MessageTypeToOpCode( 36 WebSocketMessageType type) { 37 DCHECK(type == WEB_SOCKET_MESSAGE_TYPE_CONTINUATION || 38 type == WEB_SOCKET_MESSAGE_TYPE_TEXT || 39 type == WEB_SOCKET_MESSAGE_TYPE_BINARY); 40 typedef net::WebSocketFrameHeader::OpCode OpCode; 41 // These compile asserts verify that the same underlying values are used for 42 // both types, so we can simply cast between them. 43 COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_CONTINUATION) == 44 net::WebSocketFrameHeader::kOpCodeContinuation, 45 enum_values_must_match_for_opcode_continuation); 46 COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_TEXT) == 47 net::WebSocketFrameHeader::kOpCodeText, 48 enum_values_must_match_for_opcode_text); 49 COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_BINARY) == 50 net::WebSocketFrameHeader::kOpCodeBinary, 51 enum_values_must_match_for_opcode_binary); 52 return static_cast<OpCode>(type); 53} 54 55WebSocketMessageType OpCodeToMessageType( 56 net::WebSocketFrameHeader::OpCode opCode) { 57 DCHECK(opCode == net::WebSocketFrameHeader::kOpCodeContinuation || 58 opCode == net::WebSocketFrameHeader::kOpCodeText || 59 opCode == net::WebSocketFrameHeader::kOpCodeBinary); 60 // This cast is guaranteed valid by the COMPILE_ASSERT() statements above. 61 return static_cast<WebSocketMessageType>(opCode); 62} 63 64ChannelState StateCast(WebSocketDispatcherHost::WebSocketHostState host_state) { 65 const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_ALIVE = 66 WebSocketDispatcherHost::WEBSOCKET_HOST_ALIVE; 67 const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_DELETED = 68 WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED; 69 70 DCHECK(host_state == WEBSOCKET_HOST_ALIVE || 71 host_state == WEBSOCKET_HOST_DELETED); 72 // These compile asserts verify that we can get away with using static_cast<> 73 // for the conversion. 74 COMPILE_ASSERT(static_cast<ChannelState>(WEBSOCKET_HOST_ALIVE) == 75 net::WebSocketEventInterface::CHANNEL_ALIVE, 76 enum_values_must_match_for_state_alive); 77 COMPILE_ASSERT(static_cast<ChannelState>(WEBSOCKET_HOST_DELETED) == 78 net::WebSocketEventInterface::CHANNEL_DELETED, 79 enum_values_must_match_for_state_deleted); 80 return static_cast<ChannelState>(host_state); 81} 82 83// Implementation of net::WebSocketEventInterface. Receives events from our 84// WebSocketChannel object. Each event is translated to an IPC and sent to the 85// renderer or child process via WebSocketDispatcherHost. 86class WebSocketEventHandler : public net::WebSocketEventInterface { 87 public: 88 WebSocketEventHandler(WebSocketDispatcherHost* dispatcher, 89 int routing_id, 90 int render_frame_id); 91 virtual ~WebSocketEventHandler(); 92 93 // net::WebSocketEventInterface implementation 94 95 virtual ChannelState OnAddChannelResponse( 96 bool fail, 97 const std::string& selected_subprotocol, 98 const std::string& extensions) OVERRIDE; 99 virtual ChannelState OnDataFrame(bool fin, 100 WebSocketMessageType type, 101 const std::vector<char>& data) OVERRIDE; 102 virtual ChannelState OnClosingHandshake() OVERRIDE; 103 virtual ChannelState OnFlowControl(int64 quota) OVERRIDE; 104 virtual ChannelState OnDropChannel(bool was_clean, 105 uint16 code, 106 const std::string& reason) OVERRIDE; 107 virtual ChannelState OnFailChannel(const std::string& message) OVERRIDE; 108 virtual ChannelState OnStartOpeningHandshake( 109 scoped_ptr<net::WebSocketHandshakeRequestInfo> request) OVERRIDE; 110 virtual ChannelState OnFinishOpeningHandshake( 111 scoped_ptr<net::WebSocketHandshakeResponseInfo> response) OVERRIDE; 112 virtual ChannelState OnSSLCertificateError( 113 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks, 114 const GURL& url, 115 const net::SSLInfo& ssl_info, 116 bool fatal) OVERRIDE; 117 118 private: 119 class SSLErrorHandlerDelegate : public SSLErrorHandler::Delegate { 120 public: 121 SSLErrorHandlerDelegate( 122 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks); 123 virtual ~SSLErrorHandlerDelegate(); 124 125 base::WeakPtr<SSLErrorHandler::Delegate> GetWeakPtr(); 126 127 // SSLErrorHandler::Delegate methods 128 virtual void CancelSSLRequest(const GlobalRequestID& id, 129 int error, 130 const net::SSLInfo* ssl_info) OVERRIDE; 131 virtual void ContinueSSLRequest(const GlobalRequestID& id) OVERRIDE; 132 133 private: 134 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks_; 135 base::WeakPtrFactory<SSLErrorHandlerDelegate> weak_ptr_factory_; 136 137 DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate); 138 }; 139 140 WebSocketDispatcherHost* const dispatcher_; 141 const int routing_id_; 142 const int render_frame_id_; 143 scoped_ptr<SSLErrorHandlerDelegate> ssl_error_handler_delegate_; 144 145 DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler); 146}; 147 148WebSocketEventHandler::WebSocketEventHandler( 149 WebSocketDispatcherHost* dispatcher, 150 int routing_id, 151 int render_frame_id) 152 : dispatcher_(dispatcher), 153 routing_id_(routing_id), 154 render_frame_id_(render_frame_id) { 155} 156 157WebSocketEventHandler::~WebSocketEventHandler() { 158 DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_; 159} 160 161ChannelState WebSocketEventHandler::OnAddChannelResponse( 162 bool fail, 163 const std::string& selected_protocol, 164 const std::string& extensions) { 165 DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse" 166 << " routing_id=" << routing_id_ << " fail=" << fail 167 << " selected_protocol=\"" << selected_protocol << "\"" 168 << " extensions=\"" << extensions << "\""; 169 170 return StateCast(dispatcher_->SendAddChannelResponse( 171 routing_id_, fail, selected_protocol, extensions)); 172} 173 174ChannelState WebSocketEventHandler::OnDataFrame( 175 bool fin, 176 net::WebSocketFrameHeader::OpCode type, 177 const std::vector<char>& data) { 178 DVLOG(3) << "WebSocketEventHandler::OnDataFrame" 179 << " routing_id=" << routing_id_ << " fin=" << fin 180 << " type=" << type << " data is " << data.size() << " bytes"; 181 182 return StateCast(dispatcher_->SendFrame( 183 routing_id_, fin, OpCodeToMessageType(type), data)); 184} 185 186ChannelState WebSocketEventHandler::OnClosingHandshake() { 187 DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake" 188 << " routing_id=" << routing_id_; 189 190 return StateCast(dispatcher_->NotifyClosingHandshake(routing_id_)); 191} 192 193ChannelState WebSocketEventHandler::OnFlowControl(int64 quota) { 194 DVLOG(3) << "WebSocketEventHandler::OnFlowControl" 195 << " routing_id=" << routing_id_ << " quota=" << quota; 196 197 return StateCast(dispatcher_->SendFlowControl(routing_id_, quota)); 198} 199 200ChannelState WebSocketEventHandler::OnDropChannel(bool was_clean, 201 uint16 code, 202 const std::string& reason) { 203 DVLOG(3) << "WebSocketEventHandler::OnDropChannel" 204 << " routing_id=" << routing_id_ << " was_clean=" << was_clean 205 << " code=" << code << " reason=\"" << reason << "\""; 206 207 return StateCast( 208 dispatcher_->DoDropChannel(routing_id_, was_clean, code, reason)); 209} 210 211ChannelState WebSocketEventHandler::OnFailChannel(const std::string& message) { 212 DVLOG(3) << "WebSocketEventHandler::OnFailChannel" 213 << " routing_id=" << routing_id_ 214 << " message=\"" << message << "\""; 215 216 return StateCast(dispatcher_->NotifyFailure(routing_id_, message)); 217} 218 219ChannelState WebSocketEventHandler::OnStartOpeningHandshake( 220 scoped_ptr<net::WebSocketHandshakeRequestInfo> request) { 221 bool should_send = dispatcher_->CanReadRawCookies(); 222 DVLOG(3) << "WebSocketEventHandler::OnStartOpeningHandshake " 223 << "should_send=" << should_send; 224 225 if (!should_send) 226 return WebSocketEventInterface::CHANNEL_ALIVE; 227 228 WebSocketHandshakeRequest request_to_pass; 229 request_to_pass.url.Swap(&request->url); 230 net::HttpRequestHeaders::Iterator it(request->headers); 231 while (it.GetNext()) 232 request_to_pass.headers.push_back(std::make_pair(it.name(), it.value())); 233 request_to_pass.headers_text = 234 base::StringPrintf("GET %s HTTP/1.1\r\n", 235 request_to_pass.url.spec().c_str()) + 236 request->headers.ToString(); 237 request_to_pass.request_time = request->request_time; 238 239 return StateCast(dispatcher_->NotifyStartOpeningHandshake(routing_id_, 240 request_to_pass)); 241} 242 243ChannelState WebSocketEventHandler::OnFinishOpeningHandshake( 244 scoped_ptr<net::WebSocketHandshakeResponseInfo> response) { 245 bool should_send = dispatcher_->CanReadRawCookies(); 246 DVLOG(3) << "WebSocketEventHandler::OnFinishOpeningHandshake " 247 << "should_send=" << should_send; 248 249 if (!should_send) 250 return WebSocketEventInterface::CHANNEL_ALIVE; 251 252 WebSocketHandshakeResponse response_to_pass; 253 response_to_pass.url.Swap(&response->url); 254 response_to_pass.status_code = response->status_code; 255 response_to_pass.status_text.swap(response->status_text); 256 void* iter = NULL; 257 std::string name, value; 258 while (response->headers->EnumerateHeaderLines(&iter, &name, &value)) 259 response_to_pass.headers.push_back(std::make_pair(name, value)); 260 response_to_pass.headers_text = 261 net::HttpUtil::ConvertHeadersBackToHTTPResponse( 262 response->headers->raw_headers()); 263 response_to_pass.response_time = response->response_time; 264 265 return StateCast(dispatcher_->NotifyFinishOpeningHandshake(routing_id_, 266 response_to_pass)); 267} 268 269ChannelState WebSocketEventHandler::OnSSLCertificateError( 270 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks, 271 const GURL& url, 272 const net::SSLInfo& ssl_info, 273 bool fatal) { 274 DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError" 275 << " routing_id=" << routing_id_ << " url=" << url.spec() 276 << " cert_status=" << ssl_info.cert_status << " fatal=" << fatal; 277 ssl_error_handler_delegate_.reset( 278 new SSLErrorHandlerDelegate(callbacks.Pass())); 279 // We don't need request_id to be unique so just make a fake one. 280 GlobalRequestID request_id(-1, -1); 281 SSLManager::OnSSLCertificateError(ssl_error_handler_delegate_->GetWeakPtr(), 282 request_id, 283 RESOURCE_TYPE_SUB_RESOURCE, 284 url, 285 dispatcher_->render_process_id(), 286 render_frame_id_, 287 ssl_info, 288 fatal); 289 // The above method is always asynchronous. 290 return WebSocketEventInterface::CHANNEL_ALIVE; 291} 292 293WebSocketEventHandler::SSLErrorHandlerDelegate::SSLErrorHandlerDelegate( 294 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks) 295 : callbacks_(callbacks.Pass()), weak_ptr_factory_(this) {} 296 297WebSocketEventHandler::SSLErrorHandlerDelegate::~SSLErrorHandlerDelegate() {} 298 299base::WeakPtr<SSLErrorHandler::Delegate> 300WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() { 301 return weak_ptr_factory_.GetWeakPtr(); 302} 303 304void WebSocketEventHandler::SSLErrorHandlerDelegate::CancelSSLRequest( 305 const GlobalRequestID& id, 306 int error, 307 const net::SSLInfo* ssl_info) { 308 DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest" 309 << " error=" << error 310 << " cert_status=" << (ssl_info ? ssl_info->cert_status 311 : static_cast<net::CertStatus>(-1)); 312 callbacks_->CancelSSLRequest(error, ssl_info); 313} 314 315void WebSocketEventHandler::SSLErrorHandlerDelegate::ContinueSSLRequest( 316 const GlobalRequestID& id) { 317 DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest"; 318 callbacks_->ContinueSSLRequest(); 319} 320 321} // namespace 322 323WebSocketHost::WebSocketHost(int routing_id, 324 WebSocketDispatcherHost* dispatcher, 325 net::URLRequestContext* url_request_context) 326 : dispatcher_(dispatcher), 327 url_request_context_(url_request_context), 328 routing_id_(routing_id) { 329 DVLOG(1) << "WebSocketHost: created routing_id=" << routing_id; 330} 331 332WebSocketHost::~WebSocketHost() {} 333 334void WebSocketHost::GoAway() { 335 OnDropChannel(false, static_cast<uint16>(net::kWebSocketErrorGoingAway), ""); 336} 337 338bool WebSocketHost::OnMessageReceived(const IPC::Message& message) { 339 bool handled = true; 340 IPC_BEGIN_MESSAGE_MAP(WebSocketHost, message) 341 IPC_MESSAGE_HANDLER(WebSocketHostMsg_AddChannelRequest, OnAddChannelRequest) 342 IPC_MESSAGE_HANDLER(WebSocketMsg_SendFrame, OnSendFrame) 343 IPC_MESSAGE_HANDLER(WebSocketMsg_FlowControl, OnFlowControl) 344 IPC_MESSAGE_HANDLER(WebSocketMsg_DropChannel, OnDropChannel) 345 IPC_MESSAGE_UNHANDLED(handled = false) 346 IPC_END_MESSAGE_MAP() 347 return handled; 348} 349 350void WebSocketHost::OnAddChannelRequest( 351 const GURL& socket_url, 352 const std::vector<std::string>& requested_protocols, 353 const url::Origin& origin, 354 int render_frame_id) { 355 DVLOG(3) << "WebSocketHost::OnAddChannelRequest" 356 << " routing_id=" << routing_id_ << " socket_url=\"" << socket_url 357 << "\" requested_protocols=\"" 358 << JoinString(requested_protocols, ", ") << "\" origin=\"" 359 << origin.string() << "\""; 360 361 DCHECK(!channel_); 362 scoped_ptr<net::WebSocketEventInterface> event_interface( 363 new WebSocketEventHandler(dispatcher_, routing_id_, render_frame_id)); 364 channel_.reset( 365 new net::WebSocketChannel(event_interface.Pass(), url_request_context_)); 366 channel_->SendAddChannelRequest(socket_url, requested_protocols, origin); 367} 368 369void WebSocketHost::OnSendFrame(bool fin, 370 WebSocketMessageType type, 371 const std::vector<char>& data) { 372 DVLOG(3) << "WebSocketHost::OnSendFrame" 373 << " routing_id=" << routing_id_ << " fin=" << fin 374 << " type=" << type << " data is " << data.size() << " bytes"; 375 376 DCHECK(channel_); 377 channel_->SendFrame(fin, MessageTypeToOpCode(type), data); 378} 379 380void WebSocketHost::OnFlowControl(int64 quota) { 381 DVLOG(3) << "WebSocketHost::OnFlowControl" 382 << " routing_id=" << routing_id_ << " quota=" << quota; 383 384 DCHECK(channel_); 385 channel_->SendFlowControl(quota); 386} 387 388void WebSocketHost::OnDropChannel(bool was_clean, 389 uint16 code, 390 const std::string& reason) { 391 DVLOG(3) << "WebSocketHost::OnDropChannel" 392 << " routing_id=" << routing_id_ << " was_clean=" << was_clean 393 << " code=" << code << " reason=\"" << reason << "\""; 394 395 DCHECK(channel_); 396 // TODO(yhirano): Handle |was_clean| appropriately. 397 channel_->StartClosingHandshake(code, reason); 398} 399 400} // namespace content 401