authorization_set.cpp revision 62de26672193373972f2ce968b51cf8335f118f9
1/*
2 * Copyright (C) 2014 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <stdlib.h>
18#include <string.h>
19#include <stddef.h>
20
21#include <assert.h>
22
23#include "authorization_set.h"
24#include "google_keymaster_utils.h"
25
26namespace keymaster {
27
28static inline bool is_blob_tag(keymaster_tag_t tag) {
29    return (keymaster_tag_get_type(tag) == KM_BYTES || keymaster_tag_get_type(tag) == KM_BIGNUM);
30}
31
32const size_t STARTING_ELEMS_CAPACITY = 8;
33
34AuthorizationSet::AuthorizationSet(const AuthorizationSet& set)
35    : Serializable(), elems_(NULL), indirect_data_(NULL) {
36    Reinitialize(set.elems_, set.elems_size_);
37}
38
39AuthorizationSet::~AuthorizationSet() {
40    FreeData();
41}
42
43bool AuthorizationSet::reserve_elems(size_t count) {
44    if (is_valid() != OK)
45        return false;
46
47    if (count >= elems_capacity_) {
48        keymaster_key_param_t* new_elems = new keymaster_key_param_t[count];
49        if (new_elems == NULL) {
50            set_invalid(ALLOCATION_FAILURE);
51            return false;
52        }
53        memcpy(new_elems, elems_, sizeof(*elems_) * elems_size_);
54        delete[] elems_;
55        elems_ = new_elems;
56        elems_capacity_ = count;
57    }
58    return true;
59}
60
61bool AuthorizationSet::reserve_indirect(size_t length) {
62    if (is_valid() != OK)
63        return false;
64
65    if (length > indirect_data_capacity_) {
66        uint8_t* new_data = new uint8_t[length];
67        if (new_data == NULL) {
68            set_invalid(ALLOCATION_FAILURE);
69            return false;
70        }
71        memcpy(new_data, indirect_data_, indirect_data_size_);
72
73        // Fix up the data pointers to point into the new region.
74        for (size_t i = 0; i < elems_size_; ++i) {
75            if (is_blob_tag(elems_[i].tag))
76                elems_[i].blob.data = new_data + (elems_[i].blob.data - indirect_data_);
77        }
78        delete[] indirect_data_;
79        indirect_data_ = new_data;
80        indirect_data_capacity_ = length;
81    }
82    return true;
83}
84
85bool AuthorizationSet::Reinitialize(const keymaster_key_param_t* elems, const size_t count) {
86    FreeData();
87
88    if (!reserve_elems(count))
89        return false;
90
91    if (!reserve_indirect(ComputeIndirectDataSize(elems, count)))
92        return false;
93
94    memcpy(elems_, elems, sizeof(keymaster_key_param_t) * count);
95    elems_size_ = count;
96    CopyIndirectData();
97    error_ = OK;
98    return true;
99}
100
101void AuthorizationSet::set_invalid(Error error) {
102    FreeData();
103    error_ = error;
104}
105
106int AuthorizationSet::find(keymaster_tag_t tag, int begin) const {
107    if (is_valid() != OK)
108        return -1;
109
110    int i = ++begin;
111    while (i < (int)elems_size_ && elems_[i].tag != tag)
112        ++i;
113    if (i == (int)elems_size_)
114        return -1;
115    else
116        return i;
117}
118
119keymaster_key_param_t empty;
120keymaster_key_param_t AuthorizationSet::operator[](int at) const {
121    if (is_valid() == OK && at < (int)elems_size_) {
122        return elems_[at];
123    }
124    memset(&empty, 0, sizeof(empty));
125    return empty;
126}
127
128template <typename T> int comparator(const T& a, const T& b) {
129    if (a < b)
130        return -1;
131    else if (a > b)
132        return 1;
133    else
134        return 0;
135}
136
137static int param_comparator(const void* a, const void* b) {
138    const keymaster_key_param_t* lhs = static_cast<const keymaster_key_param_t*>(a);
139    const keymaster_key_param_t* rhs = static_cast<const keymaster_key_param_t*>(b);
140
141    if (lhs->tag < rhs->tag)
142        return -1;
143    else if (lhs->tag > rhs->tag)
144        return 1;
145    else
146        switch (keymaster_tag_get_type(lhs->tag)) {
147        default:
148        case KM_INVALID:
149            return 0;
150        case KM_ENUM:
151        case KM_ENUM_REP:
152            return comparator(lhs->enumerated, rhs->enumerated);
153        case KM_INT:
154        case KM_INT_REP:
155            return comparator(lhs->integer, rhs->integer);
156        case KM_LONG:
157            return comparator(lhs->long_integer, rhs->long_integer);
158        case KM_DATE:
159            return comparator(lhs->date_time, rhs->date_time);
160        case KM_BOOL:
161            return comparator(lhs->boolean, rhs->boolean);
162        case KM_BIGNUM:
163        case KM_BYTES: {
164            size_t min_len = lhs->blob.data_length;
165            if (rhs->blob.data_length < min_len)
166                min_len = rhs->blob.data_length;
167
168            if (lhs->blob.data_length == rhs->blob.data_length && min_len > 0)
169                return memcmp(lhs->blob.data, rhs->blob.data, min_len);
170            int cmp_result = memcmp(lhs->blob.data, rhs->blob.data, min_len);
171            if (cmp_result == 0) {
172                // The blobs are equal up to the length of the shortest (which may have length 0),
173                // so the shorter is less, the longer is greater and if they have the same length
174                // they're identical.
175                return comparator(lhs->blob.data_length, rhs->blob.data_length);
176            }
177            return cmp_result;
178        } break;
179        }
180}
181
182bool AuthorizationSet::push_back(const AuthorizationSet& set) {
183    if (is_valid() != OK)
184        return false;
185
186    if (!reserve_elems(elems_size_ + set.elems_size_))
187        return false;
188
189    if (!reserve_indirect(indirect_data_size_ + set.indirect_data_size_))
190        return false;
191
192    for (size_t i = 0; i < set.size(); ++i)
193        if (!push_back(set[i]))
194            return false;
195
196    return true;
197}
198
199bool AuthorizationSet::push_back(keymaster_key_param_t elem) {
200    if (is_valid() != OK)
201        return false;
202
203    if (elems_size_ >= elems_capacity_)
204        if (!reserve_elems(elems_capacity_ ? elems_capacity_ * 2 : STARTING_ELEMS_CAPACITY))
205            return false;
206
207    if (is_blob_tag(elem.tag)) {
208        if (indirect_data_capacity_ - indirect_data_size_ < elem.blob.data_length)
209            if (!reserve_indirect(2 * (indirect_data_capacity_ + elem.blob.data_length)))
210                return false;
211
212        memcpy(indirect_data_ + indirect_data_size_, elem.blob.data, elem.blob.data_length);
213        elem.blob.data = indirect_data_ + indirect_data_size_;
214        indirect_data_size_ += elem.blob.data_length;
215    }
216
217    elems_[elems_size_++] = elem;
218    return true;
219}
220
221static size_t serialized_size(const keymaster_key_param_t& param) {
222    switch (keymaster_tag_get_type(param.tag)) {
223    case KM_INVALID:
224    default:
225        return sizeof(uint32_t);
226    case KM_ENUM:
227    case KM_ENUM_REP:
228    case KM_INT:
229    case KM_INT_REP:
230        return sizeof(uint32_t) * 2;
231    case KM_LONG:
232    case KM_DATE:
233        return sizeof(uint32_t) + sizeof(uint64_t);
234    case KM_BOOL:
235        return sizeof(uint32_t) + 1;
236        break;
237    case KM_BIGNUM:
238    case KM_BYTES:
239        return sizeof(uint32_t) * 3;
240    }
241}
242
243static uint8_t* serialize(const keymaster_key_param_t& param, uint8_t* buf, const uint8_t* end,
244                          const uint8_t* indirect_base) {
245    buf = append_uint32_to_buf(buf, end, param.tag);
246    switch (keymaster_tag_get_type(param.tag)) {
247    case KM_INVALID:
248        break;
249    case KM_ENUM:
250    case KM_ENUM_REP:
251        buf = append_uint32_to_buf(buf, end, param.enumerated);
252        break;
253    case KM_INT:
254    case KM_INT_REP:
255        buf = append_uint32_to_buf(buf, end, param.integer);
256        break;
257    case KM_LONG:
258        buf = append_uint64_to_buf(buf, end, param.long_integer);
259        break;
260    case KM_DATE:
261        buf = append_uint64_to_buf(buf, end, param.date_time);
262        break;
263    case KM_BOOL:
264        if (buf < end)
265            *buf = static_cast<uint8_t>(param.boolean);
266        buf++;
267        break;
268    case KM_BIGNUM:
269    case KM_BYTES:
270        buf = append_uint32_to_buf(buf, end, param.blob.data_length);
271        buf = append_uint32_to_buf(buf, end, param.blob.data - indirect_base);
272        break;
273    }
274    return buf;
275}
276
277static bool deserialize(keymaster_key_param_t* param, const uint8_t** buf_ptr, const uint8_t* end,
278                        const uint8_t* indirect_base, const uint8_t* indirect_end) {
279    if (!copy_uint32_from_buf(buf_ptr, end, &param->tag))
280        return false;
281
282    switch (keymaster_tag_get_type(param->tag)) {
283    default:
284    case KM_INVALID:
285        return false;
286    case KM_ENUM:
287    case KM_ENUM_REP:
288        return copy_uint32_from_buf(buf_ptr, end, &param->enumerated);
289    case KM_INT:
290    case KM_INT_REP:
291        return copy_uint32_from_buf(buf_ptr, end, &param->integer);
292    case KM_LONG:
293        return copy_uint64_from_buf(buf_ptr, end, &param->long_integer);
294    case KM_DATE:
295        return copy_uint64_from_buf(buf_ptr, end, &param->date_time);
296        break;
297    case KM_BOOL:
298        if (*buf_ptr < end) {
299            param->boolean = static_cast<bool>(**buf_ptr);
300            (*buf_ptr)++;
301            return true;
302        }
303        return false;
304
305    case KM_BIGNUM:
306    case KM_BYTES: {
307        uint32_t offset;
308        if (!copy_uint32_from_buf(buf_ptr, end, &param->blob.data_length) ||
309            !copy_uint32_from_buf(buf_ptr, end, &offset))
310            return false;
311        if (static_cast<ptrdiff_t>(offset) > indirect_end - indirect_base ||
312            static_cast<ptrdiff_t>(offset + param->blob.data_length) > indirect_end - indirect_base)
313            return false;
314        param->blob.data = indirect_base + offset;
315        return true;
316    }
317    }
318}
319
320size_t AuthorizationSet::SerializedSizeOfElements() const {
321    size_t size = 0;
322    for (size_t i = 0; i < elems_size_; ++i) {
323        size += serialized_size(elems_[i]);
324    }
325    return size;
326}
327
328size_t AuthorizationSet::SerializedSize() const {
329    return sizeof(uint32_t) +           // Size of indirect_data_
330           indirect_data_size_ +        // indirect_data_
331           sizeof(uint32_t) +           // Number of elems_
332           sizeof(uint32_t) +           // Size of elems_
333           SerializedSizeOfElements();  // elems_
334}
335
336uint8_t* AuthorizationSet::Serialize(uint8_t* buf, const uint8_t* end) const {
337    buf = append_size_and_data_to_buf(buf, end, indirect_data_, indirect_data_size_);
338    buf = append_uint32_to_buf(buf, end, elems_size_);
339    buf = append_uint32_to_buf(buf, end, SerializedSizeOfElements());
340    for (size_t i = 0; i < elems_size_; ++i) {
341        buf = serialize(elems_[i], buf, end, indirect_data_);
342    }
343    return buf;
344}
345
346bool AuthorizationSet::DeserializeIndirectData(const uint8_t** buf_ptr, const uint8_t* end) {
347    if (!copy_size_and_data_from_buf(buf_ptr, end, &indirect_data_size_, &indirect_data_)) {
348        set_invalid(MALFORMED_DATA);
349        return false;
350    }
351    return true;
352}
353
354bool AuthorizationSet::DeserializeElementsData(const uint8_t** buf_ptr, const uint8_t* end) {
355    uint32_t elements_count;
356    uint32_t elements_size;
357    if (!copy_uint32_from_buf(buf_ptr, end, &elements_count) ||
358        !copy_uint32_from_buf(buf_ptr, end, &elements_size)) {
359        set_invalid(MALFORMED_DATA);
360        return false;
361    }
362
363    // Note that the following validation of elements_count is weak, but it prevents allocation of
364    // elems_ arrays which are clearly too large to be reasonable.
365    if (static_cast<ptrdiff_t>(elements_size) > end - *buf_ptr ||
366        elements_count * sizeof(uint32_t) > elements_size) {
367        set_invalid(MALFORMED_DATA);
368        return false;
369    }
370
371    if (!reserve_elems(elements_count))
372        return false;
373
374    uint8_t* indirect_end = indirect_data_ + indirect_data_size_;
375    const uint8_t* elements_end = *buf_ptr + elements_size;
376    for (size_t i = 0; i < elements_count; ++i) {
377        if (!deserialize(elems_ + i, buf_ptr, elements_end, indirect_data_, indirect_end)) {
378            set_invalid(MALFORMED_DATA);
379            return false;
380        }
381    }
382    elems_size_ = elements_count;
383    return true;
384}
385
386bool AuthorizationSet::Deserialize(const uint8_t** buf_ptr, const uint8_t* end) {
387    FreeData();
388
389    if (!DeserializeIndirectData(buf_ptr, end) || !DeserializeElementsData(buf_ptr, end))
390        return false;
391
392    if (indirect_data_size_ != ComputeIndirectDataSize(elems_, elems_size_)) {
393        set_invalid(MALFORMED_DATA);
394        return false;
395    }
396    return true;
397}
398
399void AuthorizationSet::FreeData() {
400    if (elems_ != NULL)
401        memset_s(elems_, 0, elems_size_ * sizeof(keymaster_key_param_t));
402    if (indirect_data_ != NULL)
403        memset_s(indirect_data_, 0, indirect_data_size_);
404
405    delete[] elems_;
406    delete[] indirect_data_;
407
408    elems_ = NULL;
409    indirect_data_ = NULL;
410    elems_size_ = 0;
411    elems_capacity_ = 0;
412    indirect_data_size_ = 0;
413    indirect_data_capacity_ = 0;
414    error_ = OK;
415}
416
417/* static */
418size_t AuthorizationSet::ComputeIndirectDataSize(const keymaster_key_param_t* elems, size_t count) {
419    size_t size = 0;
420    for (size_t i = 0; i < count; ++i) {
421        if (is_blob_tag(elems[i].tag)) {
422            size += elems[i].blob.data_length;
423        }
424    }
425    return size;
426}
427
428void AuthorizationSet::CopyIndirectData() {
429    memset_s(indirect_data_, 0, indirect_data_capacity_);
430
431    uint8_t* indirect_data_pos = indirect_data_;
432    for (size_t i = 0; i < elems_size_; ++i) {
433        assert(indirect_data_pos <= indirect_data_ + indirect_data_capacity_);
434        if (is_blob_tag(elems_[i].tag)) {
435            memcpy(indirect_data_pos, elems_[i].blob.data, elems_[i].blob.data_length);
436            elems_[i].blob.data = indirect_data_pos;
437            indirect_data_pos += elems_[i].blob.data_length;
438        }
439    }
440    assert(indirect_data_pos == indirect_data_ + indirect_data_capacity_);
441    indirect_data_size_ = indirect_data_pos - indirect_data_;
442}
443
444bool AuthorizationSet::GetTagValueEnum(keymaster_tag_t tag, uint32_t* val) const {
445    int pos = find(tag);
446    if (pos == -1) {
447        return false;
448    }
449    *val = elems_[pos].enumerated;
450    return true;
451}
452
453bool AuthorizationSet::GetTagValueEnumRep(keymaster_tag_t tag, size_t instance,
454                                          uint32_t* val) const {
455    size_t count = 0;
456    int pos = -1;
457    while (count <= instance) {
458        pos = find(tag, pos);
459        if (pos == -1) {
460            return false;
461        }
462        ++count;
463    }
464    *val = elems_[pos].enumerated;
465    return true;
466}
467
468bool AuthorizationSet::GetTagValueInt(keymaster_tag_t tag, uint32_t* val) const {
469    int pos = find(tag);
470    if (pos == -1) {
471        return false;
472    }
473    *val = elems_[pos].integer;
474    return true;
475}
476
477bool AuthorizationSet::GetTagValueIntRep(keymaster_tag_t tag, size_t instance,
478                                         uint32_t* val) const {
479    size_t count = 0;
480    int pos = -1;
481    while (count <= instance) {
482        pos = find(tag, pos);
483        if (pos == -1) {
484            return false;
485        }
486        ++count;
487    }
488    *val = elems_[pos].integer;
489    return true;
490}
491
492bool AuthorizationSet::GetTagValueLong(keymaster_tag_t tag, uint64_t* val) const {
493    int pos = find(tag);
494    if (pos == -1) {
495        return false;
496    }
497    *val = elems_[pos].long_integer;
498    return true;
499}
500
501bool AuthorizationSet::GetTagValueDate(keymaster_tag_t tag, uint64_t* val) const {
502    int pos = find(tag);
503    if (pos == -1) {
504        return false;
505    }
506    *val = elems_[pos].date_time;
507    return true;
508}
509
510bool AuthorizationSet::GetTagValueBlob(keymaster_tag_t tag, keymaster_blob_t* val) const {
511    int pos = find(tag);
512    if (pos == -1) {
513        return false;
514    }
515    *val = elems_[pos].blob;
516    return true;
517}
518
519}  // namespace keymaster
520