1// Copyright 2014 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "mojo/services/network/web_socket_impl.h" 6 7#include "base/logging.h" 8#include "base/message_loop/message_loop.h" 9#include "mojo/common/handle_watcher.h" 10#include "mojo/services/network/network_context.h" 11#include "mojo/services/public/cpp/network/web_socket_read_queue.h" 12#include "mojo/services/public/cpp/network/web_socket_write_queue.h" 13#include "net/websockets/websocket_channel.h" 14#include "net/websockets/websocket_errors.h" 15#include "net/websockets/websocket_event_interface.h" 16#include "net/websockets/websocket_frame.h" // for WebSocketFrameHeader::OpCode 17#include "net/websockets/websocket_handshake_request_info.h" 18#include "net/websockets/websocket_handshake_response_info.h" 19#include "url/origin.h" 20 21namespace mojo { 22 23template <> 24struct TypeConverter<net::WebSocketFrameHeader::OpCode, 25 WebSocket::MessageType> { 26 static net::WebSocketFrameHeader::OpCode Convert( 27 WebSocket::MessageType type) { 28 DCHECK(type == WebSocket::MESSAGE_TYPE_CONTINUATION || 29 type == WebSocket::MESSAGE_TYPE_TEXT || 30 type == WebSocket::MESSAGE_TYPE_BINARY); 31 typedef net::WebSocketFrameHeader::OpCode OpCode; 32 // These compile asserts verify that the same underlying values are used for 33 // both types, so we can simply cast between them. 34 COMPILE_ASSERT(static_cast<OpCode>(WebSocket::MESSAGE_TYPE_CONTINUATION) == 35 net::WebSocketFrameHeader::kOpCodeContinuation, 36 enum_values_must_match_for_opcode_continuation); 37 COMPILE_ASSERT(static_cast<OpCode>(WebSocket::MESSAGE_TYPE_TEXT) == 38 net::WebSocketFrameHeader::kOpCodeText, 39 enum_values_must_match_for_opcode_text); 40 COMPILE_ASSERT(static_cast<OpCode>(WebSocket::MESSAGE_TYPE_BINARY) == 41 net::WebSocketFrameHeader::kOpCodeBinary, 42 enum_values_must_match_for_opcode_binary); 43 return static_cast<OpCode>(type); 44 } 45}; 46 47template <> 48struct TypeConverter<WebSocket::MessageType, 49 net::WebSocketFrameHeader::OpCode> { 50 static WebSocket::MessageType Convert( 51 net::WebSocketFrameHeader::OpCode type) { 52 DCHECK(type == net::WebSocketFrameHeader::kOpCodeContinuation || 53 type == net::WebSocketFrameHeader::kOpCodeText || 54 type == net::WebSocketFrameHeader::kOpCodeBinary); 55 return static_cast<WebSocket::MessageType>(type); 56 } 57}; 58 59namespace { 60 61typedef net::WebSocketEventInterface::ChannelState ChannelState; 62 63struct WebSocketEventHandler : public net::WebSocketEventInterface { 64 public: 65 WebSocketEventHandler(WebSocketClientPtr client) 66 : client_(client.Pass()) { 67 } 68 virtual ~WebSocketEventHandler() {} 69 70 private: 71 // net::WebSocketEventInterface methods: 72 virtual ChannelState OnAddChannelResponse( 73 bool fail, 74 const std::string& selected_subprotocol, 75 const std::string& extensions) OVERRIDE; 76 virtual ChannelState OnDataFrame(bool fin, 77 WebSocketMessageType type, 78 const std::vector<char>& data) OVERRIDE; 79 virtual ChannelState OnClosingHandshake() OVERRIDE; 80 virtual ChannelState OnFlowControl(int64 quota) OVERRIDE; 81 virtual ChannelState OnDropChannel(bool was_clean, 82 uint16 code, 83 const std::string& reason) OVERRIDE; 84 virtual ChannelState OnFailChannel(const std::string& message) OVERRIDE; 85 virtual ChannelState OnStartOpeningHandshake( 86 scoped_ptr<net::WebSocketHandshakeRequestInfo> request) OVERRIDE; 87 virtual ChannelState OnFinishOpeningHandshake( 88 scoped_ptr<net::WebSocketHandshakeResponseInfo> response) OVERRIDE; 89 virtual ChannelState OnSSLCertificateError( 90 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks, 91 const GURL& url, 92 const net::SSLInfo& ssl_info, 93 bool fatal) OVERRIDE; 94 95 // Called once we've written to |receive_stream_|. 96 void DidWriteToReceiveStream(bool fin, 97 net::WebSocketFrameHeader::OpCode type, 98 uint32_t num_bytes, 99 const char* buffer); 100 WebSocketClientPtr client_; 101 ScopedDataPipeProducerHandle receive_stream_; 102 scoped_ptr<WebSocketWriteQueue> write_queue_; 103 104 DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler); 105}; 106 107ChannelState WebSocketEventHandler::OnAddChannelResponse( 108 bool fail, 109 const std::string& selected_protocol, 110 const std::string& extensions) { 111 DataPipe data_pipe; 112 receive_stream_ = data_pipe.producer_handle.Pass(); 113 write_queue_.reset(new WebSocketWriteQueue(receive_stream_.get())); 114 client_->DidConnect( 115 fail, selected_protocol, extensions, data_pipe.consumer_handle.Pass()); 116 if (fail) 117 return WebSocketEventInterface::CHANNEL_DELETED; 118 return WebSocketEventInterface::CHANNEL_ALIVE; 119} 120 121ChannelState WebSocketEventHandler::OnDataFrame( 122 bool fin, 123 net::WebSocketFrameHeader::OpCode type, 124 const std::vector<char>& data) { 125 uint32_t size = static_cast<uint32_t>(data.size()); 126 write_queue_->Write( 127 &data[0], size, 128 base::Bind(&WebSocketEventHandler::DidWriteToReceiveStream, 129 base::Unretained(this), 130 fin, type, size)); 131 return WebSocketEventInterface::CHANNEL_ALIVE; 132} 133 134ChannelState WebSocketEventHandler::OnClosingHandshake() { 135 return WebSocketEventInterface::CHANNEL_ALIVE; 136} 137 138ChannelState WebSocketEventHandler::OnFlowControl(int64 quota) { 139 client_->DidReceiveFlowControl(quota); 140 return WebSocketEventInterface::CHANNEL_ALIVE; 141} 142 143ChannelState WebSocketEventHandler::OnDropChannel(bool was_clean, 144 uint16 code, 145 const std::string& reason) { 146 client_->DidClose(was_clean, code, reason); 147 return WebSocketEventInterface::CHANNEL_DELETED; 148} 149 150ChannelState WebSocketEventHandler::OnFailChannel(const std::string& message) { 151 client_->DidFail(message); 152 return WebSocketEventInterface::CHANNEL_DELETED; 153} 154 155ChannelState WebSocketEventHandler::OnStartOpeningHandshake( 156 scoped_ptr<net::WebSocketHandshakeRequestInfo> request) { 157 return WebSocketEventInterface::CHANNEL_ALIVE; 158} 159 160ChannelState WebSocketEventHandler::OnFinishOpeningHandshake( 161 scoped_ptr<net::WebSocketHandshakeResponseInfo> response) { 162 return WebSocketEventInterface::CHANNEL_ALIVE; 163} 164 165ChannelState WebSocketEventHandler::OnSSLCertificateError( 166 scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks, 167 const GURL& url, 168 const net::SSLInfo& ssl_info, 169 bool fatal) { 170 client_->DidFail("SSL Error"); 171 return WebSocketEventInterface::CHANNEL_DELETED; 172} 173 174void WebSocketEventHandler::DidWriteToReceiveStream( 175 bool fin, 176 net::WebSocketFrameHeader::OpCode type, 177 uint32_t num_bytes, 178 const char* buffer) { 179 client_->DidReceiveData( 180 fin, ConvertTo<WebSocket::MessageType>(type), num_bytes); 181} 182 183} // namespace mojo 184 185WebSocketImpl::WebSocketImpl(NetworkContext* context) : context_(context) { 186} 187 188WebSocketImpl::~WebSocketImpl() { 189} 190 191void WebSocketImpl::Connect(const String& url, 192 Array<String> protocols, 193 const String& origin, 194 ScopedDataPipeConsumerHandle send_stream, 195 WebSocketClientPtr client) { 196 DCHECK(!channel_); 197 send_stream_ = send_stream.Pass(); 198 read_queue_.reset(new WebSocketReadQueue(send_stream_.get())); 199 scoped_ptr<net::WebSocketEventInterface> event_interface( 200 new WebSocketEventHandler(client.Pass())); 201 channel_.reset(new net::WebSocketChannel(event_interface.Pass(), 202 context_->url_request_context())); 203 channel_->SendAddChannelRequest(GURL(url.get()), 204 protocols.To<std::vector<std::string> >(), 205 url::Origin(origin.get())); 206} 207 208void WebSocketImpl::Send(bool fin, 209 WebSocket::MessageType type, 210 uint32_t num_bytes) { 211 DCHECK(channel_); 212 read_queue_->Read(num_bytes, 213 base::Bind(&WebSocketImpl::DidReadFromSendStream, 214 base::Unretained(this), 215 fin, type, num_bytes)); 216} 217 218void WebSocketImpl::FlowControl(int64_t quota) { 219 DCHECK(channel_); 220 channel_->SendFlowControl(quota); 221} 222 223void WebSocketImpl::Close(uint16_t code, const String& reason) { 224 DCHECK(channel_); 225 channel_->StartClosingHandshake(code, reason); 226} 227 228void WebSocketImpl::DidReadFromSendStream(bool fin, 229 WebSocket::MessageType type, 230 uint32_t num_bytes, 231 const char* data) { 232 std::vector<char> buffer(num_bytes); 233 memcpy(&buffer[0], data, num_bytes); 234 DCHECK(channel_); 235 channel_->SendFrame( 236 fin, ConvertTo<net::WebSocketFrameHeader::OpCode>(type), buffer); 237} 238 239} // namespace mojo 240