1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "dns_tls_frontend.h"
18
19#include <netdb.h>
20#include <stdio.h>
21#include <unistd.h>
22#include <sys/poll.h>
23#include <sys/socket.h>
24#include <sys/types.h>
25#include <arpa/inet.h>
26#include <openssl/err.h>
27#include <openssl/evp.h>
28#include <openssl/ssl.h>
29
30#define LOG_TAG "DnsTlsFrontend"
31#include <log/log.h>
32
33#include <unistd.h>
34
35namespace {
36
37const int SHA256_SIZE = 32;
38
39// Copied from DnsTlsTransport.
40bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) {
41    int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL);
42    unsigned char spki[spki_len];
43    unsigned char* temp = spki;
44    if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) {
45        ALOGE("SPKI length mismatch");
46        return false;
47    }
48    out->resize(SHA256_SIZE);
49    unsigned int digest_len = 0;
50    int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL);
51    if (ret != 1) {
52        ALOGE("Server cert digest extraction failed");
53        return false;
54    }
55    if (digest_len != out->size()) {
56        ALOGE("Wrong digest length: %d", digest_len);
57        return false;
58    }
59    return true;
60}
61
62std::string errno2str() {
63    char error_msg[512] = { 0 };
64    if (strerror_r(errno, error_msg, sizeof(error_msg)))
65        return std::string();
66    return std::string(error_msg);
67}
68
69#define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
70
71std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
72    char host_str[NI_MAXHOST] = { 0 };
73    int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
74                         NI_NUMERICHOST);
75    if (rv == 0) return std::string(host_str);
76    return std::string();
77}
78
79bssl::UniquePtr<EVP_PKEY> make_private_key() {
80    bssl::UniquePtr<BIGNUM> e(BN_new());
81    if (!e) {
82        ALOGE("BN_new failed");
83        return nullptr;
84    }
85    if (!BN_set_word(e.get(), RSA_F4)) {
86        ALOGE("BN_set_word failed");
87        return nullptr;
88    }
89
90    bssl::UniquePtr<RSA> rsa(RSA_new());
91    if (!rsa) {
92        ALOGE("RSA_new failed");
93        return nullptr;
94    }
95    if (!RSA_generate_key_ex(rsa.get(), 2048, e.get(), NULL)) {
96        ALOGE("RSA_generate_key_ex failed");
97        return nullptr;
98    }
99
100    bssl::UniquePtr<EVP_PKEY> privkey(EVP_PKEY_new());
101    if (!privkey) {
102        ALOGE("EVP_PKEY_new failed");
103        return nullptr;
104    }
105    if(!EVP_PKEY_assign_RSA(privkey.get(), rsa.get())) {
106        ALOGE("EVP_PKEY_assign_RSA failed");
107        return nullptr;
108    }
109
110    // |rsa| is now owned by |privkey|, so no need to free it.
111    rsa.release();
112    return privkey;
113}
114
115bssl::UniquePtr<X509> make_cert(EVP_PKEY* privkey) {
116    bssl::UniquePtr<X509> cert(X509_new());
117    if (!cert) {
118        ALOGE("X509_new failed");
119        return nullptr;
120    }
121
122    ASN1_INTEGER_set(X509_get_serialNumber(cert.get()), 1);
123
124    // Set one hour expiration.
125    X509_gmtime_adj(X509_get_notBefore(cert.get()), 0);
126    X509_gmtime_adj(X509_get_notAfter(cert.get()), 60 * 60);
127
128    X509_set_pubkey(cert.get(), privkey);
129
130    if (!X509_sign(cert.get(), privkey, EVP_sha256())) {
131        ALOGE("X509_sign failed");
132        return nullptr;
133    }
134
135    return cert;
136}
137
138}
139
140namespace test {
141
142bool DnsTlsFrontend::startServer() {
143    SSL_load_error_strings();
144    OpenSSL_add_ssl_algorithms();
145
146    ctx_.reset(SSL_CTX_new(TLS_server_method()));
147    if (!ctx_) {
148        ALOGE("SSL context creation failed");
149        return false;
150    }
151
152    SSL_CTX_set_ecdh_auto(ctx_.get(), 1);
153
154    bssl::UniquePtr<EVP_PKEY> key(make_private_key());
155    bssl::UniquePtr<X509> cert(make_cert(key.get()));
156    if (SSL_CTX_use_certificate(ctx_.get(), cert.get()) <= 0) {
157        ALOGE("SSL_CTX_use_certificate failed");
158        return false;
159    }
160
161    if (!getSPKIDigest(cert.get(), &fingerprint_)) {
162        ALOGE("getSPKIDigest failed");
163        return false;
164    }
165
166    if (SSL_CTX_use_PrivateKey(ctx_.get(), key.get()) <= 0 ) {
167        ALOGE("SSL_CTX_use_PrivateKey failed");
168        return false;
169    }
170
171    // Set up TCP server socket for clients.
172    addrinfo frontend_ai_hints{
173        .ai_family = AF_UNSPEC,
174        .ai_socktype = SOCK_STREAM,
175        .ai_flags = AI_PASSIVE
176    };
177    addrinfo* frontend_ai_res;
178    int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
179                         &frontend_ai_hints, &frontend_ai_res);
180    if (rv) {
181        ALOGE("frontend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
182            listen_service_.c_str(), gai_strerror(rv));
183        return false;
184    }
185
186    int s = -1;
187    for (const addrinfo* ai = frontend_ai_res ; ai ; ai = ai->ai_next) {
188        s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
189        if (s < 0) continue;
190        const int one = 1;
191        setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
192        if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
193            APLOGI("bind failed for socket %d", s);
194            close(s);
195            s = -1;
196            continue;
197        }
198        std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
199        ALOGI("bound to TCP %s:%s", host_str.c_str(), listen_service_.c_str());
200        break;
201    }
202    freeaddrinfo(frontend_ai_res);
203    if (s < 0) {
204        ALOGE("server socket creation failed");
205        return false;
206    }
207
208    if (listen(s, 1) < 0) {
209        ALOGE("listen failed");
210        return false;
211    }
212
213    socket_ = s;
214
215    // Set up UDP client socket to backend.
216    addrinfo backend_ai_hints{
217        .ai_family = AF_UNSPEC,
218        .ai_socktype = SOCK_DGRAM
219    };
220    addrinfo* backend_ai_res;
221    rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(),
222                         &backend_ai_hints, &backend_ai_res);
223    if (rv) {
224        ALOGE("backend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
225            listen_service_.c_str(), gai_strerror(rv));
226        return false;
227    }
228    backend_socket_ = socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype,
229        backend_ai_res->ai_protocol);
230    if (backend_socket_ < 0) {
231        ALOGE("backend socket creation failed");
232        return false;
233    }
234    connect(backend_socket_, backend_ai_res->ai_addr, backend_ai_res->ai_addrlen);
235    freeaddrinfo(backend_ai_res);
236
237    {
238        std::lock_guard<std::mutex> lock(update_mutex_);
239        handler_thread_ = std::thread(&DnsTlsFrontend::requestHandler, this);
240    }
241    ALOGI("server started successfully");
242    return true;
243}
244
245void DnsTlsFrontend::requestHandler() {
246    ALOGD("Request handler started");
247    struct pollfd fds[1] = {{ .fd = socket_, .events = POLLIN }};
248
249    while (!terminate_) {
250        int poll_code = poll(fds, 1, 10 /* ms */);
251        if (poll_code == 0) {
252            // Timeout.  Poll again.
253            continue;
254        } else if (poll_code < 0) {
255            ALOGW("Poll failed with error %d", poll_code);
256            // Error.
257            break;
258        }
259        sockaddr_storage addr;
260        socklen_t len = sizeof(addr);
261
262        ALOGD("Trying to accept a client");
263        int client = accept(socket_, reinterpret_cast<sockaddr*>(&addr), &len);
264        ALOGD("Got client socket %d", client);
265        if (client < 0) {
266            // Stop
267            break;
268        }
269
270        bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get()));
271        SSL_set_fd(ssl.get(), client);
272
273        ALOGD("Doing SSL handshake");
274        bool success = false;
275        if (SSL_accept(ssl.get()) <= 0) {
276            ALOGI("SSL negotiation failure");
277        } else {
278            ALOGD("SSL handshake complete");
279            success = handleOneRequest(ssl.get());
280        }
281
282        close(client);
283
284        if (success) {
285            // Increment queries_ as late as possible, because it represents
286            // a query that is fully processed, and the response returned to the
287            // client, including cleanup actions.
288            ++queries_;
289        }
290    }
291    ALOGD("Request handler terminating");
292}
293
294bool DnsTlsFrontend::handleOneRequest(SSL* ssl) {
295    uint8_t queryHeader[2];
296    if (SSL_read(ssl, &queryHeader, 2) != 2) {
297        ALOGI("Not enough header bytes");
298        return false;
299    }
300    const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1];
301    uint8_t query[qlen];
302    if (SSL_read(ssl, &query, qlen) != qlen) {
303        ALOGI("Not enough query bytes");
304        return false;
305    }
306    int sent = send(backend_socket_, query, qlen, 0);
307    if (sent != qlen) {
308        ALOGI("Failed to send query");
309        return false;
310    }
311    const int max_size = 4096;
312    uint8_t recv_buffer[max_size];
313    int rlen = recv(backend_socket_, recv_buffer, max_size, 0);
314    if (rlen <= 0) {
315        ALOGI("Failed to receive response");
316        return false;
317    }
318    uint8_t responseHeader[2];
319    responseHeader[0] = rlen >> 8;
320    responseHeader[1] = rlen;
321    if (SSL_write(ssl, responseHeader, 2) != 2) {
322        ALOGI("Failed to write response header");
323        return false;
324    }
325    if (SSL_write(ssl, recv_buffer, rlen) != rlen) {
326        ALOGI("Failed to write response body");
327        return false;
328    }
329    return true;
330}
331
332bool DnsTlsFrontend::stopServer() {
333    std::lock_guard<std::mutex> lock(update_mutex_);
334    if (!running()) {
335        ALOGI("server not running");
336        return false;
337    }
338    if (terminate_) {
339        ALOGI("LOGIC ERROR");
340        return false;
341    }
342    ALOGI("stopping frontend");
343    terminate_ = true;
344    handler_thread_.join();
345    close(socket_);
346    close(backend_socket_);
347    terminate_ = false;
348    socket_ = -1;
349    backend_socket_ = -1;
350    ctx_.reset();
351    fingerprint_.clear();
352    ALOGI("frontend stopped successfully");
353    return true;
354}
355
356bool DnsTlsFrontend::waitForQueries(int number, int timeoutMs) const {
357    constexpr int intervalMs = 20;
358    int limit = timeoutMs / intervalMs;
359    for (int count = 0; count <= limit; ++count) {
360        bool done = queries_ >= number;
361        // Always sleep at least one more interval after we are done, to wait for
362        // any immediate post-query actions that the client may take (such as
363        // marking this server as reachable during validation).
364        usleep(intervalMs * 1000);
365        if (done) {
366            // For ensuring that calls have sufficient headroom for slow machines
367            ALOGD("Query arrived in %d/%d of allotted time", count, limit);
368            return true;
369        }
370    }
371    return false;
372}
373
374}  // namespace test
375