1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "content/browser/renderer_host/websocket_host.h"
6
7#include "base/basictypes.h"
8#include "base/memory/weak_ptr.h"
9#include "base/strings/string_util.h"
10#include "content/browser/renderer_host/websocket_dispatcher_host.h"
11#include "content/browser/ssl/ssl_error_handler.h"
12#include "content/browser/ssl/ssl_manager.h"
13#include "content/common/websocket_messages.h"
14#include "ipc/ipc_message_macros.h"
15#include "net/http/http_request_headers.h"
16#include "net/http/http_response_headers.h"
17#include "net/http/http_util.h"
18#include "net/ssl/ssl_info.h"
19#include "net/websockets/websocket_channel.h"
20#include "net/websockets/websocket_errors.h"
21#include "net/websockets/websocket_event_interface.h"
22#include "net/websockets/websocket_frame.h"  // for WebSocketFrameHeader::OpCode
23#include "net/websockets/websocket_handshake_request_info.h"
24#include "net/websockets/websocket_handshake_response_info.h"
25#include "url/origin.h"
26
27namespace content {
28
29namespace {
30
31typedef net::WebSocketEventInterface::ChannelState ChannelState;
32
33// Convert a content::WebSocketMessageType to a
34// net::WebSocketFrameHeader::OpCode
35net::WebSocketFrameHeader::OpCode MessageTypeToOpCode(
36    WebSocketMessageType type) {
37  DCHECK(type == WEB_SOCKET_MESSAGE_TYPE_CONTINUATION ||
38         type == WEB_SOCKET_MESSAGE_TYPE_TEXT ||
39         type == WEB_SOCKET_MESSAGE_TYPE_BINARY);
40  typedef net::WebSocketFrameHeader::OpCode OpCode;
41  // These compile asserts verify that the same underlying values are used for
42  // both types, so we can simply cast between them.
43  COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_CONTINUATION) ==
44                     net::WebSocketFrameHeader::kOpCodeContinuation,
45                 enum_values_must_match_for_opcode_continuation);
46  COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_TEXT) ==
47                     net::WebSocketFrameHeader::kOpCodeText,
48                 enum_values_must_match_for_opcode_text);
49  COMPILE_ASSERT(static_cast<OpCode>(WEB_SOCKET_MESSAGE_TYPE_BINARY) ==
50                     net::WebSocketFrameHeader::kOpCodeBinary,
51                 enum_values_must_match_for_opcode_binary);
52  return static_cast<OpCode>(type);
53}
54
55WebSocketMessageType OpCodeToMessageType(
56    net::WebSocketFrameHeader::OpCode opCode) {
57  DCHECK(opCode == net::WebSocketFrameHeader::kOpCodeContinuation ||
58         opCode == net::WebSocketFrameHeader::kOpCodeText ||
59         opCode == net::WebSocketFrameHeader::kOpCodeBinary);
60  // This cast is guaranteed valid by the COMPILE_ASSERT() statements above.
61  return static_cast<WebSocketMessageType>(opCode);
62}
63
64ChannelState StateCast(WebSocketDispatcherHost::WebSocketHostState host_state) {
65  const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_ALIVE =
66      WebSocketDispatcherHost::WEBSOCKET_HOST_ALIVE;
67  const WebSocketDispatcherHost::WebSocketHostState WEBSOCKET_HOST_DELETED =
68      WebSocketDispatcherHost::WEBSOCKET_HOST_DELETED;
69
70  DCHECK(host_state == WEBSOCKET_HOST_ALIVE ||
71         host_state == WEBSOCKET_HOST_DELETED);
72  // These compile asserts verify that we can get away with using static_cast<>
73  // for the conversion.
74  COMPILE_ASSERT(static_cast<ChannelState>(WEBSOCKET_HOST_ALIVE) ==
75                     net::WebSocketEventInterface::CHANNEL_ALIVE,
76                 enum_values_must_match_for_state_alive);
77  COMPILE_ASSERT(static_cast<ChannelState>(WEBSOCKET_HOST_DELETED) ==
78                     net::WebSocketEventInterface::CHANNEL_DELETED,
79                 enum_values_must_match_for_state_deleted);
80  return static_cast<ChannelState>(host_state);
81}
82
83// Implementation of net::WebSocketEventInterface. Receives events from our
84// WebSocketChannel object. Each event is translated to an IPC and sent to the
85// renderer or child process via WebSocketDispatcherHost.
86class WebSocketEventHandler : public net::WebSocketEventInterface {
87 public:
88  WebSocketEventHandler(WebSocketDispatcherHost* dispatcher,
89                        int routing_id,
90                        int render_frame_id);
91  virtual ~WebSocketEventHandler();
92
93  // net::WebSocketEventInterface implementation
94
95  virtual ChannelState OnAddChannelResponse(
96      bool fail,
97      const std::string& selected_subprotocol,
98      const std::string& extensions) OVERRIDE;
99  virtual ChannelState OnDataFrame(bool fin,
100                                   WebSocketMessageType type,
101                                   const std::vector<char>& data) OVERRIDE;
102  virtual ChannelState OnClosingHandshake() OVERRIDE;
103  virtual ChannelState OnFlowControl(int64 quota) OVERRIDE;
104  virtual ChannelState OnDropChannel(bool was_clean,
105                                     uint16 code,
106                                     const std::string& reason) OVERRIDE;
107  virtual ChannelState OnFailChannel(const std::string& message) OVERRIDE;
108  virtual ChannelState OnStartOpeningHandshake(
109      scoped_ptr<net::WebSocketHandshakeRequestInfo> request) OVERRIDE;
110  virtual ChannelState OnFinishOpeningHandshake(
111      scoped_ptr<net::WebSocketHandshakeResponseInfo> response) OVERRIDE;
112  virtual ChannelState OnSSLCertificateError(
113      scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
114      const GURL& url,
115      const net::SSLInfo& ssl_info,
116      bool fatal) OVERRIDE;
117
118 private:
119  class SSLErrorHandlerDelegate : public SSLErrorHandler::Delegate {
120   public:
121    SSLErrorHandlerDelegate(
122        scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks);
123    virtual ~SSLErrorHandlerDelegate();
124
125    base::WeakPtr<SSLErrorHandler::Delegate> GetWeakPtr();
126
127    // SSLErrorHandler::Delegate methods
128    virtual void CancelSSLRequest(const GlobalRequestID& id,
129                                  int error,
130                                  const net::SSLInfo* ssl_info) OVERRIDE;
131    virtual void ContinueSSLRequest(const GlobalRequestID& id) OVERRIDE;
132
133   private:
134    scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks_;
135    base::WeakPtrFactory<SSLErrorHandlerDelegate> weak_ptr_factory_;
136
137    DISALLOW_COPY_AND_ASSIGN(SSLErrorHandlerDelegate);
138  };
139
140  WebSocketDispatcherHost* const dispatcher_;
141  const int routing_id_;
142  const int render_frame_id_;
143  scoped_ptr<SSLErrorHandlerDelegate> ssl_error_handler_delegate_;
144
145  DISALLOW_COPY_AND_ASSIGN(WebSocketEventHandler);
146};
147
148WebSocketEventHandler::WebSocketEventHandler(
149    WebSocketDispatcherHost* dispatcher,
150    int routing_id,
151    int render_frame_id)
152    : dispatcher_(dispatcher),
153      routing_id_(routing_id),
154      render_frame_id_(render_frame_id) {
155}
156
157WebSocketEventHandler::~WebSocketEventHandler() {
158  DVLOG(1) << "WebSocketEventHandler destroyed routing_id=" << routing_id_;
159}
160
161ChannelState WebSocketEventHandler::OnAddChannelResponse(
162    bool fail,
163    const std::string& selected_protocol,
164    const std::string& extensions) {
165  DVLOG(3) << "WebSocketEventHandler::OnAddChannelResponse"
166           << " routing_id=" << routing_id_ << " fail=" << fail
167           << " selected_protocol=\"" << selected_protocol << "\""
168           << " extensions=\"" << extensions << "\"";
169
170  return StateCast(dispatcher_->SendAddChannelResponse(
171      routing_id_, fail, selected_protocol, extensions));
172}
173
174ChannelState WebSocketEventHandler::OnDataFrame(
175    bool fin,
176    net::WebSocketFrameHeader::OpCode type,
177    const std::vector<char>& data) {
178  DVLOG(3) << "WebSocketEventHandler::OnDataFrame"
179           << " routing_id=" << routing_id_ << " fin=" << fin
180           << " type=" << type << " data is " << data.size() << " bytes";
181
182  return StateCast(dispatcher_->SendFrame(
183      routing_id_, fin, OpCodeToMessageType(type), data));
184}
185
186ChannelState WebSocketEventHandler::OnClosingHandshake() {
187  DVLOG(3) << "WebSocketEventHandler::OnClosingHandshake"
188           << " routing_id=" << routing_id_;
189
190  return StateCast(dispatcher_->NotifyClosingHandshake(routing_id_));
191}
192
193ChannelState WebSocketEventHandler::OnFlowControl(int64 quota) {
194  DVLOG(3) << "WebSocketEventHandler::OnFlowControl"
195           << " routing_id=" << routing_id_ << " quota=" << quota;
196
197  return StateCast(dispatcher_->SendFlowControl(routing_id_, quota));
198}
199
200ChannelState WebSocketEventHandler::OnDropChannel(bool was_clean,
201                                                  uint16 code,
202                                                  const std::string& reason) {
203  DVLOG(3) << "WebSocketEventHandler::OnDropChannel"
204           << " routing_id=" << routing_id_ << " was_clean=" << was_clean
205           << " code=" << code << " reason=\"" << reason << "\"";
206
207  return StateCast(
208      dispatcher_->DoDropChannel(routing_id_, was_clean, code, reason));
209}
210
211ChannelState WebSocketEventHandler::OnFailChannel(const std::string& message) {
212  DVLOG(3) << "WebSocketEventHandler::OnFailChannel"
213           << " routing_id=" << routing_id_
214           << " message=\"" << message << "\"";
215
216  return StateCast(dispatcher_->NotifyFailure(routing_id_, message));
217}
218
219ChannelState WebSocketEventHandler::OnStartOpeningHandshake(
220    scoped_ptr<net::WebSocketHandshakeRequestInfo> request) {
221  bool should_send = dispatcher_->CanReadRawCookies();
222  DVLOG(3) << "WebSocketEventHandler::OnStartOpeningHandshake "
223           << "should_send=" << should_send;
224
225  if (!should_send)
226    return WebSocketEventInterface::CHANNEL_ALIVE;
227
228  WebSocketHandshakeRequest request_to_pass;
229  request_to_pass.url.Swap(&request->url);
230  net::HttpRequestHeaders::Iterator it(request->headers);
231  while (it.GetNext())
232    request_to_pass.headers.push_back(std::make_pair(it.name(), it.value()));
233  request_to_pass.headers_text =
234      base::StringPrintf("GET %s HTTP/1.1\r\n",
235                         request_to_pass.url.spec().c_str()) +
236      request->headers.ToString();
237  request_to_pass.request_time = request->request_time;
238
239  return StateCast(dispatcher_->NotifyStartOpeningHandshake(routing_id_,
240                                                            request_to_pass));
241}
242
243ChannelState WebSocketEventHandler::OnFinishOpeningHandshake(
244    scoped_ptr<net::WebSocketHandshakeResponseInfo> response) {
245  bool should_send = dispatcher_->CanReadRawCookies();
246  DVLOG(3) << "WebSocketEventHandler::OnFinishOpeningHandshake "
247           << "should_send=" << should_send;
248
249  if (!should_send)
250    return WebSocketEventInterface::CHANNEL_ALIVE;
251
252  WebSocketHandshakeResponse response_to_pass;
253  response_to_pass.url.Swap(&response->url);
254  response_to_pass.status_code = response->status_code;
255  response_to_pass.status_text.swap(response->status_text);
256  void* iter = NULL;
257  std::string name, value;
258  while (response->headers->EnumerateHeaderLines(&iter, &name, &value))
259    response_to_pass.headers.push_back(std::make_pair(name, value));
260  response_to_pass.headers_text =
261      net::HttpUtil::ConvertHeadersBackToHTTPResponse(
262          response->headers->raw_headers());
263  response_to_pass.response_time = response->response_time;
264
265  return StateCast(dispatcher_->NotifyFinishOpeningHandshake(routing_id_,
266                                                             response_to_pass));
267}
268
269ChannelState WebSocketEventHandler::OnSSLCertificateError(
270    scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks,
271    const GURL& url,
272    const net::SSLInfo& ssl_info,
273    bool fatal) {
274  DVLOG(3) << "WebSocketEventHandler::OnSSLCertificateError"
275           << " routing_id=" << routing_id_ << " url=" << url.spec()
276           << " cert_status=" << ssl_info.cert_status << " fatal=" << fatal;
277  ssl_error_handler_delegate_.reset(
278      new SSLErrorHandlerDelegate(callbacks.Pass()));
279  // We don't need request_id to be unique so just make a fake one.
280  GlobalRequestID request_id(-1, -1);
281  SSLManager::OnSSLCertificateError(ssl_error_handler_delegate_->GetWeakPtr(),
282                                    request_id,
283                                    RESOURCE_TYPE_SUB_RESOURCE,
284                                    url,
285                                    dispatcher_->render_process_id(),
286                                    render_frame_id_,
287                                    ssl_info,
288                                    fatal);
289  // The above method is always asynchronous.
290  return WebSocketEventInterface::CHANNEL_ALIVE;
291}
292
293WebSocketEventHandler::SSLErrorHandlerDelegate::SSLErrorHandlerDelegate(
294    scoped_ptr<net::WebSocketEventInterface::SSLErrorCallbacks> callbacks)
295    : callbacks_(callbacks.Pass()), weak_ptr_factory_(this) {}
296
297WebSocketEventHandler::SSLErrorHandlerDelegate::~SSLErrorHandlerDelegate() {}
298
299base::WeakPtr<SSLErrorHandler::Delegate>
300WebSocketEventHandler::SSLErrorHandlerDelegate::GetWeakPtr() {
301  return weak_ptr_factory_.GetWeakPtr();
302}
303
304void WebSocketEventHandler::SSLErrorHandlerDelegate::CancelSSLRequest(
305    const GlobalRequestID& id,
306    int error,
307    const net::SSLInfo* ssl_info) {
308  DVLOG(3) << "SSLErrorHandlerDelegate::CancelSSLRequest"
309           << " error=" << error
310           << " cert_status=" << (ssl_info ? ssl_info->cert_status
311                                           : static_cast<net::CertStatus>(-1));
312  callbacks_->CancelSSLRequest(error, ssl_info);
313}
314
315void WebSocketEventHandler::SSLErrorHandlerDelegate::ContinueSSLRequest(
316    const GlobalRequestID& id) {
317  DVLOG(3) << "SSLErrorHandlerDelegate::ContinueSSLRequest";
318  callbacks_->ContinueSSLRequest();
319}
320
321}  // namespace
322
323WebSocketHost::WebSocketHost(int routing_id,
324                             WebSocketDispatcherHost* dispatcher,
325                             net::URLRequestContext* url_request_context)
326    : dispatcher_(dispatcher),
327      url_request_context_(url_request_context),
328      routing_id_(routing_id) {
329  DVLOG(1) << "WebSocketHost: created routing_id=" << routing_id;
330}
331
332WebSocketHost::~WebSocketHost() {}
333
334void WebSocketHost::GoAway() {
335  OnDropChannel(false, static_cast<uint16>(net::kWebSocketErrorGoingAway), "");
336}
337
338bool WebSocketHost::OnMessageReceived(const IPC::Message& message) {
339  bool handled = true;
340  IPC_BEGIN_MESSAGE_MAP(WebSocketHost, message)
341    IPC_MESSAGE_HANDLER(WebSocketHostMsg_AddChannelRequest, OnAddChannelRequest)
342    IPC_MESSAGE_HANDLER(WebSocketMsg_SendFrame, OnSendFrame)
343    IPC_MESSAGE_HANDLER(WebSocketMsg_FlowControl, OnFlowControl)
344    IPC_MESSAGE_HANDLER(WebSocketMsg_DropChannel, OnDropChannel)
345    IPC_MESSAGE_UNHANDLED(handled = false)
346  IPC_END_MESSAGE_MAP()
347  return handled;
348}
349
350void WebSocketHost::OnAddChannelRequest(
351    const GURL& socket_url,
352    const std::vector<std::string>& requested_protocols,
353    const url::Origin& origin,
354    int render_frame_id) {
355  DVLOG(3) << "WebSocketHost::OnAddChannelRequest"
356           << " routing_id=" << routing_id_ << " socket_url=\"" << socket_url
357           << "\" requested_protocols=\""
358           << JoinString(requested_protocols, ", ") << "\" origin=\""
359           << origin.string() << "\"";
360
361  DCHECK(!channel_);
362  scoped_ptr<net::WebSocketEventInterface> event_interface(
363      new WebSocketEventHandler(dispatcher_, routing_id_, render_frame_id));
364  channel_.reset(
365      new net::WebSocketChannel(event_interface.Pass(), url_request_context_));
366  channel_->SendAddChannelRequest(socket_url, requested_protocols, origin);
367}
368
369void WebSocketHost::OnSendFrame(bool fin,
370                                WebSocketMessageType type,
371                                const std::vector<char>& data) {
372  DVLOG(3) << "WebSocketHost::OnSendFrame"
373           << " routing_id=" << routing_id_ << " fin=" << fin
374           << " type=" << type << " data is " << data.size() << " bytes";
375
376  DCHECK(channel_);
377  channel_->SendFrame(fin, MessageTypeToOpCode(type), data);
378}
379
380void WebSocketHost::OnFlowControl(int64 quota) {
381  DVLOG(3) << "WebSocketHost::OnFlowControl"
382           << " routing_id=" << routing_id_ << " quota=" << quota;
383
384  DCHECK(channel_);
385  channel_->SendFlowControl(quota);
386}
387
388void WebSocketHost::OnDropChannel(bool was_clean,
389                                  uint16 code,
390                                  const std::string& reason) {
391  DVLOG(3) << "WebSocketHost::OnDropChannel"
392           << " routing_id=" << routing_id_ << " was_clean=" << was_clean
393           << " code=" << code << " reason=\"" << reason << "\"";
394
395  DCHECK(channel_);
396  // TODO(yhirano): Handle |was_clean| appropriately.
397  channel_->StartClosingHandshake(code, reason);
398}
399
400}  // namespace content
401