1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "DnsTlsTransport"
18//#define LOG_NDEBUG 0
19
20#include "dns/DnsTlsTransport.h"
21
22#include <arpa/inet.h>
23#include <arpa/nameser.h>
24
25#include "dns/DnsTlsServer.h"
26#include "dns/DnsTlsSocketFactory.h"
27#include "dns/IDnsTlsSocketFactory.h"
28
29#include "log/log.h"
30#include "Fwmark.h"
31#undef ADD  // already defined in nameser.h
32#include "NetdConstants.h"
33#include "Permission.h"
34
35namespace android {
36namespace net {
37
38std::future<DnsTlsTransport::Result> DnsTlsTransport::query(const netdutils::Slice query) {
39    std::lock_guard<std::mutex> guard(mLock);
40
41    auto record = mQueries.recordQuery(query);
42    if (!record) {
43        return std::async(std::launch::deferred, []{
44            return (Result) { .code = Response::internal_error };
45        });
46    }
47
48    if (!mSocket) {
49        ALOGV("No socket for query.  Opening socket and sending.");
50        doConnect();
51    } else {
52        sendQuery(record->query);
53    }
54
55    return std::move(record->result);
56}
57
58bool DnsTlsTransport::sendQuery(const DnsTlsQueryMap::Query q) {
59    // Strip off the ID number and send the new ID instead.
60    bool sent = mSocket->query(q.newId, netdutils::drop(q.query, 2));
61    if (sent) {
62        mQueries.markTried(q.newId);
63    }
64    return sent;
65}
66
67void DnsTlsTransport::doConnect() {
68    ALOGV("Constructing new socket");
69    mSocket = mFactory->createDnsTlsSocket(mServer, mMark, this, &mCache);
70
71    if (mSocket) {
72        auto queries = mQueries.getAll();
73        ALOGV("Initialization succeeded.  Reissuing %zu queries.", queries.size());
74        for(auto& q : queries) {
75            if (!sendQuery(q)) {
76                break;
77            }
78        }
79    } else {
80        ALOGV("Initialization failed.");
81        mSocket.reset();
82        ALOGV("Failing all pending queries.");
83        mQueries.clear();
84    }
85}
86
87void DnsTlsTransport::onResponse(std::vector<uint8_t> response) {
88    mQueries.onResponse(std::move(response));
89}
90
91void DnsTlsTransport::onClosed() {
92    std::lock_guard<std::mutex> guard(mLock);
93    if (mClosing) {
94        return;
95    }
96    // Move remaining operations to a new thread.
97    // This is necessary because
98    // 1. onClosed is currently running on a thread that blocks mSocket's destructor
99    // 2. doReconnect will call that destructor
100    if (mReconnectThread) {
101        // Complete cleanup of a previous reconnect thread, if present.
102        mReconnectThread->join();
103        // Joining a thread that is trying to acquire mLock, while holding mLock,
104        // looks like it risks a deadlock.  However, a deadlock will not occur because
105        // once onClosed is called, it cannot be called again until after doReconnect
106        // acquires mLock.
107    }
108    mReconnectThread.reset(new std::thread(&DnsTlsTransport::doReconnect, this));
109}
110
111void DnsTlsTransport::doReconnect() {
112    std::lock_guard<std::mutex> guard(mLock);
113    if (mClosing) {
114        return;
115    }
116    mQueries.cleanup();
117    if (!mQueries.empty()) {
118        ALOGV("Fast reconnect to retry remaining queries");
119        doConnect();
120    } else {
121        ALOGV("No pending queries.  Going idle.");
122        mSocket.reset();
123    }
124}
125
126DnsTlsTransport::~DnsTlsTransport() {
127    ALOGV("Destructor");
128    {
129        std::lock_guard<std::mutex> guard(mLock);
130        ALOGV("Locked destruction procedure");
131        mQueries.clear();
132        mClosing = true;
133    }
134    // It's possible that a reconnect thread was spawned and waiting for mLock.
135    // It's safe for that thread to run now because mClosing is true (and mQueries is empty),
136    // but we need to wait for it to finish before allowing destruction to proceed.
137    if (mReconnectThread) {
138        ALOGV("Waiting for reconnect thread to terminate");
139        mReconnectThread->join();
140        mReconnectThread.reset();
141    }
142    // Ensure that the socket is destroyed, and can clean up its callback threads,
143    // before any of this object's fields become invalid.
144    mSocket.reset();
145    ALOGV("Destructor completed");
146}
147
148// static
149// TODO: Use this function to preheat the session cache.
150// That may require moving it to DnsTlsDispatcher.
151bool DnsTlsTransport::validate(const DnsTlsServer& server, unsigned netid) {
152    ALOGV("Beginning validation on %u", netid);
153    // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
154    // order to prove that it is actually a working DNS over TLS server.
155    static const char kDnsSafeChars[] =
156            "abcdefhijklmnopqrstuvwxyz"
157            "ABCDEFHIJKLMNOPQRSTUVWXYZ"
158            "0123456789";
159    const auto c = [](uint8_t rnd) -> uint8_t {
160        return kDnsSafeChars[(rnd % ARRAY_SIZE(kDnsSafeChars))];
161    };
162    uint8_t rnd[8];
163    arc4random_buf(rnd, ARRAY_SIZE(rnd));
164    // We could try to use res_mkquery() here, but it's basically the same.
165    uint8_t query[] = {
166        rnd[6], rnd[7],  // [0-1]   query ID
167        1, 0,  // [2-3]   flags; query[2] = 1 for recursion desired (RD).
168        0, 1,  // [4-5]   QDCOUNT (number of queries)
169        0, 0,  // [6-7]   ANCOUNT (number of answers)
170        0, 0,  // [8-9]   NSCOUNT (number of name server records)
171        0, 0,  // [10-11] ARCOUNT (number of additional records)
172        17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
173            '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
174        6, 'm', 'e', 't', 'r', 'i', 'c',
175        7, 'g', 's', 't', 'a', 't', 'i', 'c',
176        3, 'c', 'o', 'm',
177        0,  // null terminator of FQDN (root TLD)
178        0, ns_t_aaaa,  // QTYPE
179        0, ns_c_in     // QCLASS
180    };
181    const int qlen = ARRAY_SIZE(query);
182
183    // At validation time, we only know the netId, so we have to guess/compute the
184    // corresponding socket mark.
185    Fwmark fwmark;
186    fwmark.permission = PERMISSION_SYSTEM;
187    fwmark.explicitlySelected = true;
188    fwmark.protectedFromVpn = true;
189    fwmark.netId = netid;
190    unsigned mark = fwmark.intValue;
191    int replylen = 0;
192    DnsTlsSocketFactory factory;
193    DnsTlsTransport transport(server, mark, &factory);
194    auto r = transport.query(Slice(query, qlen)).get();
195    if (r.code != Response::success) {
196        ALOGV("query failed");
197        return false;
198    }
199
200    const std::vector<uint8_t>& recvbuf = r.response;
201    if (recvbuf.size() < NS_HFIXEDSZ) {
202        ALOGW("short response: %d", replylen);
203        return false;
204    }
205
206    const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
207    if (qdcount != 1) {
208        ALOGW("reply query count != 1: %d", qdcount);
209        return false;
210    }
211
212    const int ancount = (recvbuf[6] << 8) | recvbuf[7];
213    ALOGV("%u answer count: %d", netid, ancount);
214
215    // TODO: Further validate the response contents (check for valid AAAA record, ...).
216    // Note that currently, integration tests rely on this function accepting a
217    // response with zero records.
218#if 0
219    for (int i = 0; i < resplen; i++) {
220        ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
221    }
222#endif
223    return true;
224}
225
226}  // end of namespace net
227}  // end of namespace android
228