1#include <openssl/bn.h>
2#include <openssl/evp.h>
3#include <sparse/sparse.h>
4
5#undef NDEBUG
6
7#include <assert.h>
8#include <errno.h>
9#include <getopt.h>
10#include <fcntl.h>
11#include <inttypes.h>
12#include <limits.h>
13#include <stdbool.h>
14#include <stdio.h>
15#include <stdlib.h>
16#include <string.h>
17#include <unistd.h>
18
19#include <android-base/file.h>
20
21struct sparse_hash_ctx {
22    unsigned char *hashes;
23    const unsigned char *salt;
24    uint64_t salt_size;
25    uint64_t hash_size;
26    uint64_t block_size;
27    const unsigned char *zero_block_hash;
28    const EVP_MD *md;
29};
30
31#define div_round_up(x,y) (((x) + (y) - 1)/(y))
32
33#define round_up(x,y) (div_round_up(x,y)*(y))
34
35#define FATAL(x...) { \
36    fprintf(stderr, x); \
37    exit(1); \
38}
39
40size_t verity_tree_blocks(uint64_t data_size, size_t block_size, size_t hash_size,
41                          int level)
42{
43    size_t level_blocks = div_round_up(data_size, block_size);
44    int hashes_per_block = div_round_up(block_size, hash_size);
45
46    do {
47        level_blocks = div_round_up(level_blocks, hashes_per_block);
48    } while (level--);
49
50    return level_blocks;
51}
52
53int hash_block(const EVP_MD *md,
54               const unsigned char *block, size_t len,
55               const unsigned char *salt, size_t salt_len,
56               unsigned char *out, size_t *out_size)
57{
58    EVP_MD_CTX *mdctx;
59    unsigned int s;
60    int ret = 1;
61
62    mdctx = EVP_MD_CTX_create();
63    assert(mdctx);
64    ret &= EVP_DigestInit_ex(mdctx, md, NULL);
65    ret &= EVP_DigestUpdate(mdctx, salt, salt_len);
66    ret &= EVP_DigestUpdate(mdctx, block, len);
67    ret &= EVP_DigestFinal_ex(mdctx, out, &s);
68    EVP_MD_CTX_destroy(mdctx);
69    assert(ret == 1);
70    if (out_size) {
71        *out_size = s;
72    }
73    return 0;
74}
75
76int hash_blocks(const EVP_MD *md,
77                const unsigned char *in, size_t in_size,
78                unsigned char *out, size_t *out_size,
79                const unsigned char *salt, size_t salt_size,
80                size_t block_size)
81{
82    size_t s;
83    *out_size = 0;
84    for (size_t i = 0; i < in_size; i += block_size) {
85        hash_block(md, in + i, block_size, salt, salt_size, out, &s);
86        out += s;
87        *out_size += s;
88    }
89
90    return 0;
91}
92
93int hash_chunk(void *priv, const void *data, int len)
94{
95    struct sparse_hash_ctx *ctx = (struct sparse_hash_ctx *)priv;
96    assert(len % ctx->block_size == 0);
97    if (data) {
98        size_t s;
99        hash_blocks(ctx->md, (const unsigned char *)data, len,
100                    ctx->hashes, &s,
101                    ctx->salt, ctx->salt_size, ctx->block_size);
102        ctx->hashes += s;
103    } else {
104        for (size_t i = 0; i < (size_t)len; i += ctx->block_size) {
105            memcpy(ctx->hashes, ctx->zero_block_hash, ctx->hash_size);
106            ctx->hashes += ctx->hash_size;
107        }
108    }
109    return 0;
110}
111
112void usage(void)
113{
114    printf("usage: build_verity_tree [ <options> ] -s <size> | <data> <verity>\n"
115           "options:\n"
116           "  -a,--salt-str=<string>       set salt to <string>\n"
117           "  -A,--salt-hex=<hex digits>   set salt to <hex digits>\n"
118           "  -h                           show this help\n"
119           "  -s,--verity-size=<data size> print the size of the verity tree\n"
120           "  -v,                          enable verbose logging\n"
121           "  -S                           treat <data image> as a sparse file\n"
122        );
123}
124
125int main(int argc, char **argv)
126{
127    char *data_filename;
128    char *verity_filename;
129    unsigned char *salt = NULL;
130    size_t salt_size = 0;
131    bool sparse = false;
132    size_t block_size = 4096;
133    uint64_t calculate_size = 0;
134    bool verbose = false;
135
136    while (1) {
137        const static struct option long_options[] = {
138            {"salt-str", required_argument, 0, 'a'},
139            {"salt-hex", required_argument, 0, 'A'},
140            {"help", no_argument, 0, 'h'},
141            {"sparse", no_argument, 0, 'S'},
142            {"verity-size", required_argument, 0, 's'},
143            {"verbose", no_argument, 0, 'v'},
144            {NULL, 0, 0, 0}
145        };
146        int c = getopt_long(argc, argv, "a:A:hSs:v", long_options, NULL);
147        if (c < 0) {
148            break;
149        }
150
151        switch (c) {
152        case 'a':
153            salt_size = strlen(optarg);
154            salt = new unsigned char[salt_size]();
155            if (salt == NULL) {
156                FATAL("failed to allocate memory for salt\n");
157            }
158            memcpy(salt, optarg, salt_size);
159            break;
160        case 'A': {
161                BIGNUM *bn = NULL;
162                if(!BN_hex2bn(&bn, optarg)) {
163                    FATAL("failed to convert salt from hex\n");
164                }
165                salt_size = BN_num_bytes(bn);
166                salt = new unsigned char[salt_size]();
167                if (salt == NULL) {
168                    FATAL("failed to allocate memory for salt\n");
169                }
170                if((size_t)BN_bn2bin(bn, salt) != salt_size) {
171                    FATAL("failed to convert salt to bytes\n");
172                }
173            }
174            break;
175        case 'h':
176            usage();
177            return 1;
178        case 'S':
179            sparse = true;
180            break;
181        case 's': {
182                char* endptr;
183                errno = 0;
184                unsigned long long int inSize = strtoull(optarg, &endptr, 0);
185                if (optarg[0] == '\0' || *endptr != '\0' ||
186                        (errno == ERANGE && inSize == ULLONG_MAX)) {
187                    FATAL("invalid value of verity-size\n");
188                }
189                if (inSize > UINT64_MAX) {
190                    FATAL("invalid value of verity-size\n");
191                }
192                calculate_size = (uint64_t)inSize;
193            }
194            break;
195        case 'v':
196            verbose = true;
197            break;
198        case '?':
199            usage();
200            return 1;
201        default:
202            abort();
203        }
204    }
205
206    argc -= optind;
207    argv += optind;
208
209    const EVP_MD *md = EVP_sha256();
210    if (!md) {
211        FATAL("failed to get digest\n");
212    }
213
214    size_t hash_size = EVP_MD_size(md);
215    assert(hash_size * 2 < block_size);
216
217    if (!salt || !salt_size) {
218        salt_size = hash_size;
219        salt = new unsigned char[salt_size];
220        if (salt == NULL) {
221            FATAL("failed to allocate memory for salt\n");
222        }
223
224        int random_fd = open("/dev/urandom", O_RDONLY);
225        if (random_fd < 0) {
226            FATAL("failed to open /dev/urandom\n");
227        }
228
229        ssize_t ret = read(random_fd, salt, salt_size);
230        if (ret != (ssize_t)salt_size) {
231            FATAL("failed to read %zu bytes from /dev/urandom: %zd %d\n", salt_size, ret, errno);
232        }
233        close(random_fd);
234    }
235
236    if (calculate_size) {
237        if (argc != 0) {
238            usage();
239            return 1;
240        }
241        size_t verity_blocks = 0;
242        size_t level_blocks;
243        int levels = 0;
244        do {
245            level_blocks = verity_tree_blocks(calculate_size, block_size, hash_size, levels);
246            levels++;
247            verity_blocks += level_blocks;
248        } while (level_blocks > 1);
249
250        printf("%" PRIu64 "\n", (uint64_t)verity_blocks * block_size);
251        return 0;
252    }
253
254    if (argc != 2) {
255        usage();
256        return 1;
257    }
258
259    data_filename = argv[0];
260    verity_filename = argv[1];
261
262    int fd = open(data_filename, O_RDONLY);
263    if (fd < 0) {
264        FATAL("failed to open %s\n", data_filename);
265    }
266
267    struct sparse_file *file;
268    if (sparse) {
269        file = sparse_file_import(fd, false, false);
270    } else {
271        file = sparse_file_import_auto(fd, false, verbose);
272    }
273
274    if (!file) {
275        FATAL("failed to read file %s\n", data_filename);
276    }
277
278    int64_t len = sparse_file_len(file, false, false);
279    if (len % block_size != 0) {
280        FATAL("file size %" PRIu64 " is not a multiple of %zu bytes\n",
281                len, block_size);
282    }
283
284    int levels = 0;
285    size_t verity_blocks = 0;
286    size_t level_blocks;
287
288    do {
289        level_blocks = verity_tree_blocks(len, block_size, hash_size, levels);
290        levels++;
291        verity_blocks += level_blocks;
292    } while (level_blocks > 1);
293
294    unsigned char *verity_tree = new unsigned char[verity_blocks * block_size]();
295    unsigned char **verity_tree_levels = new unsigned char *[levels + 1]();
296    size_t *verity_tree_level_blocks = new size_t[levels]();
297    if (verity_tree == NULL || verity_tree_levels == NULL || verity_tree_level_blocks == NULL) {
298        FATAL("failed to allocate memory for verity tree\n");
299    }
300
301    unsigned char *ptr = verity_tree;
302    for (int i = levels - 1; i >= 0; i--) {
303        verity_tree_levels[i] = ptr;
304        verity_tree_level_blocks[i] = verity_tree_blocks(len, block_size, hash_size, i);
305        ptr += verity_tree_level_blocks[i] * block_size;
306    }
307    assert(ptr == verity_tree + verity_blocks * block_size);
308    assert(verity_tree_level_blocks[levels - 1] == 1);
309
310    unsigned char zero_block_hash[hash_size];
311    unsigned char zero_block[block_size];
312    memset(zero_block, 0, block_size);
313    hash_block(md, zero_block, block_size, salt, salt_size, zero_block_hash, NULL);
314
315    unsigned char root_hash[hash_size];
316    verity_tree_levels[levels] = root_hash;
317
318    struct sparse_hash_ctx ctx;
319    ctx.hashes = verity_tree_levels[0];
320    ctx.salt = salt;
321    ctx.salt_size = salt_size;
322    ctx.hash_size = hash_size;
323    ctx.block_size = block_size;
324    ctx.zero_block_hash = zero_block_hash;
325    ctx.md = md;
326
327    sparse_file_callback(file, false, false, hash_chunk, &ctx);
328
329    sparse_file_destroy(file);
330    close(fd);
331
332    for (int i = 0; i < levels; i++) {
333        size_t out_size;
334        hash_blocks(md,
335                verity_tree_levels[i], verity_tree_level_blocks[i] * block_size,
336                verity_tree_levels[i + 1], &out_size,
337                salt, salt_size, block_size);
338          if (i < levels - 1) {
339              assert(div_round_up(out_size, block_size) == verity_tree_level_blocks[i + 1]);
340          } else {
341              assert(out_size == hash_size);
342          }
343    }
344
345    for (size_t i = 0; i < hash_size; i++) {
346        printf("%02x", root_hash[i]);
347    }
348    printf(" ");
349    for (size_t i = 0; i < salt_size; i++) {
350        printf("%02x", salt[i]);
351    }
352    printf("\n");
353
354    fd = open(verity_filename, O_WRONLY|O_CREAT, 0666);
355    if (fd < 0) {
356        FATAL("failed to open output file '%s'\n", verity_filename);
357    }
358    if (!android::base::WriteFully(fd, verity_tree, verity_blocks * block_size)) {
359        FATAL("failed to write '%s'\n", verity_filename);
360    }
361    close(fd);
362
363    delete[] verity_tree_levels;
364    delete[] verity_tree_level_blocks;
365    delete[] verity_tree;
366    delete[] salt;
367}
368