1/*
2 * Copyright (C) 2014 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <keystore/authorization_set.h>
18
19#include <assert.h>
20#include <stddef.h>
21#include <stdlib.h>
22#include <string.h>
23#include <limits>
24#include <ostream>
25#include <istream>
26
27#include <new>
28
29namespace keystore {
30
31inline bool keyParamLess(const KeyParameter& a, const KeyParameter& b) {
32    if (a.tag != b.tag) return a.tag < b.tag;
33    int retval;
34    switch (typeFromTag(a.tag)) {
35    case TagType::INVALID:
36    case TagType::BOOL:
37        return false;
38    case TagType::ENUM:
39    case TagType::ENUM_REP:
40    case TagType::UINT:
41    case TagType::UINT_REP:
42        return a.f.integer < b.f.integer;
43    case TagType::ULONG:
44    case TagType::ULONG_REP:
45        return a.f.longInteger < b.f.longInteger;
46    case TagType::DATE:
47        return a.f.dateTime < b.f.dateTime;
48    case TagType::BIGNUM:
49    case TagType::BYTES:
50        // Handle the empty cases.
51        if (a.blob.size() == 0)
52            return b.blob.size() != 0;
53        if (b.blob.size() == 0) return false;
54
55        retval = memcmp(&a.blob[0], &b.blob[0], std::min(a.blob.size(), b.blob.size()));
56        // if one is the prefix of the other the longer wins
57        if (retval == 0) return a.blob.size() < b.blob.size();
58        // Otherwise a is less if a is less.
59        else return retval < 0;
60    }
61    return false;
62}
63
64inline bool keyParamEqual(const KeyParameter& a, const KeyParameter& b) {
65    if (a.tag != b.tag) return false;
66
67    switch (typeFromTag(a.tag)) {
68    case TagType::INVALID:
69    case TagType::BOOL:
70        return true;
71    case TagType::ENUM:
72    case TagType::ENUM_REP:
73    case TagType::UINT:
74    case TagType::UINT_REP:
75        return a.f.integer == b.f.integer;
76    case TagType::ULONG:
77    case TagType::ULONG_REP:
78        return a.f.longInteger == b.f.longInteger;
79    case TagType::DATE:
80        return a.f.dateTime == b.f.dateTime;
81    case TagType::BIGNUM:
82    case TagType::BYTES:
83        if (a.blob.size() != b.blob.size()) return false;
84        return a.blob.size() == 0 ||
85                memcmp(&a.blob[0], &b.blob[0], a.blob.size()) == 0;
86    }
87    return false;
88}
89
90void AuthorizationSet::Sort() {
91    std::sort(data_.begin(), data_.end(), keyParamLess);
92}
93
94void AuthorizationSet::Deduplicate() {
95    if (data_.empty()) return;
96
97    Sort();
98    std::vector<KeyParameter> result;
99
100    auto curr = data_.begin();
101    auto prev = curr++;
102    for (; curr != data_.end(); ++prev, ++curr) {
103        if (prev->tag == Tag::INVALID) continue;
104
105        if (!keyParamEqual(*prev, *curr)) {
106            result.emplace_back(std::move(*prev));
107        }
108    }
109    result.emplace_back(std::move(*prev));
110
111    std::swap(data_, result);
112}
113
114void AuthorizationSet::Union(const AuthorizationSet& other) {
115    data_.insert(data_.end(), other.data_.begin(), other.data_.end());
116    Deduplicate();
117}
118
119void AuthorizationSet::Subtract(const AuthorizationSet& other) {
120    Deduplicate();
121
122    auto i = other.begin();
123    while (i != other.end()) {
124        int pos = -1;
125        do {
126            pos = find(i->tag, pos);
127            if (pos != -1 && keyParamEqual(*i, data_[pos])) {
128                data_.erase(data_.begin() + pos);
129                break;
130            }
131        } while (pos != -1);
132        ++i;
133    }
134}
135
136int AuthorizationSet::find(Tag tag, int begin) const {
137    auto iter = data_.begin() + (1 + begin);
138
139    while (iter != data_.end() && iter->tag != tag) ++iter;
140
141    if (iter != data_.end()) return iter - data_.begin();
142    return -1;
143}
144
145bool AuthorizationSet::erase(int index) {
146    auto pos = data_.begin() + index;
147    if (pos != data_.end()) {
148        data_.erase(pos);
149        return true;
150    }
151    return false;
152}
153
154KeyParameter& AuthorizationSet::operator[](int at) {
155    return data_[at];
156}
157
158const KeyParameter& AuthorizationSet::operator[](int at) const {
159    return data_[at];
160}
161
162void AuthorizationSet::Clear() {
163    data_.clear();
164}
165
166size_t AuthorizationSet::GetTagCount(Tag tag) const {
167    size_t count = 0;
168    for (int pos = -1; (pos = find(tag, pos)) != -1;)
169        ++count;
170    return count;
171}
172
173NullOr<const KeyParameter&> AuthorizationSet::GetEntry(Tag tag) const {
174    int pos = find(tag);
175    if (pos == -1) return {};
176    return data_[pos];
177}
178
179/**
180 * Persistent format is:
181 * | 32 bit indirect_size         |
182 * --------------------------------
183 * | indirect_size bytes of data  | this is where the blob data is stored
184 * --------------------------------
185 * | 32 bit element_count         | number of entries
186 * | 32 bit elements_size         | total bytes used by entries (entries have variable length)
187 * --------------------------------
188 * | elementes_size bytes of data | where the elements are stored
189 */
190
191/**
192 * Persistent format of blobs and bignums:
193 * | 32 bit tag             |
194 * | 32 bit blob_length     |
195 * | 32 bit indirect_offset |
196 */
197
198struct OutStreams {
199    std::ostream& indirect;
200    std::ostream& elements;
201};
202
203OutStreams& serializeParamValue(OutStreams& out, const hidl_vec<uint8_t>& blob) {
204    uint32_t buffer;
205
206    // write blob_length
207    auto blob_length = blob.size();
208    if (blob_length > std::numeric_limits<uint32_t>::max()) {
209        out.elements.setstate(std::ios_base::badbit);
210        return out;
211    }
212    buffer = blob_length;
213    out.elements.write(reinterpret_cast<const char*>(&buffer), sizeof(uint32_t));
214
215    // write indirect_offset
216    auto offset = out.indirect.tellp();
217    if (offset < 0 || offset > std::numeric_limits<uint32_t>::max() ||
218            uint32_t(offset) + uint32_t(blob_length) < uint32_t(offset)) { // overflow check
219        out.elements.setstate(std::ios_base::badbit);
220        return out;
221    }
222    buffer = offset;
223    out.elements.write(reinterpret_cast<const char*>(&buffer), sizeof(uint32_t));
224
225    // write blob to indirect stream
226    if(blob_length)
227        out.indirect.write(reinterpret_cast<const char*>(&blob[0]), blob_length);
228
229    return out;
230}
231
232template <typename T>
233OutStreams& serializeParamValue(OutStreams& out, const T& value) {
234    out.elements.write(reinterpret_cast<const char*>(&value), sizeof(T));
235    return out;
236}
237
238OutStreams& serialize(TAG_INVALID_t&&, OutStreams& out, const KeyParameter&) {
239    // skip invalid entries.
240    return out;
241}
242template <typename T>
243OutStreams& serialize(T ttag, OutStreams& out, const KeyParameter& param) {
244    out.elements.write(reinterpret_cast<const char*>(&param.tag), sizeof(int32_t));
245    return serializeParamValue(out, accessTagValue(ttag, param));
246}
247
248template <typename... T>
249struct choose_serializer;
250template <typename... Tags>
251struct choose_serializer<MetaList<Tags...>> {
252    static OutStreams& serialize(OutStreams& out, const KeyParameter& param) {
253        return choose_serializer<Tags...>::serialize(out, param);
254    }
255};
256template <>
257struct choose_serializer<> {
258    static OutStreams& serialize(OutStreams& out, const KeyParameter&) {
259        return out;
260    }
261};
262template <TagType tag_type, Tag tag, typename... Tail>
263struct choose_serializer<TypedTag<tag_type, tag>, Tail...> {
264    static OutStreams& serialize(OutStreams& out, const KeyParameter& param) {
265        if (param.tag == tag) {
266            return keystore::serialize(TypedTag<tag_type, tag>(), out, param);
267        } else {
268            return choose_serializer<Tail...>::serialize(out, param);
269        }
270    }
271};
272
273OutStreams& serialize(OutStreams& out, const KeyParameter& param) {
274    return choose_serializer<all_tags_t>::serialize(out, param);
275}
276
277std::ostream& serialize(std::ostream& out, const std::vector<KeyParameter>& params) {
278    std::stringstream indirect;
279    std::stringstream elements;
280    OutStreams streams = { indirect, elements };
281    for (const auto& param: params) {
282        serialize(streams, param);
283    }
284    if (indirect.bad() || elements.bad()) {
285        out.setstate(std::ios_base::badbit);
286        return out;
287    }
288    auto pos = indirect.tellp();
289    if (pos < 0 || pos > std::numeric_limits<uint32_t>::max()) {
290        out.setstate(std::ios_base::badbit);
291        return out;
292    }
293    uint32_t indirect_size = pos;
294    pos = elements.tellp();
295    if (pos < 0 || pos > std::numeric_limits<uint32_t>::max()) {
296        out.setstate(std::ios_base::badbit);
297        return out;
298    }
299    uint32_t elements_size = pos;
300    uint32_t element_count = params.size();
301
302    out.write(reinterpret_cast<const char*>(&indirect_size), sizeof(uint32_t));
303
304    pos = out.tellp();
305    if (indirect_size)
306        out << indirect.rdbuf();
307    assert(out.tellp() - pos == indirect_size);
308
309    out.write(reinterpret_cast<const char*>(&element_count), sizeof(uint32_t));
310    out.write(reinterpret_cast<const char*>(&elements_size), sizeof(uint32_t));
311
312    pos = out.tellp();
313    if (elements_size)
314        out << elements.rdbuf();
315    assert(out.tellp() - pos == elements_size);
316
317    return out;
318}
319
320struct InStreams {
321    std::istream& indirect;
322    std::istream& elements;
323};
324
325InStreams& deserializeParamValue(InStreams& in, hidl_vec<uint8_t>* blob) {
326    uint32_t blob_length = 0;
327    uint32_t offset = 0;
328    in.elements.read(reinterpret_cast<char*>(&blob_length), sizeof(uint32_t));
329    blob->resize(blob_length);
330    in.elements.read(reinterpret_cast<char*>(&offset), sizeof(uint32_t));
331    in.indirect.seekg(offset);
332    in.indirect.read(reinterpret_cast<char*>(&(*blob)[0]), blob->size());
333    return in;
334}
335
336template <typename T>
337InStreams& deserializeParamValue(InStreams& in, T* value) {
338    in.elements.read(reinterpret_cast<char*>(value), sizeof(T));
339    return in;
340}
341
342InStreams& deserialize(TAG_INVALID_t&&, InStreams& in, KeyParameter*) {
343    // there should be no invalid KeyParamaters but if handle them as zero sized.
344    return in;
345}
346
347template <typename T>
348InStreams& deserialize(T&& ttag, InStreams& in, KeyParameter* param) {
349    return deserializeParamValue(in, &accessTagValue(ttag, *param));
350}
351
352template <typename... T>
353struct choose_deserializer;
354template <typename... Tags>
355struct choose_deserializer<MetaList<Tags...>> {
356    static InStreams& deserialize(InStreams& in, KeyParameter* param) {
357        return choose_deserializer<Tags...>::deserialize(in, param);
358    }
359};
360template <>
361struct choose_deserializer<> {
362    static InStreams& deserialize(InStreams& in, KeyParameter*) {
363        // encountered an unknown tag -> fail parsing
364        in.elements.setstate(std::ios_base::badbit);
365        return in;
366    }
367};
368template <TagType tag_type, Tag tag, typename... Tail>
369struct choose_deserializer<TypedTag<tag_type, tag>, Tail...> {
370    static InStreams& deserialize(InStreams& in, KeyParameter* param) {
371        if (param->tag == tag) {
372            return keystore::deserialize(TypedTag<tag_type, tag>(), in, param);
373        } else {
374            return choose_deserializer<Tail...>::deserialize(in, param);
375        }
376    }
377};
378
379InStreams& deserialize(InStreams& in, KeyParameter* param) {
380    in.elements.read(reinterpret_cast<char*>(&param->tag), sizeof(Tag));
381    return choose_deserializer<all_tags_t>::deserialize(in, param);
382}
383
384std::istream& deserialize(std::istream& in, std::vector<KeyParameter>* params) {
385    uint32_t indirect_size = 0;
386    in.read(reinterpret_cast<char*>(&indirect_size), sizeof(uint32_t));
387    std::string indirect_buffer(indirect_size, '\0');
388    if (indirect_buffer.size() != indirect_size) {
389        in.setstate(std::ios_base::badbit);
390        return in;
391    }
392    in.read(&indirect_buffer[0], indirect_buffer.size());
393
394    uint32_t element_count = 0;
395    in.read(reinterpret_cast<char*>(&element_count), sizeof(uint32_t));
396    uint32_t elements_size = 0;
397    in.read(reinterpret_cast<char*>(&elements_size), sizeof(uint32_t));
398
399    std::string elements_buffer(elements_size, '\0');
400    if(elements_buffer.size() != elements_size) {
401        in.setstate(std::ios_base::badbit);
402        return in;
403    }
404    in.read(&elements_buffer[0], elements_buffer.size());
405
406    if (in.bad()) return in;
407
408    // TODO write one-shot stream buffer to avoid copying here
409    std::stringstream indirect(indirect_buffer);
410    std::stringstream elements(elements_buffer);
411    InStreams streams = { indirect, elements };
412
413    params->resize(element_count);
414
415    for (uint32_t i = 0; i < element_count; ++i) {
416        deserialize(streams, &(*params)[i]);
417    }
418    return in;
419}
420void AuthorizationSet::Serialize(std::ostream* out) const {
421    serialize(*out, data_);
422}
423void AuthorizationSet::Deserialize(std::istream* in) {
424    deserialize(*in, &data_);
425}
426
427}  // namespace keystore
428