1/*
2 * Copyright (C) 2014 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "authorization_set.h"
18
19#include <assert.h>
20#include <istream>
21#include <limits>
22#include <ostream>
23#include <stddef.h>
24#include <stdlib.h>
25#include <string.h>
26
27#include <new>
28
29namespace android {
30namespace hardware {
31namespace keymaster {
32namespace V3_0 {
33
34inline bool keyParamLess(const KeyParameter& a, const KeyParameter& b) {
35    if (a.tag != b.tag) return a.tag < b.tag;
36    int retval;
37    switch (typeFromTag(a.tag)) {
38    case TagType::INVALID:
39    case TagType::BOOL:
40        return false;
41    case TagType::ENUM:
42    case TagType::ENUM_REP:
43    case TagType::UINT:
44    case TagType::UINT_REP:
45        return a.f.integer < b.f.integer;
46    case TagType::ULONG:
47    case TagType::ULONG_REP:
48        return a.f.longInteger < b.f.longInteger;
49    case TagType::DATE:
50        return a.f.dateTime < b.f.dateTime;
51    case TagType::BIGNUM:
52    case TagType::BYTES:
53        // Handle the empty cases.
54        if (a.blob.size() == 0) return b.blob.size() != 0;
55        if (b.blob.size() == 0) return false;
56
57        retval = memcmp(&a.blob[0], &b.blob[0], std::min(a.blob.size(), b.blob.size()));
58        if (retval == 0) {
59            // One is the prefix of the other, so the longer wins
60            return a.blob.size() < b.blob.size();
61        } else {
62            return retval < 0;
63        }
64    }
65    return false;
66}
67
68inline bool keyParamEqual(const KeyParameter& a, const KeyParameter& b) {
69    if (a.tag != b.tag) return false;
70
71    switch (typeFromTag(a.tag)) {
72    case TagType::INVALID:
73    case TagType::BOOL:
74        return true;
75    case TagType::ENUM:
76    case TagType::ENUM_REP:
77    case TagType::UINT:
78    case TagType::UINT_REP:
79        return a.f.integer == b.f.integer;
80    case TagType::ULONG:
81    case TagType::ULONG_REP:
82        return a.f.longInteger == b.f.longInteger;
83    case TagType::DATE:
84        return a.f.dateTime == b.f.dateTime;
85    case TagType::BIGNUM:
86    case TagType::BYTES:
87        if (a.blob.size() != b.blob.size()) return false;
88        return a.blob.size() == 0 || memcmp(&a.blob[0], &b.blob[0], a.blob.size()) == 0;
89    }
90    return false;
91}
92
93void AuthorizationSet::Sort() {
94    std::sort(data_.begin(), data_.end(), keyParamLess);
95}
96
97void AuthorizationSet::Deduplicate() {
98    if (data_.empty()) return;
99
100    Sort();
101    std::vector<KeyParameter> result;
102
103    auto curr = data_.begin();
104    auto prev = curr++;
105    for (; curr != data_.end(); ++prev, ++curr) {
106        if (prev->tag == Tag::INVALID) continue;
107
108        if (!keyParamEqual(*prev, *curr)) {
109            result.emplace_back(std::move(*prev));
110        }
111    }
112    result.emplace_back(std::move(*prev));
113
114    std::swap(data_, result);
115}
116
117void AuthorizationSet::Union(const AuthorizationSet& other) {
118    data_.insert(data_.end(), other.data_.begin(), other.data_.end());
119    Deduplicate();
120}
121
122void AuthorizationSet::Subtract(const AuthorizationSet& other) {
123    Deduplicate();
124
125    auto i = other.begin();
126    while (i != other.end()) {
127        int pos = -1;
128        do {
129            pos = find(i->tag, pos);
130            if (pos != -1 && keyParamEqual(*i, data_[pos])) {
131                data_.erase(data_.begin() + pos);
132                break;
133            }
134        } while (pos != -1);
135        ++i;
136    }
137}
138
139int AuthorizationSet::find(Tag tag, int begin) const {
140    auto iter = data_.begin() + (1 + begin);
141
142    while (iter != data_.end() && iter->tag != tag)
143        ++iter;
144
145    if (iter != data_.end()) return iter - data_.begin();
146    return -1;
147}
148
149bool AuthorizationSet::erase(int index) {
150    auto pos = data_.begin() + index;
151    if (pos != data_.end()) {
152        data_.erase(pos);
153        return true;
154    }
155    return false;
156}
157
158KeyParameter& AuthorizationSet::operator[](int at) {
159    return data_[at];
160}
161
162const KeyParameter& AuthorizationSet::operator[](int at) const {
163    return data_[at];
164}
165
166void AuthorizationSet::Clear() {
167    data_.clear();
168}
169
170size_t AuthorizationSet::GetTagCount(Tag tag) const {
171    size_t count = 0;
172    for (int pos = -1; (pos = find(tag, pos)) != -1;)
173        ++count;
174    return count;
175}
176
177NullOr<const KeyParameter&> AuthorizationSet::GetEntry(Tag tag) const {
178    int pos = find(tag);
179    if (pos == -1) return {};
180    return data_[pos];
181}
182
183/**
184 * Persistent format is:
185 * | 32 bit indirect_size         |
186 * --------------------------------
187 * | indirect_size bytes of data  | this is where the blob data is stored
188 * --------------------------------
189 * | 32 bit element_count         | number of entries
190 * | 32 bit elements_size         | total bytes used by entries (entries have variable length)
191 * --------------------------------
192 * | elementes_size bytes of data | where the elements are stored
193 */
194
195/**
196 * Persistent format of blobs and bignums:
197 * | 32 bit tag             |
198 * | 32 bit blob_length     |
199 * | 32 bit indirect_offset |
200 */
201
202struct OutStreams {
203    std::ostream& indirect;
204    std::ostream& elements;
205};
206
207OutStreams& serializeParamValue(OutStreams& out, const hidl_vec<uint8_t>& blob) {
208    uint32_t buffer;
209
210    // write blob_length
211    auto blob_length = blob.size();
212    if (blob_length > std::numeric_limits<uint32_t>::max()) {
213        out.elements.setstate(std::ios_base::badbit);
214        return out;
215    }
216    buffer = blob_length;
217    out.elements.write(reinterpret_cast<const char*>(&buffer), sizeof(uint32_t));
218
219    // write indirect_offset
220    auto offset = out.indirect.tellp();
221    if (offset < 0 || offset > std::numeric_limits<uint32_t>::max() ||
222        static_cast<uint32_t>((std::numeric_limits<uint32_t>::max() - offset)) <
223            blob_length) {  // overflow check
224        out.elements.setstate(std::ios_base::badbit);
225        return out;
226    }
227    buffer = offset;
228    out.elements.write(reinterpret_cast<const char*>(&buffer), sizeof(uint32_t));
229
230    // write blob to indirect stream
231    if (blob_length) out.indirect.write(reinterpret_cast<const char*>(&blob[0]), blob_length);
232
233    return out;
234}
235
236template <typename T> OutStreams& serializeParamValue(OutStreams& out, const T& value) {
237    out.elements.write(reinterpret_cast<const char*>(&value), sizeof(T));
238    return out;
239}
240
241OutStreams& serialize(TAG_INVALID_t&&, OutStreams& out, const KeyParameter&) {
242    // skip invalid entries.
243    return out;
244}
245template <typename T> OutStreams& serialize(T ttag, OutStreams& out, const KeyParameter& param) {
246    out.elements.write(reinterpret_cast<const char*>(&param.tag), sizeof(int32_t));
247    return serializeParamValue(out, accessTagValue(ttag, param));
248}
249
250template <typename... T> struct choose_serializer;
251template <typename... Tags> struct choose_serializer<MetaList<Tags...>> {
252    static OutStreams& serialize(OutStreams& out, const KeyParameter& param) {
253        return choose_serializer<Tags...>::serialize(out, param);
254    }
255};
256template <> struct choose_serializer<> {
257    static OutStreams& serialize(OutStreams& out, const KeyParameter&) { return out; }
258};
259template <TagType tag_type, Tag tag, typename... Tail>
260struct choose_serializer<TypedTag<tag_type, tag>, Tail...> {
261    static OutStreams& serialize(OutStreams& out, const KeyParameter& param) {
262        if (param.tag == tag) {
263            return V3_0::serialize(TypedTag<tag_type, tag>(), out, param);
264        } else {
265            return choose_serializer<Tail...>::serialize(out, param);
266        }
267    }
268};
269
270OutStreams& serialize(OutStreams& out, const KeyParameter& param) {
271    return choose_serializer<all_tags_t>::serialize(out, param);
272}
273
274std::ostream& serialize(std::ostream& out, const std::vector<KeyParameter>& params) {
275    std::stringstream indirect;
276    std::stringstream elements;
277    OutStreams streams = {indirect, elements};
278    for (const auto& param : params) {
279        serialize(streams, param);
280    }
281    if (indirect.bad() || elements.bad()) {
282        out.setstate(std::ios_base::badbit);
283        return out;
284    }
285    auto pos = indirect.tellp();
286    if (pos < 0 || pos > std::numeric_limits<uint32_t>::max()) {
287        out.setstate(std::ios_base::badbit);
288        return out;
289    }
290    uint32_t indirect_size = pos;
291    pos = elements.tellp();
292    if (pos < 0 || pos > std::numeric_limits<uint32_t>::max()) {
293        out.setstate(std::ios_base::badbit);
294        return out;
295    }
296    uint32_t elements_size = pos;
297    uint32_t element_count = params.size();
298
299    out.write(reinterpret_cast<const char*>(&indirect_size), sizeof(uint32_t));
300
301    pos = out.tellp();
302    if (indirect_size) out << indirect.rdbuf();
303    assert(out.tellp() - pos == indirect_size);
304
305    out.write(reinterpret_cast<const char*>(&element_count), sizeof(uint32_t));
306    out.write(reinterpret_cast<const char*>(&elements_size), sizeof(uint32_t));
307
308    pos = out.tellp();
309    if (elements_size) out << elements.rdbuf();
310    assert(out.tellp() - pos == elements_size);
311
312    return out;
313}
314
315struct InStreams {
316    std::istream& indirect;
317    std::istream& elements;
318};
319
320InStreams& deserializeParamValue(InStreams& in, hidl_vec<uint8_t>* blob) {
321    uint32_t blob_length = 0;
322    uint32_t offset = 0;
323    in.elements.read(reinterpret_cast<char*>(&blob_length), sizeof(uint32_t));
324    blob->resize(blob_length);
325    in.elements.read(reinterpret_cast<char*>(&offset), sizeof(uint32_t));
326    in.indirect.seekg(offset);
327    in.indirect.read(reinterpret_cast<char*>(&(*blob)[0]), blob->size());
328    return in;
329}
330
331template <typename T> InStreams& deserializeParamValue(InStreams& in, T* value) {
332    in.elements.read(reinterpret_cast<char*>(value), sizeof(T));
333    return in;
334}
335
336InStreams& deserialize(TAG_INVALID_t&&, InStreams& in, KeyParameter*) {
337    // there should be no invalid KeyParamaters but if handle them as zero sized.
338    return in;
339}
340
341template <typename T> InStreams& deserialize(T&& ttag, InStreams& in, KeyParameter* param) {
342    return deserializeParamValue(in, &accessTagValue(ttag, *param));
343}
344
345template <typename... T> struct choose_deserializer;
346template <typename... Tags> struct choose_deserializer<MetaList<Tags...>> {
347    static InStreams& deserialize(InStreams& in, KeyParameter* param) {
348        return choose_deserializer<Tags...>::deserialize(in, param);
349    }
350};
351template <> struct choose_deserializer<> {
352    static InStreams& deserialize(InStreams& in, KeyParameter*) {
353        // encountered an unknown tag -> fail parsing
354        in.elements.setstate(std::ios_base::badbit);
355        return in;
356    }
357};
358template <TagType tag_type, Tag tag, typename... Tail>
359struct choose_deserializer<TypedTag<tag_type, tag>, Tail...> {
360    static InStreams& deserialize(InStreams& in, KeyParameter* param) {
361        if (param->tag == tag) {
362            return V3_0::deserialize(TypedTag<tag_type, tag>(), in, param);
363        } else {
364            return choose_deserializer<Tail...>::deserialize(in, param);
365        }
366    }
367};
368
369InStreams& deserialize(InStreams& in, KeyParameter* param) {
370    in.elements.read(reinterpret_cast<char*>(&param->tag), sizeof(Tag));
371    return choose_deserializer<all_tags_t>::deserialize(in, param);
372}
373
374std::istream& deserialize(std::istream& in, std::vector<KeyParameter>* params) {
375    uint32_t indirect_size = 0;
376    in.read(reinterpret_cast<char*>(&indirect_size), sizeof(uint32_t));
377    std::string indirect_buffer(indirect_size, '\0');
378    if (indirect_buffer.size() != indirect_size) {
379        in.setstate(std::ios_base::badbit);
380        return in;
381    }
382    in.read(&indirect_buffer[0], indirect_buffer.size());
383
384    uint32_t element_count = 0;
385    in.read(reinterpret_cast<char*>(&element_count), sizeof(uint32_t));
386    uint32_t elements_size = 0;
387    in.read(reinterpret_cast<char*>(&elements_size), sizeof(uint32_t));
388
389    std::string elements_buffer(elements_size, '\0');
390    if (elements_buffer.size() != elements_size) {
391        in.setstate(std::ios_base::badbit);
392        return in;
393    }
394    in.read(&elements_buffer[0], elements_buffer.size());
395
396    if (in.bad()) return in;
397
398    // TODO write one-shot stream buffer to avoid copying here
399    std::stringstream indirect(indirect_buffer);
400    std::stringstream elements(elements_buffer);
401    InStreams streams = {indirect, elements};
402
403    params->resize(element_count);
404
405    for (uint32_t i = 0; i < element_count; ++i) {
406        deserialize(streams, &(*params)[i]);
407    }
408    return in;
409}
410
411void AuthorizationSet::Serialize(std::ostream* out) const {
412    serialize(*out, data_);
413}
414
415void AuthorizationSet::Deserialize(std::istream* in) {
416    deserialize(*in, &data_);
417}
418
419}  // namespace V3_0
420}  // namespace keymaster
421}  // namespace hardware
422}  // namespace android
423