1/*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless requied by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16 */
17
18#include <arpa/inet.h>
19#include <errno.h>
20#include <netdb.h>
21#include <stdarg.h>
22#include <stdio.h>
23#include <stdlib.h>
24#include <unistd.h>
25
26#include <cutils/sockets.h>
27#include <android-base/stringprintf.h>
28#include <private/android_filesystem_config.h>
29
30#include <openssl/base64.h>
31
32#include <algorithm>
33#include <chrono>
34#include <iterator>
35#include <numeric>
36#include <thread>
37
38#define LOG_TAG "netd_test"
39// TODO: make this dynamic and stop depending on implementation details.
40#define TEST_OEM_NETWORK "oem29"
41#define TEST_NETID 30
42
43#include "NetdClient.h"
44
45#include <gtest/gtest.h>
46
47#include <utils/Log.h>
48
49#include "dns_responder.h"
50#include "dns_responder_client.h"
51#include "dns_tls_frontend.h"
52#include "resolv_params.h"
53#include "ResolverStats.h"
54
55#include "android/net/INetd.h"
56#include "android/net/metrics/INetdEventListener.h"
57#include "binder/IServiceManager.h"
58
59using android::base::StringPrintf;
60using android::base::StringAppendF;
61using android::net::ResolverStats;
62using android::net::metrics::INetdEventListener;
63
64// Emulates the behavior of UnorderedElementsAreArray, which currently cannot be used.
65// TODO: Use UnorderedElementsAreArray, which depends on being able to compile libgmock_host,
66// if that is not possible, improve this hacky algorithm, which is O(n**2)
67template <class A, class B>
68bool UnorderedCompareArray(const A& a, const B& b) {
69    if (a.size() != b.size()) return false;
70    for (const auto& a_elem : a) {
71        size_t a_count = 0;
72        for (const auto& a_elem2 : a) {
73            if (a_elem == a_elem2) {
74                ++a_count;
75            }
76        }
77        size_t b_count = 0;
78        for (const auto& b_elem : b) {
79            if (a_elem == b_elem) ++b_count;
80        }
81        if (a_count != b_count) return false;
82    }
83    return true;
84}
85
86class AddrInfo {
87  public:
88    AddrInfo() : ai_(nullptr), error_(0) {}
89
90    AddrInfo(const char* node, const char* service, const addrinfo& hints) : ai_(nullptr) {
91        init(node, service, hints);
92    }
93
94    AddrInfo(const char* node, const char* service) : ai_(nullptr) {
95        init(node, service);
96    }
97
98    ~AddrInfo() { clear(); }
99
100    int init(const char* node, const char* service, const addrinfo& hints) {
101        clear();
102        error_ = getaddrinfo(node, service, &hints, &ai_);
103        return error_;
104    }
105
106    int init(const char* node, const char* service) {
107        clear();
108        error_ = getaddrinfo(node, service, nullptr, &ai_);
109        return error_;
110    }
111
112    void clear() {
113        if (ai_ != nullptr) {
114            freeaddrinfo(ai_);
115            ai_ = nullptr;
116            error_ = 0;
117        }
118    }
119
120    const addrinfo& operator*() const { return *ai_; }
121    const addrinfo* get() const { return ai_; }
122    const addrinfo* operator&() const { return ai_; }
123    int error() const { return error_; }
124
125  private:
126    addrinfo* ai_;
127    int error_;
128};
129
130class ResolverTest : public ::testing::Test, public DnsResponderClient {
131private:
132    int mOriginalMetricsLevel;
133
134protected:
135    virtual void SetUp() {
136        // Ensure resolutions go via proxy.
137        DnsResponderClient::SetUp();
138
139        // If DNS reporting is off: turn it on so we run through everything.
140        auto rv = mNetdSrv->getMetricsReportingLevel(&mOriginalMetricsLevel);
141        ASSERT_TRUE(rv.isOk());
142        if (mOriginalMetricsLevel != INetdEventListener::REPORTING_LEVEL_FULL) {
143            rv = mNetdSrv->setMetricsReportingLevel(INetdEventListener::REPORTING_LEVEL_FULL);
144            ASSERT_TRUE(rv.isOk());
145        }
146    }
147
148    virtual void TearDown() {
149        if (mOriginalMetricsLevel != INetdEventListener::REPORTING_LEVEL_FULL) {
150            auto rv = mNetdSrv->setMetricsReportingLevel(mOriginalMetricsLevel);
151            ASSERT_TRUE(rv.isOk());
152        }
153
154        DnsResponderClient::TearDown();
155    }
156
157    bool GetResolverInfo(std::vector<std::string>* servers, std::vector<std::string>* domains,
158            __res_params* params, std::vector<ResolverStats>* stats) {
159        using android::net::INetd;
160        std::vector<int32_t> params32;
161        std::vector<int32_t> stats32;
162        auto rv = mNetdSrv->getResolverInfo(TEST_NETID, servers, domains, &params32, &stats32);
163        if (!rv.isOk() || params32.size() != INetd::RESOLVER_PARAMS_COUNT) {
164            return false;
165        }
166        *params = __res_params {
167            .sample_validity = static_cast<uint16_t>(
168                    params32[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY]),
169            .success_threshold = static_cast<uint8_t>(
170                    params32[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD]),
171            .min_samples = static_cast<uint8_t>(
172                    params32[INetd::RESOLVER_PARAMS_MIN_SAMPLES]),
173            .max_samples = static_cast<uint8_t>(
174                    params32[INetd::RESOLVER_PARAMS_MAX_SAMPLES])
175        };
176        return ResolverStats::decodeAll(stats32, stats);
177    }
178
179    std::string ToString(const hostent* he) const {
180        if (he == nullptr) return "<null>";
181        char buffer[INET6_ADDRSTRLEN];
182        if (!inet_ntop(he->h_addrtype, he->h_addr_list[0], buffer, sizeof(buffer))) {
183            return "<invalid>";
184        }
185        return buffer;
186    }
187
188    std::string ToString(const addrinfo* ai) const {
189        if (!ai)
190            return "<null>";
191        for (const auto* aip = ai ; aip != nullptr ; aip = aip->ai_next) {
192            char host[NI_MAXHOST];
193            int rv = getnameinfo(aip->ai_addr, aip->ai_addrlen, host, sizeof(host), nullptr, 0,
194                    NI_NUMERICHOST);
195            if (rv != 0)
196                return gai_strerror(rv);
197            return host;
198        }
199        return "<invalid>";
200    }
201
202    size_t GetNumQueries(const test::DNSResponder& dns, const char* name) const {
203        auto queries = dns.queries();
204        size_t found = 0;
205        for (const auto& p : queries) {
206            if (p.first == name) {
207                ++found;
208            }
209        }
210        return found;
211    }
212
213    size_t GetNumQueriesForType(const test::DNSResponder& dns, ns_type type,
214            const char* name) const {
215        auto queries = dns.queries();
216        size_t found = 0;
217        for (const auto& p : queries) {
218            if (p.second == type && p.first == name) {
219                ++found;
220            }
221        }
222        return found;
223    }
224
225    void RunGetAddrInfoStressTest_Binder(unsigned num_hosts, unsigned num_threads,
226            unsigned num_queries) {
227        std::vector<std::string> domains = { "example.com" };
228        std::vector<std::unique_ptr<test::DNSResponder>> dns;
229        std::vector<std::string> servers;
230        std::vector<DnsResponderClient::Mapping> mappings;
231        ASSERT_NO_FATAL_FAILURE(SetupMappings(num_hosts, domains, &mappings));
232        ASSERT_NO_FATAL_FAILURE(SetupDNSServers(MAXNS, mappings, &dns, &servers));
233
234        ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams_Binder));
235
236        auto t0 = std::chrono::steady_clock::now();
237        std::vector<std::thread> threads(num_threads);
238        for (std::thread& thread : threads) {
239           thread = std::thread([this, &mappings, num_queries]() {
240                for (unsigned i = 0 ; i < num_queries ; ++i) {
241                    uint32_t ofs = arc4random_uniform(mappings.size());
242                    auto& mapping = mappings[ofs];
243                    addrinfo* result = nullptr;
244                    int rv = getaddrinfo(mapping.host.c_str(), nullptr, nullptr, &result);
245                    EXPECT_EQ(0, rv) << "error [" << rv << "] " << gai_strerror(rv);
246                    if (rv == 0) {
247                        std::string result_str = ToString(result);
248                        EXPECT_TRUE(result_str == mapping.ip4 || result_str == mapping.ip6)
249                            << "result='" << result_str << "', ip4='" << mapping.ip4
250                            << "', ip6='" << mapping.ip6;
251                    }
252                    if (result) {
253                        freeaddrinfo(result);
254                        result = nullptr;
255                    }
256                }
257            });
258        }
259
260        for (std::thread& thread : threads) {
261            thread.join();
262        }
263        auto t1 = std::chrono::steady_clock::now();
264        ALOGI("%u hosts, %u threads, %u queries, %Es", num_hosts, num_threads, num_queries,
265                std::chrono::duration<double>(t1 - t0).count());
266        ASSERT_NO_FATAL_FAILURE(ShutdownDNSServers(&dns));
267    }
268
269    const std::vector<std::string> mDefaultSearchDomains = { "example.com" };
270    // <sample validity in s> <success threshold in percent> <min samples> <max samples>
271    const std::string mDefaultParams = "300 25 8 8";
272    const std::vector<int> mDefaultParams_Binder = { 300, 25, 8, 8 };
273};
274
275TEST_F(ResolverTest, GetHostByName) {
276    const char* listen_addr = "127.0.0.3";
277    const char* listen_srv = "53";
278    const char* host_name = "hello.example.com.";
279    const char *nonexistent_host_name = "nonexistent.example.com.";
280    test::DNSResponder dns(listen_addr, listen_srv, 250, ns_rcode::ns_r_servfail, 1.0);
281    dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
282    ASSERT_TRUE(dns.startServer());
283    std::vector<std::string> servers = { listen_addr };
284    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
285
286    const hostent* result;
287
288    dns.clearQueries();
289    result = gethostbyname("nonexistent");
290    EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, nonexistent_host_name));
291    ASSERT_TRUE(result == nullptr);
292    ASSERT_EQ(HOST_NOT_FOUND, h_errno);
293
294    dns.clearQueries();
295    result = gethostbyname("hello");
296    EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
297    ASSERT_FALSE(result == nullptr);
298    ASSERT_EQ(4, result->h_length);
299    ASSERT_FALSE(result->h_addr_list[0] == nullptr);
300    EXPECT_EQ("1.2.3.3", ToString(result));
301    EXPECT_TRUE(result->h_addr_list[1] == nullptr);
302
303    dns.stopServer();
304}
305
306TEST_F(ResolverTest, TestBinderSerialization) {
307    using android::net::INetd;
308    std::vector<int> params_offsets = {
309        INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY,
310        INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD,
311        INetd::RESOLVER_PARAMS_MIN_SAMPLES,
312        INetd::RESOLVER_PARAMS_MAX_SAMPLES
313    };
314    int size = static_cast<int>(params_offsets.size());
315    EXPECT_EQ(size, INetd::RESOLVER_PARAMS_COUNT);
316    std::sort(params_offsets.begin(), params_offsets.end());
317    for (int i = 0 ; i < size ; ++i) {
318        EXPECT_EQ(params_offsets[i], i);
319    }
320}
321
322TEST_F(ResolverTest, GetHostByName_Binder) {
323    using android::net::INetd;
324
325    std::vector<std::string> domains = { "example.com" };
326    std::vector<std::unique_ptr<test::DNSResponder>> dns;
327    std::vector<std::string> servers;
328    std::vector<Mapping> mappings;
329    ASSERT_NO_FATAL_FAILURE(SetupMappings(1, domains, &mappings));
330    ASSERT_NO_FATAL_FAILURE(SetupDNSServers(4, mappings, &dns, &servers));
331    ASSERT_EQ(1U, mappings.size());
332    const Mapping& mapping = mappings[0];
333
334    ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams_Binder));
335
336    const hostent* result = gethostbyname(mapping.host.c_str());
337    size_t total_queries = std::accumulate(dns.begin(), dns.end(), 0,
338            [this, &mapping](size_t total, auto& d) {
339                return total + GetNumQueriesForType(*d, ns_type::ns_t_a, mapping.entry.c_str());
340            });
341
342    EXPECT_LE(1U, total_queries);
343    ASSERT_FALSE(result == nullptr);
344    ASSERT_EQ(4, result->h_length);
345    ASSERT_FALSE(result->h_addr_list[0] == nullptr);
346    EXPECT_EQ(mapping.ip4, ToString(result));
347    EXPECT_TRUE(result->h_addr_list[1] == nullptr);
348
349    std::vector<std::string> res_servers;
350    std::vector<std::string> res_domains;
351    __res_params res_params;
352    std::vector<ResolverStats> res_stats;
353    ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_params, &res_stats));
354    EXPECT_EQ(servers.size(), res_servers.size());
355    EXPECT_EQ(domains.size(), res_domains.size());
356    ASSERT_EQ(INetd::RESOLVER_PARAMS_COUNT, mDefaultParams_Binder.size());
357    EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY],
358            res_params.sample_validity);
359    EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD],
360            res_params.success_threshold);
361    EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MIN_SAMPLES], res_params.min_samples);
362    EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MAX_SAMPLES], res_params.max_samples);
363    EXPECT_EQ(servers.size(), res_stats.size());
364
365    EXPECT_TRUE(UnorderedCompareArray(res_servers, servers));
366    EXPECT_TRUE(UnorderedCompareArray(res_domains, domains));
367
368    ASSERT_NO_FATAL_FAILURE(ShutdownDNSServers(&dns));
369}
370
371TEST_F(ResolverTest, GetAddrInfo) {
372    addrinfo* result = nullptr;
373
374    const char* listen_addr = "127.0.0.4";
375    const char* listen_addr2 = "127.0.0.5";
376    const char* listen_srv = "53";
377    const char* host_name = "howdy.example.com.";
378    test::DNSResponder dns(listen_addr, listen_srv, 250,
379                           ns_rcode::ns_r_servfail, 1.0);
380    dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
381    dns.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.4");
382    ASSERT_TRUE(dns.startServer());
383
384    test::DNSResponder dns2(listen_addr2, listen_srv, 250,
385                            ns_rcode::ns_r_servfail, 1.0);
386    dns2.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
387    dns2.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.4");
388    ASSERT_TRUE(dns2.startServer());
389
390
391    std::vector<std::string> servers = { listen_addr };
392    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
393    dns.clearQueries();
394    dns2.clearQueries();
395
396    EXPECT_EQ(0, getaddrinfo("howdy", nullptr, nullptr, &result));
397    size_t found = GetNumQueries(dns, host_name);
398    EXPECT_LE(1U, found);
399    // Could be A or AAAA
400    std::string result_str = ToString(result);
401    EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
402        << ", result_str='" << result_str << "'";
403    // TODO: Use ScopedAddrinfo or similar once it is available in a common header file.
404    if (result) {
405        freeaddrinfo(result);
406        result = nullptr;
407    }
408
409    // Verify that the name is cached.
410    size_t old_found = found;
411    EXPECT_EQ(0, getaddrinfo("howdy", nullptr, nullptr, &result));
412    found = GetNumQueries(dns, host_name);
413    EXPECT_LE(1U, found);
414    EXPECT_EQ(old_found, found);
415    result_str = ToString(result);
416    EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
417        << result_str;
418    if (result) {
419        freeaddrinfo(result);
420        result = nullptr;
421    }
422
423    // Change the DNS resolver, ensure that queries are still cached.
424    servers = { listen_addr2 };
425    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
426    dns.clearQueries();
427    dns2.clearQueries();
428
429    EXPECT_EQ(0, getaddrinfo("howdy", nullptr, nullptr, &result));
430    found = GetNumQueries(dns, host_name);
431    size_t found2 = GetNumQueries(dns2, host_name);
432    EXPECT_EQ(0U, found);
433    EXPECT_LE(0U, found2);
434
435    // Could be A or AAAA
436    result_str = ToString(result);
437    EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
438        << ", result_str='" << result_str << "'";
439    if (result) {
440        freeaddrinfo(result);
441        result = nullptr;
442    }
443
444    dns.stopServer();
445    dns2.stopServer();
446}
447
448TEST_F(ResolverTest, GetAddrInfoV4) {
449    addrinfo* result = nullptr;
450
451    const char* listen_addr = "127.0.0.5";
452    const char* listen_srv = "53";
453    const char* host_name = "hola.example.com.";
454    test::DNSResponder dns(listen_addr, listen_srv, 250,
455                           ns_rcode::ns_r_servfail, 1.0);
456    dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.5");
457    ASSERT_TRUE(dns.startServer());
458    std::vector<std::string> servers = { listen_addr };
459    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
460
461    addrinfo hints;
462    memset(&hints, 0, sizeof(hints));
463    hints.ai_family = AF_INET;
464    EXPECT_EQ(0, getaddrinfo("hola", nullptr, &hints, &result));
465    EXPECT_EQ(1U, GetNumQueries(dns, host_name));
466    EXPECT_EQ("1.2.3.5", ToString(result));
467    if (result) {
468        freeaddrinfo(result);
469        result = nullptr;
470    }
471}
472
473TEST_F(ResolverTest, MultidomainResolution) {
474    std::vector<std::string> searchDomains = { "example1.com", "example2.com", "example3.com" };
475    const char* listen_addr = "127.0.0.6";
476    const char* listen_srv = "53";
477    const char* host_name = "nihao.example2.com.";
478    test::DNSResponder dns(listen_addr, listen_srv, 250,
479                           ns_rcode::ns_r_servfail, 1.0);
480    dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
481    ASSERT_TRUE(dns.startServer());
482    std::vector<std::string> servers = { listen_addr };
483    ASSERT_TRUE(SetResolversForNetwork(searchDomains, servers, mDefaultParams));
484
485    dns.clearQueries();
486    const hostent* result = gethostbyname("nihao");
487    EXPECT_EQ(1U, GetNumQueriesForType(dns, ns_type::ns_t_a, host_name));
488    ASSERT_FALSE(result == nullptr);
489    ASSERT_EQ(4, result->h_length);
490    ASSERT_FALSE(result->h_addr_list[0] == nullptr);
491    EXPECT_EQ("1.2.3.3", ToString(result));
492    EXPECT_TRUE(result->h_addr_list[1] == nullptr);
493    dns.stopServer();
494}
495
496TEST_F(ResolverTest, GetAddrInfoV6_failing) {
497    addrinfo* result = nullptr;
498
499    const char* listen_addr0 = "127.0.0.7";
500    const char* listen_addr1 = "127.0.0.8";
501    const char* listen_srv = "53";
502    const char* host_name = "ohayou.example.com.";
503    test::DNSResponder dns0(listen_addr0, listen_srv, 250,
504                            ns_rcode::ns_r_servfail, 0.0);
505    test::DNSResponder dns1(listen_addr1, listen_srv, 250,
506                            ns_rcode::ns_r_servfail, 1.0);
507    dns0.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::5");
508    dns1.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::6");
509    ASSERT_TRUE(dns0.startServer());
510    ASSERT_TRUE(dns1.startServer());
511    std::vector<std::string> servers = { listen_addr0, listen_addr1 };
512    // <sample validity in s> <success threshold in percent> <min samples> <max samples>
513    unsigned sample_validity = 300;
514    int success_threshold = 25;
515    int sample_count = 8;
516    std::string params = StringPrintf("%u %d %d %d", sample_validity, success_threshold,
517            sample_count, sample_count);
518    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, params));
519
520    // Repeatedly perform resolutions for non-existing domains until MAXNSSAMPLES resolutions have
521    // reached the dns0, which is set to fail. No more requests should then arrive at that server
522    // for the next sample_lifetime seconds.
523    // TODO: This approach is implementation-dependent, change once metrics reporting is available.
524    addrinfo hints;
525    memset(&hints, 0, sizeof(hints));
526    hints.ai_family = AF_INET6;
527    for (int i = 0 ; i < sample_count ; ++i) {
528        std::string domain = StringPrintf("nonexistent%d", i);
529        getaddrinfo(domain.c_str(), nullptr, &hints, &result);
530        if (result) {
531            freeaddrinfo(result);
532            result = nullptr;
533        }
534    }
535    // Due to 100% errors for all possible samples, the server should be ignored from now on and
536    // only the second one used for all following queries, until NSSAMPLE_VALIDITY is reached.
537    dns0.clearQueries();
538    dns1.clearQueries();
539    EXPECT_EQ(0, getaddrinfo("ohayou", nullptr, &hints, &result));
540    EXPECT_EQ(0U, GetNumQueries(dns0, host_name));
541    EXPECT_EQ(1U, GetNumQueries(dns1, host_name));
542    if (result) {
543        freeaddrinfo(result);
544        result = nullptr;
545    }
546}
547
548TEST_F(ResolverTest, GetAddrInfoV6_concurrent) {
549    const char* listen_addr0 = "127.0.0.9";
550    const char* listen_addr1 = "127.0.0.10";
551    const char* listen_addr2 = "127.0.0.11";
552    const char* listen_srv = "53";
553    const char* host_name = "konbanha.example.com.";
554    test::DNSResponder dns0(listen_addr0, listen_srv, 250,
555                            ns_rcode::ns_r_servfail, 1.0);
556    test::DNSResponder dns1(listen_addr1, listen_srv, 250,
557                            ns_rcode::ns_r_servfail, 1.0);
558    test::DNSResponder dns2(listen_addr2, listen_srv, 250,
559                            ns_rcode::ns_r_servfail, 1.0);
560    dns0.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::5");
561    dns1.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::6");
562    dns2.addMapping(host_name, ns_type::ns_t_aaaa, "2001:db8::7");
563    ASSERT_TRUE(dns0.startServer());
564    ASSERT_TRUE(dns1.startServer());
565    ASSERT_TRUE(dns2.startServer());
566    const std::vector<std::string> servers = { listen_addr0, listen_addr1, listen_addr2 };
567    std::vector<std::thread> threads(10);
568    for (std::thread& thread : threads) {
569       thread = std::thread([this, &servers]() {
570            unsigned delay = arc4random_uniform(1*1000*1000); // <= 1s
571            usleep(delay);
572            std::vector<std::string> serverSubset;
573            for (const auto& server : servers) {
574                if (arc4random_uniform(2)) {
575                    serverSubset.push_back(server);
576                }
577            }
578            if (serverSubset.empty()) serverSubset = servers;
579            ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, serverSubset,
580                    mDefaultParams));
581            addrinfo hints;
582            memset(&hints, 0, sizeof(hints));
583            hints.ai_family = AF_INET6;
584            addrinfo* result = nullptr;
585            int rv = getaddrinfo("konbanha", nullptr, &hints, &result);
586            EXPECT_EQ(0, rv) << "error [" << rv << "] " << gai_strerror(rv);
587            if (result) {
588                freeaddrinfo(result);
589                result = nullptr;
590            }
591        });
592    }
593    for (std::thread& thread : threads) {
594        thread.join();
595    }
596}
597
598TEST_F(ResolverTest, GetAddrInfoStressTest_Binder_100) {
599    const unsigned num_hosts = 100;
600    const unsigned num_threads = 100;
601    const unsigned num_queries = 100;
602    ASSERT_NO_FATAL_FAILURE(RunGetAddrInfoStressTest_Binder(num_hosts, num_threads, num_queries));
603}
604
605TEST_F(ResolverTest, GetAddrInfoStressTest_Binder_100000) {
606    const unsigned num_hosts = 100000;
607    const unsigned num_threads = 100;
608    const unsigned num_queries = 100;
609    ASSERT_NO_FATAL_FAILURE(RunGetAddrInfoStressTest_Binder(num_hosts, num_threads, num_queries));
610}
611
612TEST_F(ResolverTest, EmptySetup) {
613    using android::net::INetd;
614    std::vector<std::string> servers;
615    std::vector<std::string> domains;
616    ASSERT_TRUE(SetResolversForNetwork(servers, domains, mDefaultParams_Binder));
617    std::vector<std::string> res_servers;
618    std::vector<std::string> res_domains;
619    __res_params res_params;
620    std::vector<ResolverStats> res_stats;
621    ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_params, &res_stats));
622    EXPECT_EQ(0U, res_servers.size());
623    EXPECT_EQ(0U, res_domains.size());
624    ASSERT_EQ(INetd::RESOLVER_PARAMS_COUNT, mDefaultParams_Binder.size());
625    EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SAMPLE_VALIDITY],
626            res_params.sample_validity);
627    EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_SUCCESS_THRESHOLD],
628            res_params.success_threshold);
629    EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MIN_SAMPLES], res_params.min_samples);
630    EXPECT_EQ(mDefaultParams_Binder[INetd::RESOLVER_PARAMS_MAX_SAMPLES], res_params.max_samples);
631}
632
633TEST_F(ResolverTest, SearchPathChange) {
634    addrinfo* result = nullptr;
635
636    const char* listen_addr = "127.0.0.13";
637    const char* listen_srv = "53";
638    const char* host_name1 = "test13.domain1.org.";
639    const char* host_name2 = "test13.domain2.org.";
640    test::DNSResponder dns(listen_addr, listen_srv, 250,
641                           ns_rcode::ns_r_servfail, 1.0);
642    dns.addMapping(host_name1, ns_type::ns_t_aaaa, "2001:db8::13");
643    dns.addMapping(host_name2, ns_type::ns_t_aaaa, "2001:db8::1:13");
644    ASSERT_TRUE(dns.startServer());
645    std::vector<std::string> servers = { listen_addr };
646    std::vector<std::string> domains = { "domain1.org" };
647    ASSERT_TRUE(SetResolversForNetwork(domains, servers, mDefaultParams));
648
649    addrinfo hints;
650    memset(&hints, 0, sizeof(hints));
651    hints.ai_family = AF_INET6;
652    EXPECT_EQ(0, getaddrinfo("test13", nullptr, &hints, &result));
653    EXPECT_EQ(1U, dns.queries().size());
654    EXPECT_EQ(1U, GetNumQueries(dns, host_name1));
655    EXPECT_EQ("2001:db8::13", ToString(result));
656    if (result) freeaddrinfo(result);
657
658    // Test that changing the domain search path on its own works.
659    domains = { "domain2.org" };
660    ASSERT_TRUE(SetResolversForNetwork(domains, servers, mDefaultParams));
661    dns.clearQueries();
662
663    EXPECT_EQ(0, getaddrinfo("test13", nullptr, &hints, &result));
664    EXPECT_EQ(1U, dns.queries().size());
665    EXPECT_EQ(1U, GetNumQueries(dns, host_name2));
666    EXPECT_EQ("2001:db8::1:13", ToString(result));
667    if (result) freeaddrinfo(result);
668}
669
670TEST_F(ResolverTest, MaxServerPrune_Binder) {
671    using android::net::INetd;
672
673    std::vector<std::string> domains = { "example.com" };
674    std::vector<std::unique_ptr<test::DNSResponder>> dns;
675    std::vector<std::string> servers;
676    std::vector<Mapping> mappings;
677    ASSERT_NO_FATAL_FAILURE(SetupMappings(1, domains, &mappings));
678    ASSERT_NO_FATAL_FAILURE(SetupDNSServers(MAXNS + 1, mappings, &dns, &servers));
679
680    ASSERT_TRUE(SetResolversForNetwork(servers, domains,  mDefaultParams_Binder));
681
682    std::vector<std::string> res_servers;
683    std::vector<std::string> res_domains;
684    __res_params res_params;
685    std::vector<ResolverStats> res_stats;
686    ASSERT_TRUE(GetResolverInfo(&res_servers, &res_domains, &res_params, &res_stats));
687    EXPECT_EQ(static_cast<size_t>(MAXNS), res_servers.size());
688
689    ASSERT_NO_FATAL_FAILURE(ShutdownDNSServers(&dns));
690}
691
692static std::string base64Encode(const std::vector<uint8_t>& input) {
693    size_t out_len;
694    EXPECT_EQ(1, EVP_EncodedLength(&out_len, input.size()));
695    // out_len includes the trailing NULL.
696    uint8_t output_bytes[out_len];
697    EXPECT_EQ(out_len - 1, EVP_EncodeBlock(output_bytes, input.data(), input.size()));
698    return std::string(reinterpret_cast<char*>(output_bytes));
699}
700
701// Test what happens if the specified TLS server is nonexistent.
702TEST_F(ResolverTest, GetHostByName_TlsMissing) {
703    const char* listen_addr = "127.0.0.3";
704    const char* listen_srv = "53";
705    const char* host_name = "tlsmissing.example.com.";
706    test::DNSResponder dns(listen_addr, listen_srv, 250, ns_rcode::ns_r_servfail, 1.0);
707    dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.3");
708    ASSERT_TRUE(dns.startServer());
709    std::vector<std::string> servers = { listen_addr };
710
711    // There's nothing listening on this address, so validation will either fail or
712    /// hang.  Either way, queries will continue to flow to the DNSResponder.
713    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", {});
714    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
715
716    const hostent* result;
717
718    result = gethostbyname("tlsmissing");
719    ASSERT_FALSE(result == nullptr);
720    EXPECT_EQ("1.2.3.3", ToString(result));
721
722    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
723    dns.stopServer();
724}
725
726// Test what happens if the specified TLS server replies with garbage.
727TEST_F(ResolverTest, GetHostByName_TlsBroken) {
728    const char* listen_addr = "127.0.0.3";
729    const char* listen_srv = "53";
730    const char* host_name1 = "tlsbroken1.example.com.";
731    const char* host_name2 = "tlsbroken2.example.com.";
732    test::DNSResponder dns(listen_addr, listen_srv, 250, ns_rcode::ns_r_servfail, 1.0);
733    dns.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
734    dns.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
735    ASSERT_TRUE(dns.startServer());
736    std::vector<std::string> servers = { listen_addr };
737
738    // Bind the specified private DNS socket but don't respond to any client sockets yet.
739    int s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
740    ASSERT_TRUE(s >= 0);
741    struct sockaddr_in tlsServer = {
742        .sin_family = AF_INET,
743        .sin_port = htons(853),
744    };
745    ASSERT_TRUE(inet_pton(AF_INET, listen_addr, &tlsServer.sin_addr));
746    ASSERT_FALSE(bind(s, reinterpret_cast<struct sockaddr*>(&tlsServer), sizeof(tlsServer)));
747    ASSERT_FALSE(listen(s, 1));
748
749    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", {});
750    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
751
752    // SetResolversForNetwork should have triggered a validation connection to this address.
753    struct sockaddr_storage cliaddr;
754    socklen_t sin_size = sizeof(cliaddr);
755    int new_fd = accept(s, reinterpret_cast<struct sockaddr *>(&cliaddr), &sin_size);
756    ASSERT_TRUE(new_fd > 0);
757
758    // We've received the new file descriptor but not written to it or closed, so the
759    // validation is still pending.  Queries should still flow correctly because the
760    // server is not used until validation succeeds.
761    const hostent* result;
762    result = gethostbyname("tlsbroken1");
763    ASSERT_FALSE(result == nullptr);
764    EXPECT_EQ("1.2.3.1", ToString(result));
765
766    // Now we cause the validation to fail.
767    std::string garbage = "definitely not a valid TLS ServerHello";
768    write(new_fd, garbage.data(), garbage.size());
769    close(new_fd);
770
771    // Validation failure shouldn't interfere with lookups, because lookups won't be sent
772    // to the TLS server unless validation succeeds.
773    result = gethostbyname("tlsbroken2");
774    ASSERT_FALSE(result == nullptr);
775    EXPECT_EQ("1.2.3.2", ToString(result));
776
777    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
778    dns.stopServer();
779    close(s);
780}
781
782TEST_F(ResolverTest, GetHostByName_Tls) {
783    const char* listen_addr = "127.0.0.3";
784    const char* listen_udp = "53";
785    const char* listen_tls = "853";
786    const char* host_name1 = "tls1.example.com.";
787    const char* host_name2 = "tls2.example.com.";
788    const char* host_name3 = "tls3.example.com.";
789    test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
790    dns.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
791    dns.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
792    dns.addMapping(host_name3, ns_type::ns_t_a, "1.2.3.3");
793    ASSERT_TRUE(dns.startServer());
794    std::vector<std::string> servers = { listen_addr };
795
796    test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
797    ASSERT_TRUE(tls.startServer());
798    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "", {});
799    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
800
801    const hostent* result;
802
803    // Wait for validation to complete.
804    EXPECT_TRUE(tls.waitForQueries(1, 5000));
805
806    result = gethostbyname("tls1");
807    ASSERT_FALSE(result == nullptr);
808    EXPECT_EQ("1.2.3.1", ToString(result));
809
810    // Wait for query to get counted.
811    EXPECT_TRUE(tls.waitForQueries(2, 5000));
812
813    // Stop the TLS server.  Since it's already been validated, queries will
814    // continue to be routed to it.
815    tls.stopServer();
816
817    result = gethostbyname("tls2");
818    EXPECT_TRUE(result == nullptr);
819    EXPECT_EQ(HOST_NOT_FOUND, h_errno);
820
821    // Remove the TLS server setting.  Queries should now be routed to the
822    // UDP endpoint.
823    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
824
825    result = gethostbyname("tls3");
826    ASSERT_FALSE(result == nullptr);
827    EXPECT_EQ("1.2.3.3", ToString(result));
828
829    dns.stopServer();
830}
831
832TEST_F(ResolverTest, GetHostByName_TlsFingerprint) {
833    const char* listen_addr = "127.0.0.3";
834    const char* listen_udp = "53";
835    const char* listen_tls = "853";
836    const char* host_name = "tlsfingerprint.example.com.";
837    test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
838    dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
839    ASSERT_TRUE(dns.startServer());
840    std::vector<std::string> servers = { listen_addr };
841
842    test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
843    ASSERT_TRUE(tls.startServer());
844    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
845            { base64Encode(tls.fingerprint()) });
846    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
847
848    const hostent* result;
849
850    // Wait for validation to complete.
851    EXPECT_TRUE(tls.waitForQueries(1, 5000));
852
853    result = gethostbyname("tlsfingerprint");
854    ASSERT_FALSE(result == nullptr);
855    EXPECT_EQ("1.2.3.1", ToString(result));
856
857    // Wait for query to get counted.
858    EXPECT_TRUE(tls.waitForQueries(2, 5000));
859
860    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
861    tls.stopServer();
862    dns.stopServer();
863}
864
865TEST_F(ResolverTest, GetHostByName_BadTlsFingerprint) {
866    const char* listen_addr = "127.0.0.3";
867    const char* listen_udp = "53";
868    const char* listen_tls = "853";
869    const char* host_name = "badtlsfingerprint.example.com.";
870    test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
871    dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
872    ASSERT_TRUE(dns.startServer());
873    std::vector<std::string> servers = { listen_addr };
874
875    test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
876    ASSERT_TRUE(tls.startServer());
877    std::vector<uint8_t> bad_fingerprint = tls.fingerprint();
878    bad_fingerprint[5] += 1;  // Corrupt the fingerprint.
879    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
880            { base64Encode(bad_fingerprint) });
881    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
882
883    const hostent* result;
884
885    // The initial validation should fail at the fingerprint check before
886    // issuing a query.
887    EXPECT_FALSE(tls.waitForQueries(1, 500));
888
889    result = gethostbyname("badtlsfingerprint");
890    ASSERT_FALSE(result == nullptr);
891    EXPECT_EQ("1.2.3.1", ToString(result));
892
893    // The query should have bypassed the TLS frontend, because validation
894    // failed.
895    EXPECT_FALSE(tls.waitForQueries(1, 500));
896
897    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
898    tls.stopServer();
899    dns.stopServer();
900}
901
902// Test that we can pass two different fingerprints, and connection succeeds as long as
903// at least one of them matches the server.
904TEST_F(ResolverTest, GetHostByName_TwoTlsFingerprints) {
905    const char* listen_addr = "127.0.0.3";
906    const char* listen_udp = "53";
907    const char* listen_tls = "853";
908    const char* host_name = "twotlsfingerprints.example.com.";
909    test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
910    dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.1");
911    ASSERT_TRUE(dns.startServer());
912    std::vector<std::string> servers = { listen_addr };
913
914    test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
915    ASSERT_TRUE(tls.startServer());
916    std::vector<uint8_t> bad_fingerprint = tls.fingerprint();
917    bad_fingerprint[5] += 1;  // Corrupt the fingerprint.
918    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
919            { base64Encode(bad_fingerprint), base64Encode(tls.fingerprint()) });
920    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
921
922    const hostent* result;
923
924    // Wait for validation to complete.
925    EXPECT_TRUE(tls.waitForQueries(1, 5000));
926
927    result = gethostbyname("twotlsfingerprints");
928    ASSERT_FALSE(result == nullptr);
929    EXPECT_EQ("1.2.3.1", ToString(result));
930
931    // Wait for query to get counted.
932    EXPECT_TRUE(tls.waitForQueries(2, 5000));
933
934    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
935    tls.stopServer();
936    dns.stopServer();
937}
938
939TEST_F(ResolverTest, GetHostByName_TlsFingerprintGoesBad) {
940    const char* listen_addr = "127.0.0.3";
941    const char* listen_udp = "53";
942    const char* listen_tls = "853";
943    const char* host_name1 = "tlsfingerprintgoesbad1.example.com.";
944    const char* host_name2 = "tlsfingerprintgoesbad2.example.com.";
945    test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
946    dns.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
947    dns.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
948    ASSERT_TRUE(dns.startServer());
949    std::vector<std::string> servers = { listen_addr };
950
951    test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
952    ASSERT_TRUE(tls.startServer());
953    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
954            { base64Encode(tls.fingerprint()) });
955    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
956
957    const hostent* result;
958
959    // Wait for validation to complete.
960    EXPECT_TRUE(tls.waitForQueries(1, 5000));
961
962    result = gethostbyname("tlsfingerprintgoesbad1");
963    ASSERT_FALSE(result == nullptr);
964    EXPECT_EQ("1.2.3.1", ToString(result));
965
966    // Wait for query to get counted.
967    EXPECT_TRUE(tls.waitForQueries(2, 5000));
968
969    // Restart the TLS server.  This will generate a new certificate whose fingerprint
970    // no longer matches the stored fingerprint.
971    tls.stopServer();
972    tls.startServer();
973
974    result = gethostbyname("tlsfingerprintgoesbad2");
975    ASSERT_TRUE(result == nullptr);
976    EXPECT_EQ(HOST_NOT_FOUND, h_errno);
977
978    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
979    tls.stopServer();
980    dns.stopServer();
981}
982
983TEST_F(ResolverTest, GetHostByName_TlsFailover) {
984    const char* listen_addr1 = "127.0.0.3";
985    const char* listen_addr2 = "127.0.0.4";
986    const char* listen_udp = "53";
987    const char* listen_tls = "853";
988    const char* host_name1 = "tlsfailover1.example.com.";
989    const char* host_name2 = "tlsfailover2.example.com.";
990    test::DNSResponder dns1(listen_addr1, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
991    test::DNSResponder dns2(listen_addr2, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
992    dns1.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.1");
993    dns1.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.2");
994    dns2.addMapping(host_name1, ns_type::ns_t_a, "1.2.3.3");
995    dns2.addMapping(host_name2, ns_type::ns_t_a, "1.2.3.4");
996    ASSERT_TRUE(dns1.startServer());
997    ASSERT_TRUE(dns2.startServer());
998    std::vector<std::string> servers = { listen_addr1, listen_addr2 };
999
1000    test::DnsTlsFrontend tls1(listen_addr1, listen_tls, listen_addr1, listen_udp);
1001    test::DnsTlsFrontend tls2(listen_addr2, listen_tls, listen_addr2, listen_udp);
1002    ASSERT_TRUE(tls1.startServer());
1003    ASSERT_TRUE(tls2.startServer());
1004    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr1, 853, "SHA-256",
1005            { base64Encode(tls1.fingerprint()) });
1006    rv = mNetdSrv->addPrivateDnsServer(listen_addr2, 853, "SHA-256",
1007            { base64Encode(tls2.fingerprint()) });
1008    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
1009
1010    const hostent* result;
1011
1012    // Wait for validation to complete.
1013    EXPECT_TRUE(tls1.waitForQueries(1, 5000));
1014    EXPECT_TRUE(tls2.waitForQueries(1, 5000));
1015
1016    result = gethostbyname("tlsfailover1");
1017    ASSERT_FALSE(result == nullptr);
1018    EXPECT_EQ("1.2.3.1", ToString(result));
1019
1020    // Wait for query to get counted.
1021    EXPECT_TRUE(tls1.waitForQueries(2, 5000));
1022    // No new queries should have reached tls2.
1023    EXPECT_EQ(1, tls2.queries());
1024
1025    // Stop tls1.  Subsequent queries should attempt to reach tls1, fail, and retry to tls2.
1026    tls1.stopServer();
1027
1028    result = gethostbyname("tlsfailover2");
1029    EXPECT_EQ("1.2.3.4", ToString(result));
1030
1031    // Wait for query to get counted.
1032    EXPECT_TRUE(tls2.waitForQueries(2, 5000));
1033
1034    // No additional queries should have reached the insecure servers.
1035    EXPECT_EQ(2U, dns1.queries().size());
1036    EXPECT_EQ(2U, dns2.queries().size());
1037
1038    rv = mNetdSrv->removePrivateDnsServer(listen_addr1);
1039    rv = mNetdSrv->removePrivateDnsServer(listen_addr2);
1040    tls2.stopServer();
1041    dns1.stopServer();
1042    dns2.stopServer();
1043}
1044
1045TEST_F(ResolverTest, GetAddrInfo_Tls) {
1046    const char* listen_addr = "127.0.0.3";
1047    const char* listen_udp = "53";
1048    const char* listen_tls = "853";
1049    const char* host_name = "addrinfotls.example.com.";
1050    test::DNSResponder dns(listen_addr, listen_udp, 250, ns_rcode::ns_r_servfail, 1.0);
1051    dns.addMapping(host_name, ns_type::ns_t_a, "1.2.3.4");
1052    dns.addMapping(host_name, ns_type::ns_t_aaaa, "::1.2.3.4");
1053    ASSERT_TRUE(dns.startServer());
1054    std::vector<std::string> servers = { listen_addr };
1055
1056    test::DnsTlsFrontend tls(listen_addr, listen_tls, listen_addr, listen_udp);
1057    ASSERT_TRUE(tls.startServer());
1058    auto rv = mNetdSrv->addPrivateDnsServer(listen_addr, 853, "SHA-256",
1059            { base64Encode(tls.fingerprint()) });
1060    ASSERT_TRUE(SetResolversForNetwork(mDefaultSearchDomains, servers, mDefaultParams));
1061
1062    // Wait for validation to complete.
1063    EXPECT_TRUE(tls.waitForQueries(1, 5000));
1064
1065    dns.clearQueries();
1066    addrinfo* result = nullptr;
1067    EXPECT_EQ(0, getaddrinfo("addrinfotls", nullptr, nullptr, &result));
1068    size_t found = GetNumQueries(dns, host_name);
1069    EXPECT_LE(1U, found);
1070    // Could be A or AAAA
1071    std::string result_str = ToString(result);
1072    EXPECT_TRUE(result_str == "1.2.3.4" || result_str == "::1.2.3.4")
1073        << ", result_str='" << result_str << "'";
1074    // TODO: Use ScopedAddrinfo or similar once it is available in a common header file.
1075    if (result) {
1076        freeaddrinfo(result);
1077        result = nullptr;
1078    }
1079    // Wait for both A and AAAA queries to get counted.
1080    EXPECT_TRUE(tls.waitForQueries(3, 5000));
1081
1082    rv = mNetdSrv->removePrivateDnsServer(listen_addr);
1083    tls.stopServer();
1084    dns.stopServer();
1085}
1086