1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "mojo/services/network/web_socket_impl.h"
6
7#include "base/logging.h"
8#include "base/message_loop/message_loop.h"
9#include "mojo/common/handle_watcher.h"
10#include "mojo/services/network/network_context.h"
11#include "mojo/services/public/cpp/network/web_socket_read_queue.h"
12#include "mojo/services/public/cpp/network/web_socket_write_queue.h"
13#include "net/websockets/websocket_channel.h"
14#include "net/websockets/websocket_errors.h"
15#include "net/websockets/websocket_event_interface.h"
16#include "net/websockets/websocket_frame.h"  // for WebSocketFrameHeader::OpCode
17#include "net/websockets/websocket_handshake_request_info.h"
18#include "net/websockets/websocket_handshake_response_info.h"
19#include "url/origin.h"
20
21namespace mojo {
22
23template <>
24struct TypeConverter<net::WebSocketFrameHeader::OpCode,
25                     WebSocket::MessageType> {
26  static net::WebSocketFrameHeader::OpCode Convert(
27      WebSocket::MessageType type) {
28    DCHECK(type == WebSocket::MESSAGE_TYPE_CONTINUATION ||
29           type == WebSocket::MESSAGE_TYPE_TEXT ||
30           type == WebSocket::MESSAGE_TYPE_BINARY);
31    typedef net::WebSocketFrameHeader::OpCode OpCode;
32    // These compile asserts verify that the same underlying values are used for
33    // both types, so we can simply cast between them.
34    COMPILE_ASSERT(static_cast<OpCode>(WebSocket::MESSAGE_TYPE_CONTINUATION) ==
35                       net::WebSocketFrameHeader::kOpCodeContinuation,
36                   enum_values_must_match_for_opcode_continuation);
37    COMPILE_ASSERT(static_cast<OpCode>(WebSocket::MESSAGE_TYPE_TEXT) ==
38                       net::WebSocketFrameHeader::kOpCodeText,
39                   enum_values_must_match_for_opcode_text);
40    COMPILE_ASSERT(static_cast<OpCode>(WebSocket::MESSAGE_TYPE_BINARY) ==
41                       net::WebSocketFrameHeader::kOpCodeBinary,
42                   enum_values_must_match_for_opcode_binary);
43    return static_cast<OpCode>(type);
44  }
45};
46
47template <>
48struct TypeConverter<WebSocket::MessageType,
49                     net::WebSocketFrameHeader::OpCode> {
50  static WebSocket::MessageType Convert(
51      net::WebSocketFrameHeader::OpCode type) {
52    DCHECK(type == net::WebSocketFrameHeader::kOpCodeContinuation ||
53           type == net::WebSocketFrameHeader::kOpCodeText ||
54           type == net::WebSocketFrameHeader::kOpCodeBinary);
55    return static_cast<WebSocket::MessageType>(type);
56  }
57};
58
59namespace {
60
61typedef net::WebSocketEventInterface::ChannelState ChannelState;
62
63struct WebSocketEventHandler : public net::WebSocketEventInterface {
64 public:
65  WebSocketEventHandler(WebSocketClientPtr client)
66      : client_(client.Pass()) {
67  }
68  virtual ~WebSocketEventHandler() {}
69
70 private:
71  // net::WebSocketEventInterface methods:
72  virtual ChannelState OnAddChannelResponse(
73      bool fail,
74      const std::string& selected_subprotocol,
75      const std::string& extensions) OVERRIDE;
76  virtual ChannelState OnDataFrame(bool fin,
77                                   WebSocketMessageType type,
78                                   const std::vector<char>& data) OVERRIDE;
79  virtual ChannelState OnClosingHandshake() OVERRIDE;
80  virtual ChannelState OnFlowControl(int64 quota) OVERRIDE;
81  virtual ChannelState OnDropChannel(bool was_clean,
82                                     uint16 code,
83                                     const std::string& reason) OVERRIDE;
84  virtual ChannelState OnFailChannel(const std::string& message) OVERRIDE;
85  virtual ChannelState OnStartOpeningHandshake(
86      scoped_ptr<net::WebSocketHandshakeRequestInfo> request) OVERRIDE;
87  virtual ChannelState OnFinishOpeningHandshake(
88      scoped_ptr<net::WebSocketHandshakeResponseInfo> response) OVERRIDE;
89  virtual ChannelState OnSSLCertificateError(
90      scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
91      const GURL& url,
92      const net::SSLInfo& ssl_info,
93      bool fatal) OVERRIDE;
94
95  // Called once we've written to |receive_stream_|.
96  void DidWriteToReceiveStream(bool fin,
97                               net::WebSocketFrameHeader::OpCode type,
98                               uint32_t num_bytes,
99                               const char* buffer);
100  WebSocketClientPtr client_;
101  ScopedDataPipeProducerHandle receive_stream_;
102  scoped_ptr<WebSocketWriteQueue> write_queue_;
103
104  DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler);
105};
106
107ChannelState WebSocketEventHandler::OnAddChannelResponse(
108    bool fail,
109    const std::string& selected_protocol,
110    const std::string& extensions) {
111  DataPipe data_pipe;
112  receive_stream_ = data_pipe.producer_handle.Pass();
113  write_queue_.reset(new WebSocketWriteQueue(receive_stream_.get()));
114  client_->DidConnect(
115      fail, selected_protocol, extensions, data_pipe.consumer_handle.Pass());
116  if (fail)
117    return WebSocketEventInterface::CHANNEL_DELETED;
118  return WebSocketEventInterface::CHANNEL_ALIVE;
119}
120
121ChannelState WebSocketEventHandler::OnDataFrame(
122    bool fin,
123    net::WebSocketFrameHeader::OpCode type,
124    const std::vector<char>& data) {
125  uint32_t size = static_cast<uint32_t>(data.size());
126  write_queue_->Write(
127      &data[0], size,
128      base::Bind(&WebSocketEventHandler::DidWriteToReceiveStream,
129                 base::Unretained(this),
130                 fin, type, size));
131  return WebSocketEventInterface::CHANNEL_ALIVE;
132}
133
134ChannelState WebSocketEventHandler::OnClosingHandshake() {
135  return WebSocketEventInterface::CHANNEL_ALIVE;
136}
137
138ChannelState WebSocketEventHandler::OnFlowControl(int64 quota) {
139  client_->DidReceiveFlowControl(quota);
140  return WebSocketEventInterface::CHANNEL_ALIVE;
141}
142
143ChannelState WebSocketEventHandler::OnDropChannel(bool was_clean,
144                                                  uint16 code,
145                                                  const std::string& reason) {
146  client_->DidClose(was_clean, code, reason);
147  return WebSocketEventInterface::CHANNEL_DELETED;
148}
149
150ChannelState WebSocketEventHandler::OnFailChannel(const std::string& message) {
151  client_->DidFail(message);
152  return WebSocketEventInterface::CHANNEL_DELETED;
153}
154
155ChannelState WebSocketEventHandler::OnStartOpeningHandshake(
156    scoped_ptr<net::WebSocketHandshakeRequestInfo> request) {
157  return WebSocketEventInterface::CHANNEL_ALIVE;
158}
159
160ChannelState WebSocketEventHandler::OnFinishOpeningHandshake(
161    scoped_ptr<net::WebSocketHandshakeResponseInfo> response) {
162  return WebSocketEventInterface::CHANNEL_ALIVE;
163}
164
165ChannelState WebSocketEventHandler::OnSSLCertificateError(
166    scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
167    const GURL& url,
168    const net::SSLInfo& ssl_info,
169    bool fatal) {
170  client_->DidFail("SSL Error");
171  return WebSocketEventInterface::CHANNEL_DELETED;
172}
173
174void WebSocketEventHandler::DidWriteToReceiveStream(
175    bool fin,
176    net::WebSocketFrameHeader::OpCode type,
177    uint32_t num_bytes,
178    const char* buffer) {
179  client_->DidReceiveData(
180      fin, ConvertTo<WebSocket::MessageType>(type), num_bytes);
181}
182
183}  // namespace mojo
184
185WebSocketImpl::WebSocketImpl(NetworkContext* context) : context_(context) {
186}
187
188WebSocketImpl::~WebSocketImpl() {
189}
190
191void WebSocketImpl::Connect(const String& url,
192                            Array<String> protocols,
193                            const String& origin,
194                            ScopedDataPipeConsumerHandle send_stream,
195                            WebSocketClientPtr client) {
196  DCHECK(!channel_);
197  send_stream_ = send_stream.Pass();
198  read_queue_.reset(new WebSocketReadQueue(send_stream_.get()));
199  scoped_ptr<net::WebSocketEventInterface> event_interface(
200      new WebSocketEventHandler(client.Pass()));
201  channel_.reset(new net::WebSocketChannel(event_interface.Pass(),
202                                           context_->url_request_context()));
203  channel_->SendAddChannelRequest(GURL(url.get()),
204                                  protocols.To<std::vector<std::string> >(),
205                                  url::Origin(origin.get()));
206}
207
208void WebSocketImpl::Send(bool fin,
209                         WebSocket::MessageType type,
210                         uint32_t num_bytes) {
211  DCHECK(channel_);
212  read_queue_->Read(num_bytes,
213                    base::Bind(&WebSocketImpl::DidReadFromSendStream,
214                               base::Unretained(this),
215                               fin, type, num_bytes));
216}
217
218void WebSocketImpl::FlowControl(int64_t quota) {
219  DCHECK(channel_);
220  channel_->SendFlowControl(quota);
221}
222
223void WebSocketImpl::Close(uint16_t code, const String& reason) {
224  DCHECK(channel_);
225  channel_->StartClosingHandshake(code, reason);
226}
227
228void WebSocketImpl::DidReadFromSendStream(bool fin,
229                                          WebSocket::MessageType type,
230                                          uint32_t num_bytes,
231                                          const char* data) {
232  std::vector<char> buffer(num_bytes);
233  memcpy(&buffer[0], data, num_bytes);
234  DCHECK(channel_);
235  channel_->SendFrame(
236      fin, ConvertTo<net::WebSocketFrameHeader::OpCode>(type), buffer);
237}
238
239}  // namespace mojo
240