1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//#define LOG_NDEBUG 0
18#define LOG_TAG "HlsSampleDecryptor"
19
20#include "HlsSampleDecryptor.h"
21
22#include <media/stagefright/foundation/ABuffer.h>
23#include <media/stagefright/foundation/ADebug.h>
24#include <media/stagefright/Utils.h>
25
26
27namespace android {
28
29HlsSampleDecryptor::HlsSampleDecryptor()
30    : mValidKeyInfo(false) {
31}
32
33HlsSampleDecryptor::HlsSampleDecryptor(const sp<AMessage> &sampleAesKeyItem)
34    : mValidKeyInfo(false) {
35
36    signalNewSampleAesKey(sampleAesKeyItem);
37}
38
39void HlsSampleDecryptor::signalNewSampleAesKey(const sp<AMessage> &sampleAesKeyItem) {
40
41    if (sampleAesKeyItem == NULL) {
42        mValidKeyInfo = false;
43        ALOGW("signalNewSampleAesKey: sampleAesKeyItem is NULL");
44        return;
45    }
46
47    sp<ABuffer> keyDataBuffer, initVecBuffer;
48    sampleAesKeyItem->findBuffer("keyData", &keyDataBuffer);
49    sampleAesKeyItem->findBuffer("initVec", &initVecBuffer);
50
51    if (keyDataBuffer != NULL && keyDataBuffer->size() == AES_BLOCK_SIZE &&
52        initVecBuffer != NULL && initVecBuffer->size() == AES_BLOCK_SIZE) {
53
54        ALOGV("signalNewSampleAesKey: Key: %s  IV: %s",
55              aesBlockToStr(keyDataBuffer->data()).c_str(),
56              aesBlockToStr(initVecBuffer->data()).c_str());
57
58        uint8_t KeyData[AES_BLOCK_SIZE];
59        memcpy(KeyData, keyDataBuffer->data(), AES_BLOCK_SIZE);
60        memcpy(mAESInitVec, initVecBuffer->data(), AES_BLOCK_SIZE);
61
62        mValidKeyInfo = (AES_set_decrypt_key(KeyData, 8*AES_BLOCK_SIZE/*128*/, &mAesKey) == 0);
63        if (!mValidKeyInfo) {
64            ALOGE("signalNewSampleAesKey: failed to set AES decryption key.");
65        }
66
67    } else {
68        // Media scanner might try extract/parse the TS files without knowing the key.
69        // Otherwise, shouldn't get here (unless an invalid playlist has swaped SAMPLE-AES with
70        // NONE method while still sample-encrypted stream is parsed).
71
72        mValidKeyInfo = false;
73        ALOGE("signalNewSampleAesKey Can't decrypt; keyDataBuffer: %p(%zu) initVecBuffer: %p(%zu)",
74              keyDataBuffer.get(), (keyDataBuffer.get() == NULL)? -1 : keyDataBuffer->size(),
75              initVecBuffer.get(), (initVecBuffer.get() == NULL)? -1 : initVecBuffer->size());
76    }
77}
78
79size_t HlsSampleDecryptor::processNal(uint8_t *nalData, size_t nalSize) {
80
81    unsigned nalType = nalData[0] & 0x1f;
82    if (!mValidKeyInfo) {
83        ALOGV("processNal[%d]: (%p)/%zu Skipping due to invalid key", nalType, nalData, nalSize);
84        return nalSize;
85    }
86
87    bool isEncrypted = (nalSize > VIDEO_CLEAR_LEAD + AES_BLOCK_SIZE);
88    ALOGV("processNal[%d]: (%p)/%zu isEncrypted: %d", nalType, nalData, nalSize, isEncrypted);
89
90    if (isEncrypted) {
91        // Encrypted NALUs have extra start code emulation prevention that must be
92        // stripped out before we can decrypt it.
93        size_t newSize = unescapeStream(nalData, nalSize);
94
95        ALOGV("processNal:unescapeStream[%d]: %zu -> %zu", nalType, nalSize, newSize);
96        nalSize = newSize;
97
98        //Encrypted_nal_unit () {
99        //    nal_unit_type_byte                // 1 byte
100        //    unencrypted_leader                // 31 bytes
101        //    while (bytes_remaining() > 0) {
102        //        if (bytes_remaining() > 16) {
103        //            encrypted_block           // 16 bytes
104        //        }
105        //        unencrypted_block           // MIN(144, bytes_remaining()) bytes
106        //    }
107        //}
108
109        size_t offset = VIDEO_CLEAR_LEAD;
110        size_t remainingBytes = nalSize - VIDEO_CLEAR_LEAD;
111
112        // a copy of initVec as decryptBlock updates it
113        unsigned char AESInitVec[AES_BLOCK_SIZE];
114        memcpy(AESInitVec, mAESInitVec, AES_BLOCK_SIZE);
115
116        while (remainingBytes > 0) {
117            // encrypted_block: protected block uses 10% skip encryption
118            if (remainingBytes > AES_BLOCK_SIZE) {
119                uint8_t *encrypted = nalData + offset;
120                status_t ret = decryptBlock(encrypted, AES_BLOCK_SIZE, AESInitVec);
121                if (ret != OK) {
122                    ALOGE("processNal failed with %d", ret);
123                    return nalSize; // revisit this
124                }
125
126                offset += AES_BLOCK_SIZE;
127                remainingBytes -= AES_BLOCK_SIZE;
128            }
129
130            // unencrypted_block
131            size_t clearBytes = std::min(remainingBytes, (size_t)(9 * AES_BLOCK_SIZE));
132
133            offset += clearBytes;
134            remainingBytes -= clearBytes;
135        } // while
136
137    } else { // isEncrypted == false
138        ALOGV("processNal[%d]: Unencrypted NALU  (%p)/%zu", nalType, nalData, nalSize);
139    }
140
141    return nalSize;
142}
143
144void HlsSampleDecryptor::processAAC(size_t adtsHdrSize, uint8_t *data, size_t size) {
145
146    if (!mValidKeyInfo) {
147        ALOGV("processAAC: (%p)/%zu Skipping due to invalid key", data, size);
148        return;
149    }
150
151    // ADTS header is included in the size
152    size_t offset = adtsHdrSize;
153    size_t remainingBytes = size - adtsHdrSize;
154
155    bool isEncrypted = (remainingBytes >= AUDIO_CLEAR_LEAD + AES_BLOCK_SIZE);
156    ALOGV("processAAC: header: %zu data: %p(%zu) isEncrypted: %d",
157          adtsHdrSize, data, size, isEncrypted);
158
159    //Encrypted_AAC_Frame () {
160    //    ADTS_Header                        // 7 or 9 bytes
161    //    unencrypted_leader                 // 16 bytes
162    //    while (bytes_remaining() >= 16) {
163    //        encrypted_block                // 16 bytes
164    //    }
165    //    unencrypted_trailer                // 0-15 bytes
166    //}
167
168    // with lead bytes
169    if (remainingBytes >= AUDIO_CLEAR_LEAD) {
170        offset += AUDIO_CLEAR_LEAD;
171        remainingBytes -= AUDIO_CLEAR_LEAD;
172
173        // encrypted_block
174        if (remainingBytes >= AES_BLOCK_SIZE) {
175
176            size_t encryptedBytes = (remainingBytes / AES_BLOCK_SIZE) * AES_BLOCK_SIZE;
177            unsigned char AESInitVec[AES_BLOCK_SIZE];
178            memcpy(AESInitVec, mAESInitVec, AES_BLOCK_SIZE);
179
180            // decrypting all blocks at once
181            uint8_t *encrypted = data + offset;
182            status_t ret = decryptBlock(encrypted, encryptedBytes, AESInitVec);
183            if (ret != OK) {
184                ALOGE("processAAC: decryptBlock failed with %d", ret);
185                return;
186            }
187
188            offset += encryptedBytes;
189            remainingBytes -= encryptedBytes;
190        } // encrypted
191
192        // unencrypted_trailer
193        size_t clearBytes = remainingBytes;
194        if (clearBytes > 0) {
195            CHECK(clearBytes < AES_BLOCK_SIZE);
196        }
197
198    } else { // without lead bytes
199        ALOGV("processAAC: Unencrypted frame (without lead bytes) size %zu = %zu (hdr) + %zu (rem)",
200              size, adtsHdrSize, remainingBytes);
201    }
202
203}
204
205void HlsSampleDecryptor::processAC3(uint8_t *data, size_t size) {
206
207    if (!mValidKeyInfo) {
208        ALOGV("processAC3: (%p)/%zu Skipping due to invalid key", data, size);
209        return;
210    }
211
212    bool isEncrypted = (size >= AUDIO_CLEAR_LEAD + AES_BLOCK_SIZE);
213    ALOGV("processAC3 %p(%zu) isEncrypted: %d", data, size, isEncrypted);
214
215    //Encrypted_AC3_Frame () {
216    //    unencrypted_leader                 // 16 bytes
217    //    while (bytes_remaining() >= 16) {
218    //        encrypted_block                // 16 bytes
219    //    }
220    //    unencrypted_trailer                // 0-15 bytes
221    //}
222
223    if (size >= AUDIO_CLEAR_LEAD) {
224        // unencrypted_leader
225        size_t offset = AUDIO_CLEAR_LEAD;
226        size_t remainingBytes = size - AUDIO_CLEAR_LEAD;
227
228        if (remainingBytes >= AES_BLOCK_SIZE) {
229
230            size_t encryptedBytes = (remainingBytes / AES_BLOCK_SIZE) * AES_BLOCK_SIZE;
231
232            // encrypted_block
233            unsigned char AESInitVec[AES_BLOCK_SIZE];
234            memcpy(AESInitVec, mAESInitVec, AES_BLOCK_SIZE);
235
236            // decrypting all blocks at once
237            uint8_t *encrypted = data + offset;
238            status_t ret = decryptBlock(encrypted, encryptedBytes, AESInitVec);
239            if (ret != OK) {
240                ALOGE("processAC3: decryptBlock failed with %d", ret);
241                return;
242            }
243
244            offset += encryptedBytes;
245            remainingBytes -= encryptedBytes;
246        } // encrypted
247
248        // unencrypted_trailer
249        size_t clearBytes = remainingBytes;
250        if (clearBytes > 0) {
251            CHECK(clearBytes < AES_BLOCK_SIZE);
252        }
253
254    } else {
255        ALOGV("processAC3: Unencrypted frame (without lead bytes) size %zu", size);
256    }
257}
258
259// Unescapes data replacing occurrences of [0, 0, 3] with [0, 0] and returns the new size
260size_t HlsSampleDecryptor::unescapeStream(uint8_t *data, size_t limit) const {
261    Vector<size_t> scratchEscapePositions;
262    size_t position = 0;
263
264    while (position < limit) {
265        position = findNextUnescapeIndex(data, position, limit);
266        if (position < limit) {
267            scratchEscapePositions.add(position);
268            position += 3;
269        }
270    }
271
272    size_t scratchEscapeCount = scratchEscapePositions.size();
273    size_t escapedPosition = 0; // The position being read from.
274    size_t unescapedPosition = 0; // The position being written to.
275    for (size_t i = 0; i < scratchEscapeCount; i++) {
276        size_t nextEscapePosition = scratchEscapePositions[i];
277        //TODO: add 2 and get rid of the later = 0 assignments
278        size_t copyLength = nextEscapePosition - escapedPosition;
279        memmove(data+unescapedPosition, data+escapedPosition, copyLength);
280        unescapedPosition += copyLength;
281        data[unescapedPosition++] = 0;
282        data[unescapedPosition++] = 0;
283        escapedPosition += copyLength + 3;
284    }
285
286    size_t unescapedLength = limit - scratchEscapeCount;
287    size_t remainingLength = unescapedLength - unescapedPosition;
288    memmove(data+unescapedPosition, data+escapedPosition, remainingLength);
289
290    return unescapedLength;
291}
292
293size_t HlsSampleDecryptor::findNextUnescapeIndex(uint8_t *data, size_t offset, size_t limit) const {
294    for (size_t i = offset; i < limit - 2; i++) {
295        //TODO: speed
296        if (data[i] == 0x00 && data[i + 1] == 0x00 && data[i + 2] == 0x03) {
297            return i;
298        }
299    }
300    return limit;
301}
302
303status_t HlsSampleDecryptor::decryptBlock(uint8_t *buffer, size_t size,
304        uint8_t AESInitVec[AES_BLOCK_SIZE]) {
305    if (size == 0) {
306        return OK;
307    }
308
309    if ((size % AES_BLOCK_SIZE) != 0) {
310        ALOGE("decryptBlock: size (%zu) not a multiple of block size", size);
311        return ERROR_MALFORMED;
312    }
313
314    ALOGV("decryptBlock: %p (%zu)", buffer, size);
315
316    AES_cbc_encrypt(buffer, buffer, size, &mAesKey, AESInitVec, AES_DECRYPT);
317
318    return OK;
319}
320
321AString HlsSampleDecryptor::aesBlockToStr(uint8_t block[AES_BLOCK_SIZE]) {
322    AString result;
323
324    if (block == NULL) {
325        result = AString("null");
326    } else {
327        result = AStringPrintf("0x%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X",
328            block[0], block[1], block[2], block[3], block[4], block[5], block[6], block[7],
329            block[8], block[9], block[10], block[11], block[12], block[13], block[14], block[15]);
330    }
331
332    return result;
333}
334
335
336}  // namespace android
337