1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "content/child/websocket_bridge.h"
6
7#include <stdint.h>
8#include <string>
9#include <utility>
10#include <vector>
11
12#include "base/logging.h"
13#include "base/strings/string_util.h"
14#include "content/child/child_thread.h"
15#include "content/child/websocket_dispatcher.h"
16#include "content/common/websocket.h"
17#include "content/common/websocket_messages.h"
18#include "ipc/ipc_message.h"
19#include "ipc/ipc_message_macros.h"
20#include "third_party/WebKit/public/platform/WebSerializedOrigin.h"
21#include "third_party/WebKit/public/platform/WebSocketHandle.h"
22#include "third_party/WebKit/public/platform/WebSocketHandleClient.h"
23#include "third_party/WebKit/public/platform/WebSocketHandshakeRequestInfo.h"
24#include "third_party/WebKit/public/platform/WebSocketHandshakeResponseInfo.h"
25#include "third_party/WebKit/public/platform/WebString.h"
26#include "third_party/WebKit/public/platform/WebURL.h"
27#include "third_party/WebKit/public/platform/WebVector.h"
28#include "url/gurl.h"
29#include "url/origin.h"
30
31using blink::WebSerializedOrigin;
32using blink::WebSocketHandle;
33using blink::WebSocketHandleClient;
34using blink::WebString;
35using blink::WebURL;
36using blink::WebVector;
37
38namespace content {
39
40namespace {
41
42const unsigned short kAbnormalShutdownOpCode = 1006;
43
44}  // namespace
45
46WebSocketBridge::WebSocketBridge()
47    : channel_id_(kInvalidChannelId),
48      render_frame_id_(MSG_ROUTING_NONE),
49      client_(NULL) {}
50
51WebSocketBridge::~WebSocketBridge() {
52  if (channel_id_ != kInvalidChannelId) {
53    // The connection is abruptly disconnected by the renderer without
54    // closing handshake.
55    ChildThread::current()->Send(
56        new WebSocketMsg_DropChannel(channel_id_,
57                                     false,
58                                     kAbnormalShutdownOpCode,
59                                     std::string()));
60  }
61  Disconnect();
62}
63
64bool WebSocketBridge::OnMessageReceived(const IPC::Message& msg) {
65  bool handled = true;
66  IPC_BEGIN_MESSAGE_MAP(WebSocketBridge, msg)
67    IPC_MESSAGE_HANDLER(WebSocketMsg_AddChannelResponse, DidConnect)
68    IPC_MESSAGE_HANDLER(WebSocketMsg_NotifyStartOpeningHandshake,
69                        DidStartOpeningHandshake)
70    IPC_MESSAGE_HANDLER(WebSocketMsg_NotifyFinishOpeningHandshake,
71                        DidFinishOpeningHandshake)
72    IPC_MESSAGE_HANDLER(WebSocketMsg_NotifyFailure, DidFail)
73    IPC_MESSAGE_HANDLER(WebSocketMsg_SendFrame, DidReceiveData)
74    IPC_MESSAGE_HANDLER(WebSocketMsg_FlowControl, DidReceiveFlowControl)
75    IPC_MESSAGE_HANDLER(WebSocketMsg_DropChannel, DidClose)
76    IPC_MESSAGE_HANDLER(WebSocketMsg_NotifyClosing,
77                        DidStartClosingHandshake)
78    IPC_MESSAGE_UNHANDLED(handled = false)
79  IPC_END_MESSAGE_MAP()
80  return handled;
81}
82
83void WebSocketBridge::DidConnect(bool fail,
84                                 const std::string& selected_protocol,
85                                 const std::string& extensions) {
86  WebSocketHandleClient* client = client_;
87  DVLOG(1) << "WebSocketBridge::DidConnect("
88           << fail << ", "
89           << selected_protocol << ", "
90           << extensions << ")";
91  if (fail)
92    Disconnect();
93  if (!client)
94    return;
95
96  WebString protocol_to_pass = WebString::fromUTF8(selected_protocol);
97  WebString extensions_to_pass = WebString::fromUTF8(extensions);
98  client->didConnect(this, fail, protocol_to_pass, extensions_to_pass);
99  // |this| can be deleted here.
100}
101
102void WebSocketBridge::DidStartOpeningHandshake(
103    const WebSocketHandshakeRequest& request) {
104  DVLOG(1) << "WebSocketBridge::DidStartOpeningHandshake("
105           << request.url << ")";
106  // All strings are already encoded to ASCII in the browser.
107  blink::WebSocketHandshakeRequestInfo request_to_pass;
108  request_to_pass.setURL(WebURL(request.url));
109  for (size_t i = 0; i < request.headers.size(); ++i) {
110    const std::pair<std::string, std::string>& header = request.headers[i];
111    request_to_pass.addHeaderField(WebString::fromLatin1(header.first),
112                                   WebString::fromLatin1(header.second));
113  }
114  request_to_pass.setHeadersText(WebString::fromLatin1(request.headers_text));
115  client_->didStartOpeningHandshake(this, request_to_pass);
116}
117
118void WebSocketBridge::DidFinishOpeningHandshake(
119    const WebSocketHandshakeResponse& response) {
120  DVLOG(1) << "WebSocketBridge::DidFinishOpeningHandshake("
121           << response.url << ")";
122  // All strings are already encoded to ASCII in the browser.
123  blink::WebSocketHandshakeResponseInfo response_to_pass;
124  response_to_pass.setStatusCode(response.status_code);
125  response_to_pass.setStatusText(WebString::fromLatin1(response.status_text));
126  for (size_t i = 0; i < response.headers.size(); ++i) {
127    const std::pair<std::string, std::string>& header = response.headers[i];
128    response_to_pass.addHeaderField(WebString::fromLatin1(header.first),
129                                    WebString::fromLatin1(header.second));
130  }
131  response_to_pass.setHeadersText(WebString::fromLatin1(response.headers_text));
132  client_->didFinishOpeningHandshake(this, response_to_pass);
133}
134
135void WebSocketBridge::DidFail(const std::string& message) {
136  DVLOG(1) << "WebSocketBridge::DidFail(" << message << ")";
137  WebSocketHandleClient* client = client_;
138  Disconnect();
139  if (!client)
140    return;
141
142  WebString message_to_pass = WebString::fromUTF8(message);
143  client->didFail(this, message_to_pass);
144  // |this| can be deleted here.
145}
146
147void WebSocketBridge::DidReceiveData(bool fin,
148                                     WebSocketMessageType type,
149                                     const std::vector<char>& data) {
150  DVLOG(1) << "WebSocketBridge::DidReceiveData("
151           << fin << ", "
152           << type << ", "
153           << "(data size = " << data.size() << "))";
154  if (!client_)
155    return;
156
157  WebSocketHandle::MessageType type_to_pass =
158      WebSocketHandle::MessageTypeContinuation;
159  switch (type) {
160    case WEB_SOCKET_MESSAGE_TYPE_CONTINUATION:
161      type_to_pass = WebSocketHandle::MessageTypeContinuation;
162      break;
163    case WEB_SOCKET_MESSAGE_TYPE_TEXT:
164      type_to_pass = WebSocketHandle::MessageTypeText;
165      break;
166    case WEB_SOCKET_MESSAGE_TYPE_BINARY:
167      type_to_pass = WebSocketHandle::MessageTypeBinary;
168      break;
169  }
170  const char* data_to_pass = data.empty() ? NULL : &data[0];
171  client_->didReceiveData(this, fin, type_to_pass, data_to_pass, data.size());
172  // |this| can be deleted here.
173}
174
175void WebSocketBridge::DidReceiveFlowControl(int64_t quota) {
176  DVLOG(1) << "WebSocketBridge::DidReceiveFlowControl(" << quota << ")";
177  if (!client_)
178    return;
179
180  client_->didReceiveFlowControl(this, quota);
181  // |this| can be deleted here.
182}
183
184void WebSocketBridge::DidClose(bool was_clean,
185                               unsigned short code,
186                               const std::string& reason) {
187  DVLOG(1) << "WebSocketBridge::DidClose("
188           << was_clean << ", "
189           << code << ", "
190           << reason << ")";
191  WebSocketHandleClient* client = client_;
192  Disconnect();
193  if (!client)
194    return;
195
196  WebString reason_to_pass = WebString::fromUTF8(reason);
197  client->didClose(this, was_clean, code, reason_to_pass);
198  // |this| can be deleted here.
199}
200
201void WebSocketBridge::DidStartClosingHandshake() {
202  DVLOG(1) << "WebSocketBridge::DidStartClosingHandshake()";
203  if (!client_)
204    return;
205
206  client_->didStartClosingHandshake(this);
207  // |this| can be deleted here.
208}
209
210void WebSocketBridge::connect(
211    const WebURL& url,
212    const WebVector<WebString>& protocols,
213    const WebSerializedOrigin& origin,
214    WebSocketHandleClient* client) {
215  DCHECK_EQ(kInvalidChannelId, channel_id_);
216  WebSocketDispatcher* dispatcher =
217      ChildThread::current()->websocket_dispatcher();
218  channel_id_ = dispatcher->AddBridge(this);
219  client_ = client;
220
221  std::vector<std::string> protocols_to_pass;
222  for (size_t i = 0; i < protocols.size(); ++i)
223    protocols_to_pass.push_back(protocols[i].utf8());
224  url::Origin origin_to_pass(origin);
225
226  DVLOG(1) << "Bridge#" << channel_id_ << " Connect(" << url << ", ("
227           << JoinString(protocols_to_pass, ", ") << "), "
228           << origin_to_pass.string() << ")";
229
230  ChildThread::current()->Send(new WebSocketHostMsg_AddChannelRequest(
231      channel_id_, url, protocols_to_pass, origin_to_pass, render_frame_id_));
232}
233
234void WebSocketBridge::send(bool fin,
235                           WebSocketHandle::MessageType type,
236                           const char* data,
237                           size_t size) {
238  if (channel_id_ == kInvalidChannelId)
239    return;
240
241  WebSocketMessageType type_to_pass = WEB_SOCKET_MESSAGE_TYPE_CONTINUATION;
242  switch (type) {
243    case WebSocketHandle::MessageTypeContinuation:
244      type_to_pass = WEB_SOCKET_MESSAGE_TYPE_CONTINUATION;
245      break;
246    case WebSocketHandle::MessageTypeText:
247      type_to_pass = WEB_SOCKET_MESSAGE_TYPE_TEXT;
248      break;
249    case WebSocketHandle::MessageTypeBinary:
250      type_to_pass = WEB_SOCKET_MESSAGE_TYPE_BINARY;
251      break;
252  }
253
254  DVLOG(1) << "Bridge #" << channel_id_ << " Send("
255           << fin << ", " << type_to_pass << ", "
256           << "(data size = "  << size << "))";
257
258  ChildThread::current()->Send(
259      new WebSocketMsg_SendFrame(channel_id_,
260                                 fin,
261                                 type_to_pass,
262                                 std::vector<char>(data, data + size)));
263}
264
265void WebSocketBridge::flowControl(int64_t quota) {
266  if (channel_id_ == kInvalidChannelId)
267    return;
268
269  DVLOG(1) << "Bridge #" << channel_id_ << " FlowControl(" << quota << ")";
270
271  ChildThread::current()->Send(
272      new WebSocketMsg_FlowControl(channel_id_, quota));
273}
274
275void WebSocketBridge::close(unsigned short code,
276                            const WebString& reason) {
277  if (channel_id_ == kInvalidChannelId)
278    return;
279
280  std::string reason_to_pass = reason.utf8();
281  DVLOG(1) << "Bridge #" << channel_id_ << " Close("
282           << code << ", " << reason_to_pass << ")";
283  // This method is for closing handshake and hence |was_clean| shall be true.
284  ChildThread::current()->Send(
285      new WebSocketMsg_DropChannel(channel_id_, true, code, reason_to_pass));
286}
287
288void WebSocketBridge::Disconnect() {
289  if (channel_id_ == kInvalidChannelId)
290    return;
291  WebSocketDispatcher* dispatcher =
292      ChildThread::current()->websocket_dispatcher();
293  dispatcher->RemoveBridge(channel_id_);
294
295  channel_id_ = kInvalidChannelId;
296  client_ = NULL;
297}
298
299}  // namespace content
300