1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#ifndef EXTENSIONS_BROWSER_API_SOCKET_SOCKET_API_H_
6#define EXTENSIONS_BROWSER_API_SOCKET_SOCKET_API_H_
7
8#include <string>
9
10#include "base/gtest_prod_util.h"
11#include "base/memory/ref_counted.h"
12#include "extensions/browser/api/api_resource_manager.h"
13#include "extensions/browser/api/async_api_function.h"
14#include "extensions/browser/extension_function.h"
15#include "extensions/common/api/socket.h"
16#include "net/base/address_list.h"
17#include "net/dns/host_resolver.h"
18#include "net/socket/tcp_client_socket.h"
19
20namespace content {
21class BrowserContext;
22class ResourceContext;
23}
24
25namespace net {
26class IOBuffer;
27class URLRequestContextGetter;
28class SSLClientSocket;
29}
30
31namespace extensions {
32class TLSSocket;
33class Socket;
34
35// A simple interface to ApiResourceManager<Socket> or derived class. The goal
36// of this interface is to allow Socket API functions to use distinct instances
37// of ApiResourceManager<> depending on the type of socket (old version in
38// "socket" namespace vs new version in "socket.xxx" namespaces).
39class SocketResourceManagerInterface {
40 public:
41  virtual ~SocketResourceManagerInterface() {}
42
43  virtual bool SetBrowserContext(content::BrowserContext* context) = 0;
44  virtual int Add(Socket* socket) = 0;
45  virtual Socket* Get(const std::string& extension_id, int api_resource_id) = 0;
46  virtual void Remove(const std::string& extension_id, int api_resource_id) = 0;
47  virtual void Replace(const std::string& extension_id,
48                       int api_resource_id,
49                       Socket* socket) = 0;
50  virtual base::hash_set<int>* GetResourceIds(
51      const std::string& extension_id) = 0;
52};
53
54// Implementation of SocketResourceManagerInterface using an
55// ApiResourceManager<T> instance (where T derives from Socket).
56template <typename T>
57class SocketResourceManager : public SocketResourceManagerInterface {
58 public:
59  SocketResourceManager() : manager_(NULL) {}
60
61  virtual bool SetBrowserContext(content::BrowserContext* context) OVERRIDE {
62    manager_ = ApiResourceManager<T>::Get(context);
63    DCHECK(manager_)
64        << "There is no socket manager. "
65           "If this assertion is failing during a test, then it is likely that "
66           "TestExtensionSystem is failing to provide an instance of "
67           "ApiResourceManager<Socket>.";
68    return manager_ != NULL;
69  }
70
71  virtual int Add(Socket* socket) OVERRIDE {
72    // Note: Cast needed here, because "T" may be a subclass of "Socket".
73    return manager_->Add(static_cast<T*>(socket));
74  }
75
76  virtual Socket* Get(const std::string& extension_id,
77                      int api_resource_id) OVERRIDE {
78    return manager_->Get(extension_id, api_resource_id);
79  }
80
81  virtual void Replace(const std::string& extension_id,
82                       int api_resource_id,
83                       Socket* socket) OVERRIDE {
84    manager_->Replace(extension_id, api_resource_id, static_cast<T*>(socket));
85  }
86
87  virtual void Remove(const std::string& extension_id,
88                      int api_resource_id) OVERRIDE {
89    manager_->Remove(extension_id, api_resource_id);
90  }
91
92  virtual base::hash_set<int>* GetResourceIds(const std::string& extension_id)
93      OVERRIDE {
94    return manager_->GetResourceIds(extension_id);
95  }
96
97 private:
98  ApiResourceManager<T>* manager_;
99};
100
101class SocketAsyncApiFunction : public AsyncApiFunction {
102 public:
103  SocketAsyncApiFunction();
104
105 protected:
106  virtual ~SocketAsyncApiFunction();
107
108  // AsyncApiFunction:
109  virtual bool PrePrepare() OVERRIDE;
110  virtual bool Respond() OVERRIDE;
111
112  virtual scoped_ptr<SocketResourceManagerInterface>
113      CreateSocketResourceManager();
114
115  int AddSocket(Socket* socket);
116  Socket* GetSocket(int api_resource_id);
117  void ReplaceSocket(int api_resource_id, Socket* socket);
118  void RemoveSocket(int api_resource_id);
119  base::hash_set<int>* GetSocketIds();
120
121 private:
122  scoped_ptr<SocketResourceManagerInterface> manager_;
123};
124
125class SocketExtensionWithDnsLookupFunction : public SocketAsyncApiFunction {
126 protected:
127  SocketExtensionWithDnsLookupFunction();
128  virtual ~SocketExtensionWithDnsLookupFunction();
129
130  // AsyncApiFunction:
131  virtual bool PrePrepare() OVERRIDE;
132
133  void StartDnsLookup(const std::string& hostname);
134  virtual void AfterDnsLookup(int lookup_result) = 0;
135
136  std::string resolved_address_;
137
138 private:
139  void OnDnsLookup(int resolve_result);
140
141  // Weak pointer to the resource context.
142  content::ResourceContext* resource_context_;
143
144  scoped_ptr<net::HostResolver::RequestHandle> request_handle_;
145  scoped_ptr<net::AddressList> addresses_;
146};
147
148class SocketCreateFunction : public SocketAsyncApiFunction {
149 public:
150  DECLARE_EXTENSION_FUNCTION("socket.create", SOCKET_CREATE)
151
152  SocketCreateFunction();
153
154 protected:
155  virtual ~SocketCreateFunction();
156
157  // AsyncApiFunction:
158  virtual bool Prepare() OVERRIDE;
159  virtual void Work() OVERRIDE;
160
161 private:
162  FRIEND_TEST_ALL_PREFIXES(SocketUnitTest, Create);
163  enum SocketType { kSocketTypeInvalid = -1, kSocketTypeTCP, kSocketTypeUDP };
164
165  scoped_ptr<core_api::socket::Create::Params> params_;
166  SocketType socket_type_;
167};
168
169class SocketDestroyFunction : public SocketAsyncApiFunction {
170 public:
171  DECLARE_EXTENSION_FUNCTION("socket.destroy", SOCKET_DESTROY)
172
173 protected:
174  virtual ~SocketDestroyFunction() {}
175
176  // AsyncApiFunction:
177  virtual bool Prepare() OVERRIDE;
178  virtual void Work() OVERRIDE;
179
180 private:
181  int socket_id_;
182};
183
184class SocketConnectFunction : public SocketExtensionWithDnsLookupFunction {
185 public:
186  DECLARE_EXTENSION_FUNCTION("socket.connect", SOCKET_CONNECT)
187
188  SocketConnectFunction();
189
190 protected:
191  virtual ~SocketConnectFunction();
192
193  // AsyncApiFunction:
194  virtual bool Prepare() OVERRIDE;
195  virtual void AsyncWorkStart() OVERRIDE;
196
197  // SocketExtensionWithDnsLookupFunction:
198  virtual void AfterDnsLookup(int lookup_result) OVERRIDE;
199
200 private:
201  void StartConnect();
202  void OnConnect(int result);
203
204  int socket_id_;
205  std::string hostname_;
206  int port_;
207  Socket* socket_;
208};
209
210class SocketDisconnectFunction : public SocketAsyncApiFunction {
211 public:
212  DECLARE_EXTENSION_FUNCTION("socket.disconnect", SOCKET_DISCONNECT)
213
214 protected:
215  virtual ~SocketDisconnectFunction() {}
216
217  // AsyncApiFunction:
218  virtual bool Prepare() OVERRIDE;
219  virtual void Work() OVERRIDE;
220
221 private:
222  int socket_id_;
223};
224
225class SocketBindFunction : public SocketAsyncApiFunction {
226 public:
227  DECLARE_EXTENSION_FUNCTION("socket.bind", SOCKET_BIND)
228
229 protected:
230  virtual ~SocketBindFunction() {}
231
232  // AsyncApiFunction:
233  virtual bool Prepare() OVERRIDE;
234  virtual void Work() OVERRIDE;
235
236 private:
237  int socket_id_;
238  std::string address_;
239  int port_;
240};
241
242class SocketListenFunction : public SocketAsyncApiFunction {
243 public:
244  DECLARE_EXTENSION_FUNCTION("socket.listen", SOCKET_LISTEN)
245
246  SocketListenFunction();
247
248 protected:
249  virtual ~SocketListenFunction();
250
251  // AsyncApiFunction:
252  virtual bool Prepare() OVERRIDE;
253  virtual void Work() OVERRIDE;
254
255 private:
256  scoped_ptr<core_api::socket::Listen::Params> params_;
257};
258
259class SocketAcceptFunction : public SocketAsyncApiFunction {
260 public:
261  DECLARE_EXTENSION_FUNCTION("socket.accept", SOCKET_ACCEPT)
262
263  SocketAcceptFunction();
264
265 protected:
266  virtual ~SocketAcceptFunction();
267
268  // AsyncApiFunction:
269  virtual bool Prepare() OVERRIDE;
270  virtual void AsyncWorkStart() OVERRIDE;
271
272 private:
273  void OnAccept(int result_code, net::TCPClientSocket* socket);
274
275  scoped_ptr<core_api::socket::Accept::Params> params_;
276};
277
278class SocketReadFunction : public SocketAsyncApiFunction {
279 public:
280  DECLARE_EXTENSION_FUNCTION("socket.read", SOCKET_READ)
281
282  SocketReadFunction();
283
284 protected:
285  virtual ~SocketReadFunction();
286
287  // AsyncApiFunction:
288  virtual bool Prepare() OVERRIDE;
289  virtual void AsyncWorkStart() OVERRIDE;
290  void OnCompleted(int result, scoped_refptr<net::IOBuffer> io_buffer);
291
292 private:
293  scoped_ptr<core_api::socket::Read::Params> params_;
294};
295
296class SocketWriteFunction : public SocketAsyncApiFunction {
297 public:
298  DECLARE_EXTENSION_FUNCTION("socket.write", SOCKET_WRITE)
299
300  SocketWriteFunction();
301
302 protected:
303  virtual ~SocketWriteFunction();
304
305  // AsyncApiFunction:
306  virtual bool Prepare() OVERRIDE;
307  virtual void AsyncWorkStart() OVERRIDE;
308  void OnCompleted(int result);
309
310 private:
311  int socket_id_;
312  scoped_refptr<net::IOBuffer> io_buffer_;
313  size_t io_buffer_size_;
314};
315
316class SocketRecvFromFunction : public SocketAsyncApiFunction {
317 public:
318  DECLARE_EXTENSION_FUNCTION("socket.recvFrom", SOCKET_RECVFROM)
319
320  SocketRecvFromFunction();
321
322 protected:
323  virtual ~SocketRecvFromFunction();
324
325  // AsyncApiFunction
326  virtual bool Prepare() OVERRIDE;
327  virtual void AsyncWorkStart() OVERRIDE;
328  void OnCompleted(int result,
329                   scoped_refptr<net::IOBuffer> io_buffer,
330                   const std::string& address,
331                   int port);
332
333 private:
334  scoped_ptr<core_api::socket::RecvFrom::Params> params_;
335};
336
337class SocketSendToFunction : public SocketExtensionWithDnsLookupFunction {
338 public:
339  DECLARE_EXTENSION_FUNCTION("socket.sendTo", SOCKET_SENDTO)
340
341  SocketSendToFunction();
342
343 protected:
344  virtual ~SocketSendToFunction();
345
346  // AsyncApiFunction:
347  virtual bool Prepare() OVERRIDE;
348  virtual void AsyncWorkStart() OVERRIDE;
349  void OnCompleted(int result);
350
351  // SocketExtensionWithDnsLookupFunction:
352  virtual void AfterDnsLookup(int lookup_result) OVERRIDE;
353
354 private:
355  void StartSendTo();
356
357  int socket_id_;
358  scoped_refptr<net::IOBuffer> io_buffer_;
359  size_t io_buffer_size_;
360  std::string hostname_;
361  int port_;
362  Socket* socket_;
363};
364
365class SocketSetKeepAliveFunction : public SocketAsyncApiFunction {
366 public:
367  DECLARE_EXTENSION_FUNCTION("socket.setKeepAlive", SOCKET_SETKEEPALIVE)
368
369  SocketSetKeepAliveFunction();
370
371 protected:
372  virtual ~SocketSetKeepAliveFunction();
373
374  // AsyncApiFunction:
375  virtual bool Prepare() OVERRIDE;
376  virtual void Work() OVERRIDE;
377
378 private:
379  scoped_ptr<core_api::socket::SetKeepAlive::Params> params_;
380};
381
382class SocketSetNoDelayFunction : public SocketAsyncApiFunction {
383 public:
384  DECLARE_EXTENSION_FUNCTION("socket.setNoDelay", SOCKET_SETNODELAY)
385
386  SocketSetNoDelayFunction();
387
388 protected:
389  virtual ~SocketSetNoDelayFunction();
390
391  // AsyncApiFunction:
392  virtual bool Prepare() OVERRIDE;
393  virtual void Work() OVERRIDE;
394
395 private:
396  scoped_ptr<core_api::socket::SetNoDelay::Params> params_;
397};
398
399class SocketGetInfoFunction : public SocketAsyncApiFunction {
400 public:
401  DECLARE_EXTENSION_FUNCTION("socket.getInfo", SOCKET_GETINFO)
402
403  SocketGetInfoFunction();
404
405 protected:
406  virtual ~SocketGetInfoFunction();
407
408  // AsyncApiFunction:
409  virtual bool Prepare() OVERRIDE;
410  virtual void Work() OVERRIDE;
411
412 private:
413  scoped_ptr<core_api::socket::GetInfo::Params> params_;
414};
415
416class SocketGetNetworkListFunction : public AsyncExtensionFunction {
417 public:
418  DECLARE_EXTENSION_FUNCTION("socket.getNetworkList", SOCKET_GETNETWORKLIST)
419
420 protected:
421  virtual ~SocketGetNetworkListFunction() {}
422  virtual bool RunAsync() OVERRIDE;
423
424 private:
425  void GetNetworkListOnFileThread();
426  void HandleGetNetworkListError();
427  void SendResponseOnUIThread(const net::NetworkInterfaceList& interface_list);
428};
429
430class SocketJoinGroupFunction : public SocketAsyncApiFunction {
431 public:
432  DECLARE_EXTENSION_FUNCTION("socket.joinGroup", SOCKET_MULTICAST_JOIN_GROUP)
433
434  SocketJoinGroupFunction();
435
436 protected:
437  virtual ~SocketJoinGroupFunction();
438
439  // AsyncApiFunction
440  virtual bool Prepare() OVERRIDE;
441  virtual void Work() OVERRIDE;
442
443 private:
444  scoped_ptr<core_api::socket::JoinGroup::Params> params_;
445};
446
447class SocketLeaveGroupFunction : public SocketAsyncApiFunction {
448 public:
449  DECLARE_EXTENSION_FUNCTION("socket.leaveGroup", SOCKET_MULTICAST_LEAVE_GROUP)
450
451  SocketLeaveGroupFunction();
452
453 protected:
454  virtual ~SocketLeaveGroupFunction();
455
456  // AsyncApiFunction
457  virtual bool Prepare() OVERRIDE;
458  virtual void Work() OVERRIDE;
459
460 private:
461  scoped_ptr<core_api::socket::LeaveGroup::Params> params_;
462};
463
464class SocketSetMulticastTimeToLiveFunction : public SocketAsyncApiFunction {
465 public:
466  DECLARE_EXTENSION_FUNCTION("socket.setMulticastTimeToLive",
467                             SOCKET_MULTICAST_SET_TIME_TO_LIVE)
468
469  SocketSetMulticastTimeToLiveFunction();
470
471 protected:
472  virtual ~SocketSetMulticastTimeToLiveFunction();
473
474  // AsyncApiFunction
475  virtual bool Prepare() OVERRIDE;
476  virtual void Work() OVERRIDE;
477
478 private:
479  scoped_ptr<core_api::socket::SetMulticastTimeToLive::Params> params_;
480};
481
482class SocketSetMulticastLoopbackModeFunction : public SocketAsyncApiFunction {
483 public:
484  DECLARE_EXTENSION_FUNCTION("socket.setMulticastLoopbackMode",
485                             SOCKET_MULTICAST_SET_LOOPBACK_MODE)
486
487  SocketSetMulticastLoopbackModeFunction();
488
489 protected:
490  virtual ~SocketSetMulticastLoopbackModeFunction();
491
492  // AsyncApiFunction
493  virtual bool Prepare() OVERRIDE;
494  virtual void Work() OVERRIDE;
495
496 private:
497  scoped_ptr<core_api::socket::SetMulticastLoopbackMode::Params> params_;
498};
499
500class SocketGetJoinedGroupsFunction : public SocketAsyncApiFunction {
501 public:
502  DECLARE_EXTENSION_FUNCTION("socket.getJoinedGroups",
503                             SOCKET_MULTICAST_GET_JOINED_GROUPS)
504
505  SocketGetJoinedGroupsFunction();
506
507 protected:
508  virtual ~SocketGetJoinedGroupsFunction();
509
510  // AsyncApiFunction
511  virtual bool Prepare() OVERRIDE;
512  virtual void Work() OVERRIDE;
513
514 private:
515  scoped_ptr<core_api::socket::GetJoinedGroups::Params> params_;
516};
517
518class SocketSecureFunction : public SocketAsyncApiFunction {
519 public:
520  DECLARE_EXTENSION_FUNCTION("socket.secure", SOCKET_SECURE);
521  SocketSecureFunction();
522
523 protected:
524  virtual ~SocketSecureFunction();
525
526  // AsyncApiFunction
527  virtual bool Prepare() OVERRIDE;
528  virtual void AsyncWorkStart() OVERRIDE;
529
530 private:
531  // Callback from TLSSocket::UpgradeSocketToTLS().
532  void TlsConnectDone(scoped_ptr<TLSSocket> socket, int result);
533
534  scoped_ptr<core_api::socket::Secure::Params> params_;
535  scoped_refptr<net::URLRequestContextGetter> url_request_getter_;
536
537  DISALLOW_COPY_AND_ASSIGN(SocketSecureFunction);
538};
539
540}  // namespace extensions
541
542#endif  // EXTENSIONS_BROWSER_API_SOCKET_SOCKET_API_H_
543