1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "extensions/browser/api/socket/socket_api.h"
6
7#include <vector>
8
9#include "base/bind.h"
10#include "base/containers/hash_tables.h"
11#include "content/public/browser/browser_context.h"
12#include "content/public/browser/resource_context.h"
13#include "extensions/browser/api/dns/host_resolver_wrapper.h"
14#include "extensions/browser/api/socket/socket.h"
15#include "extensions/browser/api/socket/tcp_socket.h"
16#include "extensions/browser/api/socket/tls_socket.h"
17#include "extensions/browser/api/socket/udp_socket.h"
18#include "extensions/browser/extension_system.h"
19#include "extensions/common/extension.h"
20#include "extensions/common/permissions/permissions_data.h"
21#include "extensions/common/permissions/socket_permission.h"
22#include "net/base/host_port_pair.h"
23#include "net/base/io_buffer.h"
24#include "net/base/ip_endpoint.h"
25#include "net/base/net_errors.h"
26#include "net/base/net_log.h"
27#include "net/base/net_util.h"
28#include "net/url_request/url_request_context.h"
29#include "net/url_request/url_request_context_getter.h"
30
31namespace extensions {
32
33using content::SocketPermissionRequest;
34
35const char kAddressKey[] = "address";
36const char kPortKey[] = "port";
37const char kBytesWrittenKey[] = "bytesWritten";
38const char kDataKey[] = "data";
39const char kResultCodeKey[] = "resultCode";
40const char kSocketIdKey[] = "socketId";
41
42const char kSocketNotFoundError[] = "Socket not found";
43const char kDnsLookupFailedError[] = "DNS resolution failed";
44const char kPermissionError[] = "App does not have permission";
45const char kNetworkListError[] = "Network lookup failed or unsupported";
46const char kTCPSocketBindError[] =
47    "TCP socket does not support bind. For TCP server please use listen.";
48const char kMulticastSocketTypeError[] = "Only UDP socket supports multicast.";
49const char kSecureSocketTypeError[] = "Only TCP sockets are supported for TLS.";
50const char kSocketNotConnectedError[] = "Socket not connected";
51const char kWildcardAddress[] = "*";
52const int kWildcardPort = 0;
53
54SocketAsyncApiFunction::SocketAsyncApiFunction() {}
55
56SocketAsyncApiFunction::~SocketAsyncApiFunction() {}
57
58bool SocketAsyncApiFunction::PrePrepare() {
59  manager_ = CreateSocketResourceManager();
60  return manager_->SetBrowserContext(browser_context());
61}
62
63bool SocketAsyncApiFunction::Respond() { return error_.empty(); }
64
65scoped_ptr<SocketResourceManagerInterface>
66SocketAsyncApiFunction::CreateSocketResourceManager() {
67  return scoped_ptr<SocketResourceManagerInterface>(
68             new SocketResourceManager<Socket>()).Pass();
69}
70
71int SocketAsyncApiFunction::AddSocket(Socket* socket) {
72  return manager_->Add(socket);
73}
74
75Socket* SocketAsyncApiFunction::GetSocket(int api_resource_id) {
76  return manager_->Get(extension_->id(), api_resource_id);
77}
78
79void SocketAsyncApiFunction::ReplaceSocket(int api_resource_id,
80                                           Socket* socket) {
81  manager_->Replace(extension_->id(), api_resource_id, socket);
82}
83
84base::hash_set<int>* SocketAsyncApiFunction::GetSocketIds() {
85  return manager_->GetResourceIds(extension_->id());
86}
87
88void SocketAsyncApiFunction::RemoveSocket(int api_resource_id) {
89  manager_->Remove(extension_->id(), api_resource_id);
90}
91
92SocketExtensionWithDnsLookupFunction::SocketExtensionWithDnsLookupFunction()
93    : resource_context_(NULL),
94      request_handle_(new net::HostResolver::RequestHandle),
95      addresses_(new net::AddressList) {}
96
97SocketExtensionWithDnsLookupFunction::~SocketExtensionWithDnsLookupFunction() {}
98
99bool SocketExtensionWithDnsLookupFunction::PrePrepare() {
100  if (!SocketAsyncApiFunction::PrePrepare())
101    return false;
102  resource_context_ = browser_context()->GetResourceContext();
103  return resource_context_ != NULL;
104}
105
106void SocketExtensionWithDnsLookupFunction::StartDnsLookup(
107    const std::string& hostname) {
108  net::HostResolver* host_resolver =
109      HostResolverWrapper::GetInstance()->GetHostResolver(resource_context_);
110  DCHECK(host_resolver);
111
112  // Yes, we are passing zero as the port. There are some interesting but not
113  // presently relevant reasons why HostResolver asks for the port of the
114  // hostname you'd like to resolve, even though it doesn't use that value in
115  // determining its answer.
116  net::HostPortPair host_port_pair(hostname, 0);
117
118  net::HostResolver::RequestInfo request_info(host_port_pair);
119  int resolve_result = host_resolver->Resolve(
120      request_info,
121      net::DEFAULT_PRIORITY,
122      addresses_.get(),
123      base::Bind(&SocketExtensionWithDnsLookupFunction::OnDnsLookup, this),
124      request_handle_.get(),
125      net::BoundNetLog());
126
127  if (resolve_result != net::ERR_IO_PENDING)
128    OnDnsLookup(resolve_result);
129}
130
131void SocketExtensionWithDnsLookupFunction::OnDnsLookup(int resolve_result) {
132  if (resolve_result == net::OK) {
133    DCHECK(!addresses_->empty());
134    resolved_address_ = addresses_->front().ToStringWithoutPort();
135  } else {
136    error_ = kDnsLookupFailedError;
137  }
138  AfterDnsLookup(resolve_result);
139}
140
141SocketCreateFunction::SocketCreateFunction()
142    : socket_type_(kSocketTypeInvalid) {}
143
144SocketCreateFunction::~SocketCreateFunction() {}
145
146bool SocketCreateFunction::Prepare() {
147  params_ = core_api::socket::Create::Params::Create(*args_);
148  EXTENSION_FUNCTION_VALIDATE(params_.get());
149
150  switch (params_->type) {
151    case extensions::core_api::socket::SOCKET_TYPE_TCP:
152      socket_type_ = kSocketTypeTCP;
153      break;
154    case extensions::core_api::socket::SOCKET_TYPE_UDP:
155      socket_type_ = kSocketTypeUDP;
156      break;
157    case extensions::core_api::socket::SOCKET_TYPE_NONE:
158      NOTREACHED();
159      break;
160  }
161
162  return true;
163}
164
165void SocketCreateFunction::Work() {
166  Socket* socket = NULL;
167  if (socket_type_ == kSocketTypeTCP) {
168    socket = new TCPSocket(extension_->id());
169  } else if (socket_type_ == kSocketTypeUDP) {
170    socket = new UDPSocket(extension_->id());
171  }
172  DCHECK(socket);
173
174  base::DictionaryValue* result = new base::DictionaryValue();
175  result->SetInteger(kSocketIdKey, AddSocket(socket));
176  SetResult(result);
177}
178
179bool SocketDestroyFunction::Prepare() {
180  EXTENSION_FUNCTION_VALIDATE(args_->GetInteger(0, &socket_id_));
181  return true;
182}
183
184void SocketDestroyFunction::Work() { RemoveSocket(socket_id_); }
185
186SocketConnectFunction::SocketConnectFunction()
187    : socket_id_(0), hostname_(), port_(0), socket_(NULL) {}
188
189SocketConnectFunction::~SocketConnectFunction() {}
190
191bool SocketConnectFunction::Prepare() {
192  EXTENSION_FUNCTION_VALIDATE(args_->GetInteger(0, &socket_id_));
193  EXTENSION_FUNCTION_VALIDATE(args_->GetString(1, &hostname_));
194  EXTENSION_FUNCTION_VALIDATE(args_->GetInteger(2, &port_));
195  return true;
196}
197
198void SocketConnectFunction::AsyncWorkStart() {
199  socket_ = GetSocket(socket_id_);
200  if (!socket_) {
201    error_ = kSocketNotFoundError;
202    SetResult(new base::FundamentalValue(-1));
203    AsyncWorkCompleted();
204    return;
205  }
206
207  socket_->set_hostname(hostname_);
208
209  SocketPermissionRequest::OperationType operation_type;
210  switch (socket_->GetSocketType()) {
211    case Socket::TYPE_TCP:
212      operation_type = SocketPermissionRequest::TCP_CONNECT;
213      break;
214    case Socket::TYPE_UDP:
215      operation_type = SocketPermissionRequest::UDP_SEND_TO;
216      break;
217    default:
218      NOTREACHED() << "Unknown socket type.";
219      operation_type = SocketPermissionRequest::NONE;
220      break;
221  }
222
223  SocketPermission::CheckParam param(operation_type, hostname_, port_);
224  if (!extension()->permissions_data()->CheckAPIPermissionWithParam(
225          APIPermission::kSocket, &param)) {
226    error_ = kPermissionError;
227    SetResult(new base::FundamentalValue(-1));
228    AsyncWorkCompleted();
229    return;
230  }
231
232  StartDnsLookup(hostname_);
233}
234
235void SocketConnectFunction::AfterDnsLookup(int lookup_result) {
236  if (lookup_result == net::OK) {
237    StartConnect();
238  } else {
239    SetResult(new base::FundamentalValue(lookup_result));
240    AsyncWorkCompleted();
241  }
242}
243
244void SocketConnectFunction::StartConnect() {
245  socket_->Connect(resolved_address_,
246                   port_,
247                   base::Bind(&SocketConnectFunction::OnConnect, this));
248}
249
250void SocketConnectFunction::OnConnect(int result) {
251  SetResult(new base::FundamentalValue(result));
252  AsyncWorkCompleted();
253}
254
255bool SocketDisconnectFunction::Prepare() {
256  EXTENSION_FUNCTION_VALIDATE(args_->GetInteger(0, &socket_id_));
257  return true;
258}
259
260void SocketDisconnectFunction::Work() {
261  Socket* socket = GetSocket(socket_id_);
262  if (socket)
263    socket->Disconnect();
264  else
265    error_ = kSocketNotFoundError;
266  SetResult(base::Value::CreateNullValue());
267}
268
269bool SocketBindFunction::Prepare() {
270  EXTENSION_FUNCTION_VALIDATE(args_->GetInteger(0, &socket_id_));
271  EXTENSION_FUNCTION_VALIDATE(args_->GetString(1, &address_));
272  EXTENSION_FUNCTION_VALIDATE(args_->GetInteger(2, &port_));
273  return true;
274}
275
276void SocketBindFunction::Work() {
277  int result = -1;
278  Socket* socket = GetSocket(socket_id_);
279
280  if (!socket) {
281    error_ = kSocketNotFoundError;
282    SetResult(new base::FundamentalValue(result));
283    return;
284  }
285
286  if (socket->GetSocketType() == Socket::TYPE_UDP) {
287    SocketPermission::CheckParam param(
288        SocketPermissionRequest::UDP_BIND, address_, port_);
289    if (!extension()->permissions_data()->CheckAPIPermissionWithParam(
290            APIPermission::kSocket, &param)) {
291      error_ = kPermissionError;
292      SetResult(new base::FundamentalValue(result));
293      return;
294    }
295  } else if (socket->GetSocketType() == Socket::TYPE_TCP) {
296    error_ = kTCPSocketBindError;
297    SetResult(new base::FundamentalValue(result));
298    return;
299  }
300
301  result = socket->Bind(address_, port_);
302  SetResult(new base::FundamentalValue(result));
303}
304
305SocketListenFunction::SocketListenFunction() {}
306
307SocketListenFunction::~SocketListenFunction() {}
308
309bool SocketListenFunction::Prepare() {
310  params_ = core_api::socket::Listen::Params::Create(*args_);
311  EXTENSION_FUNCTION_VALIDATE(params_.get());
312  return true;
313}
314
315void SocketListenFunction::Work() {
316  int result = -1;
317
318  Socket* socket = GetSocket(params_->socket_id);
319  if (socket) {
320    SocketPermission::CheckParam param(
321        SocketPermissionRequest::TCP_LISTEN, params_->address, params_->port);
322    if (!extension()->permissions_data()->CheckAPIPermissionWithParam(
323            APIPermission::kSocket, &param)) {
324      error_ = kPermissionError;
325      SetResult(new base::FundamentalValue(result));
326      return;
327    }
328
329    result =
330        socket->Listen(params_->address,
331                       params_->port,
332                       params_->backlog.get() ? *params_->backlog.get() : 5,
333                       &error_);
334  } else {
335    error_ = kSocketNotFoundError;
336  }
337
338  SetResult(new base::FundamentalValue(result));
339}
340
341SocketAcceptFunction::SocketAcceptFunction() {}
342
343SocketAcceptFunction::~SocketAcceptFunction() {}
344
345bool SocketAcceptFunction::Prepare() {
346  params_ = core_api::socket::Accept::Params::Create(*args_);
347  EXTENSION_FUNCTION_VALIDATE(params_.get());
348  return true;
349}
350
351void SocketAcceptFunction::AsyncWorkStart() {
352  Socket* socket = GetSocket(params_->socket_id);
353  if (socket) {
354    socket->Accept(base::Bind(&SocketAcceptFunction::OnAccept, this));
355  } else {
356    error_ = kSocketNotFoundError;
357    OnAccept(-1, NULL);
358  }
359}
360
361void SocketAcceptFunction::OnAccept(int result_code,
362                                    net::TCPClientSocket* socket) {
363  base::DictionaryValue* result = new base::DictionaryValue();
364  result->SetInteger(kResultCodeKey, result_code);
365  if (socket) {
366    Socket* client_socket = new TCPSocket(socket, extension_id(), true);
367    result->SetInteger(kSocketIdKey, AddSocket(client_socket));
368  }
369  SetResult(result);
370
371  AsyncWorkCompleted();
372}
373
374SocketReadFunction::SocketReadFunction() {}
375
376SocketReadFunction::~SocketReadFunction() {}
377
378bool SocketReadFunction::Prepare() {
379  params_ = core_api::socket::Read::Params::Create(*args_);
380  EXTENSION_FUNCTION_VALIDATE(params_.get());
381  return true;
382}
383
384void SocketReadFunction::AsyncWorkStart() {
385  Socket* socket = GetSocket(params_->socket_id);
386  if (!socket) {
387    error_ = kSocketNotFoundError;
388    OnCompleted(-1, NULL);
389    return;
390  }
391
392  socket->Read(params_->buffer_size.get() ? *params_->buffer_size.get() : 4096,
393               base::Bind(&SocketReadFunction::OnCompleted, this));
394}
395
396void SocketReadFunction::OnCompleted(int bytes_read,
397                                     scoped_refptr<net::IOBuffer> io_buffer) {
398  base::DictionaryValue* result = new base::DictionaryValue();
399  result->SetInteger(kResultCodeKey, bytes_read);
400  if (bytes_read > 0) {
401    result->Set(kDataKey,
402                base::BinaryValue::CreateWithCopiedBuffer(io_buffer->data(),
403                                                          bytes_read));
404  } else {
405    result->Set(kDataKey, new base::BinaryValue());
406  }
407  SetResult(result);
408
409  AsyncWorkCompleted();
410}
411
412SocketWriteFunction::SocketWriteFunction()
413    : socket_id_(0), io_buffer_(NULL), io_buffer_size_(0) {}
414
415SocketWriteFunction::~SocketWriteFunction() {}
416
417bool SocketWriteFunction::Prepare() {
418  EXTENSION_FUNCTION_VALIDATE(args_->GetInteger(0, &socket_id_));
419  base::BinaryValue* data = NULL;
420  EXTENSION_FUNCTION_VALIDATE(args_->GetBinary(1, &data));
421
422  io_buffer_size_ = data->GetSize();
423  io_buffer_ = new net::WrappedIOBuffer(data->GetBuffer());
424  return true;
425}
426
427void SocketWriteFunction::AsyncWorkStart() {
428  Socket* socket = GetSocket(socket_id_);
429
430  if (!socket) {
431    error_ = kSocketNotFoundError;
432    OnCompleted(-1);
433    return;
434  }
435
436  socket->Write(io_buffer_,
437                io_buffer_size_,
438                base::Bind(&SocketWriteFunction::OnCompleted, this));
439}
440
441void SocketWriteFunction::OnCompleted(int bytes_written) {
442  base::DictionaryValue* result = new base::DictionaryValue();
443  result->SetInteger(kBytesWrittenKey, bytes_written);
444  SetResult(result);
445
446  AsyncWorkCompleted();
447}
448
449SocketRecvFromFunction::SocketRecvFromFunction() {}
450
451SocketRecvFromFunction::~SocketRecvFromFunction() {}
452
453bool SocketRecvFromFunction::Prepare() {
454  params_ = core_api::socket::RecvFrom::Params::Create(*args_);
455  EXTENSION_FUNCTION_VALIDATE(params_.get());
456  return true;
457}
458
459void SocketRecvFromFunction::AsyncWorkStart() {
460  Socket* socket = GetSocket(params_->socket_id);
461  if (!socket) {
462    error_ = kSocketNotFoundError;
463    OnCompleted(-1, NULL, std::string(), 0);
464    return;
465  }
466
467  socket->RecvFrom(params_->buffer_size.get() ? *params_->buffer_size : 4096,
468                   base::Bind(&SocketRecvFromFunction::OnCompleted, this));
469}
470
471void SocketRecvFromFunction::OnCompleted(int bytes_read,
472                                         scoped_refptr<net::IOBuffer> io_buffer,
473                                         const std::string& address,
474                                         int port) {
475  base::DictionaryValue* result = new base::DictionaryValue();
476  result->SetInteger(kResultCodeKey, bytes_read);
477  if (bytes_read > 0) {
478    result->Set(kDataKey,
479                base::BinaryValue::CreateWithCopiedBuffer(io_buffer->data(),
480                                                          bytes_read));
481  } else {
482    result->Set(kDataKey, new base::BinaryValue());
483  }
484  result->SetString(kAddressKey, address);
485  result->SetInteger(kPortKey, port);
486  SetResult(result);
487
488  AsyncWorkCompleted();
489}
490
491SocketSendToFunction::SocketSendToFunction()
492    : socket_id_(0),
493      io_buffer_(NULL),
494      io_buffer_size_(0),
495      port_(0),
496      socket_(NULL) {}
497
498SocketSendToFunction::~SocketSendToFunction() {}
499
500bool SocketSendToFunction::Prepare() {
501  EXTENSION_FUNCTION_VALIDATE(args_->GetInteger(0, &socket_id_));
502  base::BinaryValue* data = NULL;
503  EXTENSION_FUNCTION_VALIDATE(args_->GetBinary(1, &data));
504  EXTENSION_FUNCTION_VALIDATE(args_->GetString(2, &hostname_));
505  EXTENSION_FUNCTION_VALIDATE(args_->GetInteger(3, &port_));
506
507  io_buffer_size_ = data->GetSize();
508  io_buffer_ = new net::WrappedIOBuffer(data->GetBuffer());
509  return true;
510}
511
512void SocketSendToFunction::AsyncWorkStart() {
513  socket_ = GetSocket(socket_id_);
514  if (!socket_) {
515    error_ = kSocketNotFoundError;
516    SetResult(new base::FundamentalValue(-1));
517    AsyncWorkCompleted();
518    return;
519  }
520
521  if (socket_->GetSocketType() == Socket::TYPE_UDP) {
522    SocketPermission::CheckParam param(
523        SocketPermissionRequest::UDP_SEND_TO, hostname_, port_);
524    if (!extension()->permissions_data()->CheckAPIPermissionWithParam(
525            APIPermission::kSocket, &param)) {
526      error_ = kPermissionError;
527      SetResult(new base::FundamentalValue(-1));
528      AsyncWorkCompleted();
529      return;
530    }
531  }
532
533  StartDnsLookup(hostname_);
534}
535
536void SocketSendToFunction::AfterDnsLookup(int lookup_result) {
537  if (lookup_result == net::OK) {
538    StartSendTo();
539  } else {
540    SetResult(new base::FundamentalValue(lookup_result));
541    AsyncWorkCompleted();
542  }
543}
544
545void SocketSendToFunction::StartSendTo() {
546  socket_->SendTo(io_buffer_,
547                  io_buffer_size_,
548                  resolved_address_,
549                  port_,
550                  base::Bind(&SocketSendToFunction::OnCompleted, this));
551}
552
553void SocketSendToFunction::OnCompleted(int bytes_written) {
554  base::DictionaryValue* result = new base::DictionaryValue();
555  result->SetInteger(kBytesWrittenKey, bytes_written);
556  SetResult(result);
557
558  AsyncWorkCompleted();
559}
560
561SocketSetKeepAliveFunction::SocketSetKeepAliveFunction() {}
562
563SocketSetKeepAliveFunction::~SocketSetKeepAliveFunction() {}
564
565bool SocketSetKeepAliveFunction::Prepare() {
566  params_ = core_api::socket::SetKeepAlive::Params::Create(*args_);
567  EXTENSION_FUNCTION_VALIDATE(params_.get());
568  return true;
569}
570
571void SocketSetKeepAliveFunction::Work() {
572  bool result = false;
573  Socket* socket = GetSocket(params_->socket_id);
574  if (socket) {
575    int delay = 0;
576    if (params_->delay.get())
577      delay = *params_->delay;
578    result = socket->SetKeepAlive(params_->enable, delay);
579  } else {
580    error_ = kSocketNotFoundError;
581  }
582  SetResult(new base::FundamentalValue(result));
583}
584
585SocketSetNoDelayFunction::SocketSetNoDelayFunction() {}
586
587SocketSetNoDelayFunction::~SocketSetNoDelayFunction() {}
588
589bool SocketSetNoDelayFunction::Prepare() {
590  params_ = core_api::socket::SetNoDelay::Params::Create(*args_);
591  EXTENSION_FUNCTION_VALIDATE(params_.get());
592  return true;
593}
594
595void SocketSetNoDelayFunction::Work() {
596  bool result = false;
597  Socket* socket = GetSocket(params_->socket_id);
598  if (socket)
599    result = socket->SetNoDelay(params_->no_delay);
600  else
601    error_ = kSocketNotFoundError;
602  SetResult(new base::FundamentalValue(result));
603}
604
605SocketGetInfoFunction::SocketGetInfoFunction() {}
606
607SocketGetInfoFunction::~SocketGetInfoFunction() {}
608
609bool SocketGetInfoFunction::Prepare() {
610  params_ = core_api::socket::GetInfo::Params::Create(*args_);
611  EXTENSION_FUNCTION_VALIDATE(params_.get());
612  return true;
613}
614
615void SocketGetInfoFunction::Work() {
616  Socket* socket = GetSocket(params_->socket_id);
617  if (!socket) {
618    error_ = kSocketNotFoundError;
619    return;
620  }
621
622  core_api::socket::SocketInfo info;
623  // This represents what we know about the socket, and does not call through
624  // to the system.
625  if (socket->GetSocketType() == Socket::TYPE_TCP)
626    info.socket_type = extensions::core_api::socket::SOCKET_TYPE_TCP;
627  else
628    info.socket_type = extensions::core_api::socket::SOCKET_TYPE_UDP;
629  info.connected = socket->IsConnected();
630
631  // Grab the peer address as known by the OS. This and the call below will
632  // always succeed while the socket is connected, even if the socket has
633  // been remotely closed by the peer; only reading the socket will reveal
634  // that it should be closed locally.
635  net::IPEndPoint peerAddress;
636  if (socket->GetPeerAddress(&peerAddress)) {
637    info.peer_address.reset(new std::string(peerAddress.ToStringWithoutPort()));
638    info.peer_port.reset(new int(peerAddress.port()));
639  }
640
641  // Grab the local address as known by the OS.
642  net::IPEndPoint localAddress;
643  if (socket->GetLocalAddress(&localAddress)) {
644    info.local_address.reset(
645        new std::string(localAddress.ToStringWithoutPort()));
646    info.local_port.reset(new int(localAddress.port()));
647  }
648
649  SetResult(info.ToValue().release());
650}
651
652bool SocketGetNetworkListFunction::RunAsync() {
653  content::BrowserThread::PostTask(
654      content::BrowserThread::FILE,
655      FROM_HERE,
656      base::Bind(&SocketGetNetworkListFunction::GetNetworkListOnFileThread,
657                 this));
658  return true;
659}
660
661void SocketGetNetworkListFunction::GetNetworkListOnFileThread() {
662  net::NetworkInterfaceList interface_list;
663  if (GetNetworkList(&interface_list,
664                     net::INCLUDE_HOST_SCOPE_VIRTUAL_INTERFACES)) {
665    content::BrowserThread::PostTask(
666        content::BrowserThread::UI,
667        FROM_HERE,
668        base::Bind(&SocketGetNetworkListFunction::SendResponseOnUIThread,
669                   this,
670                   interface_list));
671    return;
672  }
673
674  content::BrowserThread::PostTask(
675      content::BrowserThread::UI,
676      FROM_HERE,
677      base::Bind(&SocketGetNetworkListFunction::HandleGetNetworkListError,
678                 this));
679}
680
681void SocketGetNetworkListFunction::HandleGetNetworkListError() {
682  DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
683  error_ = kNetworkListError;
684  SendResponse(false);
685}
686
687void SocketGetNetworkListFunction::SendResponseOnUIThread(
688    const net::NetworkInterfaceList& interface_list) {
689  DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
690
691  std::vector<linked_ptr<core_api::socket::NetworkInterface> > create_arg;
692  create_arg.reserve(interface_list.size());
693  for (net::NetworkInterfaceList::const_iterator i = interface_list.begin();
694       i != interface_list.end();
695       ++i) {
696    linked_ptr<core_api::socket::NetworkInterface> info =
697        make_linked_ptr(new core_api::socket::NetworkInterface);
698    info->name = i->name;
699    info->address = net::IPAddressToString(i->address);
700    info->prefix_length = i->network_prefix;
701    create_arg.push_back(info);
702  }
703
704  results_ = core_api::socket::GetNetworkList::Results::Create(create_arg);
705  SendResponse(true);
706}
707
708SocketJoinGroupFunction::SocketJoinGroupFunction() {}
709
710SocketJoinGroupFunction::~SocketJoinGroupFunction() {}
711
712bool SocketJoinGroupFunction::Prepare() {
713  params_ = core_api::socket::JoinGroup::Params::Create(*args_);
714  EXTENSION_FUNCTION_VALIDATE(params_.get());
715  return true;
716}
717
718void SocketJoinGroupFunction::Work() {
719  int result = -1;
720  Socket* socket = GetSocket(params_->socket_id);
721  if (!socket) {
722    error_ = kSocketNotFoundError;
723    SetResult(new base::FundamentalValue(result));
724    return;
725  }
726
727  if (socket->GetSocketType() != Socket::TYPE_UDP) {
728    error_ = kMulticastSocketTypeError;
729    SetResult(new base::FundamentalValue(result));
730    return;
731  }
732
733  SocketPermission::CheckParam param(
734      SocketPermissionRequest::UDP_MULTICAST_MEMBERSHIP,
735      kWildcardAddress,
736      kWildcardPort);
737
738  if (!extension()->permissions_data()->CheckAPIPermissionWithParam(
739          APIPermission::kSocket, &param)) {
740    error_ = kPermissionError;
741    SetResult(new base::FundamentalValue(result));
742    return;
743  }
744
745  result = static_cast<UDPSocket*>(socket)->JoinGroup(params_->address);
746  if (result != 0) {
747    error_ = net::ErrorToString(result);
748  }
749  SetResult(new base::FundamentalValue(result));
750}
751
752SocketLeaveGroupFunction::SocketLeaveGroupFunction() {}
753
754SocketLeaveGroupFunction::~SocketLeaveGroupFunction() {}
755
756bool SocketLeaveGroupFunction::Prepare() {
757  params_ = core_api::socket::LeaveGroup::Params::Create(*args_);
758  EXTENSION_FUNCTION_VALIDATE(params_.get());
759  return true;
760}
761
762void SocketLeaveGroupFunction::Work() {
763  int result = -1;
764  Socket* socket = GetSocket(params_->socket_id);
765
766  if (!socket) {
767    error_ = kSocketNotFoundError;
768    SetResult(new base::FundamentalValue(result));
769    return;
770  }
771
772  if (socket->GetSocketType() != Socket::TYPE_UDP) {
773    error_ = kMulticastSocketTypeError;
774    SetResult(new base::FundamentalValue(result));
775    return;
776  }
777
778  SocketPermission::CheckParam param(
779      SocketPermissionRequest::UDP_MULTICAST_MEMBERSHIP,
780      kWildcardAddress,
781      kWildcardPort);
782  if (!extension()->permissions_data()->CheckAPIPermissionWithParam(
783          APIPermission::kSocket, &param)) {
784    error_ = kPermissionError;
785    SetResult(new base::FundamentalValue(result));
786    return;
787  }
788
789  result = static_cast<UDPSocket*>(socket)->LeaveGroup(params_->address);
790  if (result != 0)
791    error_ = net::ErrorToString(result);
792  SetResult(new base::FundamentalValue(result));
793}
794
795SocketSetMulticastTimeToLiveFunction::SocketSetMulticastTimeToLiveFunction() {}
796
797SocketSetMulticastTimeToLiveFunction::~SocketSetMulticastTimeToLiveFunction() {}
798
799bool SocketSetMulticastTimeToLiveFunction::Prepare() {
800  params_ = core_api::socket::SetMulticastTimeToLive::Params::Create(*args_);
801  EXTENSION_FUNCTION_VALIDATE(params_.get());
802  return true;
803}
804void SocketSetMulticastTimeToLiveFunction::Work() {
805  int result = -1;
806  Socket* socket = GetSocket(params_->socket_id);
807  if (!socket) {
808    error_ = kSocketNotFoundError;
809    SetResult(new base::FundamentalValue(result));
810    return;
811  }
812
813  if (socket->GetSocketType() != Socket::TYPE_UDP) {
814    error_ = kMulticastSocketTypeError;
815    SetResult(new base::FundamentalValue(result));
816    return;
817  }
818
819  result =
820      static_cast<UDPSocket*>(socket)->SetMulticastTimeToLive(params_->ttl);
821  if (result != 0)
822    error_ = net::ErrorToString(result);
823  SetResult(new base::FundamentalValue(result));
824}
825
826SocketSetMulticastLoopbackModeFunction::
827    SocketSetMulticastLoopbackModeFunction() {}
828
829SocketSetMulticastLoopbackModeFunction::
830    ~SocketSetMulticastLoopbackModeFunction() {}
831
832bool SocketSetMulticastLoopbackModeFunction::Prepare() {
833  params_ = core_api::socket::SetMulticastLoopbackMode::Params::Create(*args_);
834  EXTENSION_FUNCTION_VALIDATE(params_.get());
835  return true;
836}
837
838void SocketSetMulticastLoopbackModeFunction::Work() {
839  int result = -1;
840  Socket* socket = GetSocket(params_->socket_id);
841  if (!socket) {
842    error_ = kSocketNotFoundError;
843    SetResult(new base::FundamentalValue(result));
844    return;
845  }
846
847  if (socket->GetSocketType() != Socket::TYPE_UDP) {
848    error_ = kMulticastSocketTypeError;
849    SetResult(new base::FundamentalValue(result));
850    return;
851  }
852
853  result = static_cast<UDPSocket*>(socket)
854               ->SetMulticastLoopbackMode(params_->enabled);
855  if (result != 0)
856    error_ = net::ErrorToString(result);
857  SetResult(new base::FundamentalValue(result));
858}
859
860SocketGetJoinedGroupsFunction::SocketGetJoinedGroupsFunction() {}
861
862SocketGetJoinedGroupsFunction::~SocketGetJoinedGroupsFunction() {}
863
864bool SocketGetJoinedGroupsFunction::Prepare() {
865  params_ = core_api::socket::GetJoinedGroups::Params::Create(*args_);
866  EXTENSION_FUNCTION_VALIDATE(params_.get());
867  return true;
868}
869
870void SocketGetJoinedGroupsFunction::Work() {
871  int result = -1;
872  Socket* socket = GetSocket(params_->socket_id);
873  if (!socket) {
874    error_ = kSocketNotFoundError;
875    SetResult(new base::FundamentalValue(result));
876    return;
877  }
878
879  if (socket->GetSocketType() != Socket::TYPE_UDP) {
880    error_ = kMulticastSocketTypeError;
881    SetResult(new base::FundamentalValue(result));
882    return;
883  }
884
885  SocketPermission::CheckParam param(
886      SocketPermissionRequest::UDP_MULTICAST_MEMBERSHIP,
887      kWildcardAddress,
888      kWildcardPort);
889  if (!extension()->permissions_data()->CheckAPIPermissionWithParam(
890          APIPermission::kSocket, &param)) {
891    error_ = kPermissionError;
892    SetResult(new base::FundamentalValue(result));
893    return;
894  }
895
896  base::ListValue* values = new base::ListValue();
897  values->AppendStrings((std::vector<std::string>&)static_cast<UDPSocket*>(
898                            socket)->GetJoinedGroups());
899  SetResult(values);
900}
901
902SocketSecureFunction::SocketSecureFunction() {
903}
904
905SocketSecureFunction::~SocketSecureFunction() {
906}
907
908bool SocketSecureFunction::Prepare() {
909  DCHECK(content::BrowserThread::CurrentlyOn(content::BrowserThread::UI));
910  params_ = core_api::socket::Secure::Params::Create(*args_);
911  EXTENSION_FUNCTION_VALIDATE(params_.get());
912  url_request_getter_ = browser_context()->GetRequestContext();
913  return true;
914}
915
916// Override the regular implementation, which would call AsyncWorkCompleted
917// immediately after Work().
918void SocketSecureFunction::AsyncWorkStart() {
919  DCHECK(content::BrowserThread::CurrentlyOn(content::BrowserThread::IO));
920
921  Socket* socket = GetSocket(params_->socket_id);
922  if (!socket) {
923    SetResult(new base::FundamentalValue(net::ERR_INVALID_ARGUMENT));
924    error_ = kSocketNotFoundError;
925    AsyncWorkCompleted();
926    return;
927  }
928
929  // Make sure that the socket is a TCP client socket.
930  if (socket->GetSocketType() != Socket::TYPE_TCP ||
931      static_cast<TCPSocket*>(socket)->ClientStream() == NULL) {
932    SetResult(new base::FundamentalValue(net::ERR_INVALID_ARGUMENT));
933    error_ = kSecureSocketTypeError;
934    AsyncWorkCompleted();
935    return;
936  }
937
938  if (!socket->IsConnected()) {
939    SetResult(new base::FundamentalValue(net::ERR_INVALID_ARGUMENT));
940    error_ = kSocketNotConnectedError;
941    AsyncWorkCompleted();
942    return;
943  }
944
945  net::URLRequestContext* url_request_context =
946      url_request_getter_->GetURLRequestContext();
947
948  TLSSocket::UpgradeSocketToTLS(
949      socket,
950      url_request_context->ssl_config_service(),
951      url_request_context->cert_verifier(),
952      url_request_context->transport_security_state(),
953      extension_id(),
954      params_->options.get(),
955      base::Bind(&SocketSecureFunction::TlsConnectDone, this));
956}
957
958void SocketSecureFunction::TlsConnectDone(scoped_ptr<TLSSocket> socket,
959                                          int result) {
960  // if an error occurred, socket MUST be NULL.
961  DCHECK(result == net::OK || socket == NULL);
962
963  if (socket && result == net::OK) {
964    ReplaceSocket(params_->socket_id, socket.release());
965  } else {
966    RemoveSocket(params_->socket_id);
967    error_ = net::ErrorToString(result);
968  }
969
970  results_ = core_api::socket::Secure::Results::Create(result);
971  AsyncWorkCompleted();
972}
973
974}  // namespace extensions
975