1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/dns/dns_test_util.h"
6
7#include <string>
8
9#include "base/big_endian.h"
10#include "base/bind.h"
11#include "base/memory/weak_ptr.h"
12#include "base/message_loop/message_loop.h"
13#include "base/sys_byteorder.h"
14#include "net/base/dns_util.h"
15#include "net/base/io_buffer.h"
16#include "net/base/net_errors.h"
17#include "net/dns/address_sorter.h"
18#include "net/dns/dns_query.h"
19#include "net/dns/dns_response.h"
20#include "net/dns/dns_transaction.h"
21#include "testing/gtest/include/gtest/gtest.h"
22
23namespace net {
24namespace {
25
26class MockAddressSorter : public AddressSorter {
27 public:
28  virtual ~MockAddressSorter() {}
29  virtual void Sort(const AddressList& list,
30                    const CallbackType& callback) const OVERRIDE {
31    // Do nothing.
32    callback.Run(true, list);
33  }
34};
35
36// A DnsTransaction which uses MockDnsClientRuleList to determine the response.
37class MockTransaction : public DnsTransaction,
38                        public base::SupportsWeakPtr<MockTransaction> {
39 public:
40  MockTransaction(const MockDnsClientRuleList& rules,
41                  const std::string& hostname,
42                  uint16 qtype,
43                  const DnsTransactionFactory::CallbackType& callback)
44      : result_(MockDnsClientRule::FAIL),
45        hostname_(hostname),
46        qtype_(qtype),
47        callback_(callback),
48        started_(false),
49        delayed_(false) {
50    // Find the relevant rule which matches |qtype| and prefix of |hostname|.
51    for (size_t i = 0; i < rules.size(); ++i) {
52      const std::string& prefix = rules[i].prefix;
53      if ((rules[i].qtype == qtype) &&
54          (hostname.size() >= prefix.size()) &&
55          (hostname.compare(0, prefix.size(), prefix) == 0)) {
56        result_ = rules[i].result;
57        delayed_ = rules[i].delay;
58        break;
59      }
60    }
61  }
62
63  virtual const std::string& GetHostname() const OVERRIDE {
64    return hostname_;
65  }
66
67  virtual uint16 GetType() const OVERRIDE {
68    return qtype_;
69  }
70
71  virtual void Start() OVERRIDE {
72    EXPECT_FALSE(started_);
73    started_ = true;
74    if (delayed_)
75      return;
76    // Using WeakPtr to cleanly cancel when transaction is destroyed.
77    base::MessageLoop::current()->PostTask(
78        FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr()));
79  }
80
81  void FinishDelayedTransaction() {
82    EXPECT_TRUE(delayed_);
83    delayed_ = false;
84    Finish();
85  }
86
87  bool delayed() const { return delayed_; }
88
89 private:
90  void Finish() {
91    switch (result_) {
92      case MockDnsClientRule::EMPTY:
93      case MockDnsClientRule::OK: {
94        std::string qname;
95        DNSDomainFromDot(hostname_, &qname);
96        DnsQuery query(0, qname, qtype_);
97
98        DnsResponse response;
99        char* buffer = response.io_buffer()->data();
100        int nbytes = query.io_buffer()->size();
101        memcpy(buffer, query.io_buffer()->data(), nbytes);
102        dns_protocol::Header* header =
103            reinterpret_cast<dns_protocol::Header*>(buffer);
104        header->flags |= dns_protocol::kFlagResponse;
105
106        if (MockDnsClientRule::OK == result_) {
107          const uint16 kPointerToQueryName =
108              static_cast<uint16>(0xc000 | sizeof(*header));
109
110          const uint32 kTTL = 86400;  // One day.
111
112          // Size of RDATA which is a IPv4 or IPv6 address.
113          size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ?
114                              net::kIPv4AddressSize : net::kIPv6AddressSize;
115
116          // 12 is the sum of sizes of the compressed name reference, TYPE,
117          // CLASS, TTL and RDLENGTH.
118          size_t answer_size = 12 + rdata_size;
119
120          // Write answer with loopback IP address.
121          header->ancount = base::HostToNet16(1);
122          base::BigEndianWriter writer(buffer + nbytes, answer_size);
123          writer.WriteU16(kPointerToQueryName);
124          writer.WriteU16(qtype_);
125          writer.WriteU16(net::dns_protocol::kClassIN);
126          writer.WriteU32(kTTL);
127          writer.WriteU16(rdata_size);
128          if (qtype_ == net::dns_protocol::kTypeA) {
129            char kIPv4Loopback[] = { 0x7f, 0, 0, 1 };
130            writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback));
131          } else {
132            char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0,
133                                     0, 0, 0, 0, 0, 0, 0, 1 };
134            writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback));
135          }
136          nbytes += answer_size;
137        }
138        EXPECT_TRUE(response.InitParse(nbytes, query));
139        callback_.Run(this, OK, &response);
140      } break;
141      case MockDnsClientRule::FAIL:
142        callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL);
143        break;
144      case MockDnsClientRule::TIMEOUT:
145        callback_.Run(this, ERR_DNS_TIMED_OUT, NULL);
146        break;
147      default:
148        NOTREACHED();
149        break;
150    }
151  }
152
153  MockDnsClientRule::Result result_;
154  const std::string hostname_;
155  const uint16 qtype_;
156  DnsTransactionFactory::CallbackType callback_;
157  bool started_;
158  bool delayed_;
159};
160
161}  // namespace
162
163// A DnsTransactionFactory which creates MockTransaction.
164class MockTransactionFactory : public DnsTransactionFactory {
165 public:
166  explicit MockTransactionFactory(const MockDnsClientRuleList& rules)
167      : rules_(rules) {}
168
169  virtual ~MockTransactionFactory() {}
170
171  virtual scoped_ptr<DnsTransaction> CreateTransaction(
172      const std::string& hostname,
173      uint16 qtype,
174      const DnsTransactionFactory::CallbackType& callback,
175      const BoundNetLog&) OVERRIDE {
176    MockTransaction* transaction =
177        new MockTransaction(rules_, hostname, qtype, callback);
178    if (transaction->delayed())
179      delayed_transactions_.push_back(transaction->AsWeakPtr());
180    return scoped_ptr<DnsTransaction>(transaction);
181  }
182
183  void CompleteDelayedTransactions() {
184    DelayedTransactionList old_delayed_transactions;
185    old_delayed_transactions.swap(delayed_transactions_);
186    for (DelayedTransactionList::iterator it = old_delayed_transactions.begin();
187         it != old_delayed_transactions.end(); ++it) {
188      if (it->get())
189        (*it)->FinishDelayedTransaction();
190    }
191  }
192
193 private:
194  typedef std::vector<base::WeakPtr<MockTransaction> > DelayedTransactionList;
195
196  MockDnsClientRuleList rules_;
197  DelayedTransactionList delayed_transactions_;
198};
199
200MockDnsClient::MockDnsClient(const DnsConfig& config,
201                             const MockDnsClientRuleList& rules)
202      : config_(config),
203        factory_(new MockTransactionFactory(rules)),
204        address_sorter_(new MockAddressSorter()) {
205}
206
207MockDnsClient::~MockDnsClient() {}
208
209void MockDnsClient::SetConfig(const DnsConfig& config) {
210  config_ = config;
211}
212
213const DnsConfig* MockDnsClient::GetConfig() const {
214  return config_.IsValid() ? &config_ : NULL;
215}
216
217DnsTransactionFactory* MockDnsClient::GetTransactionFactory() {
218  return config_.IsValid() ? factory_.get() : NULL;
219}
220
221AddressSorter* MockDnsClient::GetAddressSorter() {
222  return address_sorter_.get();
223}
224
225void MockDnsClient::CompleteDelayedTransactions() {
226  factory_->CompleteDelayedTransactions();
227}
228
229}  // namespace net
230