1/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "DnsTlsQueryMap"
18//#define LOG_NDEBUG 0
19
20#include "dns/DnsTlsTransport.h"
21
22#include "log/log.h"
23
24namespace android {
25namespace net {
26
27std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery(const Slice query) {
28    std::lock_guard<std::mutex> guard(mLock);
29
30    // Store the query so it can be matched to the response or reissued.
31    if (query.size() < 2) {
32        ALOGW("Query is too short");
33        return nullptr;
34    }
35    int32_t newId = getFreeId();
36    if (newId < 0) {
37        ALOGW("All query IDs are in use");
38        return nullptr;
39    }
40    Query q = { .newId = static_cast<uint16_t>(newId), .query = query };
41    std::map<uint16_t, QueryPromise>::iterator it;
42    bool inserted;
43    std::tie(it, inserted) = mQueries.emplace(newId, q);
44    if (!inserted) {
45        ALOGE("Failed to store pending query");
46        return nullptr;
47    }
48    return std::make_unique<QueryFuture>(q, it->second.result.get_future());
49}
50
51void DnsTlsQueryMap::expire(QueryPromise* p) {
52    Result r = { .code = Response::network_error };
53    p->result.set_value(r);
54}
55
56void DnsTlsQueryMap::markTried(uint16_t newId) {
57    std::lock_guard<std::mutex> guard(mLock);
58    auto it = mQueries.find(newId);
59    if (it != mQueries.end()) {
60        it->second.tries++;
61    }
62}
63
64void DnsTlsQueryMap::cleanup() {
65    std::lock_guard<std::mutex> guard(mLock);
66    for (auto it = mQueries.begin(); it != mQueries.end();) {
67        auto& p = it->second;
68        if (p.tries >= kMaxTries) {
69            expire(&p);
70            it = mQueries.erase(it);
71        } else {
72            ++it;
73        }
74    }
75}
76
77int32_t DnsTlsQueryMap::getFreeId() {
78    if (mQueries.empty()) {
79        return 0;
80    }
81    uint16_t maxId = mQueries.rbegin()->first;
82    if (maxId < UINT16_MAX) {
83        return maxId + 1;
84    }
85    if (mQueries.size() == UINT16_MAX + 1) {
86        // Map is full.
87        return -1;
88    }
89    // Linear scan.
90    uint16_t nextId = 0;
91    for (auto& pair : mQueries) {
92        uint16_t id = pair.first;
93        if (id != nextId) {
94            // Found a gap.
95            return nextId;
96        }
97        nextId = id + 1;
98    }
99    // Unreachable (but the compiler isn't smart enough to prove it).
100    return -1;
101}
102
103std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() {
104    std::lock_guard<std::mutex> guard(mLock);
105    std::vector<DnsTlsQueryMap::Query> queries;
106    for (auto& q : mQueries) {
107        queries.push_back(q.second.query);
108    }
109    return queries;
110}
111
112bool DnsTlsQueryMap::empty() {
113    std::lock_guard<std::mutex> guard(mLock);
114    return mQueries.empty();
115}
116
117void DnsTlsQueryMap::clear() {
118    std::lock_guard<std::mutex> guard(mLock);
119    for (auto& q : mQueries) {
120        expire(&q.second);
121    }
122    mQueries.clear();
123}
124
125void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) {
126    ALOGV("Got response of size %zu", response.size());
127    if (response.size() < 2) {
128        ALOGW("Response is too short");
129        return;
130    }
131    uint16_t id = response[0] << 8 | response[1];
132    std::lock_guard<std::mutex> guard(mLock);
133    auto it = mQueries.find(id);
134    if (it == mQueries.end()) {
135        ALOGW("Discarding response: unknown ID %d", id);
136        return;
137    }
138    Result r = { .code = Response::success, .response = std::move(response) };
139    // Rewrite ID to match the query
140    const uint8_t* data = it->second.query.query.base();
141    r.response[0] = data[0];
142    r.response[1] = data[1];
143    ALOGV("Sending result to dispatcher");
144    it->second.result.set_value(std::move(r));
145    mQueries.erase(it);
146}
147
148}  // end of namespace net
149}  // end of namespace android
150