1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/dns/dns_socket_pool.h"
6
7#include "base/logging.h"
8#include "base/rand_util.h"
9#include "base/stl_util.h"
10#include "net/base/address_list.h"
11#include "net/base/ip_endpoint.h"
12#include "net/base/net_errors.h"
13#include "net/base/rand_callback.h"
14#include "net/socket/client_socket_factory.h"
15#include "net/socket/stream_socket.h"
16#include "net/udp/datagram_client_socket.h"
17
18namespace net {
19
20namespace {
21
22// When we initialize the SocketPool, we allocate kInitialPoolSize sockets.
23// When we allocate a socket, we ensure we have at least kAllocateMinSize
24// sockets to choose from.  Freed sockets are not retained.
25
26// On Windows, we can't request specific (random) ports, since that will
27// trigger firewall prompts, so request default ones, but keep a pile of
28// them.  Everywhere else, request fresh, random ports each time.
29#if defined(OS_WIN)
30const DatagramSocket::BindType kBindType = DatagramSocket::DEFAULT_BIND;
31const unsigned kInitialPoolSize = 256;
32const unsigned kAllocateMinSize = 256;
33#else
34const DatagramSocket::BindType kBindType = DatagramSocket::RANDOM_BIND;
35const unsigned kInitialPoolSize = 0;
36const unsigned kAllocateMinSize = 1;
37#endif
38
39} // namespace
40
41DnsSocketPool::DnsSocketPool(ClientSocketFactory* socket_factory)
42    : socket_factory_(socket_factory),
43      net_log_(NULL),
44      nameservers_(NULL),
45      initialized_(false) {
46}
47
48void DnsSocketPool::InitializeInternal(
49    const std::vector<IPEndPoint>* nameservers,
50    NetLog* net_log) {
51  DCHECK(nameservers);
52  DCHECK(!initialized_);
53
54  net_log_ = net_log;
55  nameservers_ = nameservers;
56  initialized_ = true;
57}
58
59scoped_ptr<StreamSocket> DnsSocketPool::CreateTCPSocket(
60    unsigned server_index,
61    const NetLog::Source& source) {
62  DCHECK_LT(server_index, nameservers_->size());
63
64  return scoped_ptr<StreamSocket>(
65      socket_factory_->CreateTransportClientSocket(
66          AddressList((*nameservers_)[server_index]), net_log_, source));
67}
68
69scoped_ptr<DatagramClientSocket> DnsSocketPool::CreateConnectedSocket(
70    unsigned server_index) {
71  DCHECK_LT(server_index, nameservers_->size());
72
73  scoped_ptr<DatagramClientSocket> socket;
74
75  NetLog::Source no_source;
76  socket = socket_factory_->CreateDatagramClientSocket(
77      kBindType, base::Bind(&base::RandInt), net_log_, no_source);
78
79  if (socket.get()) {
80    int rv = socket->Connect((*nameservers_)[server_index]);
81    if (rv != OK) {
82      VLOG(1) << "Failed to connect socket: " << rv;
83      socket.reset();
84    }
85  } else {
86    LOG(WARNING) << "Failed to create socket.";
87  }
88
89  return socket.Pass();
90}
91
92class NullDnsSocketPool : public DnsSocketPool {
93 public:
94  NullDnsSocketPool(ClientSocketFactory* factory)
95     : DnsSocketPool(factory) {
96  }
97
98  virtual void Initialize(
99      const std::vector<IPEndPoint>* nameservers,
100      NetLog* net_log) OVERRIDE {
101    InitializeInternal(nameservers, net_log);
102  }
103
104  virtual scoped_ptr<DatagramClientSocket> AllocateSocket(
105      unsigned server_index) OVERRIDE {
106    return CreateConnectedSocket(server_index);
107  }
108
109  virtual void FreeSocket(
110      unsigned server_index,
111      scoped_ptr<DatagramClientSocket> socket) OVERRIDE {
112  }
113
114 private:
115  DISALLOW_COPY_AND_ASSIGN(NullDnsSocketPool);
116};
117
118// static
119scoped_ptr<DnsSocketPool> DnsSocketPool::CreateNull(
120    ClientSocketFactory* factory) {
121  return scoped_ptr<DnsSocketPool>(new NullDnsSocketPool(factory));
122}
123
124class DefaultDnsSocketPool : public DnsSocketPool {
125 public:
126  DefaultDnsSocketPool(ClientSocketFactory* factory)
127     : DnsSocketPool(factory) {
128  };
129
130  virtual ~DefaultDnsSocketPool();
131
132  virtual void Initialize(
133      const std::vector<IPEndPoint>* nameservers,
134      NetLog* net_log) OVERRIDE;
135
136  virtual scoped_ptr<DatagramClientSocket> AllocateSocket(
137      unsigned server_index) OVERRIDE;
138
139  virtual void FreeSocket(
140      unsigned server_index,
141      scoped_ptr<DatagramClientSocket> socket) OVERRIDE;
142
143 private:
144  void FillPool(unsigned server_index, unsigned size);
145
146  typedef std::vector<DatagramClientSocket*> SocketVector;
147
148  std::vector<SocketVector> pools_;
149
150  DISALLOW_COPY_AND_ASSIGN(DefaultDnsSocketPool);
151};
152
153// static
154scoped_ptr<DnsSocketPool> DnsSocketPool::CreateDefault(
155    ClientSocketFactory* factory) {
156  return scoped_ptr<DnsSocketPool>(new DefaultDnsSocketPool(factory));
157}
158
159void DefaultDnsSocketPool::Initialize(
160    const std::vector<IPEndPoint>* nameservers,
161    NetLog* net_log) {
162  InitializeInternal(nameservers, net_log);
163
164  DCHECK(pools_.empty());
165  const unsigned num_servers = nameservers->size();
166  pools_.resize(num_servers);
167  for (unsigned server_index = 0; server_index < num_servers; ++server_index)
168    FillPool(server_index, kInitialPoolSize);
169}
170
171DefaultDnsSocketPool::~DefaultDnsSocketPool() {
172  unsigned num_servers = pools_.size();
173  for (unsigned server_index = 0; server_index < num_servers; ++server_index) {
174    SocketVector& pool = pools_[server_index];
175    STLDeleteElements(&pool);
176  }
177}
178
179scoped_ptr<DatagramClientSocket> DefaultDnsSocketPool::AllocateSocket(
180    unsigned server_index) {
181  DCHECK_LT(server_index, pools_.size());
182  SocketVector& pool = pools_[server_index];
183
184  FillPool(server_index, kAllocateMinSize);
185  if (pool.size() == 0) {
186    LOG(WARNING) << "No DNS sockets available in pool " << server_index << "!";
187    return scoped_ptr<DatagramClientSocket>();
188  }
189
190  if (pool.size() < kAllocateMinSize) {
191    LOG(WARNING) << "Low DNS port entropy: wanted " << kAllocateMinSize
192                 << " sockets to choose from, but only have " << pool.size()
193                 << " in pool " << server_index << ".";
194  }
195
196  unsigned socket_index = base::RandInt(0, pool.size() - 1);
197  DatagramClientSocket* socket = pool[socket_index];
198  pool[socket_index] = pool.back();
199  pool.pop_back();
200
201  return scoped_ptr<DatagramClientSocket>(socket);
202}
203
204void DefaultDnsSocketPool::FreeSocket(
205    unsigned server_index,
206    scoped_ptr<DatagramClientSocket> socket) {
207  DCHECK_LT(server_index, pools_.size());
208}
209
210void DefaultDnsSocketPool::FillPool(unsigned server_index, unsigned size) {
211  SocketVector& pool = pools_[server_index];
212
213  for (unsigned pool_index = pool.size(); pool_index < size; ++pool_index) {
214    DatagramClientSocket* socket =
215        CreateConnectedSocket(server_index).release();
216    if (!socket)
217      break;
218    pool.push_back(socket);
219  }
220}
221
222} // namespace net
223