1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <stdio.h>
6#include <string>
7
8#include "base/at_exit.h"
9#include "base/bind.h"
10#include "base/cancelable_callback.h"
11#include "base/command_line.h"
12#include "base/files/file_util.h"
13#include "base/memory/scoped_ptr.h"
14#include "base/message_loop/message_loop.h"
15#include "base/strings/string_number_conversions.h"
16#include "base/strings/string_split.h"
17#include "base/strings/string_util.h"
18#include "base/strings/stringprintf.h"
19#include "base/strings/utf_string_conversions.h"
20#include "base/time/time.h"
21#include "net/base/address_list.h"
22#include "net/base/ip_endpoint.h"
23#include "net/base/net_errors.h"
24#include "net/base/net_log.h"
25#include "net/base/net_util.h"
26#include "net/dns/dns_client.h"
27#include "net/dns/dns_config_service.h"
28#include "net/dns/dns_protocol.h"
29#include "net/dns/host_cache.h"
30#include "net/dns/host_resolver_impl.h"
31#include "net/tools/gdig/file_net_log.h"
32
33#if defined(OS_MACOSX)
34#include "base/mac/scoped_nsautorelease_pool.h"
35#endif
36
37namespace net {
38
39namespace {
40
41bool StringToIPEndPoint(const std::string& ip_address_and_port,
42                        IPEndPoint* ip_end_point) {
43  DCHECK(ip_end_point);
44
45  std::string ip;
46  int port;
47  if (!ParseHostAndPort(ip_address_and_port, &ip, &port))
48    return false;
49  if (port == -1)
50    port = dns_protocol::kDefaultPort;
51
52  net::IPAddressNumber ip_number;
53  if (!net::ParseIPLiteralToNumber(ip, &ip_number))
54    return false;
55
56  *ip_end_point = net::IPEndPoint(ip_number, port);
57  return true;
58}
59
60// Convert DnsConfig to human readable text omitting the hosts member.
61std::string DnsConfigToString(const DnsConfig& dns_config) {
62  std::string output;
63  output.append("search ");
64  for (size_t i = 0; i < dns_config.search.size(); ++i) {
65    output.append(dns_config.search[i] + " ");
66  }
67  output.append("\n");
68
69  for (size_t i = 0; i < dns_config.nameservers.size(); ++i) {
70    output.append("nameserver ");
71    output.append(dns_config.nameservers[i].ToString()).append("\n");
72  }
73
74  base::StringAppendF(&output, "options ndots:%d\n", dns_config.ndots);
75  base::StringAppendF(&output, "options timeout:%d\n",
76                      static_cast<int>(dns_config.timeout.InMilliseconds()));
77  base::StringAppendF(&output, "options attempts:%d\n", dns_config.attempts);
78  if (dns_config.rotate)
79    output.append("options rotate\n");
80  if (dns_config.edns0)
81    output.append("options edns0\n");
82  return output;
83}
84
85// Convert DnsConfig hosts member to a human readable text.
86std::string DnsHostsToString(const DnsHosts& dns_hosts) {
87  std::string output;
88  for (DnsHosts::const_iterator i = dns_hosts.begin();
89       i != dns_hosts.end();
90       ++i) {
91    const DnsHostsKey& key = i->first;
92    std::string host_name = key.first;
93    output.append(IPEndPoint(i->second, -1).ToStringWithoutPort());
94    output.append(" ").append(host_name).append("\n");
95  }
96  return output;
97}
98
99struct ReplayLogEntry {
100  base::TimeDelta start_time;
101  std::string domain_name;
102};
103
104typedef std::vector<ReplayLogEntry> ReplayLog;
105
106// Loads and parses a replay log file and fills |replay_log| with a structured
107// representation. Returns whether the operation was successful. If not, the
108// contents of |replay_log| are undefined.
109//
110// The replay log is a text file where each line contains
111//
112//   timestamp_in_milliseconds domain_name
113//
114// The timestamp_in_milliseconds needs to be an integral delta from start of
115// resolution and is in milliseconds. domain_name is the name to be resolved.
116//
117// The file should be sorted by timestamp in ascending time.
118bool LoadReplayLog(const base::FilePath& file_path, ReplayLog* replay_log) {
119  std::string original_replay_log_contents;
120  if (!base::ReadFileToString(file_path, &original_replay_log_contents)) {
121    fprintf(stderr, "Unable to open replay file %s\n",
122            file_path.MaybeAsASCII().c_str());
123    return false;
124  }
125
126  // Strip out \r characters for Windows files. This isn't as efficient as a
127  // smarter line splitter, but this particular use does not need to target
128  // efficiency.
129  std::string replay_log_contents;
130  base::RemoveChars(original_replay_log_contents, "\r", &replay_log_contents);
131
132  std::vector<std::string> lines;
133  base::SplitString(replay_log_contents, '\n', &lines);
134  base::TimeDelta previous_delta;
135  bool bad_parse = false;
136  for (unsigned i = 0; i < lines.size(); ++i) {
137    if (lines[i].empty())
138      continue;
139    std::vector<std::string> time_and_name;
140    base::SplitString(lines[i], ' ', &time_and_name);
141    if (time_and_name.size() != 2) {
142      fprintf(
143          stderr,
144          "[%s %u] replay log should have format 'timestamp domain_name\\n'\n",
145          file_path.MaybeAsASCII().c_str(),
146          i + 1);
147      bad_parse = true;
148      continue;
149    }
150
151    int64 delta_in_milliseconds;
152    if (!base::StringToInt64(time_and_name[0], &delta_in_milliseconds)) {
153      fprintf(
154          stderr,
155          "[%s %u] replay log should have format 'timestamp domain_name\\n'\n",
156          file_path.MaybeAsASCII().c_str(),
157          i + 1);
158      bad_parse = true;
159      continue;
160    }
161
162    base::TimeDelta delta =
163        base::TimeDelta::FromMilliseconds(delta_in_milliseconds);
164    if (delta < previous_delta) {
165      fprintf(
166          stderr,
167          "[%s %u] replay log should be sorted by time\n",
168          file_path.MaybeAsASCII().c_str(),
169          i + 1);
170      bad_parse = true;
171      continue;
172    }
173
174    previous_delta = delta;
175    ReplayLogEntry entry;
176    entry.start_time = delta;
177    entry.domain_name = time_and_name[1];
178    replay_log->push_back(entry);
179  }
180  return !bad_parse;
181}
182
183class GDig {
184 public:
185  GDig();
186  ~GDig();
187
188  enum Result {
189    RESULT_NO_RESOLVE = -3,
190    RESULT_NO_CONFIG = -2,
191    RESULT_WRONG_USAGE = -1,
192    RESULT_OK = 0,
193    RESULT_PENDING = 1,
194  };
195
196  Result Main(int argc, const char* argv[]);
197
198 private:
199  bool ParseCommandLine(int argc, const char* argv[]);
200
201  void Start();
202  void Finish(Result);
203
204  void OnDnsConfig(const DnsConfig& dns_config_const);
205  void OnResolveComplete(unsigned index, AddressList* address_list,
206                         base::TimeDelta time_since_start, int val);
207  void OnTimeout();
208  void ReplayNextEntry();
209
210  base::TimeDelta config_timeout_;
211  bool print_config_;
212  bool print_hosts_;
213  net::IPEndPoint nameserver_;
214  base::TimeDelta timeout_;
215  int parallellism_;
216  ReplayLog replay_log_;
217  unsigned replay_log_index_;
218  base::Time start_time_;
219  int active_resolves_;
220  Result result_;
221
222  base::CancelableClosure timeout_closure_;
223  scoped_ptr<DnsConfigService> dns_config_service_;
224  scoped_ptr<FileNetLogObserver> log_observer_;
225  scoped_ptr<NetLog> log_;
226  scoped_ptr<HostResolver> resolver_;
227
228#if defined(OS_MACOSX)
229  // Without this there will be a mem leak on osx.
230  base::mac::ScopedNSAutoreleasePool scoped_pool_;
231#endif
232
233  // Need AtExitManager to support AsWeakPtr (in NetLog).
234  base::AtExitManager exit_manager_;
235};
236
237GDig::GDig()
238    : config_timeout_(base::TimeDelta::FromSeconds(5)),
239      print_config_(false),
240      print_hosts_(false),
241      parallellism_(6),
242      replay_log_index_(0u),
243      active_resolves_(0) {
244}
245
246GDig::~GDig() {
247  if (log_)
248    log_->RemoveThreadSafeObserver(log_observer_.get());
249}
250
251GDig::Result GDig::Main(int argc, const char* argv[]) {
252  if (!ParseCommandLine(argc, argv)) {
253      fprintf(stderr,
254              "usage: %s [--net_log[=<basic|no_bytes|all>]]"
255              " [--print_config] [--print_hosts]"
256              " [--nameserver=<ip_address[:port]>]"
257              " [--timeout=<milliseconds>]"
258              " [--config_timeout=<seconds>]"
259              " [--j=<parallel resolves>]"
260              " [--replay_file=<path>]"
261              " [domain_name]\n",
262              argv[0]);
263      return RESULT_WRONG_USAGE;
264  }
265
266  base::MessageLoopForIO loop;
267
268  result_ = RESULT_PENDING;
269  Start();
270  if (result_ == RESULT_PENDING)
271    base::MessageLoop::current()->Run();
272
273  // Destroy it while MessageLoopForIO is alive.
274  dns_config_service_.reset();
275  return result_;
276}
277
278bool GDig::ParseCommandLine(int argc, const char* argv[]) {
279  base::CommandLine::Init(argc, argv);
280  const base::CommandLine& parsed_command_line =
281      *base::CommandLine::ForCurrentProcess();
282
283  if (parsed_command_line.HasSwitch("config_timeout")) {
284    int timeout_seconds = 0;
285    bool parsed = base::StringToInt(
286        parsed_command_line.GetSwitchValueASCII("config_timeout"),
287        &timeout_seconds);
288    if (parsed && timeout_seconds > 0) {
289      config_timeout_ = base::TimeDelta::FromSeconds(timeout_seconds);
290    } else {
291      fprintf(stderr, "Invalid config_timeout parameter\n");
292      return false;
293    }
294  }
295
296  if (parsed_command_line.HasSwitch("net_log")) {
297    std::string log_param = parsed_command_line.GetSwitchValueASCII("net_log");
298    NetLog::LogLevel level = NetLog::LOG_ALL_BUT_BYTES;
299
300    if (log_param.length() > 0) {
301      std::map<std::string, NetLog::LogLevel> log_levels;
302      log_levels["all"] = NetLog::LOG_ALL;
303      log_levels["no_bytes"] = NetLog::LOG_ALL_BUT_BYTES;
304
305      if (log_levels.find(log_param) != log_levels.end()) {
306        level = log_levels[log_param];
307      } else {
308        fprintf(stderr, "Invalid net_log parameter\n");
309        return false;
310      }
311    }
312    log_.reset(new NetLog);
313    log_observer_.reset(new FileNetLogObserver(stderr));
314    log_->AddThreadSafeObserver(log_observer_.get(), level);
315  }
316
317  print_config_ = parsed_command_line.HasSwitch("print_config");
318  print_hosts_ = parsed_command_line.HasSwitch("print_hosts");
319
320  if (parsed_command_line.HasSwitch("nameserver")) {
321    std::string nameserver =
322      parsed_command_line.GetSwitchValueASCII("nameserver");
323    if (!StringToIPEndPoint(nameserver, &nameserver_)) {
324      fprintf(stderr,
325              "Cannot parse the namerserver string into an IPEndPoint\n");
326      return false;
327    }
328  }
329
330  if (parsed_command_line.HasSwitch("timeout")) {
331    int timeout_millis = 0;
332    bool parsed = base::StringToInt(
333        parsed_command_line.GetSwitchValueASCII("timeout"),
334        &timeout_millis);
335    if (parsed && timeout_millis > 0) {
336      timeout_ = base::TimeDelta::FromMilliseconds(timeout_millis);
337    } else {
338      fprintf(stderr, "Invalid timeout parameter\n");
339      return false;
340    }
341  }
342
343  if (parsed_command_line.HasSwitch("replay_file")) {
344    base::FilePath replay_path =
345        parsed_command_line.GetSwitchValuePath("replay_file");
346    if (!LoadReplayLog(replay_path, &replay_log_))
347      return false;
348  }
349
350  if (parsed_command_line.HasSwitch("j")) {
351    int parallellism = 0;
352    bool parsed = base::StringToInt(
353        parsed_command_line.GetSwitchValueASCII("j"),
354        &parallellism);
355    if (parsed && parallellism > 0) {
356      parallellism_ = parallellism;
357    } else {
358      fprintf(stderr, "Invalid parallellism parameter\n");
359    }
360  }
361
362  if (parsed_command_line.GetArgs().size() == 1) {
363    ReplayLogEntry entry;
364    entry.start_time = base::TimeDelta();
365#if defined(OS_WIN)
366    entry.domain_name = base::UTF16ToASCII(parsed_command_line.GetArgs()[0]);
367#else
368    entry.domain_name = parsed_command_line.GetArgs()[0];
369#endif
370    replay_log_.push_back(entry);
371  } else if (parsed_command_line.GetArgs().size() != 0) {
372    return false;
373  }
374  return print_config_ || print_hosts_ || !replay_log_.empty();
375}
376
377void GDig::Start() {
378  if (nameserver_.address().size() > 0) {
379    DnsConfig dns_config;
380    dns_config.attempts = 1;
381    dns_config.nameservers.push_back(nameserver_);
382    OnDnsConfig(dns_config);
383  } else {
384    dns_config_service_ = DnsConfigService::CreateSystemService();
385    dns_config_service_->ReadConfig(base::Bind(&GDig::OnDnsConfig,
386                                               base::Unretained(this)));
387    timeout_closure_.Reset(base::Bind(&GDig::OnTimeout,
388                                      base::Unretained(this)));
389    base::MessageLoop::current()->PostDelayedTask(
390        FROM_HERE, timeout_closure_.callback(), config_timeout_);
391  }
392}
393
394void GDig::Finish(Result result) {
395  DCHECK_NE(RESULT_PENDING, result);
396  result_ = result;
397  if (base::MessageLoop::current())
398    base::MessageLoop::current()->Quit();
399}
400
401void GDig::OnDnsConfig(const DnsConfig& dns_config_const) {
402  timeout_closure_.Cancel();
403  DCHECK(dns_config_const.IsValid());
404  DnsConfig dns_config = dns_config_const;
405
406  if (timeout_.InMilliseconds() > 0)
407    dns_config.timeout = timeout_;
408  if (print_config_) {
409    printf("# Dns Configuration\n"
410           "%s", DnsConfigToString(dns_config).c_str());
411  }
412  if (print_hosts_) {
413    printf("# Host Database\n"
414           "%s", DnsHostsToString(dns_config.hosts).c_str());
415  }
416
417  if (replay_log_.empty()) {
418    Finish(RESULT_OK);
419    return;
420  }
421
422  scoped_ptr<DnsClient> dns_client(DnsClient::CreateClient(NULL));
423  dns_client->SetConfig(dns_config);
424  HostResolver::Options options;
425  options.max_concurrent_resolves = parallellism_;
426  options.max_retry_attempts = 1u;
427  scoped_ptr<HostResolverImpl> resolver(
428      new HostResolverImpl(options, log_.get()));
429  resolver->SetDnsClient(dns_client.Pass());
430  resolver_ = resolver.Pass();
431
432  start_time_ = base::Time::Now();
433
434  ReplayNextEntry();
435}
436
437void GDig::ReplayNextEntry() {
438  DCHECK_LT(replay_log_index_, replay_log_.size());
439
440  base::TimeDelta time_since_start = base::Time::Now() - start_time_;
441  while (replay_log_index_ < replay_log_.size()) {
442    const ReplayLogEntry& entry = replay_log_[replay_log_index_];
443    if (time_since_start < entry.start_time) {
444      // Delay call to next time and return.
445      base::MessageLoop::current()->PostDelayedTask(
446          FROM_HERE,
447          base::Bind(&GDig::ReplayNextEntry, base::Unretained(this)),
448          entry.start_time - time_since_start);
449      return;
450    }
451
452    HostResolver::RequestInfo info(HostPortPair(entry.domain_name.c_str(), 80));
453    AddressList* addrlist = new AddressList();
454    unsigned current_index = replay_log_index_;
455    CompletionCallback callback = base::Bind(&GDig::OnResolveComplete,
456                                             base::Unretained(this),
457                                             current_index,
458                                             base::Owned(addrlist),
459                                             time_since_start);
460    ++active_resolves_;
461    ++replay_log_index_;
462    int ret = resolver_->Resolve(
463        info,
464        DEFAULT_PRIORITY,
465        addrlist,
466        callback,
467        NULL,
468        BoundNetLog::Make(log_.get(), net::NetLog::SOURCE_NONE));
469    if (ret != ERR_IO_PENDING)
470      callback.Run(ret);
471  }
472}
473
474void GDig::OnResolveComplete(unsigned entry_index,
475                             AddressList* address_list,
476                             base::TimeDelta resolve_start_time,
477                             int val) {
478  DCHECK_GT(active_resolves_, 0);
479  DCHECK(address_list);
480  DCHECK_LT(entry_index, replay_log_.size());
481  --active_resolves_;
482  base::TimeDelta resolve_end_time = base::Time::Now() - start_time_;
483  base::TimeDelta resolve_time = resolve_end_time - resolve_start_time;
484  printf("%u %d %d %s %d ",
485         entry_index,
486         static_cast<int>(resolve_end_time.InMilliseconds()),
487         static_cast<int>(resolve_time.InMilliseconds()),
488         replay_log_[entry_index].domain_name.c_str(), val);
489  if (val != OK) {
490    std::string error_string = ErrorToString(val);
491    printf("%s", error_string.c_str());
492  } else {
493    for (size_t i = 0; i < address_list->size(); ++i) {
494      if (i != 0)
495        printf(" ");
496      printf("%s", (*address_list)[i].ToStringWithoutPort().c_str());
497    }
498  }
499  printf("\n");
500  if (active_resolves_ == 0 && replay_log_index_ >= replay_log_.size())
501    Finish(RESULT_OK);
502}
503
504void GDig::OnTimeout() {
505  fprintf(stderr, "Timed out waiting to load the dns config\n");
506  Finish(RESULT_NO_CONFIG);
507}
508
509}  // empty namespace
510
511}  // namespace net
512
513int main(int argc, const char* argv[]) {
514  net::GDig dig;
515  return dig.Main(argc, argv);
516}
517