1#include <openssl/bn.h>
2#include <openssl/evp.h>
3#include <sparse/sparse.h>
4
5#undef NDEBUG
6
7#include <assert.h>
8#include <errno.h>
9#include <getopt.h>
10#include <fcntl.h>
11#include <inttypes.h>
12#include <limits.h>
13#include <stdbool.h>
14#include <stdio.h>
15#include <stdlib.h>
16#include <string.h>
17#include <unistd.h>
18
19struct sparse_hash_ctx {
20    unsigned char *hashes;
21    const unsigned char *salt;
22    uint64_t salt_size;
23    uint64_t hash_size;
24    uint64_t block_size;
25    const unsigned char *zero_block_hash;
26    const EVP_MD *md;
27};
28
29#define div_round_up(x,y) (((x) + (y) - 1)/(y))
30
31#define round_up(x,y) (div_round_up(x,y)*(y))
32
33#define FATAL(x...) { \
34    fprintf(stderr, x); \
35    exit(1); \
36}
37
38size_t verity_tree_blocks(uint64_t data_size, size_t block_size, size_t hash_size,
39                          int level)
40{
41    size_t level_blocks = div_round_up(data_size, block_size);
42    int hashes_per_block = div_round_up(block_size, hash_size);
43
44    do {
45        level_blocks = div_round_up(level_blocks, hashes_per_block);
46    } while (level--);
47
48    return level_blocks;
49}
50
51int hash_block(const EVP_MD *md,
52               const unsigned char *block, size_t len,
53               const unsigned char *salt, size_t salt_len,
54               unsigned char *out, size_t *out_size)
55{
56    EVP_MD_CTX *mdctx;
57    unsigned int s;
58    int ret = 1;
59
60    mdctx = EVP_MD_CTX_create();
61    assert(mdctx);
62    ret &= EVP_DigestInit_ex(mdctx, md, NULL);
63    ret &= EVP_DigestUpdate(mdctx, salt, salt_len);
64    ret &= EVP_DigestUpdate(mdctx, block, len);
65    ret &= EVP_DigestFinal_ex(mdctx, out, &s);
66    EVP_MD_CTX_destroy(mdctx);
67    assert(ret == 1);
68    if (out_size) {
69        *out_size = s;
70    }
71    return 0;
72}
73
74int hash_blocks(const EVP_MD *md,
75                const unsigned char *in, size_t in_size,
76                unsigned char *out, size_t *out_size,
77                const unsigned char *salt, size_t salt_size,
78                size_t block_size)
79{
80    size_t s;
81    *out_size = 0;
82    for (size_t i = 0; i < in_size; i += block_size) {
83        hash_block(md, in + i, block_size, salt, salt_size, out, &s);
84        out += s;
85        *out_size += s;
86    }
87
88    return 0;
89}
90
91int hash_chunk(void *priv, const void *data, int len)
92{
93    struct sparse_hash_ctx *ctx = (struct sparse_hash_ctx *)priv;
94    assert(len % ctx->block_size == 0);
95    if (data) {
96        size_t s;
97        hash_blocks(ctx->md, (const unsigned char *)data, len,
98                    ctx->hashes, &s,
99                    ctx->salt, ctx->salt_size, ctx->block_size);
100        ctx->hashes += s;
101    } else {
102        for (size_t i = 0; i < (size_t)len; i += ctx->block_size) {
103            memcpy(ctx->hashes, ctx->zero_block_hash, ctx->hash_size);
104            ctx->hashes += ctx->hash_size;
105        }
106    }
107    return 0;
108}
109
110void usage(void)
111{
112    printf("usage: build_verity_tree [ <options> ] -s <size> | <data> <verity>\n"
113           "options:\n"
114           "  -a,--salt-str=<string>       set salt to <string>\n"
115           "  -A,--salt-hex=<hex digits>   set salt to <hex digits>\n"
116           "  -h                           show this help\n"
117           "  -s,--verity-size=<data size> print the size of the verity tree\n"
118           "  -v,                          enable verbose logging\n"
119           "  -S                           treat <data image> as a sparse file\n"
120        );
121}
122
123int main(int argc, char **argv)
124{
125    char *data_filename;
126    char *verity_filename;
127    unsigned char *salt = NULL;
128    size_t salt_size = 0;
129    bool sparse = false;
130    size_t block_size = 4096;
131    uint64_t calculate_size = 0;
132    bool verbose = false;
133
134    while (1) {
135        const static struct option long_options[] = {
136            {"salt-str", required_argument, 0, 'a'},
137            {"salt-hex", required_argument, 0, 'A'},
138            {"help", no_argument, 0, 'h'},
139            {"sparse", no_argument, 0, 'S'},
140            {"verity-size", required_argument, 0, 's'},
141            {"verbose", no_argument, 0, 'v'},
142            {NULL, 0, 0, 0}
143        };
144        int c = getopt_long(argc, argv, "a:A:hSs:v", long_options, NULL);
145        if (c < 0) {
146            break;
147        }
148
149        switch (c) {
150        case 'a':
151            salt_size = strlen(optarg);
152            salt = new unsigned char[salt_size]();
153            if (salt == NULL) {
154                FATAL("failed to allocate memory for salt\n");
155            }
156            memcpy(salt, optarg, salt_size);
157            break;
158        case 'A': {
159                BIGNUM *bn = NULL;
160                if(!BN_hex2bn(&bn, optarg)) {
161                    FATAL("failed to convert salt from hex\n");
162                }
163                salt_size = BN_num_bytes(bn);
164                salt = new unsigned char[salt_size]();
165                if (salt == NULL) {
166                    FATAL("failed to allocate memory for salt\n");
167                }
168                if((size_t)BN_bn2bin(bn, salt) != salt_size) {
169                    FATAL("failed to convert salt to bytes\n");
170                }
171            }
172            break;
173        case 'h':
174            usage();
175            return 1;
176        case 'S':
177            sparse = true;
178            break;
179        case 's': {
180                char* endptr;
181                errno = 0;
182                unsigned long long int inSize = strtoull(optarg, &endptr, 0);
183                if (optarg[0] == '\0' || *endptr != '\0' ||
184                        (errno == ERANGE && inSize == ULLONG_MAX)) {
185                    FATAL("invalid value of verity-size\n");
186                }
187                if (inSize > UINT64_MAX) {
188                    FATAL("invalid value of verity-size\n");
189                }
190                calculate_size = (uint64_t)inSize;
191            }
192            break;
193        case 'v':
194            verbose = true;
195            break;
196        case '?':
197            usage();
198            return 1;
199        default:
200            abort();
201        }
202    }
203
204    argc -= optind;
205    argv += optind;
206
207    const EVP_MD *md = EVP_sha256();
208    if (!md) {
209        FATAL("failed to get digest\n");
210    }
211
212    size_t hash_size = EVP_MD_size(md);
213    assert(hash_size * 2 < block_size);
214
215    if (!salt || !salt_size) {
216        salt_size = hash_size;
217        salt = new unsigned char[salt_size];
218        if (salt == NULL) {
219            FATAL("failed to allocate memory for salt\n");
220        }
221
222        int random_fd = open("/dev/urandom", O_RDONLY);
223        if (random_fd < 0) {
224            FATAL("failed to open /dev/urandom\n");
225        }
226
227        ssize_t ret = read(random_fd, salt, salt_size);
228        if (ret != (ssize_t)salt_size) {
229            FATAL("failed to read %zu bytes from /dev/urandom: %zd %d\n", salt_size, ret, errno);
230        }
231        close(random_fd);
232    }
233
234    if (calculate_size) {
235        if (argc != 0) {
236            usage();
237            return 1;
238        }
239        size_t verity_blocks = 0;
240        size_t level_blocks;
241        int levels = 0;
242        do {
243            level_blocks = verity_tree_blocks(calculate_size, block_size, hash_size, levels);
244            levels++;
245            verity_blocks += level_blocks;
246        } while (level_blocks > 1);
247
248        printf("%" PRIu64 "\n", (uint64_t)verity_blocks * block_size);
249        return 0;
250    }
251
252    if (argc != 2) {
253        usage();
254        return 1;
255    }
256
257    data_filename = argv[0];
258    verity_filename = argv[1];
259
260    int fd = open(data_filename, O_RDONLY);
261    if (fd < 0) {
262        FATAL("failed to open %s\n", data_filename);
263    }
264
265    struct sparse_file *file;
266    if (sparse) {
267        file = sparse_file_import(fd, false, false);
268    } else {
269        file = sparse_file_import_auto(fd, false, verbose);
270    }
271
272    if (!file) {
273        FATAL("failed to read file %s\n", data_filename);
274    }
275
276    int64_t len = sparse_file_len(file, false, false);
277    if (len % block_size != 0) {
278        FATAL("file size %" PRIu64 " is not a multiple of %zu bytes\n",
279                len, block_size);
280    }
281
282    int levels = 0;
283    size_t verity_blocks = 0;
284    size_t level_blocks;
285
286    do {
287        level_blocks = verity_tree_blocks(len, block_size, hash_size, levels);
288        levels++;
289        verity_blocks += level_blocks;
290    } while (level_blocks > 1);
291
292    unsigned char *verity_tree = new unsigned char[verity_blocks * block_size]();
293    unsigned char **verity_tree_levels = new unsigned char *[levels + 1]();
294    size_t *verity_tree_level_blocks = new size_t[levels]();
295    if (verity_tree == NULL || verity_tree_levels == NULL || verity_tree_level_blocks == NULL) {
296        FATAL("failed to allocate memory for verity tree\n");
297    }
298
299    unsigned char *ptr = verity_tree;
300    for (int i = levels - 1; i >= 0; i--) {
301        verity_tree_levels[i] = ptr;
302        verity_tree_level_blocks[i] = verity_tree_blocks(len, block_size, hash_size, i);
303        ptr += verity_tree_level_blocks[i] * block_size;
304    }
305    assert(ptr == verity_tree + verity_blocks * block_size);
306    assert(verity_tree_level_blocks[levels - 1] == 1);
307
308    unsigned char zero_block_hash[hash_size];
309    unsigned char zero_block[block_size];
310    memset(zero_block, 0, block_size);
311    hash_block(md, zero_block, block_size, salt, salt_size, zero_block_hash, NULL);
312
313    unsigned char root_hash[hash_size];
314    verity_tree_levels[levels] = root_hash;
315
316    struct sparse_hash_ctx ctx;
317    ctx.hashes = verity_tree_levels[0];
318    ctx.salt = salt;
319    ctx.salt_size = salt_size;
320    ctx.hash_size = hash_size;
321    ctx.block_size = block_size;
322    ctx.zero_block_hash = zero_block_hash;
323    ctx.md = md;
324
325    sparse_file_callback(file, false, false, hash_chunk, &ctx);
326
327    sparse_file_destroy(file);
328    close(fd);
329
330    for (int i = 0; i < levels; i++) {
331        size_t out_size;
332        hash_blocks(md,
333                verity_tree_levels[i], verity_tree_level_blocks[i] * block_size,
334                verity_tree_levels[i + 1], &out_size,
335                salt, salt_size, block_size);
336          if (i < levels - 1) {
337              assert(div_round_up(out_size, block_size) == verity_tree_level_blocks[i + 1]);
338          } else {
339              assert(out_size == hash_size);
340          }
341    }
342
343    for (size_t i = 0; i < hash_size; i++) {
344        printf("%02x", root_hash[i]);
345    }
346    printf(" ");
347    for (size_t i = 0; i < salt_size; i++) {
348        printf("%02x", salt[i]);
349    }
350    printf("\n");
351
352    fd = open(verity_filename, O_WRONLY|O_CREAT, 0666);
353    if (fd < 0) {
354        FATAL("failed to open output file '%s'\n", verity_filename);
355    }
356    write(fd, verity_tree, verity_blocks * block_size);
357    close(fd);
358
359    delete[] verity_tree_levels;
360    delete[] verity_tree_level_blocks;
361    delete[] verity_tree;
362    delete[] salt;
363}
364