1/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "DnsTlsSocket"
18//#define LOG_NDEBUG 0
19
20#include "dns/DnsTlsSocket.h"
21
22#include <algorithm>
23#include <arpa/inet.h>
24#include <arpa/nameser.h>
25#include <errno.h>
26#include <linux/tcp.h>
27#include <openssl/err.h>
28#include <sys/poll.h>
29
30#include "dns/DnsTlsSessionCache.h"
31#include "dns/IDnsTlsSocketObserver.h"
32
33#include "log/log.h"
34#include "netdutils/SocketOption.h"
35#include "Fwmark.h"
36#undef ADD  // already defined in nameser.h
37#include "NetdConstants.h"
38#include "Permission.h"
39
40
41namespace android {
42
43using netdutils::enableSockopt;
44using netdutils::enableTcpKeepAlives;
45using netdutils::isOk;
46using netdutils::Status;
47
48namespace net {
49namespace {
50
51constexpr const char kCaCertDir[] = "/system/etc/security/cacerts";
52
53int waitForReading(int fd) {
54    struct pollfd fds = { .fd = fd, .events = POLLIN };
55    const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1));
56    return ret;
57}
58
59int waitForWriting(int fd) {
60    struct pollfd fds = { .fd = fd, .events = POLLOUT };
61    const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1));
62    return ret;
63}
64
65}  // namespace
66
67Status DnsTlsSocket::tcpConnect() {
68    ALOGV("%u connecting TCP socket", mMark);
69    int type = SOCK_NONBLOCK | SOCK_CLOEXEC;
70    switch (mServer.protocol) {
71        case IPPROTO_TCP:
72            type |= SOCK_STREAM;
73            break;
74        default:
75            return Status(EPROTONOSUPPORT);
76    }
77
78    mSslFd.reset(socket(mServer.ss.ss_family, type, mServer.protocol));
79    if (mSslFd.get() == -1) {
80        ALOGE("Failed to create socket");
81        return Status(errno);
82    }
83
84    const socklen_t len = sizeof(mMark);
85    if (setsockopt(mSslFd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
86        ALOGE("Failed to set socket mark");
87        mSslFd.reset();
88        return Status(errno);
89    }
90
91    const Status tfo = enableSockopt(mSslFd.get(), SOL_TCP, TCP_FASTOPEN_CONNECT);
92    if (!isOk(tfo) && tfo.code() != ENOPROTOOPT) {
93        ALOGI("Failed to enable TFO: %s", tfo.msg().c_str());
94    }
95
96    // Send 5 keepalives, 3 seconds apart, after 15 seconds of inactivity.
97    enableTcpKeepAlives(mSslFd.get(), 15U, 5U, 3U);
98
99    if (connect(mSslFd.get(), reinterpret_cast<const struct sockaddr *>(&mServer.ss),
100                sizeof(mServer.ss)) != 0 &&
101            errno != EINPROGRESS) {
102        ALOGV("Socket failed to connect");
103        mSslFd.reset();
104        return Status(errno);
105    }
106
107    return netdutils::status::ok;
108}
109
110bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
111    int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
112    unsigned char spki[spki_len];
113    unsigned char* temp = spki;
114    if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
115        ALOGW("SPKI length mismatch");
116        return false;
117    }
118    out->resize(SHA256_SIZE);
119    unsigned int digest_len = 0;
120    int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
121    if (ret != 1) {
122        ALOGW("Server cert digest extraction failed");
123        return false;
124    }
125    if (digest_len != out->size()) {
126        ALOGW("Wrong digest length: %d", digest_len);
127        return false;
128    }
129    return true;
130}
131
132bool DnsTlsSocket::initialize() {
133    // This method should only be called once, at the beginning, so locking should be
134    // unnecessary.  This lock only serves to help catch bugs in code that calls this method.
135    std::lock_guard<std::mutex> guard(mLock);
136    if (mSslCtx) {
137        // This is a bug in the caller.
138        return false;
139    }
140    mSslCtx.reset(SSL_CTX_new(TLS_method()));
141    if (!mSslCtx) {
142        return false;
143    }
144
145    // Load system CA certs for hostname verification.
146    //
147    // For discussion of alternative, sustainable approaches see b/71909242.
148    if (SSL_CTX_load_verify_locations(mSslCtx.get(), nullptr, kCaCertDir) != 1) {
149        ALOGE("Failed to load CA cert dir: %s", kCaCertDir);
150        return false;
151    }
152
153    // Enable TLS false start
154    SSL_CTX_set_false_start_allowed_without_alpn(mSslCtx.get(), 1);
155    SSL_CTX_set_mode(mSslCtx.get(), SSL_MODE_ENABLE_FALSE_START);
156
157    // Enable session cache
158    mCache->prepareSslContext(mSslCtx.get());
159
160    // Connect
161    Status status = tcpConnect();
162    if (!status.ok()) {
163        return false;
164    }
165    mSsl = sslConnect(mSslFd.get());
166    if (!mSsl) {
167        return false;
168    }
169    int sv[2];
170    if (socketpair(AF_LOCAL, SOCK_SEQPACKET, 0, sv)) {
171        return false;
172    }
173    // The two sockets are perfectly symmetrical, so the choice of which one is
174    // "in" and which one is "out" is arbitrary.
175    mIpcInFd.reset(sv[0]);
176    mIpcOutFd.reset(sv[1]);
177
178    // Start the I/O loop.
179    mLoopThread.reset(new std::thread(&DnsTlsSocket::loop, this));
180
181    return true;
182}
183
184bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) {
185    if (!mSslCtx) {
186        ALOGE("Internal error: context is null in sslConnect");
187        return nullptr;
188    }
189    if (!SSL_CTX_set_min_proto_version(mSslCtx.get(), TLS1_2_VERSION)) {
190        ALOGE("Failed to set minimum TLS version");
191        return nullptr;
192    }
193
194    bssl::UniquePtr<SSL> ssl(SSL_new(mSslCtx.get()));
195    // This file descriptor is owned by mSslFd, so don't let libssl close it.
196    bssl::UniquePtr<BIO> bio(BIO_new_socket(fd, BIO_NOCLOSE));
197    SSL_set_bio(ssl.get(), bio.get(), bio.get());
198    bio.release();
199
200    if (!mCache->prepareSsl(ssl.get())) {
201        return nullptr;
202    }
203
204    if (!mServer.name.empty()) {
205        if (SSL_set_tlsext_host_name(ssl.get(), mServer.name.c_str()) != 1) {
206            ALOGE("Failed to set SNI to %s", mServer.name.c_str());
207            return nullptr;
208        }
209        X509_VERIFY_PARAM* param = SSL_get0_param(ssl.get());
210        if (X509_VERIFY_PARAM_set1_host(param, mServer.name.data(), mServer.name.size()) != 1) {
211            ALOGE("Failed to set verify host param to %s", mServer.name.c_str());
212            return nullptr;
213        }
214        // This will cause the handshake to fail if certificate verification fails.
215        SSL_set_verify(ssl.get(), SSL_VERIFY_PEER, nullptr);
216    }
217
218    bssl::UniquePtr<SSL_SESSION> session = mCache->getSession();
219    if (session) {
220        ALOGV("Setting session");
221        SSL_set_session(ssl.get(), session.get());
222    } else {
223        ALOGV("No session available");
224    }
225
226    for (;;) {
227        ALOGV("%u Calling SSL_connect", mMark);
228        int ret = SSL_connect(ssl.get());
229        ALOGV("%u SSL_connect returned %d", mMark, ret);
230        if (ret == 1) break;  // SSL handshake complete;
231
232        const int ssl_err = SSL_get_error(ssl.get(), ret);
233        switch (ssl_err) {
234            case SSL_ERROR_WANT_READ:
235                if (waitForReading(fd) != 1) {
236                    ALOGW("SSL_connect read error: %d", errno);
237                    return nullptr;
238                }
239                break;
240            case SSL_ERROR_WANT_WRITE:
241                if (waitForWriting(fd) != 1) {
242                    ALOGW("SSL_connect write error");
243                    return nullptr;
244                }
245                break;
246            default:
247                ALOGW("SSL_connect error %d, errno=%d", ssl_err, errno);
248                return nullptr;
249        }
250    }
251
252    // TODO: Call SSL_shutdown before discarding the session if validation fails.
253    if (!mServer.fingerprints.empty()) {
254        ALOGV("Checking DNS over TLS fingerprint");
255
256        // We only care that the chain is internally self-consistent, not that
257        // it chains to a trusted root, so we can ignore some kinds of errors.
258        // TODO: Add a CA root verification mode that respects these errors.
259        int verify_result = SSL_get_verify_result(ssl.get());
260        switch (verify_result) {
261            case X509_V_OK:
262            case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT:
263            case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN:
264            case X509_V_ERR_CERT_UNTRUSTED:
265                break;
266            default:
267                ALOGW("Invalid certificate chain, error %d", verify_result);
268                return nullptr;
269        }
270
271        STACK_OF(X509) *chain = SSL_get_peer_cert_chain(ssl.get());
272        if (!chain) {
273            ALOGW("Server has null certificate");
274            return nullptr;
275        }
276        // Chain and its contents are owned by ssl, so we don't need to free explicitly.
277        bool matched = false;
278        for (size_t i = 0; i < sk_X509_num(chain); ++i) {
279            // This appears to be O(N^2), but there doesn't seem to be a straightforward
280            // way to walk a STACK_OF nondestructively in linear time.
281            X509* cert = sk_X509_value(chain, i);
282            std::vector<uint8_t> digest;
283            if (!getSPKIDigest(cert, &digest)) {
284                ALOGE("Digest computation failed");
285                return nullptr;
286            }
287
288            if (mServer.fingerprints.count(digest) > 0) {
289                matched = true;
290                break;
291            }
292        }
293
294        if (!matched) {
295            ALOGW("No matching fingerprint");
296            return nullptr;
297        }
298
299        ALOGV("DNS over TLS fingerprint is correct");
300    }
301
302    ALOGV("%u handshake complete", mMark);
303
304    return ssl;
305}
306
307void DnsTlsSocket::sslDisconnect() {
308    if (mSsl) {
309        SSL_shutdown(mSsl.get());
310        mSsl.reset();
311    }
312    mSslFd.reset();
313}
314
315bool DnsTlsSocket::sslWrite(const Slice buffer) {
316    ALOGV("%u Writing %zu bytes", mMark, buffer.size());
317    for (;;) {
318        int ret = SSL_write(mSsl.get(), buffer.base(), buffer.size());
319        if (ret == int(buffer.size())) break;  // SSL write complete;
320
321        if (ret < 1) {
322            const int ssl_err = SSL_get_error(mSsl.get(), ret);
323            switch (ssl_err) {
324                case SSL_ERROR_WANT_WRITE:
325                    if (waitForWriting(mSslFd.get()) != 1) {
326                        ALOGV("SSL_write error");
327                        return false;
328                    }
329                    continue;
330                case 0:
331                    break;  // SSL write complete;
332                default:
333                    ALOGV("SSL_write error %d", ssl_err);
334                    return false;
335            }
336        }
337    }
338    ALOGV("%u Wrote %zu bytes", mMark, buffer.size());
339    return true;
340}
341
342void DnsTlsSocket::loop() {
343    std::lock_guard<std::mutex> guard(mLock);
344    // Buffer at most one query.
345    Query q;
346
347    const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000;
348    while (true) {
349        // poll() ignores negative fds
350        struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } };
351        enum { SSLFD = 0, IPCFD = 1 };
352
353        // Always listen for a response from server.
354        fds[SSLFD].fd = mSslFd.get();
355        fds[SSLFD].events = POLLIN;
356
357        // If we have a pending query, also wait for space
358        // to write it, otherwise listen for a new query.
359        if (!q.query.empty()) {
360            fds[SSLFD].events |= POLLOUT;
361        } else {
362            fds[IPCFD].fd = mIpcOutFd.get();
363            fds[IPCFD].events = POLLIN;
364        }
365
366        const int s = TEMP_FAILURE_RETRY(poll(fds, ARRAY_SIZE(fds), timeout_msecs));
367        if (s == 0) {
368            ALOGV("Idle timeout");
369            break;
370        }
371        if (s < 0) {
372            ALOGV("Poll failed: %d", errno);
373            break;
374        }
375        if (fds[SSLFD].revents & (POLLIN | POLLERR)) {
376            if (!readResponse()) {
377                ALOGV("SSL remote close or read error.");
378                break;
379            }
380        }
381        if (fds[IPCFD].revents & (POLLIN | POLLERR)) {
382            int res = read(mIpcOutFd.get(), &q, sizeof(q));
383            if (res < 0) {
384                ALOGW("Error during IPC read");
385                break;
386            } else if (res == 0) {
387                ALOGV("IPC channel closed; disconnecting");
388                break;
389            } else if (res != sizeof(q)) {
390                ALOGE("Struct size mismatch: %d != %zu", res, sizeof(q));
391                break;
392            }
393        } else if (fds[SSLFD].revents & POLLOUT) {
394            // query cannot be null here.
395            if (!sendQuery(q)) {
396                break;
397            }
398            q = Query();  // Reset q to empty
399        }
400    }
401    ALOGV("Closing IPC read FD");
402    mIpcOutFd.reset();
403    ALOGV("Disconnecting");
404    sslDisconnect();
405    ALOGV("Calling onClosed");
406    mObserver->onClosed();
407    ALOGV("Ending loop");
408}
409
410DnsTlsSocket::~DnsTlsSocket() {
411    ALOGV("Destructor");
412    // This will trigger an orderly shutdown in loop().
413    mIpcInFd.reset();
414    {
415        // Wait for the orderly shutdown to complete.
416        std::lock_guard<std::mutex> guard(mLock);
417        if (mLoopThread && std::this_thread::get_id() == mLoopThread->get_id()) {
418            ALOGE("Violation of re-entrance precondition");
419            return;
420        }
421    }
422    if (mLoopThread) {
423        ALOGV("Waiting for loop thread to terminate");
424        mLoopThread->join();
425        mLoopThread.reset();
426    }
427    ALOGV("Destructor completed");
428}
429
430bool DnsTlsSocket::query(uint16_t id, const Slice query) {
431    const Query q = { .id = id, .query = query };
432    if (!mIpcInFd) {
433        return false;
434    }
435    int written = write(mIpcInFd.get(), &q, sizeof(q));
436    return written == sizeof(q);
437}
438
439// Read exactly len bytes into buffer or fail with an SSL error code
440int DnsTlsSocket::sslRead(const Slice buffer, bool wait) {
441    size_t remaining = buffer.size();
442    while (remaining > 0) {
443        int ret = SSL_read(mSsl.get(), buffer.limit() - remaining, remaining);
444        if (ret == 0) {
445            ALOGW_IF(remaining < buffer.size(), "SSL closed with %zu of %zu bytes remaining",
446                     remaining, buffer.size());
447            return SSL_ERROR_ZERO_RETURN;
448        }
449
450        if (ret < 0) {
451            const int ssl_err = SSL_get_error(mSsl.get(), ret);
452            if (wait && ssl_err == SSL_ERROR_WANT_READ) {
453                if (waitForReading(mSslFd.get()) != 1) {
454                    ALOGV("Poll failed in sslRead: %d", errno);
455                    return SSL_ERROR_SYSCALL;
456                }
457                continue;
458            } else {
459                ALOGV("SSL_read error %d", ssl_err);
460                return ssl_err;
461            }
462        }
463
464        remaining -= ret;
465        wait = true;  // Once a read is started, try to finish.
466    }
467    return SSL_ERROR_NONE;
468}
469
470bool DnsTlsSocket::sendQuery(const Query& q) {
471    ALOGV("sending query");
472    // Compose the entire message in a single buffer, so that it can be
473    // sent as a single TLS record.
474    std::vector<uint8_t> buf(q.query.size() + 4);
475    // Write 2-byte length
476    uint16_t len = q.query.size() + 2; // + 2 for the ID.
477    buf[0] = len >> 8;
478    buf[1] = len;
479    // Write 2-byte ID
480    buf[2] = q.id >> 8;
481    buf[3] = q.id;
482    // Copy body
483    std::memcpy(buf.data() + 4, q.query.base(), q.query.size());
484    if (!sslWrite(netdutils::makeSlice(buf))) {
485        return false;
486    }
487    ALOGV("%u SSL_write complete", mMark);
488    return true;
489}
490
491bool DnsTlsSocket::readResponse() {
492    ALOGV("reading response");
493    uint8_t responseHeader[2];
494    int err = sslRead(Slice(responseHeader, 2), false);
495    if (err == SSL_ERROR_WANT_READ) {
496        ALOGV("Ignoring spurious wakeup from server");
497        return true;
498    }
499    if (err != SSL_ERROR_NONE) {
500        return false;
501    }
502    // Truncate responses larger than MAX_SIZE.  This is safe because a DNS packet is
503    // always invalid when truncated, so the response will be treated as an error.
504    constexpr uint16_t MAX_SIZE = 8192;
505    const uint16_t responseSize = (responseHeader[0] << 8) | responseHeader[1];
506    ALOGV("%u Expecting response of size %i", mMark, responseSize);
507    std::vector<uint8_t> response(std::min(responseSize, MAX_SIZE));
508    if (sslRead(netdutils::makeSlice(response), true) != SSL_ERROR_NONE) {
509        ALOGV("%u Failed to read %zu bytes", mMark, response.size());
510        return false;
511    }
512    uint16_t remainingBytes = responseSize - response.size();
513    while (remainingBytes > 0) {
514        constexpr uint16_t CHUNK_SIZE = 2048;
515        std::vector<uint8_t> discard(std::min(remainingBytes, CHUNK_SIZE));
516        if (sslRead(netdutils::makeSlice(discard), true) != SSL_ERROR_NONE) {
517            ALOGV("%u Failed to discard %zu bytes", mMark, discard.size());
518            return false;
519        }
520        remainingBytes -= discard.size();
521    }
522    ALOGV("%u SSL_read complete", mMark);
523
524    mObserver->onResponse(std::move(response));
525    return true;
526}
527
528}  // end of namespace net
529}  // end of namespace android
530