1/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Base64.h"
18
19#include <string>
20
21namespace android {
22namespace hardware {
23namespace drm {
24namespace V1_1 {
25namespace clearkey {
26
27sp<Buffer> decodeBase64(const std::string &s) {
28    size_t n = s.size();
29
30    if ((n % 4) != 0) {
31        return nullptr;
32    }
33
34    size_t padding = 0;
35    if (n >= 1 && s.c_str()[n - 1] == '=') {
36        padding = 1;
37
38        if (n >= 2 && s.c_str()[n - 2] == '=') {
39            padding = 2;
40
41            if (n >= 3 && s.c_str()[n - 3] == '=') {
42                padding = 3;
43            }
44        }
45    }
46
47    // We divide first to avoid overflow. It's OK to do this because we
48    // already made sure that n % 4 == 0.
49    size_t outLen = (n / 4) * 3 - padding;
50
51    sp<Buffer> buffer = new Buffer(outLen);
52    uint8_t *out = buffer->data();
53    if (out == nullptr || buffer->size() < outLen) {
54        return nullptr;
55    }
56
57    size_t j = 0;
58    uint32_t accum = 0;
59    for (size_t i = 0; i < n; ++i) {
60        char c = s.c_str()[i];
61        unsigned value;
62        if (c >= 'A' && c <= 'Z') {
63            value = c - 'A';
64        } else if (c >= 'a' && c <= 'z') {
65            value = 26 + c - 'a';
66        } else if (c >= '0' && c <= '9') {
67            value = 52 + c - '0';
68        } else if (c == '+' || c == '-') {
69            value = 62;
70        } else if (c == '/' || c == '_') {
71            value = 63;
72        } else if (c != '=') {
73            return nullptr;
74        } else {
75            if (i < n - padding) {
76                return nullptr;
77            }
78
79            value = 0;
80        }
81
82        accum = (accum << 6) | value;
83
84        if (((i + 1) % 4) == 0) {
85            if (j < outLen) { out[j++] = (accum >> 16); }
86            if (j < outLen) { out[j++] = (accum >> 8) & 0xff; }
87            if (j < outLen) { out[j++] = accum & 0xff; }
88
89            accum = 0;
90        }
91    }
92
93    return buffer;
94}
95
96static char encode6Bit(unsigned x) {
97    if (x <= 25) {
98        return 'A' + x;
99    } else if (x <= 51) {
100        return 'a' + x - 26;
101    } else if (x <= 61) {
102        return '0' + x - 52;
103    } else if (x == 62) {
104        return '+';
105    } else {
106        return '/';
107    }
108}
109
110void encodeBase64(const void *_data, size_t size, std::string *out) {
111    out->clear();
112
113    const uint8_t *data = (const uint8_t *)_data;
114
115    size_t i;
116    for (i = 0; i < (size / 3) * 3; i += 3) {
117        uint8_t x1 = data[i];
118        uint8_t x2 = data[i + 1];
119        uint8_t x3 = data[i + 2];
120
121        out->push_back(encode6Bit(x1 >> 2));
122        out->push_back(encode6Bit((x1 << 4 | x2 >> 4) & 0x3f));
123        out->push_back(encode6Bit((x2 << 2 | x3 >> 6) & 0x3f));
124        out->push_back(encode6Bit(x3 & 0x3f));
125    }
126    switch (size % 3) {
127        case 0:
128            break;
129        case 2:
130        {
131            uint8_t x1 = data[i];
132            uint8_t x2 = data[i + 1];
133            out->push_back(encode6Bit(x1 >> 2));
134            out->push_back(encode6Bit((x1 << 4 | x2 >> 4) & 0x3f));
135            out->push_back(encode6Bit((x2 << 2) & 0x3f));
136            out->push_back('=');
137            break;
138        }
139        default:
140        {
141            uint8_t x1 = data[i];
142            out->push_back(encode6Bit(x1 >> 2));
143            out->push_back(encode6Bit((x1 << 4) & 0x3f));
144            out->append("==");
145            break;
146        }
147    }
148}
149
150void encodeBase64Url(const void *_data, size_t size, std::string *out) {
151    encodeBase64(_data, size, out);
152
153    if ((std::string::npos != out->find("+")) ||
154            (std::string::npos != out->find("/"))) {
155        size_t outLen = out->size();
156        char *base64url = new char[outLen];
157        for (size_t i = 0; i < outLen; ++i) {
158            if (out->c_str()[i] == '+')
159                base64url[i] = '-';
160            else if (out->c_str()[i] == '/')
161                base64url[i] = '_';
162            else
163                base64url[i] = out->c_str()[i];
164        }
165
166        out->assign(base64url, outLen);
167        delete[] base64url;
168    }
169}
170
171} // namespace clearkey
172} // namespace V1_1
173} // namespace drm
174} // namespace hardware
175} // namespace android
176