1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "media/cdm/json_web_key.h"
6
7#include "base/base64.h"
8#include "base/json/json_reader.h"
9#include "base/json/json_string_value_serializer.h"
10#include "base/json/string_escape.h"
11#include "base/logging.h"
12#include "base/memory/scoped_ptr.h"
13#include "base/strings/string_util.h"
14#include "base/values.h"
15
16namespace media {
17
18const char kKeysTag[] = "keys";
19const char kKeyTypeTag[] = "kty";
20const char kSymmetricKeyValue[] = "oct";
21const char kKeyTag[] = "k";
22const char kKeyIdTag[] = "kid";
23const char kKeyIdsTag[] = "kids";
24const char kBase64Padding = '=';
25const char kTypeTag[] = "type";
26const char kPersistentType[] = "persistent";
27const char kTemporaryType[] = "temporary";
28
29// Encodes |input| into a base64 string without padding.
30static std::string EncodeBase64(const uint8* input, int input_length) {
31  std::string encoded_text;
32  base::Base64Encode(
33      std::string(reinterpret_cast<const char*>(input), input_length),
34      &encoded_text);
35
36  // Remove any padding characters added by Base64Encode().
37  size_t found = encoded_text.find_last_not_of(kBase64Padding);
38  if (found != std::string::npos)
39    encoded_text.erase(found + 1);
40
41  return encoded_text;
42}
43
44// Decodes an unpadded base64 string. Returns empty string on error.
45static std::string DecodeBase64(const std::string& encoded_text) {
46  // EME spec doesn't allow padding characters.
47  if (encoded_text.find_first_of(kBase64Padding) != std::string::npos)
48    return std::string();
49
50  // Since base::Base64Decode() requires padding characters, add them so length
51  // of |encoded_text| is exactly a multiple of 4.
52  size_t num_last_grouping_chars = encoded_text.length() % 4;
53  std::string modified_text = encoded_text;
54  if (num_last_grouping_chars > 0)
55    modified_text.append(4 - num_last_grouping_chars, kBase64Padding);
56
57  std::string decoded_text;
58  if (!base::Base64Decode(modified_text, &decoded_text))
59    return std::string();
60
61  return decoded_text;
62}
63
64std::string GenerateJWKSet(const uint8* key, int key_length,
65                           const uint8* key_id, int key_id_length) {
66  // Both |key| and |key_id| need to be base64 encoded strings in the JWK.
67  std::string key_base64 = EncodeBase64(key, key_length);
68  std::string key_id_base64 = EncodeBase64(key_id, key_id_length);
69
70  // Create the JWK, and wrap it into a JWK Set.
71  scoped_ptr<base::DictionaryValue> jwk(new base::DictionaryValue());
72  jwk->SetString(kKeyTypeTag, kSymmetricKeyValue);
73  jwk->SetString(kKeyTag, key_base64);
74  jwk->SetString(kKeyIdTag, key_id_base64);
75  scoped_ptr<base::ListValue> list(new base::ListValue());
76  list->Append(jwk.release());
77  base::DictionaryValue jwk_set;
78  jwk_set.Set(kKeysTag, list.release());
79
80  // Finally serialize |jwk_set| into a string and return it.
81  std::string serialized_jwk;
82  JSONStringValueSerializer serializer(&serialized_jwk);
83  serializer.Serialize(jwk_set);
84  return serialized_jwk;
85}
86
87// Processes a JSON Web Key to extract the key id and key value. Sets |jwk_key|
88// to the id/value pair and returns true on success.
89static bool ConvertJwkToKeyPair(const base::DictionaryValue& jwk,
90                                KeyIdAndKeyPair* jwk_key) {
91  // Have found a JWK, start by checking that it is a symmetric key.
92  std::string type;
93  if (!jwk.GetString(kKeyTypeTag, &type) || type != kSymmetricKeyValue) {
94    DVLOG(1) << "JWK is not a symmetric key";
95    return false;
96  }
97
98  // Get the key id and actual key parameters.
99  std::string encoded_key_id;
100  std::string encoded_key;
101  if (!jwk.GetString(kKeyIdTag, &encoded_key_id)) {
102    DVLOG(1) << "Missing '" << kKeyIdTag << "' parameter";
103    return false;
104  }
105  if (!jwk.GetString(kKeyTag, &encoded_key)) {
106    DVLOG(1) << "Missing '" << kKeyTag << "' parameter";
107    return false;
108  }
109
110  // Key ID and key are base64-encoded strings, so decode them.
111  std::string raw_key_id = DecodeBase64(encoded_key_id);
112  if (raw_key_id.empty()) {
113    DVLOG(1) << "Invalid '" << kKeyIdTag << "' value: " << encoded_key_id;
114    return false;
115  }
116
117  std::string raw_key = DecodeBase64(encoded_key);
118  if (raw_key.empty()) {
119    DVLOG(1) << "Invalid '" << kKeyTag << "' value: " << encoded_key;
120    return false;
121  }
122
123  // Add the decoded key ID and the decoded key to the list.
124  *jwk_key = std::make_pair(raw_key_id, raw_key);
125  return true;
126}
127
128bool ExtractKeysFromJWKSet(const std::string& jwk_set,
129                           KeyIdAndKeyPairs* keys,
130                           MediaKeys::SessionType* session_type) {
131  if (!base::IsStringASCII(jwk_set))
132    return false;
133
134  scoped_ptr<base::Value> root(base::JSONReader().ReadToValue(jwk_set));
135  if (!root.get() || root->GetType() != base::Value::TYPE_DICTIONARY)
136    return false;
137
138  // Locate the set from the dictionary.
139  base::DictionaryValue* dictionary =
140      static_cast<base::DictionaryValue*>(root.get());
141  base::ListValue* list_val = NULL;
142  if (!dictionary->GetList(kKeysTag, &list_val)) {
143    DVLOG(1) << "Missing '" << kKeysTag
144             << "' parameter or not a list in JWK Set";
145    return false;
146  }
147
148  // Create a local list of keys, so that |jwk_keys| only gets updated on
149  // success.
150  KeyIdAndKeyPairs local_keys;
151  for (size_t i = 0; i < list_val->GetSize(); ++i) {
152    base::DictionaryValue* jwk = NULL;
153    if (!list_val->GetDictionary(i, &jwk)) {
154      DVLOG(1) << "Unable to access '" << kKeysTag << "'[" << i
155               << "] in JWK Set";
156      return false;
157    }
158    KeyIdAndKeyPair key_pair;
159    if (!ConvertJwkToKeyPair(*jwk, &key_pair)) {
160      DVLOG(1) << "Error from '" << kKeysTag << "'[" << i << "]";
161      return false;
162    }
163    local_keys.push_back(key_pair);
164  }
165
166  // Successfully processed all JWKs in the set. Now check if "type" is
167  // specified.
168  base::Value* value = NULL;
169  std::string type_id;
170  if (!dictionary->Get(kTypeTag, &value)) {
171    // Not specified, so use the default type.
172    *session_type = MediaKeys::TEMPORARY_SESSION;
173  } else if (!value->GetAsString(&type_id)) {
174    DVLOG(1) << "Invalid '" << kTypeTag << "' value";
175    return false;
176  } else if (type_id == kPersistentType) {
177    *session_type = MediaKeys::PERSISTENT_SESSION;
178  } else if (type_id == kTemporaryType) {
179    *session_type = MediaKeys::TEMPORARY_SESSION;
180  } else {
181    DVLOG(1) << "Invalid '" << kTypeTag << "' value: " << type_id;
182    return false;
183  }
184
185  // All done.
186  keys->swap(local_keys);
187  return true;
188}
189
190void CreateLicenseRequest(const uint8* key_id,
191                          int key_id_length,
192                          MediaKeys::SessionType session_type,
193                          std::vector<uint8>* license) {
194  // Create the license request.
195  scoped_ptr<base::DictionaryValue> request(new base::DictionaryValue());
196  scoped_ptr<base::ListValue> list(new base::ListValue());
197  list->AppendString(EncodeBase64(key_id, key_id_length));
198  request->Set(kKeyIdsTag, list.release());
199
200  switch (session_type) {
201    case MediaKeys::TEMPORARY_SESSION:
202      request->SetString(kTypeTag, kTemporaryType);
203      break;
204    case MediaKeys::PERSISTENT_SESSION:
205      request->SetString(kTypeTag, kPersistentType);
206      break;
207  }
208
209  // Serialize the license request as a string.
210  std::string json;
211  JSONStringValueSerializer serializer(&json);
212  serializer.Serialize(*request);
213
214  // Convert the serialized license request into std::vector and return it.
215  std::vector<uint8> result(json.begin(), json.end());
216  license->swap(result);
217}
218
219bool ExtractFirstKeyIdFromLicenseRequest(const std::vector<uint8>& license,
220                                         std::vector<uint8>* first_key) {
221  const std::string license_as_str(
222      reinterpret_cast<const char*>(!license.empty() ? &license[0] : NULL),
223      license.size());
224  if (!base::IsStringASCII(license_as_str))
225    return false;
226
227  scoped_ptr<base::Value> root(base::JSONReader().ReadToValue(license_as_str));
228  if (!root.get() || root->GetType() != base::Value::TYPE_DICTIONARY)
229    return false;
230
231  // Locate the set from the dictionary.
232  base::DictionaryValue* dictionary =
233      static_cast<base::DictionaryValue*>(root.get());
234  base::ListValue* list_val = NULL;
235  if (!dictionary->GetList(kKeyIdsTag, &list_val)) {
236    DVLOG(1) << "Missing '" << kKeyIdsTag << "' parameter or not a list";
237    return false;
238  }
239
240  // Get the first key.
241  if (list_val->GetSize() < 1) {
242    DVLOG(1) << "Empty '" << kKeyIdsTag << "' list";
243    return false;
244  }
245
246  std::string encoded_key;
247  if (!list_val->GetString(0, &encoded_key)) {
248    DVLOG(1) << "First entry in '" << kKeyIdsTag << "' not a string";
249    return false;
250  }
251
252  std::string decoded_string = DecodeBase64(encoded_key);
253  if (decoded_string.empty()) {
254    DVLOG(1) << "Invalid '" << kKeyIdsTag << "' value: " << encoded_key;
255    return false;
256  }
257
258  std::vector<uint8> result(decoded_string.begin(), decoded_string.end());
259  first_key->swap(result);
260  return true;
261}
262
263}  // namespace media
264