1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "content/browser/renderer_host/socket_stream_dispatcher_host.h"
6
7#include <string>
8
9#include "base/logging.h"
10#include "content/browser/renderer_host/socket_stream_host.h"
11#include "content/browser/ssl/ssl_manager.h"
12#include "content/common/resource_messages.h"
13#include "content/common/socket_stream.h"
14#include "content/common/socket_stream_messages.h"
15#include "content/public/browser/content_browser_client.h"
16#include "content/public/browser/global_request_id.h"
17#include "net/base/net_errors.h"
18#include "net/cookies/canonical_cookie.h"
19#include "net/url_request/url_request_context_getter.h"
20#include "net/websockets/websocket_job.h"
21#include "net/websockets/websocket_throttle.h"
22
23namespace content {
24
25namespace {
26
27const size_t kMaxSocketStreamHosts = 16 * 1024;
28
29}  // namespace
30
31SocketStreamDispatcherHost::SocketStreamDispatcherHost(
32    int render_process_id,
33    const GetRequestContextCallback& request_context_callback,
34    ResourceContext* resource_context)
35    : BrowserMessageFilter(SocketStreamMsgStart),
36      render_process_id_(render_process_id),
37      request_context_callback_(request_context_callback),
38      resource_context_(resource_context),
39      on_shutdown_(false),
40      weak_ptr_factory_(this) {
41  net::WebSocketJob::EnsureInit();
42}
43
44bool SocketStreamDispatcherHost::OnMessageReceived(
45    const IPC::Message& message) {
46  if (on_shutdown_)
47    return false;
48
49  bool handled = true;
50  IPC_BEGIN_MESSAGE_MAP(SocketStreamDispatcherHost, message)
51    IPC_MESSAGE_HANDLER(SocketStreamHostMsg_Connect, OnConnect)
52    IPC_MESSAGE_HANDLER(SocketStreamHostMsg_SendData, OnSendData)
53    IPC_MESSAGE_HANDLER(SocketStreamHostMsg_Close, OnCloseReq)
54    IPC_MESSAGE_UNHANDLED(handled = false)
55  IPC_END_MESSAGE_MAP()
56  return handled;
57}
58
59// SocketStream::Delegate methods implementations.
60void SocketStreamDispatcherHost::OnConnected(net::SocketStream* socket,
61                                             int max_pending_send_allowed) {
62  int socket_id = SocketStreamHost::SocketIdFromSocketStream(socket);
63  DVLOG(2) << "SocketStreamDispatcherHost::OnConnected socket_id=" << socket_id
64           << " max_pending_send_allowed=" << max_pending_send_allowed;
65  if (socket_id == kNoSocketId) {
66    DVLOG(1) << "NoSocketId in OnConnected";
67    return;
68  }
69  if (!Send(new SocketStreamMsg_Connected(
70          socket_id, max_pending_send_allowed))) {
71    DVLOG(1) << "SocketStreamMsg_Connected failed.";
72    DeleteSocketStreamHost(socket_id);
73  }
74}
75
76void SocketStreamDispatcherHost::OnSentData(net::SocketStream* socket,
77                                            int amount_sent) {
78  int socket_id = SocketStreamHost::SocketIdFromSocketStream(socket);
79  DVLOG(2) << "SocketStreamDispatcherHost::OnSentData socket_id=" << socket_id
80           << " amount_sent=" << amount_sent;
81  if (socket_id == kNoSocketId) {
82    DVLOG(1) << "NoSocketId in OnSentData";
83    return;
84  }
85  if (!Send(new SocketStreamMsg_SentData(socket_id, amount_sent))) {
86    DVLOG(1) << "SocketStreamMsg_SentData failed.";
87    DeleteSocketStreamHost(socket_id);
88  }
89}
90
91void SocketStreamDispatcherHost::OnReceivedData(
92    net::SocketStream* socket, const char* data, int len) {
93  int socket_id = SocketStreamHost::SocketIdFromSocketStream(socket);
94  DVLOG(2) << "SocketStreamDispatcherHost::OnReceiveData socket_id="
95           << socket_id;
96  if (socket_id == kNoSocketId) {
97    DVLOG(1) << "NoSocketId in OnReceivedData";
98    return;
99  }
100  if (!Send(new SocketStreamMsg_ReceivedData(
101          socket_id, std::vector<char>(data, data + len)))) {
102    DVLOG(1) << "SocketStreamMsg_ReceivedData failed.";
103    DeleteSocketStreamHost(socket_id);
104  }
105}
106
107void SocketStreamDispatcherHost::OnClose(net::SocketStream* socket) {
108  int socket_id = SocketStreamHost::SocketIdFromSocketStream(socket);
109  DVLOG(2) << "SocketStreamDispatcherHost::OnClosed socket_id=" << socket_id;
110  if (socket_id == kNoSocketId) {
111    DVLOG(1) << "NoSocketId in OnClose";
112    return;
113  }
114  DeleteSocketStreamHost(socket_id);
115}
116
117void SocketStreamDispatcherHost::OnError(const net::SocketStream* socket,
118                                         int error) {
119  int socket_id = SocketStreamHost::SocketIdFromSocketStream(socket);
120  DVLOG(2) << "SocketStreamDispatcherHost::OnError socket_id=" << socket_id;
121  if (socket_id == content::kNoSocketId) {
122    DVLOG(1) << "NoSocketId in OnError";
123    return;
124  }
125  // SocketStream::Delegate::OnError() events are handled as WebSocket error
126  // event when user agent was required to fail WebSocket connection or the
127  // WebSocket connection is closed with prejudice.
128  if (!Send(new SocketStreamMsg_Failed(socket_id, error))) {
129    DVLOG(1) << "SocketStreamMsg_Failed failed.";
130    DeleteSocketStreamHost(socket_id);
131  }
132}
133
134void SocketStreamDispatcherHost::OnSSLCertificateError(
135    net::SocketStream* socket, const net::SSLInfo& ssl_info, bool fatal) {
136  int socket_id = SocketStreamHost::SocketIdFromSocketStream(socket);
137  DVLOG(2) << "SocketStreamDispatcherHost::OnSSLCertificateError socket_id="
138           << socket_id;
139  if (socket_id == kNoSocketId) {
140    DVLOG(1) << "NoSocketId in OnSSLCertificateError";
141    return;
142  }
143  SocketStreamHost* socket_stream_host = hosts_.Lookup(socket_id);
144  DCHECK(socket_stream_host);
145  GlobalRequestID request_id(-1, socket_id);
146  SSLManager::OnSSLCertificateError(
147      weak_ptr_factory_.GetWeakPtr(), request_id, RESOURCE_TYPE_SUB_RESOURCE,
148      socket->url(), render_process_id_, socket_stream_host->render_frame_id(),
149      ssl_info, fatal);
150}
151
152bool SocketStreamDispatcherHost::CanGetCookies(net::SocketStream* socket,
153                                               const GURL& url) {
154  int socket_id = SocketStreamHost::SocketIdFromSocketStream(socket);
155  if (socket_id == kNoSocketId) {
156    return false;
157  }
158  SocketStreamHost* socket_stream_host = hosts_.Lookup(socket_id);
159  DCHECK(socket_stream_host);
160  return GetContentClient()->browser()->AllowGetCookie(
161      url,
162      url,
163      net::CookieList(),
164      resource_context_,
165      render_process_id_,
166      socket_stream_host->render_frame_id());
167}
168
169bool SocketStreamDispatcherHost::CanSetCookie(net::SocketStream* request,
170                                              const GURL& url,
171                                              const std::string& cookie_line,
172                                              net::CookieOptions* options) {
173  int socket_id = SocketStreamHost::SocketIdFromSocketStream(request);
174  if (socket_id == kNoSocketId) {
175    return false;
176  }
177  SocketStreamHost* socket_stream_host = hosts_.Lookup(socket_id);
178  DCHECK(socket_stream_host);
179  return GetContentClient()->browser()->AllowSetCookie(
180      url,
181      url,
182      cookie_line,
183      resource_context_,
184      render_process_id_,
185      socket_stream_host->render_frame_id(),
186      options);
187}
188
189void SocketStreamDispatcherHost::CancelSSLRequest(
190    const GlobalRequestID& id,
191    int error,
192    const net::SSLInfo* ssl_info) {
193  int socket_id = id.request_id;
194  DVLOG(2) << "SocketStreamDispatcherHost::CancelSSLRequest socket_id="
195           << socket_id;
196  DCHECK_NE(kNoSocketId, socket_id);
197  SocketStreamHost* socket_stream_host = hosts_.Lookup(socket_id);
198  DCHECK(socket_stream_host);
199  if (ssl_info)
200    socket_stream_host->CancelWithSSLError(*ssl_info);
201  else
202    socket_stream_host->CancelWithError(error);
203}
204
205void SocketStreamDispatcherHost::ContinueSSLRequest(
206    const GlobalRequestID& id) {
207  int socket_id = id.request_id;
208  DVLOG(2) << "SocketStreamDispatcherHost::ContinueSSLRequest socket_id="
209           << socket_id;
210  DCHECK_NE(kNoSocketId, socket_id);
211  SocketStreamHost* socket_stream_host = hosts_.Lookup(socket_id);
212  DCHECK(socket_stream_host);
213  socket_stream_host->ContinueDespiteError();
214}
215
216SocketStreamDispatcherHost::~SocketStreamDispatcherHost() {
217  DCHECK_CURRENTLY_ON(BrowserThread::IO);
218  Shutdown();
219}
220
221// Message handlers called by OnMessageReceived.
222void SocketStreamDispatcherHost::OnConnect(int render_frame_id,
223                                           const GURL& url,
224                                           int socket_id) {
225  DVLOG(2) << "SocketStreamDispatcherHost::OnConnect"
226           << " render_frame_id=" << render_frame_id
227           << " url=" << url
228           << " socket_id=" << socket_id;
229  DCHECK_NE(kNoSocketId, socket_id);
230
231  if (hosts_.size() >= kMaxSocketStreamHosts) {
232    if (!Send(new SocketStreamMsg_Failed(socket_id,
233                                         net::ERR_TOO_MANY_SOCKET_STREAMS))) {
234      DVLOG(1) << "SocketStreamMsg_Failed failed.";
235    }
236    if (!Send(new SocketStreamMsg_Closed(socket_id))) {
237      DVLOG(1) << "SocketStreamMsg_Closed failed.";
238    }
239    return;
240  }
241
242  if (hosts_.Lookup(socket_id)) {
243    DVLOG(1) << "socket_id=" << socket_id << " already registered.";
244    return;
245  }
246
247  // Note that the SocketStreamHost is responsible for checking that |url|
248  // is valid.
249  SocketStreamHost* socket_stream_host =
250      new SocketStreamHost(this, render_process_id_, render_frame_id,
251                           socket_id);
252  hosts_.AddWithID(socket_stream_host, socket_id);
253  socket_stream_host->Connect(url, GetURLRequestContext());
254  DVLOG(2) << "SocketStreamDispatcherHost::OnConnect -> " << socket_id;
255}
256
257void SocketStreamDispatcherHost::OnSendData(
258    int socket_id, const std::vector<char>& data) {
259  DVLOG(2) << "SocketStreamDispatcherHost::OnSendData socket_id=" << socket_id;
260  SocketStreamHost* socket_stream_host = hosts_.Lookup(socket_id);
261  if (!socket_stream_host) {
262    DVLOG(1) << "socket_id=" << socket_id << " already closed.";
263    return;
264  }
265  if (!socket_stream_host->SendData(data)) {
266    // Cannot accept more data to send.
267    socket_stream_host->Close();
268  }
269}
270
271void SocketStreamDispatcherHost::OnCloseReq(int socket_id) {
272  DVLOG(2) << "SocketStreamDispatcherHost::OnCloseReq socket_id=" << socket_id;
273  SocketStreamHost* socket_stream_host = hosts_.Lookup(socket_id);
274  if (!socket_stream_host)
275    return;
276  socket_stream_host->Close();
277}
278
279void SocketStreamDispatcherHost::DeleteSocketStreamHost(int socket_id) {
280  SocketStreamHost* socket_stream_host = hosts_.Lookup(socket_id);
281  DCHECK(socket_stream_host);
282  delete socket_stream_host;
283  hosts_.Remove(socket_id);
284  if (!Send(new SocketStreamMsg_Closed(socket_id))) {
285    DVLOG(1) << "SocketStreamMsg_Closed failed.";
286  }
287}
288
289net::URLRequestContext* SocketStreamDispatcherHost::GetURLRequestContext() {
290  return request_context_callback_.Run(RESOURCE_TYPE_SUB_RESOURCE);
291}
292
293void SocketStreamDispatcherHost::Shutdown() {
294  DCHECK_CURRENTLY_ON(BrowserThread::IO);
295  // TODO(ukai): Implement IDMap::RemoveAll().
296  for (IDMap<SocketStreamHost>::const_iterator iter(&hosts_);
297       !iter.IsAtEnd();
298       iter.Advance()) {
299    int socket_id = iter.GetCurrentKey();
300    const SocketStreamHost* socket_stream_host = iter.GetCurrentValue();
301    delete socket_stream_host;
302    hosts_.Remove(socket_id);
303  }
304  on_shutdown_ = true;
305}
306
307}  // namespace content
308