1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/dns/mock_host_resolver.h"
6
7#include <string>
8#include <vector>
9
10#include "base/bind.h"
11#include "base/memory/ref_counted.h"
12#include "base/message_loop/message_loop.h"
13#include "base/stl_util.h"
14#include "base/strings/string_split.h"
15#include "base/strings/string_util.h"
16#include "base/threading/platform_thread.h"
17#include "net/base/ip_endpoint.h"
18#include "net/base/net_errors.h"
19#include "net/base/net_util.h"
20#include "net/base/test_completion_callback.h"
21#include "net/dns/host_cache.h"
22
23#if defined(OS_WIN)
24#include "net/base/winsock_init.h"
25#endif
26
27namespace net {
28
29namespace {
30
31// Cache size for the MockCachingHostResolver.
32const unsigned kMaxCacheEntries = 100;
33// TTL for the successful resolutions. Failures are not cached.
34const unsigned kCacheEntryTTLSeconds = 60;
35
36}  // namespace
37
38int ParseAddressList(const std::string& host_list,
39                     const std::string& canonical_name,
40                     AddressList* addrlist) {
41  *addrlist = AddressList();
42  std::vector<std::string> addresses;
43  base::SplitString(host_list, ',', &addresses);
44  addrlist->set_canonical_name(canonical_name);
45  for (size_t index = 0; index < addresses.size(); ++index) {
46    IPAddressNumber ip_number;
47    if (!ParseIPLiteralToNumber(addresses[index], &ip_number)) {
48      LOG(WARNING) << "Not a supported IP literal: " << addresses[index];
49      return ERR_UNEXPECTED;
50    }
51    addrlist->push_back(IPEndPoint(ip_number, -1));
52  }
53  return OK;
54}
55
56struct MockHostResolverBase::Request {
57  Request(const RequestInfo& req_info,
58          AddressList* addr,
59          const CompletionCallback& cb)
60      : info(req_info), addresses(addr), callback(cb) {}
61  RequestInfo info;
62  AddressList* addresses;
63  CompletionCallback callback;
64};
65
66MockHostResolverBase::~MockHostResolverBase() {
67  STLDeleteValues(&requests_);
68}
69
70int MockHostResolverBase::Resolve(const RequestInfo& info,
71                                  RequestPriority priority,
72                                  AddressList* addresses,
73                                  const CompletionCallback& callback,
74                                  RequestHandle* handle,
75                                  const BoundNetLog& net_log) {
76  DCHECK(CalledOnValidThread());
77  last_request_priority_ = priority;
78  num_resolve_++;
79  size_t id = next_request_id_++;
80  int rv = ResolveFromIPLiteralOrCache(info, addresses);
81  if (rv != ERR_DNS_CACHE_MISS) {
82    return rv;
83  }
84  if (synchronous_mode_) {
85    return ResolveProc(id, info, addresses);
86  }
87  // Store the request for asynchronous resolution
88  Request* req = new Request(info, addresses, callback);
89  requests_[id] = req;
90  if (handle)
91    *handle = reinterpret_cast<RequestHandle>(id);
92
93  if (!ondemand_mode_) {
94    base::MessageLoop::current()->PostTask(
95        FROM_HERE,
96        base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id));
97  }
98
99  return ERR_IO_PENDING;
100}
101
102int MockHostResolverBase::ResolveFromCache(const RequestInfo& info,
103                                           AddressList* addresses,
104                                           const BoundNetLog& net_log) {
105  num_resolve_from_cache_++;
106  DCHECK(CalledOnValidThread());
107  next_request_id_++;
108  int rv = ResolveFromIPLiteralOrCache(info, addresses);
109  return rv;
110}
111
112void MockHostResolverBase::CancelRequest(RequestHandle handle) {
113  DCHECK(CalledOnValidThread());
114  size_t id = reinterpret_cast<size_t>(handle);
115  RequestMap::iterator it = requests_.find(id);
116  if (it != requests_.end()) {
117    scoped_ptr<Request> req(it->second);
118    requests_.erase(it);
119  } else {
120    NOTREACHED() << "CancelRequest must NOT be called after request is "
121        "complete or canceled.";
122  }
123}
124
125HostCache* MockHostResolverBase::GetHostCache() {
126  return cache_.get();
127}
128
129void MockHostResolverBase::ResolveAllPending() {
130  DCHECK(CalledOnValidThread());
131  DCHECK(ondemand_mode_);
132  for (RequestMap::iterator i = requests_.begin(); i != requests_.end(); ++i) {
133    base::MessageLoop::current()->PostTask(
134        FROM_HERE,
135        base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), i->first));
136  }
137}
138
139// start id from 1 to distinguish from NULL RequestHandle
140MockHostResolverBase::MockHostResolverBase(bool use_caching)
141    : last_request_priority_(DEFAULT_PRIORITY),
142      synchronous_mode_(false),
143      ondemand_mode_(false),
144      next_request_id_(1),
145      num_resolve_(0),
146      num_resolve_from_cache_(0) {
147  rules_ = CreateCatchAllHostResolverProc();
148
149  if (use_caching) {
150    cache_.reset(new HostCache(kMaxCacheEntries));
151  }
152}
153
154int MockHostResolverBase::ResolveFromIPLiteralOrCache(const RequestInfo& info,
155                                                      AddressList* addresses) {
156  IPAddressNumber ip;
157  if (ParseIPLiteralToNumber(info.hostname(), &ip)) {
158    // This matches the behavior HostResolverImpl.
159    if (info.address_family() != ADDRESS_FAMILY_UNSPECIFIED &&
160        info.address_family() != GetAddressFamily(ip)) {
161      return ERR_NAME_NOT_RESOLVED;
162    }
163
164    *addresses = AddressList::CreateFromIPAddress(ip, info.port());
165    if (info.host_resolver_flags() & HOST_RESOLVER_CANONNAME)
166      addresses->SetDefaultCanonicalName();
167    return OK;
168  }
169  int rv = ERR_DNS_CACHE_MISS;
170  if (cache_.get() && info.allow_cached_response()) {
171    HostCache::Key key(info.hostname(),
172                       info.address_family(),
173                       info.host_resolver_flags());
174    const HostCache::Entry* entry = cache_->Lookup(key, base::TimeTicks::Now());
175    if (entry) {
176      rv = entry->error;
177      if (rv == OK)
178        *addresses = AddressList::CopyWithPort(entry->addrlist, info.port());
179    }
180  }
181  return rv;
182}
183
184int MockHostResolverBase::ResolveProc(size_t id,
185                                      const RequestInfo& info,
186                                      AddressList* addresses) {
187  AddressList addr;
188  int rv = rules_->Resolve(info.hostname(),
189                           info.address_family(),
190                           info.host_resolver_flags(),
191                           &addr,
192                           NULL);
193  if (cache_.get()) {
194    HostCache::Key key(info.hostname(),
195                       info.address_family(),
196                       info.host_resolver_flags());
197    // Storing a failure with TTL 0 so that it overwrites previous value.
198    base::TimeDelta ttl;
199    if (rv == OK)
200      ttl = base::TimeDelta::FromSeconds(kCacheEntryTTLSeconds);
201    cache_->Set(key, HostCache::Entry(rv, addr), base::TimeTicks::Now(), ttl);
202  }
203  if (rv == OK)
204    *addresses = AddressList::CopyWithPort(addr, info.port());
205  return rv;
206}
207
208void MockHostResolverBase::ResolveNow(size_t id) {
209  RequestMap::iterator it = requests_.find(id);
210  if (it == requests_.end())
211    return;  // was canceled
212
213  scoped_ptr<Request> req(it->second);
214  requests_.erase(it);
215  int rv = ResolveProc(id, req->info, req->addresses);
216  if (!req->callback.is_null())
217    req->callback.Run(rv);
218}
219
220//-----------------------------------------------------------------------------
221
222struct RuleBasedHostResolverProc::Rule {
223  enum ResolverType {
224    kResolverTypeFail,
225    kResolverTypeSystem,
226    kResolverTypeIPLiteral,
227  };
228
229  ResolverType resolver_type;
230  std::string host_pattern;
231  AddressFamily address_family;
232  HostResolverFlags host_resolver_flags;
233  std::string replacement;
234  std::string canonical_name;
235  int latency_ms;  // In milliseconds.
236
237  Rule(ResolverType resolver_type,
238       const std::string& host_pattern,
239       AddressFamily address_family,
240       HostResolverFlags host_resolver_flags,
241       const std::string& replacement,
242       const std::string& canonical_name,
243       int latency_ms)
244      : resolver_type(resolver_type),
245        host_pattern(host_pattern),
246        address_family(address_family),
247        host_resolver_flags(host_resolver_flags),
248        replacement(replacement),
249        canonical_name(canonical_name),
250        latency_ms(latency_ms) {}
251};
252
253RuleBasedHostResolverProc::RuleBasedHostResolverProc(HostResolverProc* previous)
254    : HostResolverProc(previous) {
255}
256
257void RuleBasedHostResolverProc::AddRule(const std::string& host_pattern,
258                                        const std::string& replacement) {
259  AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
260                          replacement);
261}
262
263void RuleBasedHostResolverProc::AddRuleForAddressFamily(
264    const std::string& host_pattern,
265    AddressFamily address_family,
266    const std::string& replacement) {
267  DCHECK(!replacement.empty());
268  HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
269      HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
270  Rule rule(Rule::kResolverTypeSystem,
271            host_pattern,
272            address_family,
273            flags,
274            replacement,
275            std::string(),
276            0);
277  rules_.push_back(rule);
278}
279
280void RuleBasedHostResolverProc::AddIPLiteralRule(
281    const std::string& host_pattern,
282    const std::string& ip_literal,
283    const std::string& canonical_name) {
284  // Literals are always resolved to themselves by HostResolverImpl,
285  // consequently we do not support remapping them.
286  IPAddressNumber ip_number;
287  DCHECK(!ParseIPLiteralToNumber(host_pattern, &ip_number));
288  HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
289      HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
290  if (!canonical_name.empty())
291    flags |= HOST_RESOLVER_CANONNAME;
292  Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
293            ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, canonical_name,
294            0);
295  rules_.push_back(rule);
296}
297
298void RuleBasedHostResolverProc::AddRuleWithLatency(
299    const std::string& host_pattern,
300    const std::string& replacement,
301    int latency_ms) {
302  DCHECK(!replacement.empty());
303  HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
304      HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
305  Rule rule(Rule::kResolverTypeSystem,
306            host_pattern,
307            ADDRESS_FAMILY_UNSPECIFIED,
308            flags,
309            replacement,
310            std::string(),
311            latency_ms);
312  rules_.push_back(rule);
313}
314
315void RuleBasedHostResolverProc::AllowDirectLookup(
316    const std::string& host_pattern) {
317  HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
318      HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
319  Rule rule(Rule::kResolverTypeSystem,
320            host_pattern,
321            ADDRESS_FAMILY_UNSPECIFIED,
322            flags,
323            std::string(),
324            std::string(),
325            0);
326  rules_.push_back(rule);
327}
328
329void RuleBasedHostResolverProc::AddSimulatedFailure(
330    const std::string& host_pattern) {
331  HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
332      HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
333  Rule rule(Rule::kResolverTypeFail,
334            host_pattern,
335            ADDRESS_FAMILY_UNSPECIFIED,
336            flags,
337            std::string(),
338            std::string(),
339            0);
340  rules_.push_back(rule);
341}
342
343void RuleBasedHostResolverProc::ClearRules() {
344  rules_.clear();
345}
346
347int RuleBasedHostResolverProc::Resolve(const std::string& host,
348                                       AddressFamily address_family,
349                                       HostResolverFlags host_resolver_flags,
350                                       AddressList* addrlist,
351                                       int* os_error) {
352  RuleList::iterator r;
353  for (r = rules_.begin(); r != rules_.end(); ++r) {
354    bool matches_address_family =
355        r->address_family == ADDRESS_FAMILY_UNSPECIFIED ||
356        r->address_family == address_family;
357    // Ignore HOST_RESOLVER_SYSTEM_ONLY, since it should have no impact on
358    // whether a rule matches.
359    HostResolverFlags flags = host_resolver_flags & ~HOST_RESOLVER_SYSTEM_ONLY;
360    // Flags match if all of the bitflags in host_resolver_flags are enabled
361    // in the rule's host_resolver_flags. However, the rule may have additional
362    // flags specified, in which case the flags should still be considered a
363    // match.
364    bool matches_flags = (r->host_resolver_flags & flags) == flags;
365    if (matches_flags && matches_address_family &&
366        MatchPattern(host, r->host_pattern)) {
367      if (r->latency_ms != 0) {
368        base::PlatformThread::Sleep(
369            base::TimeDelta::FromMilliseconds(r->latency_ms));
370      }
371
372      // Remap to a new host.
373      const std::string& effective_host =
374          r->replacement.empty() ? host : r->replacement;
375
376      // Apply the resolving function to the remapped hostname.
377      switch (r->resolver_type) {
378        case Rule::kResolverTypeFail:
379          return ERR_NAME_NOT_RESOLVED;
380        case Rule::kResolverTypeSystem:
381#if defined(OS_WIN)
382          net::EnsureWinsockInit();
383#endif
384          return SystemHostResolverCall(effective_host,
385                                        address_family,
386                                        host_resolver_flags,
387                                        addrlist, os_error);
388        case Rule::kResolverTypeIPLiteral:
389          return ParseAddressList(effective_host,
390                                  r->canonical_name,
391                                  addrlist);
392        default:
393          NOTREACHED();
394          return ERR_UNEXPECTED;
395      }
396    }
397  }
398  return ResolveUsingPrevious(host, address_family,
399                              host_resolver_flags, addrlist, os_error);
400}
401
402RuleBasedHostResolverProc::~RuleBasedHostResolverProc() {
403}
404
405RuleBasedHostResolverProc* CreateCatchAllHostResolverProc() {
406  RuleBasedHostResolverProc* catchall = new RuleBasedHostResolverProc(NULL);
407  catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost");
408
409  // Next add a rules-based layer the use controls.
410  return new RuleBasedHostResolverProc(catchall);
411}
412
413//-----------------------------------------------------------------------------
414
415int HangingHostResolver::Resolve(const RequestInfo& info,
416                                 RequestPriority priority,
417                                 AddressList* addresses,
418                                 const CompletionCallback& callback,
419                                 RequestHandle* out_req,
420                                 const BoundNetLog& net_log) {
421  return ERR_IO_PENDING;
422}
423
424int HangingHostResolver::ResolveFromCache(const RequestInfo& info,
425                                          AddressList* addresses,
426                                          const BoundNetLog& net_log) {
427  return ERR_DNS_CACHE_MISS;
428}
429
430//-----------------------------------------------------------------------------
431
432ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() {}
433
434ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc(
435    HostResolverProc* proc) {
436  Init(proc);
437}
438
439ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc() {
440  HostResolverProc* old_proc =
441      HostResolverProc::SetDefault(previous_proc_.get());
442  // The lifetimes of multiple instances must be nested.
443  CHECK_EQ(old_proc, current_proc_.get());
444}
445
446void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc) {
447  current_proc_ = proc;
448  previous_proc_ = HostResolverProc::SetDefault(current_proc_.get());
449  current_proc_->SetLastProc(previous_proc_.get());
450}
451
452}  // namespace net
453