1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <audio_utils/sndfile.h>
18#include <audio_utils/primitives.h>
19#include <stdio.h>
20#include <string.h>
21#include <errno.h>
22
23#define WAVE_FORMAT_PCM         1
24#define WAVE_FORMAT_IEEE_FLOAT  3
25#define WAVE_FORMAT_EXTENSIBLE  0xFFFE
26
27struct SNDFILE_ {
28    int mode;
29    uint8_t *temp;  // realloc buffer used for shrinking 16 bits to 8 bits and byte-swapping
30    FILE *stream;
31    size_t bytesPerFrame;
32    size_t remaining;   // frames unread for SFM_READ, frames written for SFM_WRITE
33    SF_INFO info;
34};
35
36static unsigned little2u(unsigned char *ptr)
37{
38    return (ptr[1] << 8) + ptr[0];
39}
40
41static unsigned little4u(unsigned char *ptr)
42{
43    return (ptr[3] << 24) + (ptr[2] << 16) + (ptr[1] << 8) + ptr[0];
44}
45
46static int isLittleEndian(void)
47{
48    static const short one = 1;
49    return *((const char *) &one) == 1;
50}
51
52// "swab" conflicts with OS X <string.h>
53static void my_swab(short *ptr, size_t numToSwap)
54{
55    while (numToSwap > 0) {
56        *ptr = little2u((unsigned char *) ptr);
57        --numToSwap;
58        ++ptr;
59    }
60}
61
62static SNDFILE *sf_open_read(const char *path, SF_INFO *info)
63{
64    FILE *stream = fopen(path, "rb");
65    if (stream == NULL) {
66        fprintf(stderr, "fopen %s failed errno %d\n", path, errno);
67        return NULL;
68    }
69
70    SNDFILE *handle = (SNDFILE *) malloc(sizeof(SNDFILE));
71    handle->mode = SFM_READ;
72    handle->temp = NULL;
73    handle->stream = stream;
74    handle->info.format = SF_FORMAT_WAV;
75
76    // don't attempt to parse all valid forms, just the most common ones
77    unsigned char wav[12];
78    size_t actual;
79    actual = fread(wav, sizeof(char), sizeof(wav), stream);
80    if (actual < 12) {
81        fprintf(stderr, "actual %u < 44\n", actual);
82        goto close;
83    }
84    if (memcmp(wav, "RIFF", 4)) {
85        fprintf(stderr, "wav != RIFF\n");
86        goto close;
87    }
88    unsigned riffSize = little4u(&wav[4]);
89    if (riffSize < 4) {
90        fprintf(stderr, "riffSize %u < 4\n", riffSize);
91        goto close;
92    }
93    if (memcmp(&wav[8], "WAVE", 4)) {
94        fprintf(stderr, "missing WAVE\n");
95        goto close;
96    }
97    size_t remaining = riffSize - 4;
98    int hadFmt = 0;
99    int hadData = 0;
100    while (remaining >= 8) {
101        unsigned char chunk[8];
102        actual = fread(chunk, sizeof(char), sizeof(chunk), stream);
103        if (actual != sizeof(chunk)) {
104            fprintf(stderr, "actual %u != %u\n", actual, sizeof(chunk));
105            goto close;
106        }
107        remaining -= 8;
108        unsigned chunkSize = little4u(&chunk[4]);
109        if (chunkSize > remaining) {
110            fprintf(stderr, "chunkSize %u > remaining %u\n", chunkSize, remaining);
111            goto close;
112        }
113        if (!memcmp(&chunk[0], "fmt ", 4)) {
114            if (hadFmt) {
115                fprintf(stderr, "multiple fmt\n");
116                goto close;
117            }
118            if (chunkSize < 2) {
119                fprintf(stderr, "chunkSize %u < 2\n", chunkSize);
120                goto close;
121            }
122            unsigned char fmt[40];
123            actual = fread(fmt, sizeof(char), 2, stream);
124            if (actual != 2) {
125                fprintf(stderr, "actual %u != 2\n", actual);
126                goto close;
127            }
128            unsigned format = little2u(&fmt[0]);
129            size_t minSize = 0;
130            switch (format) {
131            case WAVE_FORMAT_PCM:
132            case WAVE_FORMAT_IEEE_FLOAT:
133                minSize = 16;
134                break;
135            case WAVE_FORMAT_EXTENSIBLE:
136                minSize = 40;
137                break;
138            default:
139                fprintf(stderr, "unsupported format %u\n", format);
140                goto close;
141            }
142            if (chunkSize < minSize) {
143                fprintf(stderr, "chunkSize %u < minSize %u\n", chunkSize, minSize);
144                goto close;
145            }
146            actual = fread(&fmt[2], sizeof(char), minSize - 2, stream);
147            if (actual != minSize - 2) {
148                fprintf(stderr, "actual %u != %u\n", actual, minSize - 16);
149                goto close;
150            }
151            if (chunkSize > minSize) {
152                fseek(stream, (long) (chunkSize - minSize), SEEK_CUR);
153            }
154            unsigned channels = little2u(&fmt[2]);
155            if (channels != 1 && channels != 2) {
156                fprintf(stderr, "channels %u != 1 or 2\n", channels);
157                goto close;
158            }
159            unsigned samplerate = little4u(&fmt[4]);
160            if (samplerate == 0) {
161                fprintf(stderr, "samplerate %u == 0\n", samplerate);
162                goto close;
163            }
164            // ignore byte rate
165            // ignore block alignment
166            unsigned bitsPerSample = little2u(&fmt[14]);
167            if (bitsPerSample != 8 && bitsPerSample != 16 && bitsPerSample != 32) {
168                fprintf(stderr, "bitsPerSample %u != 8 or 16 or 32\n", bitsPerSample);
169                goto close;
170            }
171            unsigned bytesPerFrame = (bitsPerSample >> 3) * channels;
172            handle->bytesPerFrame = bytesPerFrame;
173            handle->info.samplerate = samplerate;
174            handle->info.channels = channels;
175            switch (bitsPerSample) {
176            case 8:
177                handle->info.format |= SF_FORMAT_PCM_U8;
178                break;
179            case 16:
180                handle->info.format |= SF_FORMAT_PCM_16;
181                break;
182            case 32:
183                if (format == WAVE_FORMAT_IEEE_FLOAT)
184                    handle->info.format |= SF_FORMAT_FLOAT;
185                else
186                    handle->info.format |= SF_FORMAT_PCM_32;
187                break;
188            }
189            hadFmt = 1;
190        } else if (!memcmp(&chunk[0], "data", 4)) {
191            if (!hadFmt) {
192                fprintf(stderr, "data not preceded by fmt\n");
193                goto close;
194            }
195            if (hadData) {
196                fprintf(stderr, "multiple data\n");
197                goto close;
198            }
199            handle->remaining = chunkSize / handle->bytesPerFrame;
200            handle->info.frames = handle->remaining;
201            hadData = 1;
202        } else if (!memcmp(&chunk[0], "fact", 4)) {
203            // ignore fact
204            if (chunkSize > 0) {
205                fseek(stream, (long) chunkSize, SEEK_CUR);
206            }
207        } else {
208            // ignore unknown chunk
209            fprintf(stderr, "ignoring unknown chunk %c%c%c%c\n",
210                    chunk[0], chunk[1], chunk[2], chunk[3]);
211            if (chunkSize > 0) {
212                fseek(stream, (long) chunkSize, SEEK_CUR);
213            }
214        }
215        remaining -= chunkSize;
216    }
217    if (remaining > 0) {
218        fprintf(stderr, "partial chunk at end of RIFF, remaining %u\n", remaining);
219        goto close;
220    }
221    if (!hadData) {
222        fprintf(stderr, "missing data\n");
223        goto close;
224    }
225    *info = handle->info;
226    return handle;
227
228close:
229    free(handle);
230    fclose(stream);
231    return NULL;
232}
233
234static void write4u(unsigned char *ptr, unsigned u)
235{
236    ptr[0] = u;
237    ptr[1] = u >> 8;
238    ptr[2] = u >> 16;
239    ptr[3] = u >> 24;
240}
241
242static SNDFILE *sf_open_write(const char *path, SF_INFO *info)
243{
244    int sub = info->format & SF_FORMAT_SUBMASK;
245    if (!(
246            (info->samplerate > 0) &&
247            (info->channels == 1 || info->channels == 2) &&
248            ((info->format & SF_FORMAT_TYPEMASK) == SF_FORMAT_WAV) &&
249            (sub == SF_FORMAT_PCM_16 || sub == SF_FORMAT_PCM_U8 || sub == SF_FORMAT_FLOAT)
250          )) {
251        return NULL;
252    }
253    FILE *stream = fopen(path, "w+b");
254    unsigned char wav[58];
255    memset(wav, 0, sizeof(wav));
256    memcpy(wav, "RIFF", 4);
257    memcpy(&wav[8], "WAVEfmt ", 8);
258    if (sub == SF_FORMAT_FLOAT) {
259        wav[4] = 50;    // riffSize
260        wav[16] = 18;   // fmtSize
261        wav[20] = WAVE_FORMAT_IEEE_FLOAT;
262    } else {
263        wav[4] = 36;    // riffSize
264        wav[16] = 16;   // fmtSize
265        wav[20] = WAVE_FORMAT_PCM;
266    }
267    wav[22] = info->channels;
268    write4u(&wav[24], info->samplerate);
269    unsigned bitsPerSample;
270    switch (sub) {
271    case SF_FORMAT_PCM_16:
272        bitsPerSample = 16;
273        break;
274    case SF_FORMAT_PCM_U8:
275        bitsPerSample = 8;
276        break;
277    case SF_FORMAT_FLOAT:
278        bitsPerSample = 32;
279        break;
280    default:    // not reachable
281        bitsPerSample = 0;
282        break;
283    }
284    unsigned blockAlignment = (bitsPerSample >> 3) * info->channels;
285    unsigned byteRate = info->samplerate * blockAlignment;
286    write4u(&wav[28], byteRate);
287    wav[32] = blockAlignment;
288    wav[34] = bitsPerSample;
289    if (sub == SF_FORMAT_FLOAT) {
290        memcpy(&wav[38], "fact", 4);
291        wav[42] = 4;
292        memcpy(&wav[50], "data", 4);
293    } else
294        memcpy(&wav[36], "data", 4);
295    // dataSize is initially zero
296    (void) fwrite(wav, sizeof(wav), 1, stream);
297    SNDFILE *handle = (SNDFILE *) malloc(sizeof(SNDFILE));
298    handle->mode = SFM_WRITE;
299    handle->temp = NULL;
300    handle->stream = stream;
301    handle->bytesPerFrame = blockAlignment;
302    handle->remaining = 0;
303    handle->info = *info;
304    return handle;
305}
306
307SNDFILE *sf_open(const char *path, int mode, SF_INFO *info)
308{
309    if (path == NULL || info == NULL) {
310        fprintf(stderr, "path=%p info=%p\n", path, info);
311        return NULL;
312    }
313    switch (mode) {
314    case SFM_READ:
315        return sf_open_read(path, info);
316    case SFM_WRITE:
317        return sf_open_write(path, info);
318    default:
319        fprintf(stderr, "mode=%d\n", mode);
320        return NULL;
321    }
322}
323
324void sf_close(SNDFILE *handle)
325{
326    if (handle == NULL)
327        return;
328    free(handle->temp);
329    if (handle->mode == SFM_WRITE) {
330        (void) fflush(handle->stream);
331        rewind(handle->stream);
332        unsigned char wav[58];
333        size_t extra = (handle->info.format & SF_FORMAT_SUBMASK) == SF_FORMAT_FLOAT ? 14 : 0;
334        (void) fread(wav, 44 + extra, 1, handle->stream);
335        unsigned dataSize = handle->remaining * handle->bytesPerFrame;
336        write4u(&wav[4], dataSize + 36 + extra);    // riffSize
337        write4u(&wav[40 + extra], dataSize);        // dataSize
338        rewind(handle->stream);
339        (void) fwrite(wav, 44 + extra, 1, handle->stream);
340    }
341    (void) fclose(handle->stream);
342    free(handle);
343}
344
345sf_count_t sf_readf_short(SNDFILE *handle, short *ptr, sf_count_t desiredFrames)
346{
347    if (handle == NULL || handle->mode != SFM_READ || ptr == NULL || !handle->remaining ||
348            desiredFrames <= 0) {
349        return 0;
350    }
351    if (handle->remaining < (size_t) desiredFrames)
352        desiredFrames = handle->remaining;
353    // does not check for numeric overflow
354    size_t desiredBytes = desiredFrames * handle->bytesPerFrame;
355    size_t actualBytes;
356    void *temp = NULL;
357    unsigned format = handle->info.format & SF_FORMAT_SUBMASK;
358    if (format == SF_FORMAT_PCM_32 || format == SF_FORMAT_FLOAT) {
359        temp = malloc(desiredBytes);
360        actualBytes = fread(temp, sizeof(char), desiredBytes, handle->stream);
361    } else {
362        actualBytes = fread(ptr, sizeof(char), desiredBytes, handle->stream);
363    }
364    size_t actualFrames = actualBytes / handle->bytesPerFrame;
365    handle->remaining -= actualFrames;
366    switch (format) {
367    case SF_FORMAT_PCM_U8:
368        memcpy_to_i16_from_u8(ptr, (unsigned char *) ptr, actualFrames * handle->info.channels);
369        break;
370    case SF_FORMAT_PCM_16:
371        if (!isLittleEndian())
372            my_swab(ptr, actualFrames * handle->info.channels);
373        break;
374    case SF_FORMAT_PCM_32:
375        memcpy_to_i16_from_i32(ptr, (const int *) temp, actualFrames * handle->info.channels);
376        free(temp);
377        break;
378    case SF_FORMAT_FLOAT:
379        memcpy_to_i16_from_float(ptr, (const float *) temp, actualFrames * handle->info.channels);
380        free(temp);
381        break;
382    default:
383        memset(ptr, 0, actualFrames * handle->info.channels * sizeof(short));
384        break;
385    }
386    return actualFrames;
387}
388
389sf_count_t sf_readf_float(SNDFILE *handle, float *ptr, sf_count_t desiredFrames)
390{
391    return 0;
392}
393
394sf_count_t sf_writef_short(SNDFILE *handle, const short *ptr, sf_count_t desiredFrames)
395{
396    if (handle == NULL || handle->mode != SFM_WRITE || ptr == NULL || desiredFrames <= 0)
397        return 0;
398    size_t desiredBytes = desiredFrames * handle->bytesPerFrame;
399    size_t actualBytes = 0;
400    switch (handle->info.format & SF_FORMAT_SUBMASK) {
401    case SF_FORMAT_PCM_U8:
402        handle->temp = realloc(handle->temp, desiredBytes);
403        memcpy_to_u8_from_i16(handle->temp, ptr, desiredBytes);
404        actualBytes = fwrite(handle->temp, sizeof(char), desiredBytes, handle->stream);
405        break;
406    case SF_FORMAT_PCM_16:
407        // does not check for numeric overflow
408        if (isLittleEndian()) {
409            actualBytes = fwrite(ptr, sizeof(char), desiredBytes, handle->stream);
410        } else {
411            handle->temp = realloc(handle->temp, desiredBytes);
412            memcpy(handle->temp, ptr, desiredBytes);
413            my_swab((short *) handle->temp, desiredFrames * handle->info.channels);
414            actualBytes = fwrite(handle->temp, sizeof(char), desiredBytes, handle->stream);
415        }
416        break;
417    case SF_FORMAT_FLOAT:   // transcoding from short to float not yet implemented
418    default:
419        break;
420    }
421    size_t actualFrames = actualBytes / handle->bytesPerFrame;
422    handle->remaining += actualFrames;
423    return actualFrames;
424}
425
426sf_count_t sf_writef_float(SNDFILE *handle, const float *ptr, sf_count_t desiredFrames)
427{
428    if (handle == NULL || handle->mode != SFM_WRITE || ptr == NULL || desiredFrames <= 0)
429        return 0;
430    size_t desiredBytes = desiredFrames * handle->bytesPerFrame;
431    size_t actualBytes = 0;
432    switch (handle->info.format & SF_FORMAT_SUBMASK) {
433    case SF_FORMAT_FLOAT:
434        actualBytes = fwrite(ptr, sizeof(char), desiredBytes, handle->stream);
435        break;
436    case SF_FORMAT_PCM_U8:  // transcoding from float to byte/short not yet implemented
437    case SF_FORMAT_PCM_16:
438    default:
439        break;
440    }
441    size_t actualFrames = actualBytes / handle->bytesPerFrame;
442    handle->remaining += actualFrames;
443    return actualFrames;
444}
445