1/*
2 * Copyright (C) 2015 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <gtest/gtest.h>
18#include <string.h>
19#include <stdlib.h>
20#include <stdio.h>
21
22#include <gatekeeper/gatekeeper_messages.h>
23
24using ::gatekeeper::SizedBuffer;
25using ::testing::Test;
26using ::gatekeeper::EnrollRequest;
27using ::gatekeeper::EnrollResponse;
28using ::gatekeeper::VerifyRequest;
29using ::gatekeeper::VerifyResponse;
30using std::cout;
31using std::endl;
32
33static const uint32_t USER_ID = 3857;
34
35static SizedBuffer *make_buffer(uint32_t size) {
36    SizedBuffer *result = new SizedBuffer;
37    result->length = size;
38    uint8_t *buffer = new uint8_t[size];
39    srand(size);
40
41    for (uint32_t i = 0; i < size; i++) {
42        buffer[i] = rand();
43    }
44
45    result->buffer.reset(buffer);
46    return result;
47}
48
49TEST(RoundTripTest, EnrollRequestNullEnrolledNullHandle) {
50    const uint32_t password_size = 512;
51    SizedBuffer *provided_password = make_buffer(password_size);
52    const SizedBuffer *deserialized_password;
53    // create request, serialize, deserialize, and validate
54    EnrollRequest msg(USER_ID, NULL, provided_password, NULL);
55    SizedBuffer serialized_msg(msg.GetSerializedSize());
56    msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length);
57
58    EnrollRequest deserialized_msg;
59    deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get()
60            + serialized_msg.length);
61
62    ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE,
63            deserialized_msg.error);
64
65    deserialized_password = &deserialized_msg.provided_password;
66    ASSERT_EQ(USER_ID, deserialized_msg.user_id);
67    ASSERT_EQ((uint32_t) password_size, deserialized_password->length);
68    ASSERT_EQ(0, memcmp(msg.provided_password.buffer.get(), deserialized_password->buffer.get(), password_size));
69    ASSERT_EQ((uint32_t) 0, deserialized_msg.enrolled_password.length);
70    ASSERT_EQ(NULL, deserialized_msg.enrolled_password.buffer.get());
71    ASSERT_EQ((uint32_t) 0, deserialized_msg.password_handle.length);
72    ASSERT_EQ(NULL, deserialized_msg.password_handle.buffer.get());
73    delete provided_password;
74}
75
76TEST(RoundTripTest, EnrollRequestEmptyEnrolledEmptyHandle) {
77    const uint32_t password_size = 512;
78    SizedBuffer *provided_password = make_buffer(password_size);
79    SizedBuffer enrolled;
80    SizedBuffer handle;
81    const SizedBuffer *deserialized_password;
82    // create request, serialize, deserialize, and validate
83    EnrollRequest msg(USER_ID, &handle, provided_password, &enrolled);
84    SizedBuffer serialized_msg(msg.GetSerializedSize());
85    msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length);
86
87    EnrollRequest deserialized_msg;
88    deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get()
89            + serialized_msg.length);
90
91    ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE,
92            deserialized_msg.error);
93
94    deserialized_password = &deserialized_msg.provided_password;
95    ASSERT_EQ(USER_ID, deserialized_msg.user_id);
96    ASSERT_EQ((uint32_t) password_size, deserialized_password->length);
97    ASSERT_EQ(0, memcmp(msg.provided_password.buffer.get(), deserialized_password->buffer.get(), password_size));
98    ASSERT_EQ((uint32_t) 0, deserialized_msg.enrolled_password.length);
99    ASSERT_EQ(NULL, deserialized_msg.enrolled_password.buffer.get());
100    ASSERT_EQ((uint32_t) 0, deserialized_msg.password_handle.length);
101    ASSERT_EQ(NULL, deserialized_msg.password_handle.buffer.get());
102    delete provided_password;
103}
104
105TEST(RoundTripTest, EnrollRequestNonNullEnrolledOrHandle) {
106    const uint32_t password_size = 512;
107    SizedBuffer *provided_password = make_buffer(password_size);
108    SizedBuffer *enrolled_password = make_buffer(password_size);
109    SizedBuffer *password_handle = make_buffer(password_size);
110    const SizedBuffer *deserialized_password;
111    const SizedBuffer *deserialized_enrolled;
112    const SizedBuffer *deserialized_handle;
113    // create request, serialize, deserialize, and validate
114    EnrollRequest msg(USER_ID, password_handle, provided_password, enrolled_password);
115    SizedBuffer serialized_msg(msg.GetSerializedSize());
116    msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length);
117
118    EnrollRequest deserialized_msg;
119    deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get()
120            + serialized_msg.length);
121
122    ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE,
123            deserialized_msg.error);
124
125    deserialized_password = &deserialized_msg.provided_password;
126    deserialized_enrolled = &deserialized_msg.enrolled_password;
127    deserialized_handle = &deserialized_msg.password_handle;
128    ASSERT_EQ(USER_ID, deserialized_msg.user_id);
129    ASSERT_EQ((uint32_t) password_size, deserialized_password->length);
130    ASSERT_EQ(0, memcmp(msg.provided_password.buffer.get(), deserialized_password->buffer.get(), password_size));
131    ASSERT_EQ((uint32_t) password_size, deserialized_enrolled->length);
132    ASSERT_EQ(0, memcmp(msg.enrolled_password.buffer.get(), deserialized_enrolled->buffer.get(), password_size));
133    ASSERT_EQ((uint32_t) password_size, deserialized_handle->length);
134    ASSERT_EQ(0, memcmp(msg.password_handle.buffer.get(), deserialized_handle->buffer.get(), password_size));
135    delete provided_password;
136    delete enrolled_password;
137    delete password_handle;
138}
139
140
141TEST(RoundTripTest, EnrollResponse) {
142    const uint32_t password_size = 512;
143    SizedBuffer *enrolled_password = make_buffer(password_size);
144    const SizedBuffer *deserialized_password;
145    // create request, serialize, deserialize, and validate
146    EnrollResponse msg(USER_ID, enrolled_password);
147    SizedBuffer serialized_msg(msg.GetSerializedSize());
148    msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length);
149
150    EnrollResponse deserialized_msg;
151    deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get()
152            + serialized_msg.length);
153
154    ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE,
155            deserialized_msg.error);
156
157    deserialized_password = &deserialized_msg.enrolled_password_handle;
158    ASSERT_EQ(USER_ID, deserialized_msg.user_id);
159    ASSERT_EQ((uint32_t) password_size, deserialized_password->length);
160    ASSERT_EQ(0, memcmp(msg.enrolled_password_handle.buffer.get(),
161                deserialized_password->buffer.get(), password_size));
162}
163
164TEST(RoundTripTest, VerifyRequest) {
165    const uint32_t password_size = 512;
166    SizedBuffer *provided_password = make_buffer(password_size),
167          *password_handle = make_buffer(password_size);
168    const SizedBuffer *deserialized_password;
169    // create request, serialize, deserialize, and validate
170    VerifyRequest msg(USER_ID, 1, password_handle, provided_password);
171    SizedBuffer serialized_msg(msg.GetSerializedSize());
172    msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length);
173
174    VerifyRequest deserialized_msg;
175    deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get()
176            + serialized_msg.length);
177
178    ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE,
179            deserialized_msg.error);
180
181    ASSERT_EQ(USER_ID, deserialized_msg.user_id);
182    ASSERT_EQ((uint64_t) 1, deserialized_msg.challenge);
183    deserialized_password = &deserialized_msg.password_handle;
184    ASSERT_EQ((uint32_t) password_size, deserialized_password->length);
185    ASSERT_EQ(0, memcmp(msg.provided_password.buffer.get(), deserialized_password->buffer.get(),
186                password_size));
187
188    deserialized_password = &deserialized_msg.password_handle;
189    ASSERT_EQ((uint32_t) password_size, deserialized_password->length);
190    ASSERT_EQ(0, memcmp(msg.password_handle.buffer.get(), deserialized_password->buffer.get(),
191                password_size));
192}
193
194TEST(RoundTripTest, VerifyResponse) {
195    const uint32_t password_size = 512;
196    SizedBuffer *auth_token = make_buffer(password_size);
197    const SizedBuffer *deserialized_password;
198    // create request, serialize, deserialize, and validate
199    VerifyResponse msg(USER_ID, auth_token);
200    SizedBuffer serialized_msg(msg.GetSerializedSize());
201    msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length);
202
203    VerifyResponse deserialized_msg;
204    deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get()
205            + serialized_msg.length);
206
207    ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_NONE,
208            deserialized_msg.error);
209
210    ASSERT_EQ(USER_ID, deserialized_msg.user_id);
211    deserialized_password = &deserialized_msg.auth_token;
212    ASSERT_EQ((uint32_t) password_size, deserialized_password->length);
213    ASSERT_EQ(0, memcmp(msg.auth_token.buffer.get(), deserialized_password->buffer.get(),
214                password_size));
215}
216
217TEST(RoundTripTest, VerifyResponseError) {
218    VerifyResponse msg;
219    msg.error = gatekeeper::gatekeeper_error_t::ERROR_INVALID;
220    SizedBuffer serialized_msg(msg.GetSerializedSize());
221    msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length);
222    VerifyResponse deserialized_msg;
223    deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length);
224    ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_INVALID,
225            deserialized_msg.error);
226}
227
228TEST(RoundTripTest, VerifyRequestError) {
229    VerifyRequest msg;
230    msg.error = gatekeeper::gatekeeper_error_t::ERROR_INVALID;
231    SizedBuffer serialized_msg(msg.GetSerializedSize());
232    msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length);
233    VerifyRequest deserialized_msg;
234    deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length);
235    ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_INVALID,
236            deserialized_msg.error);
237}
238
239TEST(RoundTripTest, EnrollResponseError) {
240    EnrollResponse msg;
241    msg.error = gatekeeper::gatekeeper_error_t::ERROR_INVALID;
242    SizedBuffer serialized_msg(msg.GetSerializedSize());
243    msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length);
244    EnrollResponse deserialized_msg;
245    deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length);
246    ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_INVALID,
247            deserialized_msg.error);
248}
249
250TEST(RoundTripTest, EnrollRequestError) {
251    EnrollRequest msg;
252    msg.error = gatekeeper::gatekeeper_error_t::ERROR_INVALID;
253    SizedBuffer serialized_msg(msg.GetSerializedSize());
254    msg.Serialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length);
255    EnrollRequest deserialized_msg;
256    deserialized_msg.Deserialize(serialized_msg.buffer.get(), serialized_msg.buffer.get() + serialized_msg.length);
257    ASSERT_EQ(gatekeeper::gatekeeper_error_t::ERROR_INVALID,
258            deserialized_msg.error);
259}
260
261uint8_t msgbuf[] = {
262    220, 88,  183, 255, 71,  1,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
263    0,   173, 0,   0,   0,   228, 174, 98,  187, 191, 135, 253, 200, 51,  230, 114, 247, 151, 109,
264    237, 79,  87,  32,  94,  5,   204, 46,  154, 30,  91,  6,   103, 148, 254, 129, 65,  171, 228,
265    167, 224, 163, 9,   15,  206, 90,  58,  11,  205, 55,  211, 33,  87,  178, 149, 91,  28,  236,
266    218, 112, 231, 34,  82,  82,  134, 103, 137, 115, 27,  156, 102, 159, 220, 226, 89,  42,  25,
267    37,  9,   84,  239, 76,  161, 198, 72,  167, 163, 39,  91,  148, 191, 17,  191, 87,  169, 179,
268    136, 10,  194, 154, 4,   40,  107, 109, 61,  161, 20,  176, 247, 13,  214, 106, 229, 45,  17,
269    5,   60,  189, 64,  39,  166, 208, 14,  57,  25,  140, 148, 25,  177, 246, 189, 43,  181, 88,
270    204, 29,  126, 224, 100, 143, 93,  60,  57,  249, 55,  0,   87,  83,  227, 224, 166, 59,  214,
271    81,  144, 129, 58,  6,   57,  46,  254, 232, 41,  220, 209, 230, 167, 138, 158, 94,  180, 125,
272    247, 26,  162, 116, 238, 202, 187, 100, 65,  13,  180, 44,  245, 159, 83,  161, 176, 58,  72,
273    236, 109, 105, 160, 0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
274    0,   11,  0,   0,   0,   98,  0,   0,   0,   1,   0,   0,   32,  2,   0,   0,   0,   1,   0,
275    0,   32,  3,   0,   0,   0,   2,   0,   0,   16,  1,   0,   0,   0,   3,   0,   0,   48,  0,
276    1,   0,   0,   200, 0,   0,   80,  3,   0,   0,   0,   0,   0,   0,   0,   244, 1,   0,   112,
277    1,   246, 1,   0,   112, 1,   189, 2,   0,   96,  144, 178, 236, 250, 255, 255, 255, 255, 145,
278    1,   0,   96,  144, 226, 33,  60,  222, 2,   0,   0,   189, 2,   0,   96,  0,   0,   0,   0,
279    0,   0,   0,   0,   190, 2,   0,   16,  1,   0,   0,   0,   12,  0,   0,   0,   0,   0,   0,
280    0,   0,   0,   0,   0,   0,   0,   0,   0,   110, 0,   0,   0,   0,   0,   0,   0,   11,  0,
281    0,   0,   98,  0,   0,   0,   1,   0,   0,   32,  2,   0,   0,   0,   1,   0,   0,   32,  3,
282    0,   0,   0,   2,   0,   0,   16,  1,   0,   0,   0,   3,   0,   0,   48,  0,   1,   0,   0,
283    200, 0,   0,   80,  3,   0,   0,   0,   0,   0,   0,   0,   244, 1,   0,   112, 1,   246, 1,
284    0,   112, 1,   189, 2,   0,   96,  144, 178, 236, 250, 255, 255, 255, 255, 145, 1,   0,   96,
285    144, 226, 33,  60,  222, 2,   0,   0,   189, 2,   0,   96,  0,   0,   0,   0,   0,   0,   0,
286    0,   190, 2,   0,   16,  1,   0,   0,   0,
287};
288
289
290/*
291 * These tests don't have any assertions or expectations. They just try to parse garbage, to see if
292 * the result will be a crash.  This is especially informative when run under Valgrind memcheck.
293 */
294
295template <typename Message> void parse_garbage() {
296    Message msg;
297    uint32_t array_length = sizeof(msgbuf) / sizeof(msgbuf[0]);
298    const uint8_t* end = msgbuf + array_length;
299    for (uint32_t i = 0; i < array_length; ++i) {
300        const uint8_t* begin = msgbuf + i;
301        const uint8_t* p = begin;
302        msg.Deserialize(p, end);
303    }
304}
305
306#define GARBAGE_TEST(Message)                                                                      \
307    TEST(GarbageTest, Message) { parse_garbage<Message>(); }
308
309GARBAGE_TEST(VerifyRequest);
310GARBAGE_TEST(VerifyResponse);
311GARBAGE_TEST(EnrollRequest);
312GARBAGE_TEST(EnrollResponse);
313