1/* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include "dns_tls_frontend.h" 18 19#include <netdb.h> 20#include <stdio.h> 21#include <unistd.h> 22#include <sys/poll.h> 23#include <sys/socket.h> 24#include <sys/types.h> 25#include <arpa/inet.h> 26#include <openssl/err.h> 27#include <openssl/evp.h> 28#include <openssl/ssl.h> 29 30#define LOG_TAG "DnsTlsFrontend" 31#include <log/log.h> 32 33#include <unistd.h> 34 35namespace { 36 37const int SHA256_SIZE = 32; 38 39// Copied from DnsTlsTransport. 40bool getSPKIDigest(const X509* cert, std::vector<uint8_t>* out) { 41 int spki_len = i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), NULL); 42 unsigned char spki[spki_len]; 43 unsigned char* temp = spki; 44 if (spki_len != i2d_X509_PUBKEY(X509_get_X509_PUBKEY(cert), &temp)) { 45 ALOGE("SPKI length mismatch"); 46 return false; 47 } 48 out->resize(SHA256_SIZE); 49 unsigned int digest_len = 0; 50 int ret = EVP_Digest(spki, spki_len, out->data(), &digest_len, EVP_sha256(), NULL); 51 if (ret != 1) { 52 ALOGE("Server cert digest extraction failed"); 53 return false; 54 } 55 if (digest_len != out->size()) { 56 ALOGE("Wrong digest length: %d", digest_len); 57 return false; 58 } 59 return true; 60} 61 62std::string errno2str() { 63 char error_msg[512] = { 0 }; 64 if (strerror_r(errno, error_msg, sizeof(error_msg))) 65 return std::string(); 66 return std::string(error_msg); 67} 68 69#define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str()) 70 71std::string addr2str(const sockaddr* sa, socklen_t sa_len) { 72 char host_str[NI_MAXHOST] = { 0 }; 73 int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0, 74 NI_NUMERICHOST); 75 if (rv == 0) return std::string(host_str); 76 return std::string(); 77} 78 79bssl::UniquePtr<EVP_PKEY> make_private_key() { 80 bssl::UniquePtr<BIGNUM> e(BN_new()); 81 if (!e) { 82 ALOGE("BN_new failed"); 83 return nullptr; 84 } 85 if (!BN_set_word(e.get(), RSA_F4)) { 86 ALOGE("BN_set_word failed"); 87 return nullptr; 88 } 89 90 bssl::UniquePtr<RSA> rsa(RSA_new()); 91 if (!rsa) { 92 ALOGE("RSA_new failed"); 93 return nullptr; 94 } 95 if (!RSA_generate_key_ex(rsa.get(), 2048, e.get(), NULL)) { 96 ALOGE("RSA_generate_key_ex failed"); 97 return nullptr; 98 } 99 100 bssl::UniquePtr<EVP_PKEY> privkey(EVP_PKEY_new()); 101 if (!privkey) { 102 ALOGE("EVP_PKEY_new failed"); 103 return nullptr; 104 } 105 if(!EVP_PKEY_assign_RSA(privkey.get(), rsa.get())) { 106 ALOGE("EVP_PKEY_assign_RSA failed"); 107 return nullptr; 108 } 109 110 // |rsa| is now owned by |privkey|, so no need to free it. 111 rsa.release(); 112 return privkey; 113} 114 115bssl::UniquePtr<X509> make_cert(EVP_PKEY* privkey) { 116 bssl::UniquePtr<X509> cert(X509_new()); 117 if (!cert) { 118 ALOGE("X509_new failed"); 119 return nullptr; 120 } 121 122 ASN1_INTEGER_set(X509_get_serialNumber(cert.get()), 1); 123 124 // Set one hour expiration. 125 X509_gmtime_adj(X509_get_notBefore(cert.get()), 0); 126 X509_gmtime_adj(X509_get_notAfter(cert.get()), 60 * 60); 127 128 X509_set_pubkey(cert.get(), privkey); 129 130 if (!X509_sign(cert.get(), privkey, EVP_sha256())) { 131 ALOGE("X509_sign failed"); 132 return nullptr; 133 } 134 135 return cert; 136} 137 138} 139 140namespace test { 141 142bool DnsTlsFrontend::startServer() { 143 SSL_load_error_strings(); 144 OpenSSL_add_ssl_algorithms(); 145 146 ctx_.reset(SSL_CTX_new(TLS_server_method())); 147 if (!ctx_) { 148 ALOGE("SSL context creation failed"); 149 return false; 150 } 151 152 SSL_CTX_set_ecdh_auto(ctx_.get(), 1); 153 154 bssl::UniquePtr<EVP_PKEY> key(make_private_key()); 155 bssl::UniquePtr<X509> cert(make_cert(key.get())); 156 if (SSL_CTX_use_certificate(ctx_.get(), cert.get()) <= 0) { 157 ALOGE("SSL_CTX_use_certificate failed"); 158 return false; 159 } 160 161 if (!getSPKIDigest(cert.get(), &fingerprint_)) { 162 ALOGE("getSPKIDigest failed"); 163 return false; 164 } 165 166 if (SSL_CTX_use_PrivateKey(ctx_.get(), key.get()) <= 0 ) { 167 ALOGE("SSL_CTX_use_PrivateKey failed"); 168 return false; 169 } 170 171 // Set up TCP server socket for clients. 172 addrinfo frontend_ai_hints{ 173 .ai_family = AF_UNSPEC, 174 .ai_socktype = SOCK_STREAM, 175 .ai_flags = AI_PASSIVE 176 }; 177 addrinfo* frontend_ai_res; 178 int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), 179 &frontend_ai_hints, &frontend_ai_res); 180 if (rv) { 181 ALOGE("frontend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(), 182 listen_service_.c_str(), gai_strerror(rv)); 183 return false; 184 } 185 186 int s = -1; 187 for (const addrinfo* ai = frontend_ai_res ; ai ; ai = ai->ai_next) { 188 s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); 189 if (s < 0) continue; 190 const int one = 1; 191 setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one)); 192 if (bind(s, ai->ai_addr, ai->ai_addrlen)) { 193 APLOGI("bind failed for socket %d", s); 194 close(s); 195 s = -1; 196 continue; 197 } 198 std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen); 199 ALOGI("bound to TCP %s:%s", host_str.c_str(), listen_service_.c_str()); 200 break; 201 } 202 freeaddrinfo(frontend_ai_res); 203 if (s < 0) { 204 ALOGE("server socket creation failed"); 205 return false; 206 } 207 208 if (listen(s, 1) < 0) { 209 ALOGE("listen failed"); 210 return false; 211 } 212 213 socket_ = s; 214 215 // Set up UDP client socket to backend. 216 addrinfo backend_ai_hints{ 217 .ai_family = AF_UNSPEC, 218 .ai_socktype = SOCK_DGRAM 219 }; 220 addrinfo* backend_ai_res; 221 rv = getaddrinfo(backend_address_.c_str(), backend_service_.c_str(), 222 &backend_ai_hints, &backend_ai_res); 223 if (rv) { 224 ALOGE("backend getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(), 225 listen_service_.c_str(), gai_strerror(rv)); 226 return false; 227 } 228 backend_socket_ = socket(backend_ai_res->ai_family, backend_ai_res->ai_socktype, 229 backend_ai_res->ai_protocol); 230 if (backend_socket_ < 0) { 231 ALOGE("backend socket creation failed"); 232 return false; 233 } 234 connect(backend_socket_, backend_ai_res->ai_addr, backend_ai_res->ai_addrlen); 235 freeaddrinfo(backend_ai_res); 236 237 { 238 std::lock_guard<std::mutex> lock(update_mutex_); 239 handler_thread_ = std::thread(&DnsTlsFrontend::requestHandler, this); 240 } 241 ALOGI("server started successfully"); 242 return true; 243} 244 245void DnsTlsFrontend::requestHandler() { 246 ALOGD("Request handler started"); 247 struct pollfd fds[1] = {{ .fd = socket_, .events = POLLIN }}; 248 249 while (!terminate_) { 250 int poll_code = poll(fds, 1, 10 /* ms */); 251 if (poll_code == 0) { 252 // Timeout. Poll again. 253 continue; 254 } else if (poll_code < 0) { 255 ALOGW("Poll failed with error %d", poll_code); 256 // Error. 257 break; 258 } 259 sockaddr_storage addr; 260 socklen_t len = sizeof(addr); 261 262 ALOGD("Trying to accept a client"); 263 int client = accept(socket_, reinterpret_cast<sockaddr*>(&addr), &len); 264 ALOGD("Got client socket %d", client); 265 if (client < 0) { 266 // Stop 267 break; 268 } 269 270 bssl::UniquePtr<SSL> ssl(SSL_new(ctx_.get())); 271 SSL_set_fd(ssl.get(), client); 272 273 ALOGD("Doing SSL handshake"); 274 bool success = false; 275 if (SSL_accept(ssl.get()) <= 0) { 276 ALOGI("SSL negotiation failure"); 277 } else { 278 ALOGD("SSL handshake complete"); 279 success = handleOneRequest(ssl.get()); 280 } 281 282 close(client); 283 284 if (success) { 285 // Increment queries_ as late as possible, because it represents 286 // a query that is fully processed, and the response returned to the 287 // client, including cleanup actions. 288 ++queries_; 289 } 290 } 291 ALOGD("Request handler terminating"); 292} 293 294bool DnsTlsFrontend::handleOneRequest(SSL* ssl) { 295 uint8_t queryHeader[2]; 296 if (SSL_read(ssl, &queryHeader, 2) != 2) { 297 ALOGI("Not enough header bytes"); 298 return false; 299 } 300 const uint16_t qlen = (queryHeader[0] << 8) | queryHeader[1]; 301 uint8_t query[qlen]; 302 if (SSL_read(ssl, &query, qlen) != qlen) { 303 ALOGI("Not enough query bytes"); 304 return false; 305 } 306 int sent = send(backend_socket_, query, qlen, 0); 307 if (sent != qlen) { 308 ALOGI("Failed to send query"); 309 return false; 310 } 311 const int max_size = 4096; 312 uint8_t recv_buffer[max_size]; 313 int rlen = recv(backend_socket_, recv_buffer, max_size, 0); 314 if (rlen <= 0) { 315 ALOGI("Failed to receive response"); 316 return false; 317 } 318 uint8_t responseHeader[2]; 319 responseHeader[0] = rlen >> 8; 320 responseHeader[1] = rlen; 321 if (SSL_write(ssl, responseHeader, 2) != 2) { 322 ALOGI("Failed to write response header"); 323 return false; 324 } 325 if (SSL_write(ssl, recv_buffer, rlen) != rlen) { 326 ALOGI("Failed to write response body"); 327 return false; 328 } 329 return true; 330} 331 332bool DnsTlsFrontend::stopServer() { 333 std::lock_guard<std::mutex> lock(update_mutex_); 334 if (!running()) { 335 ALOGI("server not running"); 336 return false; 337 } 338 if (terminate_) { 339 ALOGI("LOGIC ERROR"); 340 return false; 341 } 342 ALOGI("stopping frontend"); 343 terminate_ = true; 344 handler_thread_.join(); 345 close(socket_); 346 close(backend_socket_); 347 terminate_ = false; 348 socket_ = -1; 349 backend_socket_ = -1; 350 ctx_.reset(); 351 fingerprint_.clear(); 352 ALOGI("frontend stopped successfully"); 353 return true; 354} 355 356bool DnsTlsFrontend::waitForQueries(int number, int timeoutMs) const { 357 constexpr int intervalMs = 20; 358 int limit = timeoutMs / intervalMs; 359 for (int count = 0; count <= limit; ++count) { 360 bool done = queries_ >= number; 361 // Always sleep at least one more interval after we are done, to wait for 362 // any immediate post-query actions that the client may take (such as 363 // marking this server as reachable during validation). 364 usleep(intervalMs * 1000); 365 if (done) { 366 // For ensuring that calls have sufficient headroom for slow machines 367 ALOGD("Query arrived in %d/%d of allotted time", count, limit); 368 return true; 369 } 370 } 371 return false; 372} 373 374} // namespace test 375