1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "crypto/encryptor.h"
6
7#include <cryptohi.h>
8#include <vector>
9
10#include "base/logging.h"
11#include "crypto/nss_util.h"
12#include "crypto/symmetric_key.h"
13
14namespace crypto {
15
16namespace {
17
18inline CK_MECHANISM_TYPE GetMechanism(Encryptor::Mode mode) {
19  switch (mode) {
20    case Encryptor::CBC:
21      return CKM_AES_CBC_PAD;
22    case Encryptor::CTR:
23      // AES-CTR encryption uses ECB encryptor as a building block since
24      // NSS doesn't support CTR encryption mode.
25      return CKM_AES_ECB;
26    default:
27      NOTREACHED() << "Unsupported mode of operation";
28      break;
29  }
30  return static_cast<CK_MECHANISM_TYPE>(-1);
31}
32
33}  // namespace
34
35Encryptor::Encryptor()
36    : key_(NULL),
37      mode_(CBC) {
38  EnsureNSSInit();
39}
40
41Encryptor::~Encryptor() {
42}
43
44bool Encryptor::Init(SymmetricKey* key,
45                     Mode mode,
46                     const base::StringPiece& iv) {
47  DCHECK(key);
48  DCHECK(CBC == mode || CTR == mode) << "Unsupported mode of operation";
49
50  key_ = key;
51  mode_ = mode;
52
53  if (mode == CBC && iv.size() != AES_BLOCK_SIZE)
54    return false;
55
56  switch (mode) {
57    case CBC:
58      SECItem iv_item;
59      iv_item.type = siBuffer;
60      iv_item.data = reinterpret_cast<unsigned char*>(
61          const_cast<char *>(iv.data()));
62      iv_item.len = iv.size();
63
64      param_.reset(PK11_ParamFromIV(GetMechanism(mode), &iv_item));
65      break;
66    case CTR:
67      param_.reset(PK11_ParamFromIV(GetMechanism(mode), NULL));
68      break;
69  }
70
71  return param_ != NULL;
72}
73
74bool Encryptor::Encrypt(const base::StringPiece& plaintext,
75                        std::string* ciphertext) {
76  CHECK(!plaintext.empty() || (mode_ == CBC));
77  ScopedPK11Context context(PK11_CreateContextBySymKey(GetMechanism(mode_),
78                                                       CKA_ENCRYPT,
79                                                       key_->key(),
80                                                       param_.get()));
81  if (!context.get())
82    return false;
83
84  return (mode_ == CTR) ?
85      CryptCTR(context.get(), plaintext, ciphertext) :
86      Crypt(context.get(), plaintext, ciphertext);
87}
88
89bool Encryptor::Decrypt(const base::StringPiece& ciphertext,
90                        std::string* plaintext) {
91  CHECK(!ciphertext.empty());
92  ScopedPK11Context context(PK11_CreateContextBySymKey(
93      GetMechanism(mode_), (mode_ == CTR ? CKA_ENCRYPT : CKA_DECRYPT),
94      key_->key(), param_.get()));
95  if (!context.get())
96    return false;
97
98  if (mode_ == CTR)
99    return CryptCTR(context.get(), ciphertext, plaintext);
100
101  if (ciphertext.size() % AES_BLOCK_SIZE != 0) {
102    // Decryption will fail if the input is not a multiple of the block size.
103    // PK11_CipherOp has a bug where it will do an invalid memory access before
104    // the start of the input, so avoid calling it. (NSS bug 922780).
105    plaintext->clear();
106    return false;
107  }
108
109  return Crypt(context.get(), ciphertext, plaintext);
110}
111
112bool Encryptor::Crypt(PK11Context* context,
113                      const base::StringPiece& input,
114                      std::string* output) {
115  size_t output_len = input.size() + AES_BLOCK_SIZE;
116  CHECK_GT(output_len, input.size());
117
118  output->resize(output_len);
119  uint8* output_data =
120      reinterpret_cast<uint8*>(const_cast<char*>(output->data()));
121
122  int input_len = input.size();
123  uint8* input_data =
124      reinterpret_cast<uint8*>(const_cast<char*>(input.data()));
125
126  int op_len;
127  SECStatus rv = PK11_CipherOp(context,
128                               output_data,
129                               &op_len,
130                               output_len,
131                               input_data,
132                               input_len);
133
134  if (SECSuccess != rv) {
135    output->clear();
136    return false;
137  }
138
139  unsigned int digest_len;
140  rv = PK11_DigestFinal(context,
141                        output_data + op_len,
142                        &digest_len,
143                        output_len - op_len);
144  if (SECSuccess != rv) {
145    output->clear();
146    return false;
147  }
148
149  output->resize(op_len + digest_len);
150  return true;
151}
152
153bool Encryptor::CryptCTR(PK11Context* context,
154                         const base::StringPiece& input,
155                         std::string* output) {
156  if (!counter_.get()) {
157    LOG(ERROR) << "Counter value not set in CTR mode.";
158    return false;
159  }
160
161  size_t output_len = ((input.size() + AES_BLOCK_SIZE - 1) / AES_BLOCK_SIZE) *
162      AES_BLOCK_SIZE;
163  CHECK_GE(output_len, input.size());
164  output->resize(output_len);
165  uint8* output_data =
166      reinterpret_cast<uint8*>(const_cast<char*>(output->data()));
167
168  size_t mask_len;
169  bool ret = GenerateCounterMask(input.size(), output_data, &mask_len);
170  if (!ret)
171    return false;
172
173  CHECK_EQ(mask_len, output_len);
174  int op_len;
175  SECStatus rv = PK11_CipherOp(context,
176                               output_data,
177                               &op_len,
178                               output_len,
179                               output_data,
180                               mask_len);
181  if (SECSuccess != rv)
182    return false;
183  CHECK_EQ(static_cast<int>(mask_len), op_len);
184
185  unsigned int digest_len;
186  rv = PK11_DigestFinal(context,
187                        NULL,
188                        &digest_len,
189                        0);
190  if (SECSuccess != rv)
191    return false;
192  CHECK(!digest_len);
193
194  // Use |output_data| to mask |input|.
195  MaskMessage(
196      reinterpret_cast<uint8*>(const_cast<char*>(input.data())),
197      input.length(), output_data, output_data);
198  output->resize(input.length());
199  return true;
200}
201
202}  // namespace crypto
203