1#include "gtest/gtest.h"
2#include "avb_tools.h"
3#include "keymaster_tools.h"
4#include "nugget_tools.h"
5#include "nugget/app/keymaster/keymaster.pb.h"
6#include "nugget/app/keymaster/keymaster_defs.pb.h"
7#include "nugget/app/keymaster/keymaster_types.pb.h"
8#include "Keymaster.client.h"
9#include "util.h"
10
11#include "src/blob.h"
12#include "src/macros.h"
13#include "src/test-data/test-keys/rsa.h"
14
15#include "openssl/bn.h"
16#include "openssl/ec_key.h"
17#include "openssl/nid.h"
18
19#include <sstream>
20
21using std::cout;
22using std::string;
23using std::stringstream;
24using std::unique_ptr;
25
26using namespace nugget::app::keymaster;
27
28using namespace test_data;
29
30namespace {
31
32class ImportKeyTest: public testing::Test {
33 protected:
34  static unique_ptr<nos::NuggetClientInterface> client;
35  static unique_ptr<Keymaster> service;
36  static unique_ptr<test_harness::TestHarness> uart_printer;
37
38  static void SetUpTestCase();
39  static void TearDownTestCase();
40
41  void initRSARequest(ImportKeyRequest *request, Algorithm alg) {
42    KeyParameters *params = request->mutable_params();
43    KeyParameter *param = params->add_params();
44    param->set_tag(Tag::ALGORITHM);
45    param->set_integer((uint32_t)alg);
46  }
47
48  void initRSARequest(ImportKeyRequest *request, Algorithm alg, int key_size) {
49    initRSARequest(request, alg);
50
51    if (key_size >= 0) {
52      KeyParameters *params = request->mutable_params();
53      KeyParameter *param = params->add_params();
54      param->set_tag(Tag::KEY_SIZE);
55      param->set_integer(key_size);
56    }
57  }
58
59  void initRSARequest(ImportKeyRequest *request, Algorithm alg, int key_size,
60                   int public_exponent_tag) {
61    initRSARequest(request, alg, key_size);
62
63    if (public_exponent_tag >= 0) {
64      KeyParameters *params = request->mutable_params();
65      KeyParameter *param = params->add_params();
66      param->set_tag(Tag::RSA_PUBLIC_EXPONENT);
67      param->set_long_integer(public_exponent_tag);
68    }
69  }
70
71  void initRSARequest(ImportKeyRequest *request, Algorithm alg, int key_size,
72                   int public_exponent_tag, uint32_t public_exponent,
73                   const string& d, const string& n) {
74    initRSARequest(request, alg, key_size, public_exponent_tag);
75
76    request->mutable_rsa()->set_e(public_exponent);
77    request->mutable_rsa()->set_d(d);
78    request->mutable_rsa()->set_n(n);
79  }
80};
81
82unique_ptr<nos::NuggetClientInterface> ImportKeyTest::client;
83unique_ptr<Keymaster> ImportKeyTest::service;
84unique_ptr<test_harness::TestHarness> ImportKeyTest::uart_printer;
85
86void ImportKeyTest::SetUpTestCase() {
87  uart_printer = test_harness::TestHarness::MakeUnique();
88
89  client = nugget_tools::MakeNuggetClient();
90  client->Open();
91  EXPECT_TRUE(client->IsOpen()) << "Unable to connect";
92
93  service.reset(new Keymaster(*client));
94
95  // Do setup that is normally done by the bootloader.
96  keymaster_tools::SetRootOfTrust(client.get());
97}
98
99void ImportKeyTest::TearDownTestCase() {
100  client->Close();
101  client = unique_ptr<nos::NuggetClientInterface>();
102
103  uart_printer = nullptr;
104}
105
106// TODO: refactor into import key tests.
107
108// Failure cases.
109TEST_F(ImportKeyTest, AlgorithmMissingFails) {
110  ImportKeyRequest request;
111  ImportKeyResponse response;
112
113  KeyParameters *params = request.mutable_params();
114
115  /* Algorithm tag is unspecified, import should fail. */
116  KeyParameter *param = params->add_params();
117  param->set_tag(Tag::KEY_SIZE);
118  param->set_integer(512);
119
120  param = params->add_params();
121  param->set_tag(Tag::RSA_PUBLIC_EXPONENT);
122  param->set_long_integer(3);
123
124  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
125  EXPECT_EQ((ErrorCode)response.error_code(), ErrorCode::INVALID_ARGUMENT);
126}
127
128// RSA
129
130TEST_F(ImportKeyTest, RSAInvalidKeySizeFails) {
131  ImportKeyRequest request;
132  ImportKeyResponse response;
133
134  initRSARequest(&request, Algorithm::RSA, 256, 3);
135
136  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
137  EXPECT_EQ((ErrorCode)response.error_code(), ErrorCode::UNSUPPORTED_KEY_SIZE);
138}
139
140TEST_F(ImportKeyTest, RSAInvalidPublicExponentFails) {
141  ImportKeyRequest request;
142  ImportKeyResponse response;
143
144  // Unsupported exponent
145  initRSARequest(&request, Algorithm::RSA, 512, 2, 2,
146                 string(64, '\0'), string(64, '\0'));
147
148  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
149  EXPECT_EQ((ErrorCode)response.error_code(),
150            ErrorCode::UNSUPPORTED_KEY_SIZE);
151}
152
153TEST_F(ImportKeyTest, RSAKeySizeTagMisatchNFails) {
154  ImportKeyRequest request;
155  ImportKeyResponse response;
156
157  // N does not match KEY_SIZE.
158  initRSARequest(&request, Algorithm::RSA, 512, 3, 3,
159                 string(64, '\0'), string(63, '\0'));
160  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
161  EXPECT_EQ((ErrorCode)response.error_code(),
162            ErrorCode::IMPORT_PARAMETER_MISMATCH);
163}
164
165TEST_F(ImportKeyTest, RSAKeySizeTagMisatchDFails) {
166  ImportKeyRequest request;
167  ImportKeyResponse response;
168
169  // D does not match KEY_SIZE.
170  initRSARequest(&request, Algorithm::RSA, 512, 3, 3,
171                 string(63, '\0'), string(64, '\0'));
172  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
173  EXPECT_EQ((ErrorCode)response.error_code(),
174            ErrorCode::IMPORT_PARAMETER_MISMATCH);
175}
176
177TEST_F(ImportKeyTest, RSAPublicExponentTagMisatchFails) {
178  ImportKeyRequest request;
179  ImportKeyResponse response;
180
181  // e does not match PUBLIC_EXPONENT tag.
182  initRSARequest(&request, Algorithm::RSA, 512, 3, 2,
183                 string(64, '\0'), string(64, '\0'));
184  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
185  EXPECT_EQ((ErrorCode)response.error_code(),
186            ErrorCode::IMPORT_PARAMETER_MISMATCH);
187}
188
189TEST_F(ImportKeyTest, RSA1024BadEFails) {
190  ImportKeyRequest request;
191  ImportKeyResponse response;
192
193  // Mis-matched e.
194  const string d((const char *)RSA_1024_D, sizeof(RSA_1024_D));
195  const string N((const char *)RSA_1024_N, sizeof(RSA_1024_N));
196  initRSARequest(&request, Algorithm::RSA, 1024, 3, 3, d, N);
197
198  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
199  EXPECT_EQ((ErrorCode)response.error_code(), ErrorCode::INVALID_ARGUMENT);
200}
201
202TEST_F(ImportKeyTest, RSA1024BadDFails) {
203  ImportKeyRequest request;
204  ImportKeyResponse response;
205
206  const string d(string("\x01") +  /* Twiddle LSB of D. */
207                 string((const char *)RSA_1024_D, sizeof(RSA_1024_D) - 1));
208  const string N((const char *)RSA_1024_N, sizeof(RSA_1024_N));
209  initRSARequest(&request, Algorithm::RSA, 1024, 65537, 65537, d, N);
210
211  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
212  EXPECT_EQ((ErrorCode)response.error_code(), ErrorCode::INVALID_ARGUMENT);
213}
214
215TEST_F(ImportKeyTest, RSA1024BadNFails) {
216  ImportKeyRequest request;
217  ImportKeyResponse response;
218
219  const string d((const char *)RSA_1024_D, sizeof(RSA_1024_D));
220  const string N(string("\x01") +  /* Twiddle LSB of N. */
221                 string((const char *)RSA_1024_N, sizeof(RSA_1024_N) - 1));
222  initRSARequest(&request, Algorithm::RSA, 1024, 65537, 65537, d, N);
223
224  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
225  EXPECT_EQ((ErrorCode)response.error_code(), ErrorCode::INVALID_ARGUMENT);
226}
227
228TEST_F(ImportKeyTest, RSASuccess) {
229  ImportKeyRequest request;
230  ImportKeyResponse response;
231
232  initRSARequest(&request, Algorithm::RSA);
233  KeyParameters *params = request.mutable_params();
234  KeyParameter *param = params->add_params();
235  for (size_t i = 0; i < ARRAYSIZE(TEST_RSA_KEYS); i++) {
236    param->set_tag(Tag::RSA_PUBLIC_EXPONENT);
237    param->set_long_integer(TEST_RSA_KEYS[i].e);
238
239    request.mutable_rsa()->set_e(TEST_RSA_KEYS[i].e);
240    request.mutable_rsa()->set_d(TEST_RSA_KEYS[i].d, TEST_RSA_KEYS[i].size);
241    request.mutable_rsa()->set_n(TEST_RSA_KEYS[i].n, TEST_RSA_KEYS[i].size);
242
243    stringstream ss;
244    ss << "Failed at TEST_RSA_KEYS[" << i << "]";
245    ASSERT_NO_ERROR(service->ImportKey(request, &response), ss.str());
246    EXPECT_EQ((ErrorCode)response.error_code(), ErrorCode::OK)
247        << "Failed at TEST_RSA_KEYS[" << i << "]";
248
249    /* TODO: add separate tests for blobs! */
250    EXPECT_EQ(sizeof(struct km_blob), response.blob().blob().size());
251    const struct km_blob *blob =
252        (const struct km_blob *)response.blob().blob().data();
253    EXPECT_EQ(memcmp(blob->b.key.rsa.N_bytes, TEST_RSA_KEYS[i].n,
254                     TEST_RSA_KEYS[i].size), 0);
255    EXPECT_EQ(memcmp(blob->b.key.rsa.d_bytes,
256                     TEST_RSA_KEYS[i].d, TEST_RSA_KEYS[i].size), 0);
257  }
258}
259
260TEST_F(ImportKeyTest, RSA1024OptionalParamsAbsentSuccess) {
261  ImportKeyRequest request;
262  ImportKeyResponse response;
263
264  initRSARequest(&request, Algorithm::RSA, -1, -1, 65537,
265                 string((const char *)RSA_1024_D, sizeof(RSA_1024_D)),
266                 string((const char *)RSA_1024_N, sizeof(RSA_1024_N)));
267
268  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
269  EXPECT_EQ((ErrorCode)response.error_code(), ErrorCode::OK);
270
271  EXPECT_EQ(sizeof(struct km_blob), response.blob().blob().size());
272  const struct km_blob *blob =
273      (const struct km_blob *)response.blob().blob().data();
274  EXPECT_EQ(memcmp(blob->b.key.rsa.N_bytes, RSA_1024_N,
275                   sizeof(RSA_1024_N)), 0);
276  EXPECT_EQ(memcmp(blob->b.key.rsa.d_bytes,
277                   RSA_1024_D, sizeof(RSA_1024_D)), 0);
278}
279
280// EC
281
282TEST_F(ImportKeyTest, ECMissingCurveIdTagFails) {
283  ImportKeyRequest request;
284  ImportKeyResponse response;
285
286  KeyParameters *params = request.mutable_params();
287  KeyParameter *param = params->add_params();
288  param->set_tag(Tag::ALGORITHM);
289  param->set_integer((uint32_t)Algorithm::EC);
290
291  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
292  EXPECT_EQ((ErrorCode)response.error_code(), ErrorCode::INVALID_ARGUMENT);
293}
294
295TEST_F(ImportKeyTest, ECMisMatchedCurveIdTagFails) {
296  ImportKeyRequest request;
297  ImportKeyResponse response;
298
299  KeyParameters *params = request.mutable_params();
300  KeyParameter *param = params->add_params();
301  param->set_tag(Tag::ALGORITHM);
302  param->set_integer((uint32_t)Algorithm::EC);
303
304  param = params->add_params();
305  param->set_tag(Tag::EC_CURVE);
306  param->set_integer((uint32_t)EcCurve::P_256);
307
308  request.mutable_ec()->set_curve_id(((uint32_t)EcCurve::P_256) + 1);
309
310  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
311  EXPECT_EQ((ErrorCode)response.error_code(), ErrorCode::INVALID_ARGUMENT);
312}
313
314TEST_F(ImportKeyTest, ECMisMatchedKeySizeTagCurveTagFails) {
315  ImportKeyRequest request;
316  ImportKeyResponse response;
317
318  KeyParameters *params = request.mutable_params();
319  KeyParameter *param = params->add_params();
320  param->set_tag(Tag::ALGORITHM);
321  param->set_integer((uint32_t)Algorithm::EC);
322
323  param = params->add_params();
324  param->set_tag(Tag::EC_CURVE);
325  param->set_integer((uint32_t)EcCurve::P_256);
326
327  param = params->add_params();
328  param->set_tag(Tag::KEY_SIZE);
329  param->set_integer((uint32_t)384);  /* Should be 256 */
330
331  request.mutable_ec()->set_curve_id((uint32_t)EcCurve::P_256);
332
333  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
334  EXPECT_EQ((ErrorCode)response.error_code(), ErrorCode::INVALID_ARGUMENT);
335}
336
337TEST_F(ImportKeyTest, ECMisMatchedP256KeySizeFails) {
338  ImportKeyRequest request;
339  ImportKeyResponse response;
340
341  KeyParameters *params = request.mutable_params();
342  KeyParameter *param = params->add_params();
343  param->set_tag(Tag::ALGORITHM);
344  param->set_integer((uint32_t)Algorithm::EC);
345
346  param = params->add_params();
347  param->set_tag(Tag::EC_CURVE);
348  param->set_integer((uint32_t)EcCurve::P_256);
349
350  request.mutable_ec()->set_curve_id((uint32_t)EcCurve::P_256);
351  request.mutable_ec()->set_d(string((256 >> 3) - 1, '\0'));
352  request.mutable_ec()->set_x(string((256 >> 3), '\0'));
353  request.mutable_ec()->set_y(string((256 >> 3), '\0'));
354
355  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
356  EXPECT_EQ((ErrorCode)response.error_code(), ErrorCode::INVALID_ARGUMENT);
357}
358
359// TODO: bad key tests.  invalid d, {x,y} not on curve, d, xy mismatched.
360TEST_F(ImportKeyTest, ECP256BadKeyFails) {
361  ImportKeyRequest request;
362  ImportKeyResponse response;
363
364  KeyParameters *params = request.mutable_params();
365  KeyParameter *param = params->add_params();
366  param->set_tag(Tag::ALGORITHM);
367  param->set_integer((uint32_t)Algorithm::EC);
368
369  param = params->add_params();
370  param->set_tag(Tag::EC_CURVE);
371  param->set_integer((uint32_t)EcCurve::P_256);
372
373  request.mutable_ec()->set_curve_id((uint32_t)EcCurve::P_256);
374  request.mutable_ec()->set_d(string((256 >> 3), '\0'));
375  request.mutable_ec()->set_x(string((256 >> 3), '\0'));
376  request.mutable_ec()->set_y(string((256 >> 3), '\0'));
377
378  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
379  EXPECT_EQ((ErrorCode)response.error_code(), ErrorCode::INVALID_ARGUMENT);
380}
381
382TEST_F (ImportKeyTest, ImportECP256KeySuccess) {
383  // Generate an EC key.
384  // TODO: just hardcode a test key.
385  bssl::UniquePtr<EC_KEY> ec(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1));
386  EXPECT_EQ(EC_KEY_generate_key(ec.get()), 1);
387  const EC_GROUP *group = EC_KEY_get0_group(ec.get());
388  const BIGNUM *d = EC_KEY_get0_private_key(ec.get());
389  const EC_POINT *point = EC_KEY_get0_public_key(ec.get());
390  bssl::UniquePtr<BIGNUM> x(BN_new());
391  bssl::UniquePtr<BIGNUM> y(BN_new());
392  EXPECT_EQ(EC_POINT_get_affine_coordinates_GFp(
393      group, point, x.get(), y.get(), NULL), 1);
394
395  // Turn d, x, y into binary strings.
396  const size_t field_size = (EC_GROUP_get_degree(group) + 7) >> 3;
397  std::unique_ptr<uint8_t []> dstr(new uint8_t[field_size]);
398  std::unique_ptr<uint8_t []> xstr(new uint8_t[field_size]);
399  std::unique_ptr<uint8_t []> ystr(new uint8_t[field_size]);
400
401  EXPECT_EQ(BN_bn2le_padded(dstr.get(), field_size, d), 1);
402  EXPECT_EQ(BN_bn2le_padded(xstr.get(), field_size, x.get()), 1);
403  EXPECT_EQ(BN_bn2le_padded(ystr.get(), field_size, y.get()), 1);
404
405  ImportKeyRequest request;
406  ImportKeyResponse response;
407
408  KeyParameters *params = request.mutable_params();
409  KeyParameter *param = params->add_params();
410  param->set_tag(Tag::ALGORITHM);
411  param->set_integer((uint32_t)Algorithm::EC);
412
413  param = params->add_params();
414  param->set_tag(Tag::EC_CURVE);
415  param->set_integer((uint32_t)EcCurve::P_256);
416
417  param = params->add_params();
418  param->set_tag(Tag::KEY_SIZE);
419  param->set_integer((uint32_t)256);
420
421  request.mutable_ec()->set_curve_id((uint32_t)EcCurve::P_256);
422  request.mutable_ec()->set_d(dstr.get(), field_size);
423  request.mutable_ec()->set_x(xstr.get(), field_size);
424  request.mutable_ec()->set_y(ystr.get(), field_size);
425
426  ASSERT_NO_ERROR(service->ImportKey(request, &response), "");
427  EXPECT_EQ((ErrorCode)response.error_code(), ErrorCode::OK);
428}
429
430// TODO: add tests for symmetric key import.
431
432}  // namespace
433