1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/udp/udp_socket_win.h"
6
7#include <mstcpip.h>
8
9#include "base/eintr_wrapper.h"
10#include "base/logging.h"
11#include "base/memory/memory_debug.h"
12#include "base/message_loop.h"
13#include "base/metrics/stats_counters.h"
14#include "net/base/io_buffer.h"
15#include "net/base/ip_endpoint.h"
16#include "net/base/net_errors.h"
17#include "net/base/net_log.h"
18#include "net/base/net_util.h"
19#include "net/base/winsock_init.h"
20#include "net/base/winsock_util.h"
21
22namespace net {
23
24void UDPSocketWin::ReadDelegate::OnObjectSignaled(HANDLE object) {
25  DCHECK_EQ(object, socket_->read_overlapped_.hEvent);
26  socket_->DidCompleteRead();
27}
28
29void UDPSocketWin::WriteDelegate::OnObjectSignaled(HANDLE object) {
30  DCHECK_EQ(object, socket_->write_overlapped_.hEvent);
31  socket_->DidCompleteWrite();
32}
33
34UDPSocketWin::UDPSocketWin(net::NetLog* net_log,
35                           const net::NetLog::Source& source)
36    : socket_(INVALID_SOCKET),
37      ALLOW_THIS_IN_INITIALIZER_LIST(read_delegate_(this)),
38      ALLOW_THIS_IN_INITIALIZER_LIST(write_delegate_(this)),
39      recv_from_address_(NULL),
40      read_callback_(NULL),
41      write_callback_(NULL),
42      net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {
43  EnsureWinsockInit();
44  scoped_refptr<NetLog::EventParameters> params;
45  if (source.is_valid())
46    params = new NetLogSourceParameter("source_dependency", source);
47  net_log_.BeginEvent(NetLog::TYPE_SOCKET_ALIVE, params);
48  memset(&read_overlapped_, 0, sizeof(read_overlapped_));
49  read_overlapped_.hEvent = WSACreateEvent();
50  memset(&write_overlapped_, 0, sizeof(write_overlapped_));
51  write_overlapped_.hEvent = WSACreateEvent();
52}
53
54UDPSocketWin::~UDPSocketWin() {
55  Close();
56  net_log_.EndEvent(NetLog::TYPE_SOCKET_ALIVE, NULL);
57}
58
59void UDPSocketWin::Close() {
60  DCHECK(CalledOnValidThread());
61
62  if (!is_connected())
63    return;
64
65  // Zero out any pending read/write callback state.
66  read_callback_ = NULL;
67  recv_from_address_ = NULL;
68  write_callback_ = NULL;
69
70  read_watcher_.StopWatching();
71  write_watcher_.StopWatching();
72
73  closesocket(socket_);
74  socket_ = INVALID_SOCKET;
75}
76
77int UDPSocketWin::GetPeerAddress(IPEndPoint* address) const {
78  DCHECK(CalledOnValidThread());
79  DCHECK(address);
80  if (!is_connected())
81    return ERR_SOCKET_NOT_CONNECTED;
82
83  if (!remote_address_.get()) {
84    struct sockaddr_storage addr_storage;
85    int addr_len = sizeof(addr_storage);
86    struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
87    if (getpeername(socket_, addr, &addr_len))
88      return MapSystemError(WSAGetLastError());
89    scoped_ptr<IPEndPoint> address(new IPEndPoint());
90    if (!address->FromSockAddr(addr, addr_len))
91      return ERR_FAILED;
92    remote_address_.reset(address.release());
93  }
94
95  *address = *remote_address_;
96  return OK;
97}
98
99int UDPSocketWin::GetLocalAddress(IPEndPoint* address) const {
100  DCHECK(CalledOnValidThread());
101  DCHECK(address);
102  if (!is_connected())
103    return ERR_SOCKET_NOT_CONNECTED;
104
105  if (!local_address_.get()) {
106    struct sockaddr_storage addr_storage;
107    socklen_t addr_len = sizeof(addr_storage);
108    struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
109    if (getsockname(socket_, addr, &addr_len))
110      return MapSystemError(WSAGetLastError());
111    scoped_ptr<IPEndPoint> address(new IPEndPoint());
112    if (!address->FromSockAddr(addr, addr_len))
113      return ERR_FAILED;
114    local_address_.reset(address.release());
115  }
116
117  *address = *local_address_;
118  return OK;
119}
120
121int UDPSocketWin::Read(IOBuffer* buf,
122                       int buf_len,
123                       CompletionCallback* callback) {
124  return RecvFrom(buf, buf_len, NULL, callback);
125}
126
127int UDPSocketWin::RecvFrom(IOBuffer* buf,
128                           int buf_len,
129                           IPEndPoint* address,
130                           CompletionCallback* callback) {
131  DCHECK(CalledOnValidThread());
132  DCHECK_NE(INVALID_SOCKET, socket_);
133  DCHECK(!read_callback_);
134  DCHECK(!recv_from_address_);
135  DCHECK(callback);  // Synchronous operation not supported.
136  DCHECK_GT(buf_len, 0);
137
138  int nread = InternalRecvFrom(buf, buf_len, address);
139  if (nread != ERR_IO_PENDING)
140    return nread;
141
142  read_iobuffer_ = buf;
143  read_callback_ = callback;
144  recv_from_address_ = address;
145  return ERR_IO_PENDING;
146}
147
148int UDPSocketWin::Write(IOBuffer* buf,
149                        int buf_len,
150                        CompletionCallback* callback) {
151  return SendToOrWrite(buf, buf_len, NULL, callback);
152}
153
154int UDPSocketWin::SendTo(IOBuffer* buf,
155                         int buf_len,
156                         const IPEndPoint& address,
157                         CompletionCallback* callback) {
158  return SendToOrWrite(buf, buf_len, &address, callback);
159}
160
161int UDPSocketWin::SendToOrWrite(IOBuffer* buf,
162                                int buf_len,
163                                const IPEndPoint* address,
164                                CompletionCallback* callback) {
165  DCHECK(CalledOnValidThread());
166  DCHECK_NE(INVALID_SOCKET, socket_);
167  DCHECK(!write_callback_);
168  DCHECK(callback);  // Synchronous operation not supported.
169  DCHECK_GT(buf_len, 0);
170
171  int nwrite = InternalSendTo(buf, buf_len, address);
172  if (nwrite != ERR_IO_PENDING)
173    return nwrite;
174
175  write_iobuffer_ = buf;
176  write_callback_ = callback;
177  return ERR_IO_PENDING;
178}
179
180int UDPSocketWin::Connect(const IPEndPoint& address) {
181  DCHECK(!is_connected());
182  DCHECK(!remote_address_.get());
183  int rv = CreateSocket(address);
184  if (rv < 0)
185    return rv;
186
187  struct sockaddr_storage addr_storage;
188  size_t addr_len = sizeof(addr_storage);
189  struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
190  if (!address.ToSockAddr(addr, &addr_len))
191    return ERR_FAILED;
192
193  rv = connect(socket_, addr, addr_len);
194  if (rv < 0)
195    return MapSystemError(WSAGetLastError());
196
197  remote_address_.reset(new IPEndPoint(address));
198  return rv;
199}
200
201int UDPSocketWin::Bind(const IPEndPoint& address) {
202  DCHECK(!is_connected());
203  DCHECK(!local_address_.get());
204  int rv = CreateSocket(address);
205  if (rv < 0)
206    return rv;
207
208  struct sockaddr_storage addr_storage;
209  size_t addr_len = sizeof(addr_storage);
210  struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
211  if (!address.ToSockAddr(addr, &addr_len))
212    return ERR_FAILED;
213
214  rv = bind(socket_, addr, addr_len);
215  if (rv < 0)
216    return MapSystemError(WSAGetLastError());
217
218  local_address_.reset();
219  return rv;
220}
221
222int UDPSocketWin::CreateSocket(const IPEndPoint& address) {
223  socket_ = WSASocket(address.GetFamily(), SOCK_DGRAM, IPPROTO_UDP, NULL, 0,
224                      WSA_FLAG_OVERLAPPED);
225  if (socket_ == INVALID_SOCKET)
226    return MapSystemError(WSAGetLastError());
227  return OK;
228}
229
230void UDPSocketWin::DoReadCallback(int rv) {
231  DCHECK_NE(rv, ERR_IO_PENDING);
232  DCHECK(read_callback_);
233
234  // since Run may result in Read being called, clear read_callback_ up front.
235  CompletionCallback* c = read_callback_;
236  read_callback_ = NULL;
237  c->Run(rv);
238}
239
240void UDPSocketWin::DoWriteCallback(int rv) {
241  DCHECK_NE(rv, ERR_IO_PENDING);
242  DCHECK(write_callback_);
243
244  // since Run may result in Write being called, clear write_callback_ up front.
245  CompletionCallback* c = write_callback_;
246  write_callback_ = NULL;
247  c->Run(rv);
248}
249
250void UDPSocketWin::DidCompleteRead() {
251  DWORD num_bytes, flags;
252  BOOL ok = WSAGetOverlappedResult(socket_, &read_overlapped_,
253                                   &num_bytes, FALSE, &flags);
254  WSAResetEvent(read_overlapped_.hEvent);
255  int result = ok ? num_bytes : MapSystemError(WSAGetLastError());
256  if (ok) {
257    if (!ProcessSuccessfulRead(num_bytes, recv_from_address_))
258      result = ERR_FAILED;
259  }
260  read_iobuffer_ = NULL;
261  recv_from_address_ = NULL;
262  DoReadCallback(result);
263}
264
265bool UDPSocketWin::ProcessSuccessfulRead(int num_bytes, IPEndPoint* address) {
266  base::StatsCounter read_bytes("udp.read_bytes");
267  read_bytes.Add(num_bytes);
268
269  // Convert address.
270  if (address) {
271    struct sockaddr* addr =
272        reinterpret_cast<struct sockaddr*>(&recv_addr_storage_);
273    if (!address->FromSockAddr(addr, recv_addr_len_))
274      return false;
275  }
276
277  return true;
278}
279
280void UDPSocketWin::DidCompleteWrite() {
281  DWORD num_bytes, flags;
282  BOOL ok = WSAGetOverlappedResult(socket_, &write_overlapped_,
283                                   &num_bytes, FALSE, &flags);
284  WSAResetEvent(write_overlapped_.hEvent);
285  int result = ok ? num_bytes : MapSystemError(WSAGetLastError());
286  if (ok)
287    ProcessSuccessfulWrite(num_bytes);
288  write_iobuffer_ = NULL;
289  DoWriteCallback(result);
290}
291
292void UDPSocketWin::ProcessSuccessfulWrite(int num_bytes) {
293  base::StatsCounter write_bytes("udp.write_bytes");
294  write_bytes.Add(num_bytes);
295}
296
297int UDPSocketWin::InternalRecvFrom(IOBuffer* buf, int buf_len,
298                                   IPEndPoint* address) {
299  recv_addr_len_ = sizeof(recv_addr_storage_);
300  struct sockaddr* addr =
301      reinterpret_cast<struct sockaddr*>(&recv_addr_storage_);
302
303  WSABUF read_buffer;
304  read_buffer.buf = buf->data();
305  read_buffer.len = buf_len;
306
307  DWORD flags = 0;
308  DWORD num;
309  AssertEventNotSignaled(read_overlapped_.hEvent);
310  int rv = WSARecvFrom(socket_, &read_buffer, 1, &num, &flags, addr,
311                       &recv_addr_len_, &read_overlapped_, NULL);
312  if (rv == 0) {
313    if (ResetEventIfSignaled(read_overlapped_.hEvent)) {
314      // Because of how WSARecv fills memory when used asynchronously, Purify
315      // isn't able to detect that it's been initialized, so it scans for 0xcd
316      // in the buffer and reports UMRs (uninitialized memory reads) for those
317      // individual bytes. We override that in PURIFY builds to avoid the
318      // false error reports.
319      // See bug 5297.
320      base::MemoryDebug::MarkAsInitialized(read_buffer.buf, num);
321      if (!ProcessSuccessfulRead(num, address))
322        return ERR_FAILED;
323      return static_cast<int>(num);
324    }
325  } else {
326    int os_error = WSAGetLastError();
327    if (os_error != WSA_IO_PENDING)
328      return MapSystemError(os_error);
329  }
330  read_watcher_.StartWatching(read_overlapped_.hEvent, &read_delegate_);
331  return ERR_IO_PENDING;
332}
333
334int UDPSocketWin::InternalSendTo(IOBuffer* buf, int buf_len,
335                                 const IPEndPoint* address) {
336  struct sockaddr_storage addr_storage;
337  size_t addr_len = sizeof(addr_storage);
338  struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage);
339
340  // Convert address.
341  if (!address) {
342    addr = NULL;
343    addr_len = 0;
344  } else {
345    if (!address->ToSockAddr(addr, &addr_len))
346      return ERR_FAILED;
347  }
348
349  WSABUF write_buffer;
350  write_buffer.buf = buf->data();
351  write_buffer.len = buf_len;
352
353  DWORD flags = 0;
354  DWORD num;
355  AssertEventNotSignaled(write_overlapped_.hEvent);
356  int rv = WSASendTo(socket_, &write_buffer, 1, &num, flags,
357                     addr, addr_len, &write_overlapped_, NULL);
358  if (rv == 0) {
359    if (ResetEventIfSignaled(write_overlapped_.hEvent)) {
360      ProcessSuccessfulWrite(num);
361      return static_cast<int>(num);
362    }
363  } else {
364    int os_error = WSAGetLastError();
365    if (os_error != WSA_IO_PENDING)
366      return MapSystemError(os_error);
367  }
368
369  write_watcher_.StartWatching(write_overlapped_.hEvent, &write_delegate_);
370  return ERR_IO_PENDING;
371}
372
373}  // namespace net
374