mock_host_resolver.cc revision c2e0dbddbe15c98d52c4786dac06cb8952a8ae6d
1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/dns/mock_host_resolver.h"
6
7#include <string>
8#include <vector>
9
10#include "base/bind.h"
11#include "base/memory/ref_counted.h"
12#include "base/message_loop.h"
13#include "base/stl_util.h"
14#include "base/string_util.h"
15#include "base/strings/string_split.h"
16#include "base/threading/platform_thread.h"
17#include "net/base/net_errors.h"
18#include "net/base/net_util.h"
19#include "net/base/test_completion_callback.h"
20#include "net/dns/host_cache.h"
21
22#if defined(OS_WIN)
23#include "net/base/winsock_init.h"
24#endif
25
26namespace net {
27
28namespace {
29
30// Cache size for the MockCachingHostResolver.
31const unsigned kMaxCacheEntries = 100;
32// TTL for the successful resolutions. Failures are not cached.
33const unsigned kCacheEntryTTLSeconds = 60;
34
35}  // namespace
36
37int ParseAddressList(const std::string& host_list,
38                     const std::string& canonical_name,
39                     AddressList* addrlist) {
40  *addrlist = AddressList();
41  std::vector<std::string> addresses;
42  base::SplitString(host_list, ',', &addresses);
43  addrlist->set_canonical_name(canonical_name);
44  for (size_t index = 0; index < addresses.size(); ++index) {
45    IPAddressNumber ip_number;
46    if (!ParseIPLiteralToNumber(addresses[index], &ip_number)) {
47      LOG(WARNING) << "Not a supported IP literal: " << addresses[index];
48      return ERR_UNEXPECTED;
49    }
50    addrlist->push_back(IPEndPoint(ip_number, -1));
51  }
52  return OK;
53}
54
55struct MockHostResolverBase::Request {
56  Request(const RequestInfo& req_info,
57          AddressList* addr,
58          const CompletionCallback& cb)
59      : info(req_info), addresses(addr), callback(cb) {}
60  RequestInfo info;
61  AddressList* addresses;
62  CompletionCallback callback;
63};
64
65MockHostResolverBase::~MockHostResolverBase() {
66  STLDeleteValues(&requests_);
67}
68
69int MockHostResolverBase::Resolve(const RequestInfo& info,
70                                  AddressList* addresses,
71                                  const CompletionCallback& callback,
72                                  RequestHandle* handle,
73                                  const BoundNetLog& net_log) {
74  DCHECK(CalledOnValidThread());
75  num_resolve_++;
76  size_t id = next_request_id_++;
77  int rv = ResolveFromIPLiteralOrCache(info, addresses);
78  if (rv != ERR_DNS_CACHE_MISS) {
79    return rv;
80  }
81  if (synchronous_mode_) {
82    return ResolveProc(id, info, addresses);
83  }
84  // Store the request for asynchronous resolution
85  Request* req = new Request(info, addresses, callback);
86  requests_[id] = req;
87  if (handle)
88    *handle = reinterpret_cast<RequestHandle>(id);
89
90  if (!ondemand_mode_) {
91    MessageLoop::current()->PostTask(
92        FROM_HERE,
93        base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id));
94  }
95
96  return ERR_IO_PENDING;
97}
98
99int MockHostResolverBase::ResolveFromCache(const RequestInfo& info,
100                                           AddressList* addresses,
101                                           const BoundNetLog& net_log) {
102  num_resolve_from_cache_++;
103  DCHECK(CalledOnValidThread());
104  next_request_id_++;
105  int rv = ResolveFromIPLiteralOrCache(info, addresses);
106  return rv;
107}
108
109void MockHostResolverBase::CancelRequest(RequestHandle handle) {
110  DCHECK(CalledOnValidThread());
111  size_t id = reinterpret_cast<size_t>(handle);
112  RequestMap::iterator it = requests_.find(id);
113  if (it != requests_.end()) {
114    scoped_ptr<Request> req(it->second);
115    requests_.erase(it);
116  } else {
117    NOTREACHED() << "CancelRequest must NOT be called after request is "
118        "complete or canceled.";
119  }
120}
121
122HostCache* MockHostResolverBase::GetHostCache() {
123  return cache_.get();
124}
125
126void MockHostResolverBase::ResolveAllPending() {
127  DCHECK(CalledOnValidThread());
128  DCHECK(ondemand_mode_);
129  for (RequestMap::iterator i = requests_.begin(); i != requests_.end(); ++i) {
130    MessageLoop::current()->PostTask(
131        FROM_HERE,
132        base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), i->first));
133  }
134}
135
136// start id from 1 to distinguish from NULL RequestHandle
137MockHostResolverBase::MockHostResolverBase(bool use_caching)
138    : synchronous_mode_(false),
139      ondemand_mode_(false),
140      next_request_id_(1),
141      num_resolve_(0),
142      num_resolve_from_cache_(0) {
143  rules_ = CreateCatchAllHostResolverProc();
144
145  if (use_caching) {
146    cache_.reset(new HostCache(kMaxCacheEntries));
147  }
148}
149
150int MockHostResolverBase::ResolveFromIPLiteralOrCache(const RequestInfo& info,
151                                                      AddressList* addresses) {
152  IPAddressNumber ip;
153  if (ParseIPLiteralToNumber(info.hostname(), &ip)) {
154    *addresses = AddressList::CreateFromIPAddress(ip, info.port());
155    if (info.host_resolver_flags() & HOST_RESOLVER_CANONNAME)
156      addresses->SetDefaultCanonicalName();
157    return OK;
158  }
159  int rv = ERR_DNS_CACHE_MISS;
160  if (cache_.get() && info.allow_cached_response()) {
161    HostCache::Key key(info.hostname(),
162                       info.address_family(),
163                       info.host_resolver_flags());
164    const HostCache::Entry* entry = cache_->Lookup(key, base::TimeTicks::Now());
165    if (entry) {
166      rv = entry->error;
167      if (rv == OK)
168        *addresses = AddressList::CopyWithPort(entry->addrlist, info.port());
169    }
170  }
171  return rv;
172}
173
174int MockHostResolverBase::ResolveProc(size_t id,
175                                      const RequestInfo& info,
176                                      AddressList* addresses) {
177  AddressList addr;
178  int rv = rules_->Resolve(info.hostname(),
179                           info.address_family(),
180                           info.host_resolver_flags(),
181                           &addr,
182                           NULL);
183  if (cache_.get()) {
184    HostCache::Key key(info.hostname(),
185                       info.address_family(),
186                       info.host_resolver_flags());
187    // Storing a failure with TTL 0 so that it overwrites previous value.
188    base::TimeDelta ttl;
189    if (rv == OK)
190      ttl = base::TimeDelta::FromSeconds(kCacheEntryTTLSeconds);
191    cache_->Set(key, HostCache::Entry(rv, addr), base::TimeTicks::Now(), ttl);
192  }
193  if (rv == OK)
194    *addresses = AddressList::CopyWithPort(addr, info.port());
195  return rv;
196}
197
198void MockHostResolverBase::ResolveNow(size_t id) {
199  RequestMap::iterator it = requests_.find(id);
200  if (it == requests_.end())
201    return;  // was canceled
202
203  scoped_ptr<Request> req(it->second);
204  requests_.erase(it);
205  int rv = ResolveProc(id, req->info, req->addresses);
206  if (!req->callback.is_null())
207    req->callback.Run(rv);
208}
209
210//-----------------------------------------------------------------------------
211
212struct RuleBasedHostResolverProc::Rule {
213  enum ResolverType {
214    kResolverTypeFail,
215    kResolverTypeSystem,
216    kResolverTypeIPLiteral,
217  };
218
219  ResolverType resolver_type;
220  std::string host_pattern;
221  AddressFamily address_family;
222  HostResolverFlags host_resolver_flags;
223  std::string replacement;
224  std::string canonical_name;
225  int latency_ms;  // In milliseconds.
226
227  Rule(ResolverType resolver_type,
228       const std::string& host_pattern,
229       AddressFamily address_family,
230       HostResolverFlags host_resolver_flags,
231       const std::string& replacement,
232       const std::string& canonical_name,
233       int latency_ms)
234      : resolver_type(resolver_type),
235        host_pattern(host_pattern),
236        address_family(address_family),
237        host_resolver_flags(host_resolver_flags),
238        replacement(replacement),
239        canonical_name(canonical_name),
240        latency_ms(latency_ms) {}
241};
242
243RuleBasedHostResolverProc::RuleBasedHostResolverProc(HostResolverProc* previous)
244    : HostResolverProc(previous) {
245}
246
247void RuleBasedHostResolverProc::AddRule(const std::string& host_pattern,
248                                        const std::string& replacement) {
249  AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
250                          replacement);
251}
252
253void RuleBasedHostResolverProc::AddRuleForAddressFamily(
254    const std::string& host_pattern,
255    AddressFamily address_family,
256    const std::string& replacement) {
257  DCHECK(!replacement.empty());
258  HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
259      HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
260  Rule rule(Rule::kResolverTypeSystem,
261            host_pattern,
262            address_family,
263            flags,
264            replacement,
265            std::string(),
266            0);
267  rules_.push_back(rule);
268}
269
270void RuleBasedHostResolverProc::AddIPLiteralRule(
271    const std::string& host_pattern,
272    const std::string& ip_literal,
273    const std::string& canonical_name) {
274  // Literals are always resolved to themselves by HostResolverImpl,
275  // consequently we do not support remapping them.
276  IPAddressNumber ip_number;
277  DCHECK(!ParseIPLiteralToNumber(host_pattern, &ip_number));
278  HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
279      HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
280  if (!canonical_name.empty())
281    flags |= HOST_RESOLVER_CANONNAME;
282  Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
283            ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, canonical_name,
284            0);
285  rules_.push_back(rule);
286}
287
288void RuleBasedHostResolverProc::AddRuleWithLatency(
289    const std::string& host_pattern,
290    const std::string& replacement,
291    int latency_ms) {
292  DCHECK(!replacement.empty());
293  HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
294      HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
295  Rule rule(Rule::kResolverTypeSystem,
296            host_pattern,
297            ADDRESS_FAMILY_UNSPECIFIED,
298            flags,
299            replacement,
300            std::string(),
301            latency_ms);
302  rules_.push_back(rule);
303}
304
305void RuleBasedHostResolverProc::AllowDirectLookup(
306    const std::string& host_pattern) {
307  HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
308      HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
309  Rule rule(Rule::kResolverTypeSystem,
310            host_pattern,
311            ADDRESS_FAMILY_UNSPECIFIED,
312            flags,
313            std::string(),
314            std::string(),
315            0);
316  rules_.push_back(rule);
317}
318
319void RuleBasedHostResolverProc::AddSimulatedFailure(
320    const std::string& host_pattern) {
321  HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
322      HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
323  Rule rule(Rule::kResolverTypeFail,
324            host_pattern,
325            ADDRESS_FAMILY_UNSPECIFIED,
326            flags,
327            std::string(),
328            std::string(),
329            0);
330  rules_.push_back(rule);
331}
332
333void RuleBasedHostResolverProc::ClearRules() {
334  rules_.clear();
335}
336
337int RuleBasedHostResolverProc::Resolve(const std::string& host,
338                                       AddressFamily address_family,
339                                       HostResolverFlags host_resolver_flags,
340                                       AddressList* addrlist,
341                                       int* os_error) {
342  RuleList::iterator r;
343  for (r = rules_.begin(); r != rules_.end(); ++r) {
344    bool matches_address_family =
345        r->address_family == ADDRESS_FAMILY_UNSPECIFIED ||
346        r->address_family == address_family;
347    // Flags match if all of the bitflags in host_resolver_flags are enabled
348    // in the rule's host_resolver_flags. However, the rule may have additional
349    // flags specified, in which case the flags should still be considered a
350    // match.
351    bool matches_flags = (r->host_resolver_flags & host_resolver_flags) ==
352        host_resolver_flags;
353    if (matches_flags && matches_address_family &&
354        MatchPattern(host, r->host_pattern)) {
355      if (r->latency_ms != 0) {
356        base::PlatformThread::Sleep(
357            base::TimeDelta::FromMilliseconds(r->latency_ms));
358      }
359
360      // Remap to a new host.
361      const std::string& effective_host =
362          r->replacement.empty() ? host : r->replacement;
363
364      // Apply the resolving function to the remapped hostname.
365      switch (r->resolver_type) {
366        case Rule::kResolverTypeFail:
367          return ERR_NAME_NOT_RESOLVED;
368        case Rule::kResolverTypeSystem:
369#if defined(OS_WIN)
370          net::EnsureWinsockInit();
371#endif
372          return SystemHostResolverCall(effective_host,
373                                        address_family,
374                                        host_resolver_flags,
375                                        addrlist, os_error);
376        case Rule::kResolverTypeIPLiteral:
377          return ParseAddressList(effective_host,
378                                  r->canonical_name,
379                                  addrlist);
380        default:
381          NOTREACHED();
382          return ERR_UNEXPECTED;
383      }
384    }
385  }
386  return ResolveUsingPrevious(host, address_family,
387                              host_resolver_flags, addrlist, os_error);
388}
389
390RuleBasedHostResolverProc::~RuleBasedHostResolverProc() {
391}
392
393RuleBasedHostResolverProc* CreateCatchAllHostResolverProc() {
394  RuleBasedHostResolverProc* catchall = new RuleBasedHostResolverProc(NULL);
395  catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost");
396
397  // Next add a rules-based layer the use controls.
398  return new RuleBasedHostResolverProc(catchall);
399}
400
401//-----------------------------------------------------------------------------
402
403int HangingHostResolver::Resolve(const RequestInfo& info,
404                                 AddressList* addresses,
405                                 const CompletionCallback& callback,
406                                 RequestHandle* out_req,
407                                 const BoundNetLog& net_log) {
408  return ERR_IO_PENDING;
409}
410
411int HangingHostResolver::ResolveFromCache(const RequestInfo& info,
412                                          AddressList* addresses,
413                                          const BoundNetLog& net_log) {
414  return ERR_DNS_CACHE_MISS;
415}
416
417//-----------------------------------------------------------------------------
418
419ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() {}
420
421ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc(
422    HostResolverProc* proc) {
423  Init(proc);
424}
425
426ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc() {
427  HostResolverProc* old_proc = HostResolverProc::SetDefault(previous_proc_);
428  // The lifetimes of multiple instances must be nested.
429  CHECK_EQ(old_proc, current_proc_);
430}
431
432void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc) {
433  current_proc_ = proc;
434  previous_proc_ = HostResolverProc::SetDefault(current_proc_);
435  current_proc_->SetLastProc(previous_proc_);
436}
437
438}  // namespace net
439