1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "media/cdm/aes_decryptor.h"
6
7#include <list>
8#include <vector>
9
10#include "base/logging.h"
11#include "base/stl_util.h"
12#include "base/strings/string_number_conversions.h"
13#include "crypto/encryptor.h"
14#include "crypto/symmetric_key.h"
15#include "media/base/audio_decoder_config.h"
16#include "media/base/cdm_promise.h"
17#include "media/base/decoder_buffer.h"
18#include "media/base/decrypt_config.h"
19#include "media/base/video_decoder_config.h"
20#include "media/base/video_frame.h"
21#include "media/cdm/json_web_key.h"
22
23namespace media {
24
25// Keeps track of the session IDs and DecryptionKeys. The keys are ordered by
26// insertion time (last insertion is first). It takes ownership of the
27// DecryptionKeys.
28class AesDecryptor::SessionIdDecryptionKeyMap {
29  // Use a std::list to actually hold the data. Insertion is always done
30  // at the front, so the "latest" decryption key is always the first one
31  // in the list.
32  typedef std::list<std::pair<std::string, DecryptionKey*> > KeyList;
33
34 public:
35  SessionIdDecryptionKeyMap() {}
36  ~SessionIdDecryptionKeyMap() { STLDeleteValues(&key_list_); }
37
38  // Replaces value if |session_id| is already present, or adds it if not.
39  // This |decryption_key| becomes the latest until another insertion or
40  // |session_id| is erased.
41  void Insert(const std::string& web_session_id,
42              scoped_ptr<DecryptionKey> decryption_key);
43
44  // Deletes the entry for |session_id| if present.
45  void Erase(const std::string& web_session_id);
46
47  // Returns whether the list is empty
48  bool Empty() const { return key_list_.empty(); }
49
50  // Returns the last inserted DecryptionKey.
51  DecryptionKey* LatestDecryptionKey() {
52    DCHECK(!key_list_.empty());
53    return key_list_.begin()->second;
54  }
55
56 private:
57  // Searches the list for an element with |web_session_id|.
58  KeyList::iterator Find(const std::string& web_session_id);
59
60  // Deletes the entry pointed to by |position|.
61  void Erase(KeyList::iterator position);
62
63  KeyList key_list_;
64
65  DISALLOW_COPY_AND_ASSIGN(SessionIdDecryptionKeyMap);
66};
67
68void AesDecryptor::SessionIdDecryptionKeyMap::Insert(
69    const std::string& web_session_id,
70    scoped_ptr<DecryptionKey> decryption_key) {
71  KeyList::iterator it = Find(web_session_id);
72  if (it != key_list_.end())
73    Erase(it);
74  DecryptionKey* raw_ptr = decryption_key.release();
75  key_list_.push_front(std::make_pair(web_session_id, raw_ptr));
76}
77
78void AesDecryptor::SessionIdDecryptionKeyMap::Erase(
79    const std::string& web_session_id) {
80  KeyList::iterator it = Find(web_session_id);
81  if (it == key_list_.end())
82    return;
83  Erase(it);
84}
85
86AesDecryptor::SessionIdDecryptionKeyMap::KeyList::iterator
87AesDecryptor::SessionIdDecryptionKeyMap::Find(
88    const std::string& web_session_id) {
89  for (KeyList::iterator it = key_list_.begin(); it != key_list_.end(); ++it) {
90    if (it->first == web_session_id)
91      return it;
92  }
93  return key_list_.end();
94}
95
96void AesDecryptor::SessionIdDecryptionKeyMap::Erase(
97    KeyList::iterator position) {
98  DCHECK(position->second);
99  delete position->second;
100  key_list_.erase(position);
101}
102
103uint32 AesDecryptor::next_web_session_id_ = 1;
104
105enum ClearBytesBufferSel {
106  kSrcContainsClearBytes,
107  kDstContainsClearBytes
108};
109
110static void CopySubsamples(const std::vector<SubsampleEntry>& subsamples,
111                           const ClearBytesBufferSel sel,
112                           const uint8* src,
113                           uint8* dst) {
114  for (size_t i = 0; i < subsamples.size(); i++) {
115    const SubsampleEntry& subsample = subsamples[i];
116    if (sel == kSrcContainsClearBytes) {
117      src += subsample.clear_bytes;
118    } else {
119      dst += subsample.clear_bytes;
120    }
121    memcpy(dst, src, subsample.cypher_bytes);
122    src += subsample.cypher_bytes;
123    dst += subsample.cypher_bytes;
124  }
125}
126
127// Decrypts |input| using |key|.  Returns a DecoderBuffer with the decrypted
128// data if decryption succeeded or NULL if decryption failed.
129static scoped_refptr<DecoderBuffer> DecryptData(const DecoderBuffer& input,
130                                                crypto::SymmetricKey* key) {
131  CHECK(input.data_size());
132  CHECK(input.decrypt_config());
133  CHECK(key);
134
135  crypto::Encryptor encryptor;
136  if (!encryptor.Init(key, crypto::Encryptor::CTR, "")) {
137    DVLOG(1) << "Could not initialize decryptor.";
138    return NULL;
139  }
140
141  DCHECK_EQ(input.decrypt_config()->iv().size(),
142            static_cast<size_t>(DecryptConfig::kDecryptionKeySize));
143  if (!encryptor.SetCounter(input.decrypt_config()->iv())) {
144    DVLOG(1) << "Could not set counter block.";
145    return NULL;
146  }
147
148  const char* sample = reinterpret_cast<const char*>(input.data());
149  size_t sample_size = static_cast<size_t>(input.data_size());
150
151  DCHECK_GT(sample_size, 0U) << "No sample data to be decrypted.";
152  if (sample_size == 0)
153    return NULL;
154
155  if (input.decrypt_config()->subsamples().empty()) {
156    std::string decrypted_text;
157    base::StringPiece encrypted_text(sample, sample_size);
158    if (!encryptor.Decrypt(encrypted_text, &decrypted_text)) {
159      DVLOG(1) << "Could not decrypt data.";
160      return NULL;
161    }
162
163    // TODO(xhwang): Find a way to avoid this data copy.
164    return DecoderBuffer::CopyFrom(
165        reinterpret_cast<const uint8*>(decrypted_text.data()),
166        decrypted_text.size());
167  }
168
169  const std::vector<SubsampleEntry>& subsamples =
170      input.decrypt_config()->subsamples();
171
172  size_t total_clear_size = 0;
173  size_t total_encrypted_size = 0;
174  for (size_t i = 0; i < subsamples.size(); i++) {
175    total_clear_size += subsamples[i].clear_bytes;
176    total_encrypted_size += subsamples[i].cypher_bytes;
177    // Check for overflow. This check is valid because *_size is unsigned.
178    DCHECK(total_clear_size >= subsamples[i].clear_bytes);
179    if (total_encrypted_size < subsamples[i].cypher_bytes)
180      return NULL;
181  }
182  size_t total_size = total_clear_size + total_encrypted_size;
183  if (total_size < total_clear_size || total_size != sample_size) {
184    DVLOG(1) << "Subsample sizes do not equal input size";
185    return NULL;
186  }
187
188  // No need to decrypt if there is no encrypted data.
189  if (total_encrypted_size <= 0) {
190    return DecoderBuffer::CopyFrom(reinterpret_cast<const uint8*>(sample),
191                                   sample_size);
192  }
193
194  // The encrypted portions of all subsamples must form a contiguous block,
195  // such that an encrypted subsample that ends away from a block boundary is
196  // immediately followed by the start of the next encrypted subsample. We
197  // copy all encrypted subsamples to a contiguous buffer, decrypt them, then
198  // copy the decrypted bytes over the encrypted bytes in the output.
199  // TODO(strobe): attempt to reduce number of memory copies
200  scoped_ptr<uint8[]> encrypted_bytes(new uint8[total_encrypted_size]);
201  CopySubsamples(subsamples, kSrcContainsClearBytes,
202                 reinterpret_cast<const uint8*>(sample), encrypted_bytes.get());
203
204  base::StringPiece encrypted_text(
205      reinterpret_cast<const char*>(encrypted_bytes.get()),
206      total_encrypted_size);
207  std::string decrypted_text;
208  if (!encryptor.Decrypt(encrypted_text, &decrypted_text)) {
209    DVLOG(1) << "Could not decrypt data.";
210    return NULL;
211  }
212  DCHECK_EQ(decrypted_text.size(), encrypted_text.size());
213
214  scoped_refptr<DecoderBuffer> output = DecoderBuffer::CopyFrom(
215      reinterpret_cast<const uint8*>(sample), sample_size);
216  CopySubsamples(subsamples, kDstContainsClearBytes,
217                 reinterpret_cast<const uint8*>(decrypted_text.data()),
218                 output->writable_data());
219  return output;
220}
221
222AesDecryptor::AesDecryptor(const SessionMessageCB& session_message_cb,
223                           const SessionClosedCB& session_closed_cb)
224    : session_message_cb_(session_message_cb),
225      session_closed_cb_(session_closed_cb) {
226  DCHECK(!session_message_cb_.is_null());
227  DCHECK(!session_closed_cb_.is_null());
228}
229
230AesDecryptor::~AesDecryptor() {
231  key_map_.clear();
232}
233
234void AesDecryptor::CreateSession(const std::string& init_data_type,
235                                 const uint8* init_data,
236                                 int init_data_length,
237                                 SessionType session_type,
238                                 scoped_ptr<NewSessionCdmPromise> promise) {
239  std::string web_session_id(base::UintToString(next_web_session_id_++));
240  valid_sessions_.insert(web_session_id);
241
242  // For now, the AesDecryptor does not care about |init_data_type| or
243  // |session_type|; just resolve the promise and then fire a message event
244  // with the |init_data| as the request.
245  // TODO(jrummell): Validate |init_data_type| and |session_type|.
246  std::vector<uint8> message;
247  if (init_data && init_data_length)
248    message.assign(init_data, init_data + init_data_length);
249
250  promise->resolve(web_session_id);
251
252  session_message_cb_.Run(web_session_id, message, GURL());
253}
254
255void AesDecryptor::LoadSession(const std::string& web_session_id,
256                               scoped_ptr<NewSessionCdmPromise> promise) {
257  // TODO(xhwang): Change this to NOTREACHED() when blink checks for key systems
258  // that do not support loadSession. See http://crbug.com/342481
259  promise->reject(NOT_SUPPORTED_ERROR, 0, "LoadSession() is not supported.");
260}
261
262void AesDecryptor::UpdateSession(const std::string& web_session_id,
263                                 const uint8* response,
264                                 int response_length,
265                                 scoped_ptr<SimpleCdmPromise> promise) {
266  CHECK(response);
267  CHECK_GT(response_length, 0);
268
269  // TODO(jrummell): Convert back to a DCHECK once prefixed EME is removed.
270  if (valid_sessions_.find(web_session_id) == valid_sessions_.end()) {
271    promise->reject(INVALID_ACCESS_ERROR, 0, "Session does not exist.");
272    return;
273  }
274
275  std::string key_string(reinterpret_cast<const char*>(response),
276                         response_length);
277
278  KeyIdAndKeyPairs keys;
279  if (!ExtractKeysFromJWKSet(key_string, &keys)) {
280    promise->reject(
281        INVALID_ACCESS_ERROR, 0, "response is not a valid JSON Web Key Set.");
282    return;
283  }
284
285  // Make sure that at least one key was extracted.
286  if (keys.empty()) {
287    promise->reject(
288        INVALID_ACCESS_ERROR, 0, "response does not contain any keys.");
289    return;
290  }
291
292  for (KeyIdAndKeyPairs::iterator it = keys.begin(); it != keys.end(); ++it) {
293    if (it->second.length() !=
294        static_cast<size_t>(DecryptConfig::kDecryptionKeySize)) {
295      DVLOG(1) << "Invalid key length: " << key_string.length();
296      promise->reject(INVALID_ACCESS_ERROR, 0, "Invalid key length.");
297      return;
298    }
299    if (!AddDecryptionKey(web_session_id, it->first, it->second)) {
300      promise->reject(INVALID_ACCESS_ERROR, 0, "Unable to add key.");
301      return;
302    }
303  }
304
305  {
306    base::AutoLock auto_lock(new_key_cb_lock_);
307
308    if (!new_audio_key_cb_.is_null())
309      new_audio_key_cb_.Run();
310
311    if (!new_video_key_cb_.is_null())
312      new_video_key_cb_.Run();
313  }
314
315  promise->resolve();
316}
317
318void AesDecryptor::ReleaseSession(const std::string& web_session_id,
319                                  scoped_ptr<SimpleCdmPromise> promise) {
320  // Validate that this is a reference to an active session and then forget it.
321  std::set<std::string>::iterator it = valid_sessions_.find(web_session_id);
322  // TODO(jrummell): Convert back to a DCHECK once prefixed EME is removed.
323  if (it == valid_sessions_.end()) {
324    promise->reject(INVALID_ACCESS_ERROR, 0, "Session does not exist.");
325    return;
326  }
327
328  valid_sessions_.erase(it);
329
330  // Close the session.
331  DeleteKeysForSession(web_session_id);
332  promise->resolve();
333  session_closed_cb_.Run(web_session_id);
334}
335
336Decryptor* AesDecryptor::GetDecryptor() {
337  return this;
338}
339
340void AesDecryptor::RegisterNewKeyCB(StreamType stream_type,
341                                    const NewKeyCB& new_key_cb) {
342  base::AutoLock auto_lock(new_key_cb_lock_);
343
344  switch (stream_type) {
345    case kAudio:
346      new_audio_key_cb_ = new_key_cb;
347      break;
348    case kVideo:
349      new_video_key_cb_ = new_key_cb;
350      break;
351    default:
352      NOTREACHED();
353  }
354}
355
356void AesDecryptor::Decrypt(StreamType stream_type,
357                           const scoped_refptr<DecoderBuffer>& encrypted,
358                           const DecryptCB& decrypt_cb) {
359  CHECK(encrypted->decrypt_config());
360
361  scoped_refptr<DecoderBuffer> decrypted;
362  // An empty iv string signals that the frame is unencrypted.
363  if (encrypted->decrypt_config()->iv().empty()) {
364    decrypted = DecoderBuffer::CopyFrom(encrypted->data(),
365                                        encrypted->data_size());
366  } else {
367    const std::string& key_id = encrypted->decrypt_config()->key_id();
368    DecryptionKey* key = GetKey(key_id);
369    if (!key) {
370      DVLOG(1) << "Could not find a matching key for the given key ID.";
371      decrypt_cb.Run(kNoKey, NULL);
372      return;
373    }
374
375    crypto::SymmetricKey* decryption_key = key->decryption_key();
376    decrypted = DecryptData(*encrypted.get(), decryption_key);
377    if (!decrypted.get()) {
378      DVLOG(1) << "Decryption failed.";
379      decrypt_cb.Run(kError, NULL);
380      return;
381    }
382  }
383
384  decrypted->set_timestamp(encrypted->timestamp());
385  decrypted->set_duration(encrypted->duration());
386  decrypt_cb.Run(kSuccess, decrypted);
387}
388
389void AesDecryptor::CancelDecrypt(StreamType stream_type) {
390  // Decrypt() calls the DecryptCB synchronously so there's nothing to cancel.
391}
392
393void AesDecryptor::InitializeAudioDecoder(const AudioDecoderConfig& config,
394                                          const DecoderInitCB& init_cb) {
395  // AesDecryptor does not support audio decoding.
396  init_cb.Run(false);
397}
398
399void AesDecryptor::InitializeVideoDecoder(const VideoDecoderConfig& config,
400                                          const DecoderInitCB& init_cb) {
401  // AesDecryptor does not support video decoding.
402  init_cb.Run(false);
403}
404
405void AesDecryptor::DecryptAndDecodeAudio(
406    const scoped_refptr<DecoderBuffer>& encrypted,
407    const AudioDecodeCB& audio_decode_cb) {
408  NOTREACHED() << "AesDecryptor does not support audio decoding";
409}
410
411void AesDecryptor::DecryptAndDecodeVideo(
412    const scoped_refptr<DecoderBuffer>& encrypted,
413    const VideoDecodeCB& video_decode_cb) {
414  NOTREACHED() << "AesDecryptor does not support video decoding";
415}
416
417void AesDecryptor::ResetDecoder(StreamType stream_type) {
418  NOTREACHED() << "AesDecryptor does not support audio/video decoding";
419}
420
421void AesDecryptor::DeinitializeDecoder(StreamType stream_type) {
422  NOTREACHED() << "AesDecryptor does not support audio/video decoding";
423}
424
425bool AesDecryptor::AddDecryptionKey(const std::string& web_session_id,
426                                    const std::string& key_id,
427                                    const std::string& key_string) {
428  scoped_ptr<DecryptionKey> decryption_key(new DecryptionKey(key_string));
429  if (!decryption_key->Init()) {
430    DVLOG(1) << "Could not initialize decryption key.";
431    return false;
432  }
433
434  base::AutoLock auto_lock(key_map_lock_);
435  KeyIdToSessionKeysMap::iterator key_id_entry = key_map_.find(key_id);
436  if (key_id_entry != key_map_.end()) {
437    key_id_entry->second->Insert(web_session_id, decryption_key.Pass());
438    return true;
439  }
440
441  // |key_id| not found, so need to create new entry.
442  scoped_ptr<SessionIdDecryptionKeyMap> inner_map(
443      new SessionIdDecryptionKeyMap());
444  inner_map->Insert(web_session_id, decryption_key.Pass());
445  key_map_.add(key_id, inner_map.Pass());
446  return true;
447}
448
449AesDecryptor::DecryptionKey* AesDecryptor::GetKey(
450    const std::string& key_id) const {
451  base::AutoLock auto_lock(key_map_lock_);
452  KeyIdToSessionKeysMap::const_iterator key_id_found = key_map_.find(key_id);
453  if (key_id_found == key_map_.end())
454    return NULL;
455
456  // Return the key from the "latest" session_id entry.
457  return key_id_found->second->LatestDecryptionKey();
458}
459
460void AesDecryptor::DeleteKeysForSession(const std::string& web_session_id) {
461  base::AutoLock auto_lock(key_map_lock_);
462
463  // Remove all keys associated with |web_session_id|. Since the data is
464  // optimized for access in GetKey(), we need to look at each entry in
465  // |key_map_|.
466  KeyIdToSessionKeysMap::iterator it = key_map_.begin();
467  while (it != key_map_.end()) {
468    it->second->Erase(web_session_id);
469    if (it->second->Empty()) {
470      // Need to get rid of the entry for this key_id. This will mess up the
471      // iterator, so we need to increment it first.
472      KeyIdToSessionKeysMap::iterator current = it;
473      ++it;
474      key_map_.erase(current);
475    } else {
476      ++it;
477    }
478  }
479}
480
481AesDecryptor::DecryptionKey::DecryptionKey(const std::string& secret)
482    : secret_(secret) {
483}
484
485AesDecryptor::DecryptionKey::~DecryptionKey() {}
486
487bool AesDecryptor::DecryptionKey::Init() {
488  CHECK(!secret_.empty());
489  decryption_key_.reset(crypto::SymmetricKey::Import(
490      crypto::SymmetricKey::AES, secret_));
491  if (!decryption_key_)
492    return false;
493  return true;
494}
495
496}  // namespace media
497