1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "crypto/rsa_private_key.h"
6
7#include <list>
8
9#include "base/logging.h"
10#include "base/memory/scoped_ptr.h"
11#include "base/string_util.h"
12
13namespace {
14  // Helper for error handling during key import.
15#define READ_ASSERT(truth) \
16  if (!(truth)) { \
17  NOTREACHED(); \
18  return false; \
19  }
20}  // namespace
21
22namespace crypto {
23
24// static
25RSAPrivateKey* RSAPrivateKey::Create(uint16 num_bits) {
26  scoped_ptr<RSAPrivateKey> result(new RSAPrivateKey);
27  if (!result->InitProvider())
28    return NULL;
29
30  DWORD flags = CRYPT_EXPORTABLE;
31
32  // The size is encoded as the upper 16 bits of the flags. :: sigh ::.
33  flags |= (num_bits << 16);
34  if (!CryptGenKey(result->provider_, CALG_RSA_SIGN, flags,
35                   result->key_.receive()))
36    return NULL;
37
38  return result.release();
39}
40
41// static
42RSAPrivateKey* RSAPrivateKey::CreateSensitive(uint16 num_bits) {
43  NOTIMPLEMENTED();
44  return NULL;
45}
46
47// static
48RSAPrivateKey* RSAPrivateKey::CreateFromPrivateKeyInfo(
49    const std::vector<uint8>& input) {
50  scoped_ptr<RSAPrivateKey> result(new RSAPrivateKey);
51  if (!result->InitProvider())
52    return NULL;
53
54  PrivateKeyInfoCodec pki(false);  // Little-Endian
55  pki.Import(input);
56
57  int blob_size = sizeof(PUBLICKEYSTRUC) +
58                  sizeof(RSAPUBKEY) +
59                  pki.modulus()->size() +
60                  pki.prime1()->size() +
61                  pki.prime2()->size() +
62                  pki.exponent1()->size() +
63                  pki.exponent2()->size() +
64                  pki.coefficient()->size() +
65                  pki.private_exponent()->size();
66  scoped_array<BYTE> blob(new BYTE[blob_size]);
67
68  uint8* dest = blob.get();
69  PUBLICKEYSTRUC* public_key_struc = reinterpret_cast<PUBLICKEYSTRUC*>(dest);
70  public_key_struc->bType = PRIVATEKEYBLOB;
71  public_key_struc->bVersion = 0x02;
72  public_key_struc->reserved = 0;
73  public_key_struc->aiKeyAlg = CALG_RSA_SIGN;
74  dest += sizeof(PUBLICKEYSTRUC);
75
76  RSAPUBKEY* rsa_pub_key = reinterpret_cast<RSAPUBKEY*>(dest);
77  rsa_pub_key->magic = 0x32415352;
78  rsa_pub_key->bitlen = pki.modulus()->size() * 8;
79  int public_exponent_int = 0;
80  for (size_t i = pki.public_exponent()->size(); i > 0; --i) {
81    public_exponent_int <<= 8;
82    public_exponent_int |= (*pki.public_exponent())[i - 1];
83  }
84  rsa_pub_key->pubexp = public_exponent_int;
85  dest += sizeof(RSAPUBKEY);
86
87  memcpy(dest, &pki.modulus()->front(), pki.modulus()->size());
88  dest += pki.modulus()->size();
89  memcpy(dest, &pki.prime1()->front(), pki.prime1()->size());
90  dest += pki.prime1()->size();
91  memcpy(dest, &pki.prime2()->front(), pki.prime2()->size());
92  dest += pki.prime2()->size();
93  memcpy(dest, &pki.exponent1()->front(), pki.exponent1()->size());
94  dest += pki.exponent1()->size();
95  memcpy(dest, &pki.exponent2()->front(), pki.exponent2()->size());
96  dest += pki.exponent2()->size();
97  memcpy(dest, &pki.coefficient()->front(), pki.coefficient()->size());
98  dest += pki.coefficient()->size();
99  memcpy(dest, &pki.private_exponent()->front(),
100         pki.private_exponent()->size());
101  dest += pki.private_exponent()->size();
102
103  READ_ASSERT(dest == blob.get() + blob_size);
104  if (!CryptImportKey(result->provider_,
105                      reinterpret_cast<uint8*>(public_key_struc), blob_size, 0,
106                      CRYPT_EXPORTABLE, result->key_.receive()))
107    return NULL;
108
109  return result.release();
110}
111
112// static
113RSAPrivateKey* RSAPrivateKey::CreateSensitiveFromPrivateKeyInfo(
114    const std::vector<uint8>& input) {
115  NOTIMPLEMENTED();
116  return NULL;
117}
118
119// static
120RSAPrivateKey* RSAPrivateKey::FindFromPublicKeyInfo(
121    const std::vector<uint8>& input) {
122  NOTIMPLEMENTED();
123  return NULL;
124}
125
126RSAPrivateKey::RSAPrivateKey() : provider_(NULL), key_(NULL) {}
127
128RSAPrivateKey::~RSAPrivateKey() {}
129
130bool RSAPrivateKey::InitProvider() {
131  return FALSE != CryptAcquireContext(provider_.receive(), NULL, NULL,
132                                      PROV_RSA_FULL, CRYPT_VERIFYCONTEXT);
133}
134
135bool RSAPrivateKey::ExportPrivateKey(std::vector<uint8>* output) {
136  // Export the key
137  DWORD blob_length = 0;
138  if (!CryptExportKey(key_, 0, PRIVATEKEYBLOB, 0, NULL, &blob_length)) {
139    NOTREACHED();
140    return false;
141  }
142
143  scoped_array<uint8> blob(new uint8[blob_length]);
144  if (!CryptExportKey(key_, 0, PRIVATEKEYBLOB, 0, blob.get(), &blob_length)) {
145    NOTREACHED();
146    return false;
147  }
148
149  uint8* pos = blob.get();
150  PUBLICKEYSTRUC *publickey_struct = reinterpret_cast<PUBLICKEYSTRUC*>(pos);
151  pos += sizeof(PUBLICKEYSTRUC);
152
153  RSAPUBKEY *rsa_pub_key = reinterpret_cast<RSAPUBKEY*>(pos);
154  pos += sizeof(RSAPUBKEY);
155
156  int mod_size = rsa_pub_key->bitlen / 8;
157  int primes_size = rsa_pub_key->bitlen / 16;
158
159  PrivateKeyInfoCodec pki(false);  // Little-Endian
160
161  pki.modulus()->assign(pos, pos + mod_size);
162  pos += mod_size;
163
164  pki.prime1()->assign(pos, pos + primes_size);
165  pos += primes_size;
166  pki.prime2()->assign(pos, pos + primes_size);
167  pos += primes_size;
168
169  pki.exponent1()->assign(pos, pos + primes_size);
170  pos += primes_size;
171  pki.exponent2()->assign(pos, pos + primes_size);
172  pos += primes_size;
173
174  pki.coefficient()->assign(pos, pos + primes_size);
175  pos += primes_size;
176
177  pki.private_exponent()->assign(pos, pos + mod_size);
178  pos += mod_size;
179
180  pki.public_exponent()->assign(reinterpret_cast<uint8*>(&rsa_pub_key->pubexp),
181      reinterpret_cast<uint8*>(&rsa_pub_key->pubexp) + 4);
182
183  CHECK_EQ(pos - blob_length, reinterpret_cast<BYTE*>(publickey_struct));
184
185  return pki.Export(output);
186}
187
188bool RSAPrivateKey::ExportPublicKey(std::vector<uint8>* output) {
189  DWORD key_info_len;
190  if (!CryptExportPublicKeyInfo(
191      provider_, AT_SIGNATURE, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING,
192      NULL, &key_info_len)) {
193    NOTREACHED();
194    return false;
195  }
196
197  scoped_array<uint8> key_info(new uint8[key_info_len]);
198  if (!CryptExportPublicKeyInfo(
199      provider_, AT_SIGNATURE, X509_ASN_ENCODING | PKCS_7_ASN_ENCODING,
200      reinterpret_cast<CERT_PUBLIC_KEY_INFO*>(key_info.get()), &key_info_len)) {
201    NOTREACHED();
202    return false;
203  }
204
205  DWORD encoded_length;
206  if (!CryptEncodeObject(
207      X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, X509_PUBLIC_KEY_INFO,
208      reinterpret_cast<CERT_PUBLIC_KEY_INFO*>(key_info.get()), NULL,
209      &encoded_length)) {
210    NOTREACHED();
211    return false;
212  }
213
214  scoped_array<BYTE> encoded(new BYTE[encoded_length]);
215  if (!CryptEncodeObject(
216      X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, X509_PUBLIC_KEY_INFO,
217      reinterpret_cast<CERT_PUBLIC_KEY_INFO*>(key_info.get()), encoded.get(),
218      &encoded_length)) {
219    NOTREACHED();
220    return false;
221  }
222
223  for (size_t i = 0; i < encoded_length; ++i)
224    output->push_back(encoded[i]);
225
226  return true;
227}
228
229}  // namespace crypto
230