1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "dns/DnsTlsTransport.h"
18
19#include <arpa/inet.h>
20#include <arpa/nameser.h>
21#include <errno.h>
22#include <openssl/err.h>
23#include <openssl/ssl.h>
24#include <stdlib.h>
25
26#define LOG_TAG "DnsTlsTransport"
27#define DBG 0
28
29#include "log/log.h"
30#include "Fwmark.h"
31#undef ADD  // already defined in nameser.h
32#include "NetdConstants.h"
33#include "Permission.h"
34
35
36namespace android {
37namespace net {
38
39namespace {
40
41bool setNonBlocking(int fd, bool enabled) {
42    int flags = fcntl(fd, F_GETFL);
43    if (flags < 0) return false;
44
45    if (enabled) {
46        flags |= O_NONBLOCK;
47    } else {
48        flags &= ~O_NONBLOCK;
49    }
50    return (fcntl(fd, F_SETFL, flags) == 0);
51}
52
53int waitForReading(int fd) {
54    fd_set fds;
55    FD_ZERO(&fds);
56    FD_SET(fd, &fds);
57    const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr));
58    if (DBG && ret <= 0) {
59        ALOGD("select");
60    }
61    return ret;
62}
63
64int waitForWriting(int fd) {
65    fd_set fds;
66    FD_ZERO(&fds);
67    FD_SET(fd, &fds);
68    const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr));
69    if (DBG && ret <= 0) {
70        ALOGD("select");
71    }
72    return ret;
73}
74
75}  // namespace
76
77android::base::unique_fd DnsTlsTransport::makeConnectedSocket() const {
78    android::base::unique_fd fd;
79    int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
80    switch (mProtocol) {
81        case IPPROTO_TCP:
82            type |= SOCK_STREAM;
83            break;
84        default:
85            errno = EPROTONOSUPPORT;
86            return fd;
87    }
88
89    fd.reset(socket(mAddr.ss_family, type, mProtocol));
90    if (fd.get() == -1) {
91        return fd;
92    }
93
94    const socklen_t len = sizeof(mMark);
95    if (setsockopt(fd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
96        fd.reset();
97    } else if (connect(fd.get(),
98            reinterpret_cast<const struct sockaddr *>(&mAddr), sizeof(mAddr)) != 0
99        && errno != EINPROGRESS) {
100        fd.reset();
101    }
102
103    return fd;
104}
105
106bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
107    int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
108    unsigned char spki[spki_len];
109    unsigned char* temp = spki;
110    if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
111        ALOGW("SPKI length mismatch");
112        return false;
113    }
114    out->resize(SHA256_SIZE);
115    unsigned int digest_len = 0;
116    int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
117    if (ret != 1) {
118        ALOGW("Server cert digest extraction failed");
119        return false;
120    }
121    if (digest_len != out->size()) {
122        ALOGW("Wrong digest length: %d", digest_len);
123        return false;
124    }
125    return true;
126}
127
128SSL* DnsTlsTransport::sslConnect(int fd) {
129    if (fd < 0) {
130        ALOGD("%u makeConnectedSocket() failed with: %s", mMark, strerror(errno));
131        return nullptr;
132    }
133
134    // Set up TLS context.
135    bssl::UniquePtr<SSL_CTX> ssl_ctx(SSL_CTX_new(TLS_method()));
136    if (!SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION) ||
137        !SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_1_VERSION)) {
138        ALOGD("failed to min/max TLS versions");
139        return nullptr;
140    }
141
142    bssl::UniquePtr<SSL> ssl(SSL_new(ssl_ctx.get()));
143    bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_CLOSE));
144    SSL_set_bio(ssl.get(), bio.get(), bio.get());
145    bio.release();
146
147    if (!setNonBlocking(fd, false)) {
148        ALOGE("Failed to disable nonblocking status on DNS-over-TLS fd");
149        return nullptr;
150    }
151
152    for (;;) {
153        if (DBG) {
154            ALOGD("%u Calling SSL_connect", mMark);
155        }
156        int ret = SSL_connect(ssl.get());
157        if (DBG) {
158            ALOGD("%u SSL_connect returned %d", mMark, ret);
159        }
160        if (ret == 1) break;  // SSL handshake complete;
161
162        const int ssl_err = SSL_get_error(ssl.get(), ret);
163        switch (ssl_err) {
164            case SSL_ERROR_WANT_READ:
165                if (waitForReading(fd) != 1) {
166                    ALOGW("SSL_connect read error");
167                    return nullptr;
168                }
169                break;
170            case SSL_ERROR_WANT_WRITE:
171                if (waitForWriting(fd) != 1) {
172                    ALOGW("SSL_connect write error");
173                    return nullptr;
174                }
175                break;
176            default:
177                ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
178                return nullptr;
179        }
180    }
181
182    if (!mFingerprints.empty()) {
183        if (DBG) {
184            ALOGD("Checking DNS over TLS fingerprint");
185        }
186        // TODO: Follow the cert chain and check all the way up.
187        bssl::UniquePtr<X509> cert(SSL_get_peer_certificate(ssl.get()));
188        if (!cert) {
189            ALOGW("Server has null certificate");
190            return nullptr;
191        }
192        std::vector<uint8_t> digest;
193        if (!getSPKIDigest(cert.get(), &digest)) {
194            ALOGE("Digest computation failed");
195            return nullptr;
196        }
197
198        if (mFingerprints.count(digest) == 0) {
199            ALOGW("No matching fingerprint");
200            return nullptr;
201        }
202        if (DBG) {
203            ALOGD("DNS over TLS fingerprint is correct");
204        }
205    }
206
207    if (DBG) {
208        ALOGD("%u handshake complete", mMark);
209    }
210    return ssl.release();
211}
212
213bool DnsTlsTransport::sslWrite(int fd, SSL *ssl, const uint8_t *buffer, int len) {
214    if (DBG) {
215        ALOGD("%u Writing %d bytes", mMark, len);
216    }
217    for (;;) {
218        int ret = SSL_write(ssl, buffer, len);
219        if (ret == len) break;  // SSL write complete;
220
221        if (ret < 1) {
222            const int ssl_err = SSL_get_error(ssl, ret);
223            switch (ssl_err) {
224                case SSL_ERROR_WANT_WRITE:
225                    if (waitForWriting(fd) != 1) {
226                        if (DBG) {
227                            ALOGW("SSL_write error");
228                        }
229                        return false;
230                    }
231                    continue;
232                case 0:
233                    break;  // SSL write complete;
234                default:
235                    if (DBG) {
236                        ALOGW("SSL_write error %d", ssl_err);
237                    }
238                    return false;
239            }
240        }
241    }
242    if (DBG) {
243        ALOGD("%u Wrote %d bytes", mMark, len);
244    }
245    return true;
246}
247
248// Read exactly len bytes into buffer or fail
249bool DnsTlsTransport::sslRead(int fd, SSL *ssl, uint8_t *buffer, int len) {
250    int remaining = len;
251    while (remaining > 0) {
252        int ret = SSL_read(ssl, buffer + (len - remaining), remaining);
253        if (ret == 0) {
254            ALOGE("SSL socket closed with %i of %i bytes remaining", remaining, len);
255            return false;
256        }
257
258        if (ret < 0) {
259            const int ssl_err = SSL_get_error(ssl, ret);
260            if (ssl_err == SSL_ERROR_WANT_READ) {
261                if (waitForReading(fd) != 1) {
262                    if (DBG) {
263                        ALOGW("SSL_read error");
264                    }
265                    return false;
266                }
267                continue;
268            } else {
269                if (DBG) {
270                    ALOGW("SSL_read error %d", ssl_err);
271                }
272                return false;
273            }
274        }
275
276        remaining -= ret;
277    }
278    return true;
279}
280
281DnsTlsTransport::Response DnsTlsTransport::doQuery(const uint8_t *query, size_t qlen,
282        uint8_t *response, size_t limit, int *resplen) {
283    *resplen = 0;  // Zero indicates an error.
284
285    if (DBG) {
286        ALOGD("%u connecting TCP socket", mMark);
287    }
288    android::base::unique_fd fd(makeConnectedSocket());
289    if (DBG) {
290        ALOGD("%u connecting SSL", mMark);
291    }
292    bssl::UniquePtr<SSL> ssl(sslConnect(fd));
293    if (ssl == nullptr) {
294        if (DBG) {
295            ALOGW("%u SSL connection failed", mMark);
296        }
297        return Response::network_error;
298    }
299
300    uint8_t queryHeader[2];
301    queryHeader[0] = qlen >> 8;
302    queryHeader[1] = qlen;
303    if (!sslWrite(fd.get(), ssl.get(), queryHeader, 2)) {
304        return Response::network_error;
305    }
306    if (!sslWrite(fd.get(), ssl.get(), query, qlen)) {
307        return Response::network_error;
308    }
309    if (DBG) {
310        ALOGD("%u SSL_write complete", mMark);
311    }
312
313    uint8_t responseHeader[2];
314    if (!sslRead(fd.get(), ssl.get(), responseHeader, 2)) {
315        if (DBG) {
316            ALOGW("%u Failed to read 2-byte length header", mMark);
317        }
318        return Response::network_error;
319    }
320    const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
321    if (DBG) {
322        ALOGD("%u Expecting response of size %i", mMark, responseSize);
323    }
324    if (responseSize > limit) {
325        ALOGE("%u Response doesn't fit in output buffer: %i", mMark, responseSize);
326        return Response::limit_error;
327    }
328    if (!sslRead(fd.get(), ssl.get(), response, responseSize)) {
329        if (DBG) {
330            ALOGW("%u Failed to read %i bytes", mMark, responseSize);
331        }
332        return Response::network_error;
333    }
334    if (DBG) {
335        ALOGD("%u SSL_read complete", mMark);
336    }
337
338    if (response[0] != query[0] || response[1] != query[1]) {
339        ALOGE("reply query ID != query ID");
340        return Response::internal_error;
341    }
342
343    SSL_shutdown(ssl.get());
344
345    *resplen = responseSize;
346    return Response::success;
347}
348
349bool validateDnsTlsServer(unsigned netid, const struct sockaddr_storage& ss,
350        const std::set<std::vector<uint8_t>>& fingerprints) {
351    if (DBG) {
352        ALOGD("Beginning validation on %u", netid);
353    }
354    // Generate "<random>-dnsotls-ds.metric.gstatic.com", which we will lookup through |ss| in
355    // order to prove that it is actually a working DNS over TLS server.
356    static const char kDnsSafeChars[] =
357            "abcdefhijklmnopqrstuvwxyz"
358            "ABCDEFHIJKLMNOPQRSTUVWXYZ"
359            "0123456789";
360    const auto c = [](uint8_t rnd) -> uint8_t {
361        return kDnsSafeChars[(rnd % ARRAY_SIZE(kDnsSafeChars))];
362    };
363    uint8_t rnd[8];
364    arc4random_buf(rnd, ARRAY_SIZE(rnd));
365    // We could try to use res_mkquery() here, but it's basically the same.
366    uint8_t query[] = {
367        rnd[6], rnd[7],  // [0-1]   query ID
368        1, 0,  // [2-3]   flags; query[2] = 1 for recursion desired (RD).
369        0, 1,  // [4-5]   QDCOUNT (number of queries)
370        0, 0,  // [6-7]   ANCOUNT (number of answers)
371        0, 0,  // [8-9]   NSCOUNT (number of name server records)
372        0, 0,  // [10-11] ARCOUNT (number of additional records)
373        17, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]),
374            '-', 'd', 'n', 's', 'o', 't', 'l', 's', '-', 'd', 's',
375        6, 'm', 'e', 't', 'r', 'i', 'c',
376        7, 'g', 's', 't', 'a', 't', 'i', 'c',
377        3, 'c', 'o', 'm',
378        0,  // null terminator of FQDN (root TLD)
379        0, ns_t_aaaa,  // QTYPE
380        0, ns_c_in     // QCLASS
381    };
382    const int qlen = ARRAY_SIZE(query);
383
384    const int kRecvBufSize = 4 * 1024;
385    uint8_t recvbuf[kRecvBufSize];
386
387    // At validation time, we only know the netId, so we have to guess/compute the
388    // corresponding socket mark.
389    Fwmark fwmark;
390    fwmark.permission = PERMISSION_SYSTEM;
391    fwmark.explicitlySelected = true;
392    fwmark.protectedFromVpn = true;
393    fwmark.netId = netid;
394    unsigned mark = fwmark.intValue;
395    DnsTlsTransport xport(mark, IPPROTO_TCP, ss, fingerprints);
396    int replylen = 0;
397    xport.doQuery(query, qlen, recvbuf, kRecvBufSize, &replylen);
398    if (replylen == 0) {
399        if (DBG) {
400            ALOGD("doQuery failed");
401        }
402        return false;
403    }
404
405    if (replylen < NS_HFIXEDSZ) {
406        if (DBG) {
407            ALOGW("short response: %d", replylen);
408        }
409        return false;
410    }
411
412    const int qdcount = (recvbuf[4] << 8) | recvbuf[5];
413    if (qdcount != 1) {
414        ALOGW("reply query count != 1: %d", qdcount);
415        return false;
416    }
417
418    const int ancount = (recvbuf[6] << 8) | recvbuf[7];
419    if (DBG) {
420        ALOGD("%u answer count: %d", netid, ancount);
421    }
422
423    // TODO: Further validate the response contents (check for valid AAAA record, ...).
424    // Note that currently, integration tests rely on this function accepting a
425    // response with zero records.
426#if 0
427    for (int i = 0; i < resplen; i++) {
428        ALOGD("recvbuf[%d] = %d %c", i, recvbuf[i], recvbuf[i]);
429    }
430#endif
431    return true;
432}
433
434}  // namespace net
435}  // namespace android
436