1
2#include <gtest/gtest.h>
3#include <openssl/aes.h>
4
5#include <algorithm>
6#include <fstream>
7#include <iostream>
8#include <memory>
9#include <random>
10
11#include "nugget/app/protoapi/control.pb.h"
12#include "nugget/app/protoapi/header.pb.h"
13#include "nugget/app/protoapi/testing_api.pb.h"
14#include "src/util.h"
15
16#ifdef ANDROID
17#define FLAGS_nos_test_dump_protos false
18#else
19#include <gflags/gflags.h>
20
21DEFINE_bool(nos_test_dump_protos, false, "Dump binary protobufs to a file.");
22#endif  // ANDROID
23
24using nugget::app::protoapi::AesCbcEncryptTest;
25using nugget::app::protoapi::AesCbcEncryptTestResult;
26using nugget::app::protoapi::APImessageID;
27using nugget::app::protoapi::DcryptError;
28using nugget::app::protoapi::KeySize;
29using nugget::app::protoapi::Notice;
30using nugget::app::protoapi::NoticeCode;
31using nugget::app::protoapi::OneofTestParametersCase;
32using nugget::app::protoapi::OneofTestResultsCase;
33using nugget::app::protoapi::TrngTest;
34using nugget::app::protoapi::TrngTestResult;
35using std::cout;
36using std::vector;
37using std::unique_ptr;
38using test_harness::TestHarness;
39
40#define ASSERT_NO_TH_ERROR(code) \
41  ASSERT_EQ(code, test_harness::error_codes::NO_ERROR) \
42      << code << " is " << test_harness::error_codes_name(code)
43
44#define ASSERT_MSG_TYPE(msg, type_) \
45do{if(type_ != APImessageID::NOTICE && msg.type == APImessageID::NOTICE){ \
46  Notice received; \
47  received.ParseFromArray(reinterpret_cast<char *>(msg.data), msg.data_len); \
48  ASSERT_EQ(msg.type, type_) \
49      << msg.type << " is " << APImessageID_Name((APImessageID) msg.type) \
50      << "\n" << received.DebugString(); \
51}else{ \
52  ASSERT_EQ(msg.type, type_) \
53      << msg.type << " is " << APImessageID_Name((APImessageID) msg.type); \
54}}while(0)
55
56#define ASSERT_SUBTYPE(msg, type_) \
57  EXPECT_GT(msg.data_len, 2); \
58  uint16_t subtype = (msg.data[0] << 8) | msg.data[1]; \
59  EXPECT_EQ(subtype, type_)
60
61namespace {
62
63using test_harness::BYTE_TIME;
64
65class NuggetOsTest: public testing::Test {
66 protected:
67  static void SetUpTestCase();
68  static void TearDownTestCase();
69
70 public:
71  static unique_ptr<TestHarness> harness;
72  static std::random_device random_number_generator;
73};
74
75unique_ptr<TestHarness> NuggetOsTest::harness;
76std::random_device NuggetOsTest::random_number_generator;
77
78void NuggetOsTest::SetUpTestCase() {
79  harness = TestHarness::MakeUnique();
80
81#ifndef CONFIG_NO_UART
82  if (!harness->UsingSpi()) {
83    EXPECT_TRUE(harness->SwitchFromConsoleToProtoApi());
84    EXPECT_TRUE(harness->ttyState());
85  }
86#endif  // CONFIG_NO_UART
87}
88
89void NuggetOsTest::TearDownTestCase() {
90#ifndef CONFIG_NO_UART
91  if (!harness->UsingSpi()) {
92    harness->ReadUntil(test_harness::BYTE_TIME * 1024);
93    EXPECT_TRUE(harness->SwitchFromProtoApiToConsole(NULL));
94  }
95#endif  // CONFIG_NO_UART
96  harness = unique_ptr<TestHarness>();
97}
98
99TEST_F(NuggetOsTest, NoticePing) {
100  Notice ping_msg;
101  ping_msg.set_notice_code(NoticeCode::PING);
102  Notice pong_msg;
103
104  ASSERT_NO_TH_ERROR(harness->SendProto(APImessageID::NOTICE, ping_msg));
105  if (harness->getVerbosity() >= TestHarness::VerbosityLevels::INFO) {
106    cout << ping_msg.DebugString();
107  }
108  test_harness::raw_message receive_msg;
109  ASSERT_NO_TH_ERROR(harness->GetData(&receive_msg, 4096 * BYTE_TIME));
110  ASSERT_MSG_TYPE(receive_msg, APImessageID::NOTICE);
111  pong_msg.set_notice_code(NoticeCode::PING);
112  ASSERT_TRUE(pong_msg.ParseFromArray(
113      reinterpret_cast<char *>(receive_msg.data), receive_msg.data_len));
114  if (harness->getVerbosity() >= TestHarness::VerbosityLevels::INFO) {
115    cout << pong_msg.DebugString() << std::endl;
116  }
117  EXPECT_EQ(pong_msg.notice_code(), NoticeCode::PONG);
118
119  ASSERT_NO_TH_ERROR(harness->SendProto(APImessageID::NOTICE, ping_msg));
120  if (harness->getVerbosity() >= TestHarness::VerbosityLevels::INFO) {
121    cout << ping_msg.DebugString();
122  }
123  ASSERT_NO_TH_ERROR(harness->GetData(&receive_msg, 4096 * BYTE_TIME));
124  ASSERT_MSG_TYPE(receive_msg, APImessageID::NOTICE);
125  pong_msg.set_notice_code(NoticeCode::PING);
126  ASSERT_TRUE(pong_msg.ParseFromArray(
127      reinterpret_cast<char *>(receive_msg.data), receive_msg.data_len));
128  if (harness->getVerbosity() >= TestHarness::VerbosityLevels::INFO) {
129    cout << pong_msg.DebugString() << std::endl;
130  }
131  EXPECT_EQ(pong_msg.notice_code(), NoticeCode::PONG);
132
133  ASSERT_NO_TH_ERROR(harness->SendProto(APImessageID::NOTICE, ping_msg));
134  if (harness->getVerbosity() >= TestHarness::VerbosityLevels::INFO) {
135    cout << ping_msg.DebugString();
136  }
137  ASSERT_NO_TH_ERROR(harness->GetData(&receive_msg, 4096 * BYTE_TIME));
138  ASSERT_MSG_TYPE(receive_msg, APImessageID::NOTICE);
139  pong_msg.set_notice_code(NoticeCode::PING);
140  ASSERT_TRUE(pong_msg.ParseFromArray(
141      reinterpret_cast<char *>(receive_msg.data), receive_msg.data_len));
142  if (harness->getVerbosity() >= TestHarness::VerbosityLevels::INFO) {
143    cout << pong_msg.DebugString() << std::endl;
144  }
145  EXPECT_EQ(pong_msg.notice_code(), NoticeCode::PONG);
146}
147
148TEST_F(NuggetOsTest, InvalidMessageType) {
149  const char content[] = "This is a test message.";
150
151  test_harness::raw_message msg;
152  msg.type = 0;
153  std::copy(content, content + sizeof(content), msg.data);
154  msg.data_len = sizeof(content);
155
156  ASSERT_NO_TH_ERROR(harness->SendData(msg));
157  ASSERT_NO_TH_ERROR(harness->GetData(&msg, 4096 * BYTE_TIME));
158  ASSERT_MSG_TYPE(msg, APImessageID::NOTICE);
159
160  Notice notice_msg;
161  ASSERT_TRUE(notice_msg.ParseFromArray(reinterpret_cast<char *>(msg.data),
162                                        msg.data_len));
163  if (harness->getVerbosity() >= TestHarness::VerbosityLevels::INFO) {
164    cout << notice_msg.DebugString() << std::endl;
165  }
166  EXPECT_EQ(notice_msg.notice_code(), NoticeCode::UNRECOGNIZED_MESSAGE);
167}
168
169TEST_F(NuggetOsTest, Sequence) {
170  test_harness::raw_message msg;
171  msg.type = APImessageID::SEND_SEQUENCE;
172  msg.data_len = 256;
173  for (size_t x = 0; x < msg.data_len; ++x) {
174    msg.data[x] = x;
175  }
176
177  ASSERT_NO_TH_ERROR(harness->SendData(msg));
178  ASSERT_NO_TH_ERROR(harness->GetData(&msg, 4096 * BYTE_TIME));
179  ASSERT_MSG_TYPE(msg, APImessageID::SEND_SEQUENCE);
180  for (size_t x = 0; x < msg.data_len; ++x) {
181    ASSERT_EQ(msg.data[x], x) << "Inconsistency at index " << x;
182  }
183}
184
185TEST_F(NuggetOsTest, Echo) {
186  test_harness::raw_message msg;
187  msg.type = APImessageID::ECHO_THIS;
188  // Leave some room for bytes which need escaping
189  msg.data_len = sizeof(msg.data) - 64;
190  for (size_t x = 0; x < msg.data_len; ++x) {
191    msg.data[x] = random_number_generator();
192  }
193
194  ASSERT_NO_TH_ERROR(harness->SendData(msg));
195
196  test_harness::raw_message receive_msg;
197  ASSERT_NO_TH_ERROR(harness->GetData(&receive_msg, 4096 * BYTE_TIME));
198  ASSERT_MSG_TYPE(msg, APImessageID::ECHO_THIS);
199  ASSERT_EQ(receive_msg.data_len, msg.data_len);
200
201  for (size_t x = 0; x < msg.data_len; ++x) {
202    ASSERT_EQ(msg.data[x], receive_msg.data[x])
203        << "Inconsistency at index " << x;
204  }
205}
206
207TEST_F(NuggetOsTest, AesCbc) {
208  const size_t number_of_blocks = 3;
209
210  for (auto key_size : {KeySize::s128b, KeySize::s192b, KeySize::s256b}) {
211    if (harness->getVerbosity() >= TestHarness::VerbosityLevels::INFO) {
212      cout << "Testing with a key size of: " << std::dec << (key_size * 8)
213           << std::endl;
214    }
215    AesCbcEncryptTest request;
216    request.set_key_size(key_size);
217    request.set_number_of_blocks(number_of_blocks);
218
219    vector<int> key_data(key_size / sizeof(int));
220    for (auto &part : key_data) {
221      part = random_number_generator();
222    }
223    request.set_key(key_data.data(), key_data.size() * sizeof(int));
224
225
226    if (FLAGS_nos_test_dump_protos) {
227      std::ofstream outfile;
228      outfile.open("AesCbcEncryptTest_" + std::to_string(key_size * 8) +
229                   ".proto.bin", std::ios_base::binary);
230      outfile << request.SerializeAsString();
231      outfile.close();
232    }
233
234    ASSERT_NO_TH_ERROR(harness->SendOneofProto(
235        APImessageID::TESTING_API_CALL,
236        OneofTestParametersCase::kAesCbcEncryptTest,
237        request));
238
239    test_harness::raw_message msg;
240    ASSERT_NO_TH_ERROR(harness->GetData(&msg, 4096 * BYTE_TIME));
241    ASSERT_MSG_TYPE(msg, APImessageID::TESTING_API_RESPONSE);
242    ASSERT_SUBTYPE(msg, OneofTestResultsCase::kAesCbcEncryptTestResult);
243
244    AesCbcEncryptTestResult result;
245    ASSERT_TRUE(result.ParseFromArray(reinterpret_cast<char *>(msg.data + 2),
246                                      msg.data_len - 2));
247    EXPECT_EQ(result.result_code(), DcryptError::DE_NO_ERROR)
248        << result.result_code() << " is "
249        << DcryptError_Name(result.result_code());
250    ASSERT_EQ(result.cipher_text().size(), number_of_blocks * AES_BLOCK_SIZE)
251        << "\n" << result.DebugString();
252
253    uint32_t in[4] = {0, 0, 0, 0};
254    uint8_t sw_out[AES_BLOCK_SIZE];
255    uint8_t iv[AES_BLOCK_SIZE];
256    memset(&iv, 0, sizeof(iv));
257    AES_KEY aes_key;
258    AES_set_encrypt_key(reinterpret_cast<uint8_t *>(key_data.data()),
259                        key_size * 8, &aes_key);
260    for (size_t x = 0; x < number_of_blocks; ++x) {
261      AES_cbc_encrypt(reinterpret_cast<uint8_t *>(in),
262                      reinterpret_cast<uint8_t *>(sw_out), AES_BLOCK_SIZE,
263                      &aes_key, reinterpret_cast<uint8_t *>(iv), true);
264      for (size_t y = 0; y < AES_BLOCK_SIZE; ++y) {
265        size_t index = x * AES_BLOCK_SIZE + y;
266        ASSERT_EQ(result.cipher_text()[index] & 0x00ff,
267                  sw_out[y] & 0x00ff) << "Inconsistency at index " << index;
268      }
269    }
270
271    ASSERT_EQ(result.initialization_vector().size(), (size_t) AES_BLOCK_SIZE)
272        << "\n" << result.DebugString();
273    for (size_t x = 0; x < AES_BLOCK_SIZE; ++x) {
274      ASSERT_EQ(result.initialization_vector()[x] & 0x00ff, iv[x] & 0x00ff)
275                << "Inconsistency at index " << x;
276    }
277  }
278}
279
280TEST_F(NuggetOsTest, Trng) {
281  // Have a bin for every possible byte value.
282  std::vector<size_t> counts(256, 0);
283
284  // Use most of the available space while leaving room for the transport
285  // header, escape sequences, etc.
286  const size_t request_size = 475;
287  const size_t repeats = 10;
288
289  TrngTest request;
290  request.set_number_of_bytes(request_size);
291
292  int verbosity = harness->getVerbosity();
293  for (size_t x = 0; x < repeats; ++x) {
294    ASSERT_NO_TH_ERROR(harness->SendOneofProto(
295        APImessageID::TESTING_API_CALL,
296        OneofTestParametersCase::kTrngTest,
297        request));
298    test_harness::raw_message msg;
299    ASSERT_NO_TH_ERROR(harness->GetData(&msg, 4096 * BYTE_TIME));
300    ASSERT_MSG_TYPE(msg, APImessageID::TESTING_API_RESPONSE);
301    ASSERT_SUBTYPE(msg, OneofTestResultsCase::kTrngTestResult);
302
303    TrngTestResult result;
304    ASSERT_TRUE(result.ParseFromArray(reinterpret_cast<char *>(msg.data + 2),
305                                      msg.data_len - 2));
306    ASSERT_EQ(result.random_bytes().size(), request_size);
307    for (const auto rand_byte : result.random_bytes()) {
308      ++counts[0x00ff & rand_byte];
309    }
310
311    // Print the first exchange only for debugging.
312    if (x == 0) {
313      harness->setVerbosity(harness->getVerbosity() - 1);
314    }
315  }
316  harness->setVerbosity(verbosity);
317
318  double kl_divergence = 0;
319  double ratio = (double) counts.size() / (repeats * request_size);
320  for (const auto count : counts) {
321    ASSERT_NE(count, 0u);
322    kl_divergence += count * log2(count * ratio);
323  }
324  kl_divergence *= ratio;
325  if (harness->getVerbosity() >= TestHarness::VerbosityLevels::INFO) {
326    cout << "K.L. Divergence: " << kl_divergence << "\n";
327    cout.flush();
328  }
329  ASSERT_LT(kl_divergence, 15.0);
330}
331
332}  // namespace
333