1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "content/browser/renderer_host/socket_stream_dispatcher_host.h"
6
7#include <string>
8
9#include "base/logging.h"
10#include "content/browser/renderer_host/socket_stream_host.h"
11#include "content/browser/ssl/ssl_manager.h"
12#include "content/common/resource_messages.h"
13#include "content/common/socket_stream.h"
14#include "content/common/socket_stream_messages.h"
15#include "content/public/browser/content_browser_client.h"
16#include "content/public/browser/global_request_id.h"
17#include "net/base/net_errors.h"
18#include "net/cookies/canonical_cookie.h"
19#include "net/url_request/url_request_context_getter.h"
20#include "net/websockets/websocket_job.h"
21#include "net/websockets/websocket_throttle.h"
22
23namespace content {
24
25namespace {
26
27const size_t kMaxSocketStreamHosts = 16 * 1024;
28
29}  // namespace
30
31SocketStreamDispatcherHost::SocketStreamDispatcherHost(
32    int render_process_id,
33    const GetRequestContextCallback& request_context_callback,
34    ResourceContext* resource_context)
35    : render_process_id_(render_process_id),
36      request_context_callback_(request_context_callback),
37      resource_context_(resource_context),
38      weak_ptr_factory_(this),
39      on_shutdown_(false) {
40  net::WebSocketJob::EnsureInit();
41}
42
43bool SocketStreamDispatcherHost::OnMessageReceived(const IPC::Message& message,
44                                                   bool* message_was_ok) {
45  if (on_shutdown_)
46    return false;
47
48  bool handled = true;
49  IPC_BEGIN_MESSAGE_MAP_EX(SocketStreamDispatcherHost, message, *message_was_ok)
50    IPC_MESSAGE_HANDLER(SocketStreamHostMsg_Connect, OnConnect)
51    IPC_MESSAGE_HANDLER(SocketStreamHostMsg_SendData, OnSendData)
52    IPC_MESSAGE_HANDLER(SocketStreamHostMsg_Close, OnCloseReq)
53    IPC_MESSAGE_UNHANDLED(handled = false)
54  IPC_END_MESSAGE_MAP_EX()
55  return handled;
56}
57
58// SocketStream::Delegate methods implementations.
59void SocketStreamDispatcherHost::OnConnected(net::SocketStream* socket,
60                                             int max_pending_send_allowed) {
61  int socket_id = SocketStreamHost::SocketIdFromSocketStream(socket);
62  DVLOG(2) << "SocketStreamDispatcherHost::OnConnected socket_id=" << socket_id
63           << " max_pending_send_allowed=" << max_pending_send_allowed;
64  if (socket_id == kNoSocketId) {
65    DVLOG(1) << "NoSocketId in OnConnected";
66    return;
67  }
68  if (!Send(new SocketStreamMsg_Connected(
69          socket_id, max_pending_send_allowed))) {
70    DVLOG(1) << "SocketStreamMsg_Connected failed.";
71    DeleteSocketStreamHost(socket_id);
72  }
73}
74
75void SocketStreamDispatcherHost::OnSentData(net::SocketStream* socket,
76                                            int amount_sent) {
77  int socket_id = SocketStreamHost::SocketIdFromSocketStream(socket);
78  DVLOG(2) << "SocketStreamDispatcherHost::OnSentData socket_id=" << socket_id
79           << " amount_sent=" << amount_sent;
80  if (socket_id == kNoSocketId) {
81    DVLOG(1) << "NoSocketId in OnSentData";
82    return;
83  }
84  if (!Send(new SocketStreamMsg_SentData(socket_id, amount_sent))) {
85    DVLOG(1) << "SocketStreamMsg_SentData failed.";
86    DeleteSocketStreamHost(socket_id);
87  }
88}
89
90void SocketStreamDispatcherHost::OnReceivedData(
91    net::SocketStream* socket, const char* data, int len) {
92  int socket_id = SocketStreamHost::SocketIdFromSocketStream(socket);
93  DVLOG(2) << "SocketStreamDispatcherHost::OnReceiveData socket_id="
94           << socket_id;
95  if (socket_id == kNoSocketId) {
96    DVLOG(1) << "NoSocketId in OnReceivedData";
97    return;
98  }
99  if (!Send(new SocketStreamMsg_ReceivedData(
100          socket_id, std::vector<char>(data, data + len)))) {
101    DVLOG(1) << "SocketStreamMsg_ReceivedData failed.";
102    DeleteSocketStreamHost(socket_id);
103  }
104}
105
106void SocketStreamDispatcherHost::OnClose(net::SocketStream* socket) {
107  int socket_id = SocketStreamHost::SocketIdFromSocketStream(socket);
108  DVLOG(2) << "SocketStreamDispatcherHost::OnClosed socket_id=" << socket_id;
109  if (socket_id == kNoSocketId) {
110    DVLOG(1) << "NoSocketId in OnClose";
111    return;
112  }
113  DeleteSocketStreamHost(socket_id);
114}
115
116void SocketStreamDispatcherHost::OnError(const net::SocketStream* socket,
117                                         int error) {
118  int socket_id = SocketStreamHost::SocketIdFromSocketStream(socket);
119  DVLOG(2) << "SocketStreamDispatcherHost::OnError socket_id=" << socket_id;
120  if (socket_id == content::kNoSocketId) {
121    DVLOG(1) << "NoSocketId in OnError";
122    return;
123  }
124  // SocketStream::Delegate::OnError() events are handled as WebSocket error
125  // event when user agent was required to fail WebSocket connection or the
126  // WebSocket connection is closed with prejudice.
127  if (!Send(new SocketStreamMsg_Failed(socket_id, error))) {
128    DVLOG(1) << "SocketStreamMsg_Failed failed.";
129    DeleteSocketStreamHost(socket_id);
130  }
131}
132
133void SocketStreamDispatcherHost::OnSSLCertificateError(
134    net::SocketStream* socket, const net::SSLInfo& ssl_info, bool fatal) {
135  int socket_id = SocketStreamHost::SocketIdFromSocketStream(socket);
136  DVLOG(2) << "SocketStreamDispatcherHost::OnSSLCertificateError socket_id="
137           << socket_id;
138  if (socket_id == kNoSocketId) {
139    DVLOG(1) << "NoSocketId in OnSSLCertificateError";
140    return;
141  }
142  SocketStreamHost* socket_stream_host = hosts_.Lookup(socket_id);
143  DCHECK(socket_stream_host);
144  GlobalRequestID request_id(-1, socket_id);
145  SSLManager::OnSSLCertificateError(
146      weak_ptr_factory_.GetWeakPtr(), request_id, ResourceType::SUB_RESOURCE,
147      socket->url(), render_process_id_, socket_stream_host->render_view_id(),
148      ssl_info, fatal);
149}
150
151bool SocketStreamDispatcherHost::CanGetCookies(net::SocketStream* socket,
152                                               const GURL& url) {
153  return GetContentClient()->browser()->AllowGetCookie(
154      url, url, net::CookieList(), resource_context_, 0, MSG_ROUTING_NONE);
155}
156
157bool SocketStreamDispatcherHost::CanSetCookie(net::SocketStream* request,
158                                              const GURL& url,
159                                              const std::string& cookie_line,
160                                              net::CookieOptions* options) {
161  return GetContentClient()->browser()->AllowSetCookie(
162      url, url, cookie_line, resource_context_, 0, MSG_ROUTING_NONE, options);
163}
164
165void SocketStreamDispatcherHost::CancelSSLRequest(
166    const GlobalRequestID& id,
167    int error,
168    const net::SSLInfo* ssl_info) {
169  int socket_id = id.request_id;
170  DVLOG(2) << "SocketStreamDispatcherHost::CancelSSLRequest socket_id="
171           << socket_id;
172  DCHECK_NE(kNoSocketId, socket_id);
173  SocketStreamHost* socket_stream_host = hosts_.Lookup(socket_id);
174  DCHECK(socket_stream_host);
175  if (ssl_info)
176    socket_stream_host->CancelWithSSLError(*ssl_info);
177  else
178    socket_stream_host->CancelWithError(error);
179}
180
181void SocketStreamDispatcherHost::ContinueSSLRequest(
182    const GlobalRequestID& id) {
183  int socket_id = id.request_id;
184  DVLOG(2) << "SocketStreamDispatcherHost::ContinueSSLRequest socket_id="
185           << socket_id;
186  DCHECK_NE(kNoSocketId, socket_id);
187  SocketStreamHost* socket_stream_host = hosts_.Lookup(socket_id);
188  DCHECK(socket_stream_host);
189  socket_stream_host->ContinueDespiteError();
190}
191
192SocketStreamDispatcherHost::~SocketStreamDispatcherHost() {
193  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
194  Shutdown();
195}
196
197// Message handlers called by OnMessageReceived.
198void SocketStreamDispatcherHost::OnConnect(int render_view_id,
199                                           const GURL& url,
200                                           int socket_id) {
201  DVLOG(2) << "SocketStreamDispatcherHost::OnConnect"
202           << " render_view_id=" << render_view_id
203           << " url=" << url
204           << " socket_id=" << socket_id;
205  DCHECK_NE(kNoSocketId, socket_id);
206
207  if (hosts_.size() >= kMaxSocketStreamHosts) {
208    if (!Send(new SocketStreamMsg_Failed(socket_id,
209                                         net::ERR_TOO_MANY_SOCKET_STREAMS))) {
210      DVLOG(1) << "SocketStreamMsg_Failed failed.";
211    }
212    if (!Send(new SocketStreamMsg_Closed(socket_id))) {
213      DVLOG(1) << "SocketStreamMsg_Closed failed.";
214    }
215    return;
216  }
217
218  if (hosts_.Lookup(socket_id)) {
219    DVLOG(1) << "socket_id=" << socket_id << " already registered.";
220    return;
221  }
222
223  // Note that the SocketStreamHost is responsible for checking that |url|
224  // is valid.
225  SocketStreamHost* socket_stream_host =
226      new SocketStreamHost(this, render_view_id, socket_id);
227  hosts_.AddWithID(socket_stream_host, socket_id);
228  socket_stream_host->Connect(url, GetURLRequestContext());
229  DVLOG(2) << "SocketStreamDispatcherHost::OnConnect -> " << socket_id;
230}
231
232void SocketStreamDispatcherHost::OnSendData(
233    int socket_id, const std::vector<char>& data) {
234  DVLOG(2) << "SocketStreamDispatcherHost::OnSendData socket_id=" << socket_id;
235  SocketStreamHost* socket_stream_host = hosts_.Lookup(socket_id);
236  if (!socket_stream_host) {
237    DVLOG(1) << "socket_id=" << socket_id << " already closed.";
238    return;
239  }
240  if (!socket_stream_host->SendData(data)) {
241    // Cannot accept more data to send.
242    socket_stream_host->Close();
243  }
244}
245
246void SocketStreamDispatcherHost::OnCloseReq(int socket_id) {
247  DVLOG(2) << "SocketStreamDispatcherHost::OnCloseReq socket_id=" << socket_id;
248  SocketStreamHost* socket_stream_host = hosts_.Lookup(socket_id);
249  if (!socket_stream_host)
250    return;
251  socket_stream_host->Close();
252}
253
254void SocketStreamDispatcherHost::DeleteSocketStreamHost(int socket_id) {
255  SocketStreamHost* socket_stream_host = hosts_.Lookup(socket_id);
256  DCHECK(socket_stream_host);
257  delete socket_stream_host;
258  hosts_.Remove(socket_id);
259  if (!Send(new SocketStreamMsg_Closed(socket_id))) {
260    DVLOG(1) << "SocketStreamMsg_Closed failed.";
261  }
262}
263
264net::URLRequestContext* SocketStreamDispatcherHost::GetURLRequestContext() {
265  return request_context_callback_.Run(ResourceType::SUB_RESOURCE);
266}
267
268void SocketStreamDispatcherHost::Shutdown() {
269  DCHECK(BrowserThread::CurrentlyOn(BrowserThread::IO));
270  // TODO(ukai): Implement IDMap::RemoveAll().
271  for (IDMap<SocketStreamHost>::const_iterator iter(&hosts_);
272       !iter.IsAtEnd();
273       iter.Advance()) {
274    int socket_id = iter.GetCurrentKey();
275    const SocketStreamHost* socket_stream_host = iter.GetCurrentValue();
276    delete socket_stream_host;
277    hosts_.Remove(socket_id);
278  }
279  on_shutdown_ = true;
280}
281
282}  // namespace content
283