1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/udp/udp_socket_libevent.h"
6
7#include <errno.h>
8#include <fcntl.h>
9#include <netdb.h>
10#include <sys/socket.h>
11
12#include "base/eintr_wrapper.h"
13#include "base/logging.h"
14#include "base/message_loop.h"
15#include "base/metrics/stats_counters.h"
16#include "net/base/io_buffer.h"
17#include "net/base/ip_endpoint.h"
18#include "net/base/net_errors.h"
19#include "net/base/net_log.h"
20#include "net/base/net_util.h"
21#if defined(OS_POSIX)
22#include <netinet/in.h>
23#endif
24#if defined(USE_SYSTEM_LIBEVENT)
25#include <event.h>
26#else
27#include "third_party/libevent/event.h"
28#endif
29
30namespace net {
31
32UDPSocketLibevent::UDPSocketLibevent(net::NetLog* net_log,
33                                     const net::NetLog::Source& source)
34    : socket_(kInvalidSocket),
35      read_watcher_(this),
36      write_watcher_(this),
37      read_buf_len_(0),
38      recv_from_address_(NULL),
39      write_buf_len_(0),
40      read_callback_(NULL),
41      write_callback_(NULL),
42      net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
43  scoped_refptr<NetLog::EventParameters> params;
44  if (source.is_valid())
45    params = new NetLogSourceParameter("source_dependency", source);
46  net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, params);
47}
48
49UDPSocketLibevent::~UDPSocketLibevent() {
50  Close();
51  net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE, NULL);
52}
53
54void UDPSocketLibevent::Close() {
55  DCHECK(CalledOnValidThread());
56
57  if (!is_connected())
58    return;
59
60  // Zero out any pending read/write callback state.
61  read_buf_ = NULL;
62  read_buf_len_ = 0;
63  read_callback_ = NULL;
64  recv_from_address_ = NULL;
65  write_buf_ = NULL;
66  write_buf_len_ = 0;
67  write_callback_ = NULL;
68  send_to_address_.reset();
69
70  bool ok = read_socket_watcher_.StopWatchingFileDescriptor();
71  DCHECK(ok);
72  ok = write_socket_watcher_.StopWatchingFileDescriptor();
73  DCHECK(ok);
74
75  if (HANDLE_EINTR(close(socket_)) < 0)
76    PLOG(ERROR) << "close";
77
78  socket_ = kInvalidSocket;
79}
80
81int UDPSocketLibevent::GetPeerAddress(IPEndPoint* address) const {
82  DCHECK(CalledOnValidThread());
83  DCHECK(address);
84  if (!is_connected())
85    return ERR_SOCKET_NOT_CONNECTED;
86
87  if (!remote_address_.get()) {
88    struct sockaddr_storage addr_storage;
89    socklen_t addr_len = sizeof(addr_storage);
90    struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
91    if (getpeername(socket_, addr, &addr_len))
92      return MapSystemError(errno);
93    scoped_ptr<IPEndPoint> address(new IPEndPoint());
94    if (!address->FromSockAddr(addr, addr_len))
95      return ERR_FAILED;
96    remote_address_.reset(address.release());
97  }
98
99  *address = *remote_address_;
100  return OK;
101}
102
103int UDPSocketLibevent::GetLocalAddress(IPEndPoint* address) const {
104  DCHECK(CalledOnValidThread());
105  DCHECK(address);
106  if (!is_connected())
107    return ERR_SOCKET_NOT_CONNECTED;
108
109  if (!local_address_.get()) {
110    struct sockaddr_storage addr_storage;
111    socklen_t addr_len = sizeof(addr_storage);
112    struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
113    if (getsockname(socket_, addr, &addr_len))
114      return MapSystemError(errno);
115    scoped_ptr<IPEndPoint> address(new IPEndPoint());
116    if (!address->FromSockAddr(addr, addr_len))
117      return ERR_FAILED;
118    local_address_.reset(address.release());
119  }
120
121  *address = *local_address_;
122  return OK;
123}
124
125int UDPSocketLibevent::Read(IOBuffer* buf,
126                            int buf_len,
127                            CompletionCallback* callback) {
128  return RecvFrom(buf, buf_len, NULL, callback);
129}
130
131int UDPSocketLibevent::RecvFrom(IOBuffer* buf,
132                                int buf_len,
133                                IPEndPoint* address,
134                                CompletionCallback* callback) {
135  DCHECK(CalledOnValidThread());
136  DCHECK_NE(kInvalidSocket, socket_);
137  DCHECK(!read_callback_);
138  DCHECK(!recv_from_address_);
139  DCHECK(callback);  // Synchronous operation not supported
140  DCHECK_GT(buf_len, 0);
141
142  int nread = InternalRecvFrom(buf, buf_len, address);
143  if (nread != ERR_IO_PENDING)
144    return nread;
145
146  if (!MessageLoopForIO::current()->WatchFileDescriptor(
147          socket_, true, MessageLoopForIO::WATCH_READ,
148          &read_socket_watcher_, &read_watcher_)) {
149    PLOG(ERROR) << "WatchFileDescriptor failed on read";
150    return MapSystemError(errno);
151  }
152
153  read_buf_ = buf;
154  read_buf_len_ = buf_len;
155  recv_from_address_ = address;
156  read_callback_ = callback;
157  return ERR_IO_PENDING;
158}
159
160int UDPSocketLibevent::Write(IOBuffer* buf,
161                             int buf_len,
162                             CompletionCallback* callback) {
163  return SendToOrWrite(buf, buf_len, NULL, callback);
164}
165
166int UDPSocketLibevent::SendTo(IOBuffer* buf,
167                              int buf_len,
168                              const IPEndPoint& address,
169                              CompletionCallback* callback) {
170  return SendToOrWrite(buf, buf_len, &address, callback);
171}
172
173int UDPSocketLibevent::SendToOrWrite(IOBuffer* buf,
174                                     int buf_len,
175                                     const IPEndPoint* address,
176                                     CompletionCallback* callback) {
177  DCHECK(CalledOnValidThread());
178  DCHECK_NE(kInvalidSocket, socket_);
179  DCHECK(!write_callback_);
180  DCHECK(callback);  // Synchronous operation not supported
181  DCHECK_GT(buf_len, 0);
182
183  int nwrite = InternalSendTo(buf, buf_len, address);
184  if (nwrite >= 0) {
185    base::StatsCounter write_bytes("udp.write_bytes");
186    write_bytes.Add(nwrite);
187    return nwrite;
188  }
189  if (errno != EAGAIN && errno != EWOULDBLOCK)
190    return MapSystemError(errno);
191
192  if (!MessageLoopForIO::current()->WatchFileDescriptor(
193          socket_, true, MessageLoopForIO::WATCH_WRITE,
194          &write_socket_watcher_, &write_watcher_)) {
195    DVLOG(1) << "WatchFileDescriptor failed on write, errno " << errno;
196    return MapSystemError(errno);
197  }
198
199  write_buf_ = buf;
200  write_buf_len_ = buf_len;
201  DCHECK(!send_to_address_.get());
202  if (address) {
203    send_to_address_.reset(new IPEndPoint(*address));
204  }
205  write_callback_ = callback;
206  return ERR_IO_PENDING;
207}
208
209int UDPSocketLibevent::Connect(const IPEndPoint& address) {
210  DCHECK(!is_connected());
211  DCHECK(!remote_address_.get());
212  int rv = CreateSocket(address);
213  if (rv < 0)
214    return rv;
215
216  struct sockaddr_storage addr_storage;
217  size_t addr_len = sizeof(addr_storage);
218  struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
219  if (!address.ToSockAddr(addr, &addr_len))
220    return ERR_FAILED;
221
222  rv = HANDLE_EINTR(connect(socket_, addr, addr_len));
223  if (rv < 0)
224    return MapSystemError(errno);
225
226  remote_address_.reset(new IPEndPoint(address));
227  return rv;
228}
229
230int UDPSocketLibevent::Bind(const IPEndPoint& address) {
231  DCHECK(!is_connected());
232  DCHECK(!local_address_.get());
233  int rv = CreateSocket(address);
234  if (rv < 0)
235    return rv;
236
237  struct sockaddr_storage addr_storage;
238  size_t addr_len = sizeof(addr_storage);
239  struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
240  if (!address.ToSockAddr(addr, &addr_len))
241    return ERR_FAILED;
242
243  rv = bind(socket_, addr, addr_len);
244  if (rv < 0)
245    return MapSystemError(errno);
246
247  local_address_.reset();
248  return rv;
249}
250
251void UDPSocketLibevent::DoReadCallback(int rv) {
252  DCHECK_NE(rv, ERR_IO_PENDING);
253  DCHECK(read_callback_);
254
255  // since Run may result in Read being called, clear read_callback_ up front.
256  CompletionCallback* c = read_callback_;
257  read_callback_ = NULL;
258  c->Run(rv);
259}
260
261void UDPSocketLibevent::DoWriteCallback(int rv) {
262  DCHECK_NE(rv, ERR_IO_PENDING);
263  DCHECK(write_callback_);
264
265  // since Run may result in Write being called, clear write_callback_ up front.
266  CompletionCallback* c = write_callback_;
267  write_callback_ = NULL;
268  c->Run(rv);
269}
270
271void UDPSocketLibevent::DidCompleteRead() {
272  int result = InternalRecvFrom(read_buf_, read_buf_len_, recv_from_address_);
273  if (result != ERR_IO_PENDING) {
274    read_buf_ = NULL;
275    read_buf_len_ = 0;
276    recv_from_address_ = NULL;
277    bool ok = read_socket_watcher_.StopWatchingFileDescriptor();
278    DCHECK(ok);
279    DoReadCallback(result);
280  }
281}
282
283int UDPSocketLibevent::CreateSocket(const IPEndPoint& address) {
284  socket_ = socket(address.GetFamily(), SOCK_DGRAM, 0);
285  if (socket_ == kInvalidSocket)
286    return MapSystemError(errno);
287  if (SetNonBlocking(socket_)) {
288    const int err = MapSystemError(errno);
289    Close();
290    return err;
291  }
292  return OK;
293}
294
295void UDPSocketLibevent::DidCompleteWrite() {
296  int result = InternalSendTo(write_buf_, write_buf_len_,
297                              send_to_address_.get());
298  if (result >= 0) {
299    base::StatsCounter write_bytes("udp.write_bytes");
300    write_bytes.Add(result);
301  } else {
302    result = MapSystemError(errno);
303  }
304
305  if (result != ERR_IO_PENDING) {
306    write_buf_ = NULL;
307    write_buf_len_ = 0;
308    send_to_address_.reset();
309    write_socket_watcher_.StopWatchingFileDescriptor();
310    DoWriteCallback(result);
311  }
312}
313
314int UDPSocketLibevent::InternalRecvFrom(IOBuffer* buf, int buf_len,
315                                        IPEndPoint* address) {
316  int bytes_transferred;
317  int flags = 0;
318
319  struct sockaddr_storage addr_storage;
320  socklen_t addr_len = sizeof(addr_storage);
321  struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
322
323  bytes_transferred =
324      HANDLE_EINTR(recvfrom(socket_,
325                            buf->data(),
326                            buf_len,
327                            flags,
328                            addr,
329                            &addr_len));
330  int result;
331  if (bytes_transferred >= 0) {
332    result = bytes_transferred;
333    base::StatsCounter read_bytes("udp.read_bytes");
334    read_bytes.Add(bytes_transferred);
335    if (address) {
336      if (!address->FromSockAddr(addr, addr_len))
337        result = ERR_FAILED;
338    }
339  } else {
340    result = MapSystemError(errno);
341  }
342  return result;
343}
344
345int UDPSocketLibevent::InternalSendTo(IOBuffer* buf, int buf_len,
346                                      const IPEndPoint* address) {
347  struct sockaddr_storage addr_storage;
348  size_t addr_len = sizeof(addr_storage);
349  struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
350
351  if (!address) {
352    addr = NULL;
353    addr_len = 0;
354  } else {
355    if (!address->ToSockAddr(addr, &addr_len))
356      return ERR_FAILED;
357  }
358
359  return HANDLE_EINTR(sendto(socket_,
360                             buf->data(),
361                             buf_len,
362                             0,
363                             addr,
364                             addr_len));
365}
366
367}  // namespace net
368