1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "extensions/browser/api/socket/udp_socket.h"
6
7#include <algorithm>
8
9#include "base/lazy_instance.h"
10#include "extensions/browser/api/api_resource.h"
11#include "net/base/ip_endpoint.h"
12#include "net/base/net_errors.h"
13#include "net/udp/datagram_socket.h"
14#include "net/udp/udp_client_socket.h"
15
16namespace extensions {
17
18static base::LazyInstance<
19    BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableUDPSocket> > >
20    g_factory = LAZY_INSTANCE_INITIALIZER;
21
22// static
23template <>
24BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableUDPSocket> >*
25ApiResourceManager<ResumableUDPSocket>::GetFactoryInstance() {
26  return g_factory.Pointer();
27}
28
29UDPSocket::UDPSocket(const std::string& owner_extension_id)
30    : Socket(owner_extension_id),
31      socket_(net::DatagramSocket::DEFAULT_BIND,
32              net::RandIntCallback(),
33              NULL,
34              net::NetLog::Source()) {}
35
36UDPSocket::~UDPSocket() { Disconnect(); }
37
38void UDPSocket::Connect(const std::string& address,
39                        int port,
40                        const CompletionCallback& callback) {
41  int result = net::ERR_CONNECTION_FAILED;
42  do {
43    if (is_connected_)
44      break;
45
46    net::IPEndPoint ip_end_point;
47    if (!StringAndPortToIPEndPoint(address, port, &ip_end_point)) {
48      result = net::ERR_ADDRESS_INVALID;
49      break;
50    }
51
52    result = socket_.Connect(ip_end_point);
53    is_connected_ = (result == net::OK);
54  } while (false);
55
56  callback.Run(result);
57}
58
59int UDPSocket::Bind(const std::string& address, int port) {
60  if (IsBound())
61    return net::ERR_CONNECTION_FAILED;
62
63  net::IPEndPoint ip_end_point;
64  if (!StringAndPortToIPEndPoint(address, port, &ip_end_point))
65    return net::ERR_INVALID_ARGUMENT;
66
67  return socket_.Bind(ip_end_point);
68}
69
70void UDPSocket::Disconnect() {
71  is_connected_ = false;
72  socket_.Close();
73  read_callback_.Reset();
74  recv_from_callback_.Reset();
75  send_to_callback_.Reset();
76  multicast_groups_.clear();
77}
78
79void UDPSocket::Read(int count, const ReadCompletionCallback& callback) {
80  DCHECK(!callback.is_null());
81
82  if (!read_callback_.is_null()) {
83    callback.Run(net::ERR_IO_PENDING, NULL);
84    return;
85  } else {
86    read_callback_ = callback;
87  }
88
89  int result = net::ERR_FAILED;
90  scoped_refptr<net::IOBuffer> io_buffer;
91  do {
92    if (count < 0) {
93      result = net::ERR_INVALID_ARGUMENT;
94      break;
95    }
96
97    if (!socket_.is_connected()) {
98      result = net::ERR_SOCKET_NOT_CONNECTED;
99      break;
100    }
101
102    io_buffer = new net::IOBuffer(count);
103    result = socket_.Read(
104        io_buffer.get(),
105        count,
106        base::Bind(
107            &UDPSocket::OnReadComplete, base::Unretained(this), io_buffer));
108  } while (false);
109
110  if (result != net::ERR_IO_PENDING)
111    OnReadComplete(io_buffer, result);
112}
113
114int UDPSocket::WriteImpl(net::IOBuffer* io_buffer,
115                         int io_buffer_size,
116                         const net::CompletionCallback& callback) {
117  if (!socket_.is_connected())
118    return net::ERR_SOCKET_NOT_CONNECTED;
119  else
120    return socket_.Write(io_buffer, io_buffer_size, callback);
121}
122
123void UDPSocket::RecvFrom(int count,
124                         const RecvFromCompletionCallback& callback) {
125  DCHECK(!callback.is_null());
126
127  if (!recv_from_callback_.is_null()) {
128    callback.Run(net::ERR_IO_PENDING, NULL, std::string(), 0);
129    return;
130  } else {
131    recv_from_callback_ = callback;
132  }
133
134  int result = net::ERR_FAILED;
135  scoped_refptr<net::IOBuffer> io_buffer;
136  scoped_refptr<IPEndPoint> address;
137  do {
138    if (count < 0) {
139      result = net::ERR_INVALID_ARGUMENT;
140      break;
141    }
142
143    if (!socket_.is_connected()) {
144      result = net::ERR_SOCKET_NOT_CONNECTED;
145      break;
146    }
147
148    io_buffer = new net::IOBuffer(count);
149    address = new IPEndPoint();
150    result = socket_.RecvFrom(io_buffer.get(),
151                              count,
152                              &address->data,
153                              base::Bind(&UDPSocket::OnRecvFromComplete,
154                                         base::Unretained(this),
155                                         io_buffer,
156                                         address));
157  } while (false);
158
159  if (result != net::ERR_IO_PENDING)
160    OnRecvFromComplete(io_buffer, address, result);
161}
162
163void UDPSocket::SendTo(scoped_refptr<net::IOBuffer> io_buffer,
164                       int byte_count,
165                       const std::string& address,
166                       int port,
167                       const CompletionCallback& callback) {
168  DCHECK(!callback.is_null());
169
170  if (!send_to_callback_.is_null()) {
171    // TODO(penghuang): Put requests in a pending queue to support multiple
172    // sendTo calls.
173    callback.Run(net::ERR_IO_PENDING);
174    return;
175  } else {
176    send_to_callback_ = callback;
177  }
178
179  int result = net::ERR_FAILED;
180  do {
181    net::IPEndPoint ip_end_point;
182    if (!StringAndPortToIPEndPoint(address, port, &ip_end_point)) {
183      result = net::ERR_ADDRESS_INVALID;
184      break;
185    }
186
187    if (!socket_.is_connected()) {
188      result = net::ERR_SOCKET_NOT_CONNECTED;
189      break;
190    }
191
192    result = socket_.SendTo(
193        io_buffer.get(),
194        byte_count,
195        ip_end_point,
196        base::Bind(&UDPSocket::OnSendToComplete, base::Unretained(this)));
197  } while (false);
198
199  if (result != net::ERR_IO_PENDING)
200    OnSendToComplete(result);
201}
202
203bool UDPSocket::IsConnected() { return is_connected_; }
204
205bool UDPSocket::GetPeerAddress(net::IPEndPoint* address) {
206  return !socket_.GetPeerAddress(address);
207}
208
209bool UDPSocket::GetLocalAddress(net::IPEndPoint* address) {
210  return !socket_.GetLocalAddress(address);
211}
212
213Socket::SocketType UDPSocket::GetSocketType() const { return Socket::TYPE_UDP; }
214
215void UDPSocket::OnReadComplete(scoped_refptr<net::IOBuffer> io_buffer,
216                               int result) {
217  DCHECK(!read_callback_.is_null());
218  read_callback_.Run(result, io_buffer);
219  read_callback_.Reset();
220}
221
222void UDPSocket::OnRecvFromComplete(scoped_refptr<net::IOBuffer> io_buffer,
223                                   scoped_refptr<IPEndPoint> address,
224                                   int result) {
225  DCHECK(!recv_from_callback_.is_null());
226  std::string ip;
227  int port = 0;
228  if (result > 0 && address.get()) {
229    IPEndPointToStringAndPort(address->data, &ip, &port);
230  }
231  recv_from_callback_.Run(result, io_buffer, ip, port);
232  recv_from_callback_.Reset();
233}
234
235void UDPSocket::OnSendToComplete(int result) {
236  DCHECK(!send_to_callback_.is_null());
237  send_to_callback_.Run(result);
238  send_to_callback_.Reset();
239}
240
241bool UDPSocket::IsBound() { return socket_.is_connected(); }
242
243int UDPSocket::JoinGroup(const std::string& address) {
244  net::IPAddressNumber ip;
245  if (!net::ParseIPLiteralToNumber(address, &ip))
246    return net::ERR_ADDRESS_INVALID;
247
248  std::string normalized_address = net::IPAddressToString(ip);
249  std::vector<std::string>::iterator find_result = std::find(
250      multicast_groups_.begin(), multicast_groups_.end(), normalized_address);
251  if (find_result != multicast_groups_.end())
252    return net::ERR_ADDRESS_INVALID;
253
254  int rv = socket_.JoinGroup(ip);
255  if (rv == 0)
256    multicast_groups_.push_back(normalized_address);
257  return rv;
258}
259
260int UDPSocket::LeaveGroup(const std::string& address) {
261  net::IPAddressNumber ip;
262  if (!net::ParseIPLiteralToNumber(address, &ip))
263    return net::ERR_ADDRESS_INVALID;
264
265  std::string normalized_address = net::IPAddressToString(ip);
266  std::vector<std::string>::iterator find_result = std::find(
267      multicast_groups_.begin(), multicast_groups_.end(), normalized_address);
268  if (find_result == multicast_groups_.end())
269    return net::ERR_ADDRESS_INVALID;
270
271  int rv = socket_.LeaveGroup(ip);
272  if (rv == 0)
273    multicast_groups_.erase(find_result);
274  return rv;
275}
276
277int UDPSocket::SetMulticastTimeToLive(int ttl) {
278  return socket_.SetMulticastTimeToLive(ttl);
279}
280
281int UDPSocket::SetMulticastLoopbackMode(bool loopback) {
282  return socket_.SetMulticastLoopbackMode(loopback);
283}
284
285const std::vector<std::string>& UDPSocket::GetJoinedGroups() const {
286  return multicast_groups_;
287}
288
289ResumableUDPSocket::ResumableUDPSocket(const std::string& owner_extension_id)
290    : UDPSocket(owner_extension_id),
291      persistent_(false),
292      buffer_size_(0),
293      paused_(false) {}
294
295bool ResumableUDPSocket::IsPersistent() const { return persistent(); }
296
297}  // namespace extensions
298