1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/dns/address_sorter.h"
6
7#include <winsock2.h>
8
9#include <algorithm>
10
11#include "base/bind.h"
12#include "base/location.h"
13#include "base/logging.h"
14#include "base/threading/worker_pool.h"
15#include "base/win/windows_version.h"
16#include "net/base/address_list.h"
17#include "net/base/ip_endpoint.h"
18#include "net/base/winsock_init.h"
19
20namespace net {
21
22namespace {
23
24class AddressSorterWin : public AddressSorter {
25 public:
26  AddressSorterWin() {
27    EnsureWinsockInit();
28  }
29
30  virtual ~AddressSorterWin() {}
31
32  // AddressSorter:
33  virtual void Sort(const AddressList& list,
34                    const CallbackType& callback) const OVERRIDE {
35    DCHECK(!list.empty());
36    scoped_refptr<Job> job = new Job(list, callback);
37  }
38
39 private:
40  // Executes the SIO_ADDRESS_LIST_SORT ioctl on the WorkerPool, and
41  // performs the necessary conversions to/from AddressList.
42  class Job : public base::RefCountedThreadSafe<Job> {
43   public:
44    Job(const AddressList& list, const CallbackType& callback)
45        : callback_(callback),
46          buffer_size_(sizeof(SOCKET_ADDRESS_LIST) +
47                       list.size() * (sizeof(SOCKET_ADDRESS) +
48                                      sizeof(SOCKADDR_STORAGE))),
49          input_buffer_(reinterpret_cast<SOCKET_ADDRESS_LIST*>(
50              malloc(buffer_size_))),
51          output_buffer_(reinterpret_cast<SOCKET_ADDRESS_LIST*>(
52              malloc(buffer_size_))),
53          success_(false) {
54      input_buffer_->iAddressCount = list.size();
55      SOCKADDR_STORAGE* storage = reinterpret_cast<SOCKADDR_STORAGE*>(
56          input_buffer_->Address + input_buffer_->iAddressCount);
57
58      for (size_t i = 0; i < list.size(); ++i) {
59        IPEndPoint ipe = list[i];
60        // Addresses must be sockaddr_in6.
61        if (ipe.GetFamily() == ADDRESS_FAMILY_IPV4) {
62          ipe = IPEndPoint(ConvertIPv4NumberToIPv6Number(ipe.address()),
63                           ipe.port());
64        }
65
66        struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(storage + i);
67        socklen_t addr_len = sizeof(SOCKADDR_STORAGE);
68        bool result = ipe.ToSockAddr(addr, &addr_len);
69        DCHECK(result);
70        input_buffer_->Address[i].lpSockaddr = addr;
71        input_buffer_->Address[i].iSockaddrLength = addr_len;
72      }
73
74      if (!base::WorkerPool::PostTaskAndReply(
75          FROM_HERE,
76          base::Bind(&Job::Run, this),
77          base::Bind(&Job::OnComplete, this),
78          false /* task is slow */)) {
79        LOG(ERROR) << "WorkerPool::PostTaskAndReply failed";
80        OnComplete();
81      }
82    }
83
84   private:
85    friend class base::RefCountedThreadSafe<Job>;
86    ~Job() {}
87
88    // Executed on the WorkerPool.
89    void Run() {
90      SOCKET sock = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP);
91      if (sock == INVALID_SOCKET)
92        return;
93      DWORD result_size = 0;
94      int result = WSAIoctl(sock, SIO_ADDRESS_LIST_SORT, input_buffer_.get(),
95                            buffer_size_, output_buffer_.get(), buffer_size_,
96                            &result_size, NULL, NULL);
97      if (result == SOCKET_ERROR) {
98        LOG(ERROR) << "SIO_ADDRESS_LIST_SORT failed " << WSAGetLastError();
99      } else {
100        success_ = true;
101      }
102      closesocket(sock);
103    }
104
105    // Executed on the calling thread.
106    void OnComplete() {
107      AddressList list;
108      if (success_) {
109        list.reserve(output_buffer_->iAddressCount);
110        for (int i = 0; i < output_buffer_->iAddressCount; ++i) {
111          IPEndPoint ipe;
112          ipe.FromSockAddr(output_buffer_->Address[i].lpSockaddr,
113                           output_buffer_->Address[i].iSockaddrLength);
114          // Unmap V4MAPPED IPv6 addresses so that Happy Eyeballs works.
115          if (IsIPv4Mapped(ipe.address())) {
116            ipe = IPEndPoint(ConvertIPv4MappedToIPv4(ipe.address()),
117                                                     ipe.port());
118          }
119          list.push_back(ipe);
120        }
121      }
122      callback_.Run(success_, list);
123    }
124
125    const CallbackType callback_;
126    const size_t buffer_size_;
127    scoped_ptr<SOCKET_ADDRESS_LIST, base::FreeDeleter> input_buffer_;
128    scoped_ptr<SOCKET_ADDRESS_LIST, base::FreeDeleter> output_buffer_;
129    bool success_;
130
131    DISALLOW_COPY_AND_ASSIGN(Job);
132  };
133
134  DISALLOW_COPY_AND_ASSIGN(AddressSorterWin);
135};
136
137// Merges |list_ipv4| and |list_ipv6| before passing it to |callback|, but
138// only if |success| is true.
139void MergeResults(const AddressSorter::CallbackType& callback,
140                  const AddressList& list_ipv4,
141                  bool success,
142                  const AddressList& list_ipv6) {
143  if (!success) {
144    callback.Run(false, AddressList());
145    return;
146  }
147  AddressList list;
148  list.insert(list.end(), list_ipv6.begin(), list_ipv6.end());
149  list.insert(list.end(), list_ipv4.begin(), list_ipv4.end());
150  callback.Run(true, list);
151}
152
153// Wrapper for AddressSorterWin which does not sort IPv4 or IPv4-mapped
154// addresses but always puts them at the end of the list. Needed because the
155// SIO_ADDRESS_LIST_SORT does not support IPv4 addresses on Windows XP.
156class AddressSorterWinXP : public AddressSorter {
157 public:
158  AddressSorterWinXP() {}
159  virtual ~AddressSorterWinXP() {}
160
161  // AddressSorter:
162  virtual void Sort(const AddressList& list,
163                    const CallbackType& callback) const OVERRIDE {
164    AddressList list_ipv4;
165    AddressList list_ipv6;
166    for (size_t i = 0; i < list.size(); ++i) {
167      const IPEndPoint& ipe = list[i];
168      if (ipe.GetFamily() == ADDRESS_FAMILY_IPV4) {
169        list_ipv4.push_back(ipe);
170      } else {
171        list_ipv6.push_back(ipe);
172      }
173    }
174    if (!list_ipv6.empty()) {
175      sorter_.Sort(list_ipv6, base::Bind(&MergeResults, callback, list_ipv4));
176    } else {
177      NOTREACHED() << "Should not be called with IPv4-only addresses.";
178      callback.Run(true, list);
179    }
180  }
181
182 private:
183  AddressSorterWin sorter_;
184
185  DISALLOW_COPY_AND_ASSIGN(AddressSorterWinXP);
186};
187
188}  // namespace
189
190// static
191scoped_ptr<AddressSorter> AddressSorter::CreateAddressSorter() {
192  if (base::win::GetVersion() < base::win::VERSION_VISTA)
193    return scoped_ptr<AddressSorter>(new AddressSorterWinXP());
194  return scoped_ptr<AddressSorter>(new AddressSorterWin());
195}
196
197}  // namespace net
198
199