1/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#undef NDEBUG
18#define _LARGEFILE64_SOURCE
19
20extern "C" {
21    #include <fec.h>
22}
23
24#include <assert.h>
25#include <android-base/file.h>
26#include <errno.h>
27#include <fcntl.h>
28#include <getopt.h>
29#include <openssl/sha.h>
30#include <pthread.h>
31#include <stdbool.h>
32#include <stdlib.h>
33#include <string.h>
34#include <sys/ioctl.h>
35#include <sys/mman.h>
36#include <sparse/sparse.h>
37#include "image.h"
38
39#if defined(__linux__)
40    #include <linux/fs.h>
41#elif defined(__APPLE__)
42    #include <sys/disk.h>
43    #define BLKGETSIZE64 DKIOCGETBLOCKCOUNT
44    #define O_LARGEFILE 0
45#endif
46
47void image_init(image *ctx)
48{
49    memset(ctx, 0, sizeof(*ctx));
50}
51
52void image_free(image *ctx)
53{
54    assert(ctx->input == ctx->output);
55
56    if (ctx->input) {
57        delete[] ctx->input;
58    }
59
60    if (ctx->fec) {
61        delete[] ctx->fec;
62    }
63
64    image_init(ctx);
65}
66
67static void calculate_rounds(uint64_t size, image *ctx)
68{
69    if (!size) {
70        FATAL("empty file?\n");
71    } else if (size % FEC_BLOCKSIZE) {
72        FATAL("file size %" PRIu64 " is not a multiple of %u bytes\n",
73            size, FEC_BLOCKSIZE);
74    }
75
76    ctx->inp_size = size;
77    ctx->blocks = fec_div_round_up(ctx->inp_size, FEC_BLOCKSIZE);
78    ctx->rounds = fec_div_round_up(ctx->blocks, ctx->rs_n);
79}
80
81static int process_chunk(void *priv, const void *data, int len)
82{
83    image *ctx = (image *)priv;
84    assert(len % FEC_BLOCKSIZE == 0);
85
86    if (data) {
87        memcpy(&ctx->input[ctx->pos], data, len);
88    }
89
90    ctx->pos += len;
91    return 0;
92}
93
94static void file_image_load(const std::vector<int>& fds, image *ctx)
95{
96    uint64_t size = 0;
97    std::vector<struct sparse_file *> files;
98
99    for (auto fd : fds) {
100        uint64_t len = 0;
101        struct sparse_file *file;
102
103        if (ctx->sparse) {
104            file = sparse_file_import(fd, false, false);
105        } else {
106            file = sparse_file_import_auto(fd, false, ctx->verbose);
107        }
108
109        if (!file) {
110            FATAL("failed to read file %s\n", ctx->fec_filename);
111        }
112
113        len = sparse_file_len(file, false, false);
114        files.push_back(file);
115
116        size += len;
117    }
118
119    calculate_rounds(size, ctx);
120
121    if (ctx->verbose) {
122        INFO("allocating %" PRIu64 " bytes of memory\n", ctx->inp_size);
123    }
124
125    ctx->input = new uint8_t[ctx->inp_size];
126
127    if (!ctx->input) {
128        FATAL("failed to allocate memory\n");
129    }
130
131    memset(ctx->input, 0, ctx->inp_size);
132    ctx->output = ctx->input;
133    ctx->pos = 0;
134
135    for (auto file : files) {
136        sparse_file_callback(file, false, false, process_chunk, ctx);
137        sparse_file_destroy(file);
138    }
139
140    for (auto fd : fds) {
141        close(fd);
142    }
143}
144
145bool image_load(const std::vector<std::string>& filenames, image *ctx)
146{
147    assert(ctx->roots > 0 && ctx->roots < FEC_RSM);
148    ctx->rs_n = FEC_RSM - ctx->roots;
149
150    int flags = O_RDONLY;
151
152    if (ctx->inplace) {
153        flags = O_RDWR;
154    }
155
156    std::vector<int> fds;
157
158    for (auto fn : filenames) {
159        int fd = TEMP_FAILURE_RETRY(open(fn.c_str(), flags | O_LARGEFILE));
160
161        if (fd < 0) {
162            FATAL("failed to open file '%s': %s\n", fn.c_str(), strerror(errno));
163        }
164
165        fds.push_back(fd);
166    }
167
168    file_image_load(fds, ctx);
169
170    return true;
171}
172
173bool image_save(const std::string& filename, image *ctx)
174{
175    /* TODO: support saving as a sparse file */
176    int fd = TEMP_FAILURE_RETRY(open(filename.c_str(),
177                O_WRONLY | O_CREAT | O_TRUNC, 0666));
178
179    if (fd < 0) {
180        FATAL("failed to open file '%s: %s'\n", filename.c_str(),
181            strerror(errno));
182    }
183
184    if (!android::base::WriteFully(fd, ctx->output, ctx->inp_size)) {
185        FATAL("failed to write to output: %s\n", strerror(errno));
186    }
187
188    close(fd);
189    return true;
190}
191
192bool image_ecc_new(const std::string& filename, image *ctx)
193{
194    assert(ctx->rounds > 0); /* image_load should be called first */
195
196    ctx->fec_filename = filename.c_str();
197    ctx->fec_size = ctx->rounds * ctx->roots * FEC_BLOCKSIZE;
198
199    if (ctx->verbose) {
200        INFO("allocating %u bytes of memory\n", ctx->fec_size);
201    }
202
203    ctx->fec = new uint8_t[ctx->fec_size];
204
205    if (!ctx->fec) {
206        FATAL("failed to allocate %u bytes\n", ctx->fec_size);
207    }
208
209    return true;
210}
211
212bool image_ecc_load(const std::string& filename, image *ctx)
213{
214    int fd = TEMP_FAILURE_RETRY(open(filename.c_str(), O_RDONLY));
215
216    if (fd < 0) {
217        FATAL("failed to open file '%s': %s\n", filename.c_str(),
218            strerror(errno));
219    }
220
221    if (lseek64(fd, -FEC_BLOCKSIZE, SEEK_END) < 0) {
222        FATAL("failed to seek to header in '%s': %s\n", filename.c_str(),
223            strerror(errno));
224    }
225
226    assert(sizeof(fec_header) <= FEC_BLOCKSIZE);
227
228    uint8_t header[FEC_BLOCKSIZE];
229    fec_header *p = (fec_header *)header;
230
231    if (!android::base::ReadFully(fd, header, sizeof(header))) {
232        FATAL("failed to read %zd bytes from '%s': %s\n", sizeof(header),
233            filename.c_str(), strerror(errno));
234    }
235
236    if (p->magic != FEC_MAGIC) {
237        FATAL("invalid magic in '%s': %08x\n", filename.c_str(), p->magic);
238    }
239
240    if (p->version != FEC_VERSION) {
241        FATAL("unsupported version in '%s': %u\n", filename.c_str(),
242            p->version);
243    }
244
245    if (p->size != sizeof(fec_header)) {
246        FATAL("unexpected header size in '%s': %u\n", filename.c_str(),
247            p->size);
248    }
249
250    if (p->roots == 0 || p->roots >= FEC_RSM) {
251        FATAL("invalid roots in '%s': %u\n", filename.c_str(), p->roots);
252    }
253
254    if (p->fec_size % p->roots || p->fec_size % FEC_BLOCKSIZE) {
255        FATAL("invalid length in '%s': %u\n", filename.c_str(), p->fec_size);
256    }
257
258    ctx->roots = (int)p->roots;
259    ctx->rs_n = FEC_RSM - ctx->roots;
260
261    calculate_rounds(p->inp_size, ctx);
262
263    if (!image_ecc_new(filename, ctx)) {
264        FATAL("failed to allocate ecc\n");
265    }
266
267    if (p->fec_size != ctx->fec_size) {
268        FATAL("inconsistent header in '%s'\n", filename.c_str());
269    }
270
271    if (lseek64(fd, 0, SEEK_SET) < 0) {
272        FATAL("failed to rewind '%s': %s", filename.c_str(), strerror(errno));
273    }
274
275    if (!android::base::ReadFully(fd, ctx->fec, ctx->fec_size)) {
276        FATAL("failed to read %u bytes from '%s': %s\n", ctx->fec_size,
277            filename.c_str(), strerror(errno));
278    }
279
280    close(fd);
281
282    uint8_t hash[SHA256_DIGEST_LENGTH];
283    SHA256(ctx->fec, ctx->fec_size, hash);
284
285    if (memcmp(hash, p->hash, SHA256_DIGEST_LENGTH) != 0) {
286        FATAL("invalid ecc data\n");
287    }
288
289    return true;
290}
291
292bool image_ecc_save(image *ctx)
293{
294    assert(2 * sizeof(fec_header) <= FEC_BLOCKSIZE);
295
296    uint8_t header[FEC_BLOCKSIZE] = {0};
297
298    fec_header *f = (fec_header *)header;
299
300    f->magic = FEC_MAGIC;
301    f->version = FEC_VERSION;
302    f->size = sizeof(fec_header);
303    f->roots = ctx->roots;
304    f->fec_size = ctx->fec_size;
305    f->inp_size = ctx->inp_size;
306
307    SHA256(ctx->fec, ctx->fec_size, f->hash);
308
309    /* store a copy of the fec_header at the end of the header block */
310    memcpy(&header[sizeof(header) - sizeof(fec_header)], header,
311        sizeof(fec_header));
312
313    assert(ctx->fec_filename);
314
315    int fd = TEMP_FAILURE_RETRY(open(ctx->fec_filename,
316                O_WRONLY | O_CREAT | O_TRUNC, 0666));
317
318    if (fd < 0) {
319        FATAL("failed to open file '%s': %s\n", ctx->fec_filename,
320            strerror(errno));
321    }
322
323    if (!android::base::WriteFully(fd, ctx->fec, ctx->fec_size)) {
324        FATAL("failed to write to output: %s\n", strerror(errno));
325    }
326
327    if (ctx->padding > 0) {
328        uint8_t padding[FEC_BLOCKSIZE] = {0};
329
330        for (uint32_t i = 0; i < ctx->padding; i += FEC_BLOCKSIZE) {
331            if (!android::base::WriteFully(fd, padding, FEC_BLOCKSIZE)) {
332                FATAL("failed to write padding: %s\n", strerror(errno));
333            }
334        }
335    }
336
337    if (!android::base::WriteFully(fd, header, sizeof(header))) {
338        FATAL("failed to write to header: %s\n", strerror(errno));
339    }
340
341    close(fd);
342
343    return true;
344}
345
346static void * process(void *cookie)
347{
348    image_proc_ctx *ctx = (image_proc_ctx *)cookie;
349    ctx->func(ctx);
350    return NULL;
351}
352
353bool image_process(image_proc_func func, image *ctx)
354{
355    int threads = ctx->threads;
356
357    if (threads < IMAGE_MIN_THREADS) {
358        threads = sysconf(_SC_NPROCESSORS_ONLN);
359
360        if (threads < IMAGE_MIN_THREADS) {
361            threads = IMAGE_MIN_THREADS;
362        }
363    }
364
365    assert(ctx->rounds > 0);
366
367    if ((uint64_t)threads > ctx->rounds) {
368        threads = (int)ctx->rounds;
369    }
370    if (threads > IMAGE_MAX_THREADS) {
371        threads = IMAGE_MAX_THREADS;
372    }
373
374    if (ctx->verbose) {
375        INFO("starting %d threads to compute RS(255, %d)\n", threads,
376            ctx->rs_n);
377    }
378
379    pthread_t pthreads[threads];
380    image_proc_ctx args[threads];
381
382    uint64_t current = 0;
383    uint64_t end = ctx->rounds * ctx->rs_n * FEC_BLOCKSIZE;
384    uint64_t rs_blocks_per_thread =
385        fec_div_round_up(ctx->rounds * FEC_BLOCKSIZE, threads);
386
387    if (ctx->verbose) {
388        INFO("computing %" PRIu64 " codes per thread\n", rs_blocks_per_thread);
389    }
390
391    for (int i = 0; i < threads; ++i) {
392        args[i].func = func;
393        args[i].id = i;
394        args[i].ctx = ctx;
395        args[i].rv = 0;
396        args[i].fec_pos = current * ctx->roots;
397        args[i].start = current * ctx->rs_n;
398        args[i].end = (current + rs_blocks_per_thread) * ctx->rs_n;
399
400        args[i].rs = init_rs_char(FEC_PARAMS(ctx->roots));
401
402        if (!args[i].rs) {
403            FATAL("failed to initialize encoder for thread %d\n", i);
404        }
405
406        if (args[i].end > end) {
407            args[i].end = end;
408        } else if (i == threads && args[i].end + rs_blocks_per_thread *
409                                        ctx->rs_n > end) {
410            args[i].end = end;
411        }
412
413        if (ctx->verbose) {
414            INFO("thread %d: [%" PRIu64 ", %" PRIu64 ")\n",
415                i, args[i].start, args[i].end);
416        }
417
418        assert(args[i].start < args[i].end);
419        assert((args[i].end - args[i].start) % ctx->rs_n == 0);
420
421        if (pthread_create(&pthreads[i], NULL, process, &args[i]) != 0) {
422            FATAL("failed to create thread %d\n", i);
423        }
424
425        current += rs_blocks_per_thread;
426    }
427
428    ctx->rv = 0;
429
430    for (int i = 0; i < threads; ++i) {
431        if (pthread_join(pthreads[i], NULL) != 0) {
432            FATAL("failed to join thread %d: %s\n", i, strerror(errno));
433        }
434
435        ctx->rv += args[i].rv;
436
437        if (args[i].rs) {
438            free_rs_char(args[i].rs);
439            args[i].rs = NULL;
440        }
441    }
442
443    return true;
444}
445