1/*
2 * Copyright (C) 2010 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "include/DRMExtractor.h"
18
19#include <arpa/inet.h>
20#include <utils/String8.h>
21#include <media/stagefright/foundation/ADebug.h>
22#include <media/stagefright/Utils.h>
23#include <media/stagefright/DataSource.h>
24#include <media/stagefright/MediaSource.h>
25#include <media/stagefright/MediaDefs.h>
26#include <media/stagefright/MetaData.h>
27#include <media/stagefright/MediaErrors.h>
28#include <media/stagefright/MediaBuffer.h>
29
30#include <drm/drm_framework_common.h>
31#include <utils/Errors.h>
32
33
34namespace android {
35
36class DRMSource : public MediaSource {
37public:
38    DRMSource(const sp<IMediaSource> &mediaSource,
39            const sp<DecryptHandle> &decryptHandle,
40            DrmManagerClient *managerClient,
41            int32_t trackId, DrmBuffer *ipmpBox);
42
43    virtual status_t start(MetaData *params = NULL);
44    virtual status_t stop();
45    virtual sp<MetaData> getFormat();
46    virtual status_t read(
47            MediaBuffer **buffer, const ReadOptions *options = NULL);
48
49protected:
50    virtual ~DRMSource();
51
52private:
53    sp<IMediaSource> mOriginalMediaSource;
54    sp<DecryptHandle> mDecryptHandle;
55    DrmManagerClient* mDrmManagerClient;
56    size_t mTrackId;
57    mutable Mutex mDRMLock;
58    size_t mNALLengthSize;
59    bool mWantsNALFragments;
60
61    DRMSource(const DRMSource &);
62    DRMSource &operator=(const DRMSource &);
63};
64
65////////////////////////////////////////////////////////////////////////////////
66
67DRMSource::DRMSource(const sp<IMediaSource> &mediaSource,
68        const sp<DecryptHandle> &decryptHandle,
69        DrmManagerClient *managerClient,
70        int32_t trackId, DrmBuffer *ipmpBox)
71    : mOriginalMediaSource(mediaSource),
72      mDecryptHandle(decryptHandle),
73      mDrmManagerClient(managerClient),
74      mTrackId(trackId),
75      mNALLengthSize(0),
76      mWantsNALFragments(false) {
77    CHECK(mDrmManagerClient);
78    mDrmManagerClient->initializeDecryptUnit(
79            mDecryptHandle, trackId, ipmpBox);
80
81    const char *mime;
82    bool success = getFormat()->findCString(kKeyMIMEType, &mime);
83    CHECK(success);
84
85    if (!strcasecmp(mime, MEDIA_MIMETYPE_VIDEO_AVC)) {
86        uint32_t type;
87        const void *data;
88        size_t size;
89        CHECK(getFormat()->findData(kKeyAVCC, &type, &data, &size));
90
91        const uint8_t *ptr = (const uint8_t *)data;
92
93        CHECK(size >= 7);
94        CHECK_EQ(ptr[0], 1);  // configurationVersion == 1
95
96        // The number of bytes used to encode the length of a NAL unit.
97        mNALLengthSize = 1 + (ptr[4] & 3);
98    }
99}
100
101DRMSource::~DRMSource() {
102    Mutex::Autolock autoLock(mDRMLock);
103    mDrmManagerClient->finalizeDecryptUnit(mDecryptHandle, mTrackId);
104}
105
106status_t DRMSource::start(MetaData *params) {
107    int32_t val;
108    if (params && params->findInt32(kKeyWantsNALFragments, &val)
109        && val != 0) {
110        mWantsNALFragments = true;
111    } else {
112        mWantsNALFragments = false;
113    }
114
115   return mOriginalMediaSource->start(params);
116}
117
118status_t DRMSource::stop() {
119    return mOriginalMediaSource->stop();
120}
121
122sp<MetaData> DRMSource::getFormat() {
123    return mOriginalMediaSource->getFormat();
124}
125
126status_t DRMSource::read(MediaBuffer **buffer, const ReadOptions *options) {
127    Mutex::Autolock autoLock(mDRMLock);
128    status_t err;
129    if ((err = mOriginalMediaSource->read(buffer, options)) != OK) {
130        return err;
131    }
132
133    size_t len = (*buffer)->range_length();
134
135    char *src = (char *)(*buffer)->data() + (*buffer)->range_offset();
136
137    DrmBuffer encryptedDrmBuffer(src, len);
138    DrmBuffer decryptedDrmBuffer;
139    decryptedDrmBuffer.length = len;
140    decryptedDrmBuffer.data = new char[len];
141    DrmBuffer *pDecryptedDrmBuffer = &decryptedDrmBuffer;
142
143    if ((err = mDrmManagerClient->decrypt(mDecryptHandle, mTrackId,
144            &encryptedDrmBuffer, &pDecryptedDrmBuffer)) != NO_ERROR) {
145
146        if (decryptedDrmBuffer.data) {
147            delete [] decryptedDrmBuffer.data;
148            decryptedDrmBuffer.data = NULL;
149        }
150
151        return err;
152    }
153    CHECK(pDecryptedDrmBuffer == &decryptedDrmBuffer);
154
155    const char *mime;
156    CHECK(getFormat()->findCString(kKeyMIMEType, &mime));
157
158    if (!strcasecmp(mime, MEDIA_MIMETYPE_VIDEO_AVC) && !mWantsNALFragments) {
159        uint8_t *dstData = (uint8_t*)src;
160        size_t srcOffset = 0;
161        size_t dstOffset = 0;
162
163        len = decryptedDrmBuffer.length;
164        while (srcOffset < len) {
165            CHECK(srcOffset + mNALLengthSize <= len);
166            size_t nalLength = 0;
167            const uint8_t* data = (const uint8_t*)(&decryptedDrmBuffer.data[srcOffset]);
168
169            switch (mNALLengthSize) {
170                case 1:
171                    nalLength = *data;
172                    break;
173                case 2:
174                    nalLength = U16_AT(data);
175                    break;
176                case 3:
177                    nalLength = ((size_t)data[0] << 16) | U16_AT(&data[1]);
178                    break;
179                case 4:
180                    nalLength = U32_AT(data);
181                    break;
182                default:
183                    CHECK(!"Should not be here.");
184                    break;
185            }
186
187            srcOffset += mNALLengthSize;
188
189            size_t end = srcOffset + nalLength;
190            if (end > len || end < srcOffset) {
191                if (decryptedDrmBuffer.data) {
192                    delete [] decryptedDrmBuffer.data;
193                    decryptedDrmBuffer.data = NULL;
194                }
195
196                return ERROR_MALFORMED;
197            }
198
199            if (nalLength == 0) {
200                continue;
201            }
202
203            if (dstOffset > SIZE_MAX - 4 ||
204                dstOffset + 4 > SIZE_MAX - nalLength ||
205                dstOffset + 4 + nalLength > (*buffer)->size()) {
206                (*buffer)->release();
207                (*buffer) = NULL;
208                if (decryptedDrmBuffer.data) {
209                    delete [] decryptedDrmBuffer.data;
210                    decryptedDrmBuffer.data = NULL;
211                }
212                return ERROR_MALFORMED;
213            }
214
215            dstData[dstOffset++] = 0;
216            dstData[dstOffset++] = 0;
217            dstData[dstOffset++] = 0;
218            dstData[dstOffset++] = 1;
219            memcpy(&dstData[dstOffset], &decryptedDrmBuffer.data[srcOffset], nalLength);
220            srcOffset += nalLength;
221            dstOffset += nalLength;
222        }
223
224        CHECK_EQ(srcOffset, len);
225        (*buffer)->set_range((*buffer)->range_offset(), dstOffset);
226
227    } else {
228        memcpy(src, decryptedDrmBuffer.data, decryptedDrmBuffer.length);
229        (*buffer)->set_range((*buffer)->range_offset(), decryptedDrmBuffer.length);
230    }
231
232    if (decryptedDrmBuffer.data) {
233        delete [] decryptedDrmBuffer.data;
234        decryptedDrmBuffer.data = NULL;
235    }
236
237    return OK;
238}
239
240////////////////////////////////////////////////////////////////////////////////
241
242DRMExtractor::DRMExtractor(const sp<DataSource> &source, const char* mime)
243    : mDataSource(source),
244      mDecryptHandle(NULL),
245      mDrmManagerClient(NULL) {
246    mOriginalExtractor = MediaExtractor::Create(source, mime);
247    mOriginalExtractor->setDrmFlag(true);
248    mOriginalExtractor->getMetaData()->setInt32(kKeyIsDRM, 1);
249
250    source->getDrmInfo(mDecryptHandle, &mDrmManagerClient);
251}
252
253DRMExtractor::~DRMExtractor() {
254}
255
256size_t DRMExtractor::countTracks() {
257    return mOriginalExtractor->countTracks();
258}
259
260sp<IMediaSource> DRMExtractor::getTrack(size_t index) {
261    sp<IMediaSource> originalMediaSource = mOriginalExtractor->getTrack(index);
262    originalMediaSource->getFormat()->setInt32(kKeyIsDRM, 1);
263
264    int32_t trackID;
265    CHECK(getTrackMetaData(index, 0)->findInt32(kKeyTrackID, &trackID));
266
267    DrmBuffer ipmpBox;
268    ipmpBox.data = mOriginalExtractor->getDrmTrackInfo(trackID, &(ipmpBox.length));
269    CHECK(ipmpBox.length > 0);
270
271    return interface_cast<IMediaSource>(
272            new DRMSource(originalMediaSource, mDecryptHandle, mDrmManagerClient,
273            trackID, &ipmpBox));
274}
275
276sp<MetaData> DRMExtractor::getTrackMetaData(size_t index, uint32_t flags) {
277    return mOriginalExtractor->getTrackMetaData(index, flags);
278}
279
280sp<MetaData> DRMExtractor::getMetaData() {
281    return mOriginalExtractor->getMetaData();
282}
283
284bool SniffDRM(
285    const sp<DataSource> &source, String8 *mimeType, float *confidence,
286        sp<AMessage> *) {
287    sp<DecryptHandle> decryptHandle = source->DrmInitialization();
288
289    if (decryptHandle != NULL) {
290        if (decryptHandle->decryptApiType == DecryptApiType::CONTAINER_BASED) {
291            *mimeType = String8("drm+container_based+") + decryptHandle->mimeType;
292            *confidence = 10.0f;
293        } else if (decryptHandle->decryptApiType == DecryptApiType::ELEMENTARY_STREAM_BASED) {
294            *mimeType = String8("drm+es_based+") + decryptHandle->mimeType;
295            *confidence = 10.0f;
296        } else {
297            return false;
298        }
299
300        return true;
301    }
302
303    return false;
304}
305} //namespace android
306
307