1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/ssl/default_channel_id_store.h"
6
7#include <map>
8#include <string>
9#include <vector>
10
11#include "base/bind.h"
12#include "base/compiler_specific.h"
13#include "base/logging.h"
14#include "base/memory/scoped_ptr.h"
15#include "base/message_loop/message_loop.h"
16#include "net/base/net_errors.h"
17#include "testing/gtest/include/gtest/gtest.h"
18
19namespace net {
20
21namespace {
22
23void CallCounter(int* counter) {
24  (*counter)++;
25}
26
27void GetChannelIDCallbackNotCalled(int err,
28                                   const std::string& server_identifier,
29                                   base::Time expiration_time,
30                                   const std::string& private_key_result,
31                                   const std::string& cert_result) {
32  ADD_FAILURE() << "Unexpected callback execution.";
33}
34
35class AsyncGetChannelIDHelper {
36 public:
37  AsyncGetChannelIDHelper() : called_(false) {}
38
39  void Callback(int err,
40                const std::string& server_identifier,
41                base::Time expiration_time,
42                const std::string& private_key_result,
43                const std::string& cert_result) {
44    err_ = err;
45    server_identifier_ = server_identifier;
46    expiration_time_ = expiration_time;
47    private_key_ = private_key_result;
48    cert_ = cert_result;
49    called_ = true;
50  }
51
52  int err_;
53  std::string server_identifier_;
54  base::Time expiration_time_;
55  std::string private_key_;
56  std::string cert_;
57  bool called_;
58};
59
60void GetAllCallback(
61    ChannelIDStore::ChannelIDList* dest,
62    const ChannelIDStore::ChannelIDList& result) {
63  *dest = result;
64}
65
66class MockPersistentStore
67    : public DefaultChannelIDStore::PersistentStore {
68 public:
69  MockPersistentStore();
70
71  // DefaultChannelIDStore::PersistentStore implementation.
72  virtual void Load(const LoadedCallback& loaded_callback) OVERRIDE;
73  virtual void AddChannelID(
74      const DefaultChannelIDStore::ChannelID& channel_id) OVERRIDE;
75  virtual void DeleteChannelID(
76      const DefaultChannelIDStore::ChannelID& channel_id) OVERRIDE;
77  virtual void SetForceKeepSessionState() OVERRIDE;
78
79 protected:
80  virtual ~MockPersistentStore();
81
82 private:
83  typedef std::map<std::string, DefaultChannelIDStore::ChannelID>
84      ChannelIDMap;
85
86  ChannelIDMap channel_ids_;
87};
88
89MockPersistentStore::MockPersistentStore() {}
90
91void MockPersistentStore::Load(const LoadedCallback& loaded_callback) {
92  scoped_ptr<ScopedVector<DefaultChannelIDStore::ChannelID> >
93      channel_ids(new ScopedVector<DefaultChannelIDStore::ChannelID>());
94  ChannelIDMap::iterator it;
95
96  for (it = channel_ids_.begin(); it != channel_ids_.end(); ++it) {
97    channel_ids->push_back(
98        new DefaultChannelIDStore::ChannelID(it->second));
99  }
100
101  base::MessageLoop::current()->PostTask(
102      FROM_HERE, base::Bind(loaded_callback, base::Passed(&channel_ids)));
103}
104
105void MockPersistentStore::AddChannelID(
106    const DefaultChannelIDStore::ChannelID& channel_id) {
107  channel_ids_[channel_id.server_identifier()] = channel_id;
108}
109
110void MockPersistentStore::DeleteChannelID(
111    const DefaultChannelIDStore::ChannelID& channel_id) {
112  channel_ids_.erase(channel_id.server_identifier());
113}
114
115void MockPersistentStore::SetForceKeepSessionState() {}
116
117MockPersistentStore::~MockPersistentStore() {}
118
119}  // namespace
120
121TEST(DefaultChannelIDStoreTest, TestLoading) {
122  scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
123
124  persistent_store->AddChannelID(
125      DefaultChannelIDStore::ChannelID(
126          "google.com",
127          base::Time(),
128          base::Time(),
129          "a", "b"));
130  persistent_store->AddChannelID(
131      DefaultChannelIDStore::ChannelID(
132          "verisign.com",
133          base::Time(),
134          base::Time(),
135          "c", "d"));
136
137  // Make sure channel_ids load properly.
138  DefaultChannelIDStore store(persistent_store.get());
139  // Load has not occurred yet.
140  EXPECT_EQ(0, store.GetChannelIDCount());
141  store.SetChannelID(
142      "verisign.com",
143      base::Time(),
144      base::Time(),
145      "e", "f");
146  // Wait for load & queued set task.
147  base::MessageLoop::current()->RunUntilIdle();
148  EXPECT_EQ(2, store.GetChannelIDCount());
149  store.SetChannelID(
150      "twitter.com",
151      base::Time(),
152      base::Time(),
153      "g", "h");
154  // Set should be synchronous now that load is done.
155  EXPECT_EQ(3, store.GetChannelIDCount());
156}
157
158//TODO(mattm): add more tests of without a persistent store?
159TEST(DefaultChannelIDStoreTest, TestSettingAndGetting) {
160  // No persistent store, all calls will be synchronous.
161  DefaultChannelIDStore store(NULL);
162  base::Time expiration_time;
163  std::string private_key, cert;
164  EXPECT_EQ(0, store.GetChannelIDCount());
165  EXPECT_EQ(ERR_FILE_NOT_FOUND,
166            store.GetChannelID("verisign.com",
167                               &expiration_time,
168                               &private_key,
169                               &cert,
170                               base::Bind(&GetChannelIDCallbackNotCalled)));
171  EXPECT_TRUE(private_key.empty());
172  EXPECT_TRUE(cert.empty());
173  store.SetChannelID(
174      "verisign.com",
175      base::Time::FromInternalValue(123),
176      base::Time::FromInternalValue(456),
177      "i", "j");
178  EXPECT_EQ(OK,
179            store.GetChannelID("verisign.com",
180                               &expiration_time,
181                               &private_key,
182                               &cert,
183                               base::Bind(&GetChannelIDCallbackNotCalled)));
184  EXPECT_EQ(456, expiration_time.ToInternalValue());
185  EXPECT_EQ("i", private_key);
186  EXPECT_EQ("j", cert);
187}
188
189TEST(DefaultChannelIDStoreTest, TestDuplicateChannelIds) {
190  scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
191  DefaultChannelIDStore store(persistent_store.get());
192
193  base::Time expiration_time;
194  std::string private_key, cert;
195  EXPECT_EQ(0, store.GetChannelIDCount());
196  store.SetChannelID(
197      "verisign.com",
198      base::Time::FromInternalValue(123),
199      base::Time::FromInternalValue(1234),
200      "a", "b");
201  store.SetChannelID(
202      "verisign.com",
203      base::Time::FromInternalValue(456),
204      base::Time::FromInternalValue(4567),
205      "c", "d");
206
207  // Wait for load & queued set tasks.
208  base::MessageLoop::current()->RunUntilIdle();
209  EXPECT_EQ(1, store.GetChannelIDCount());
210  EXPECT_EQ(OK,
211            store.GetChannelID("verisign.com",
212                               &expiration_time,
213                               &private_key,
214                               &cert,
215                               base::Bind(&GetChannelIDCallbackNotCalled)));
216  EXPECT_EQ(4567, expiration_time.ToInternalValue());
217  EXPECT_EQ("c", private_key);
218  EXPECT_EQ("d", cert);
219}
220
221TEST(DefaultChannelIDStoreTest, TestAsyncGet) {
222  scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
223  persistent_store->AddChannelID(ChannelIDStore::ChannelID(
224      "verisign.com",
225      base::Time::FromInternalValue(123),
226      base::Time::FromInternalValue(1234),
227      "a", "b"));
228
229  DefaultChannelIDStore store(persistent_store.get());
230  AsyncGetChannelIDHelper helper;
231  base::Time expiration_time;
232  std::string private_key;
233  std::string cert = "not set";
234  EXPECT_EQ(0, store.GetChannelIDCount());
235  EXPECT_EQ(ERR_IO_PENDING,
236            store.GetChannelID("verisign.com",
237                               &expiration_time,
238                               &private_key,
239                               &cert,
240                               base::Bind(&AsyncGetChannelIDHelper::Callback,
241                                          base::Unretained(&helper))));
242
243  // Wait for load & queued get tasks.
244  base::MessageLoop::current()->RunUntilIdle();
245  EXPECT_EQ(1, store.GetChannelIDCount());
246  EXPECT_EQ("not set", cert);
247  EXPECT_TRUE(helper.called_);
248  EXPECT_EQ(OK, helper.err_);
249  EXPECT_EQ("verisign.com", helper.server_identifier_);
250  EXPECT_EQ(1234, helper.expiration_time_.ToInternalValue());
251  EXPECT_EQ("a", helper.private_key_);
252  EXPECT_EQ("b", helper.cert_);
253}
254
255TEST(DefaultChannelIDStoreTest, TestDeleteAll) {
256  scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
257  DefaultChannelIDStore store(persistent_store.get());
258
259  store.SetChannelID(
260      "verisign.com",
261      base::Time(),
262      base::Time(),
263      "a", "b");
264  store.SetChannelID(
265      "google.com",
266      base::Time(),
267      base::Time(),
268      "c", "d");
269  store.SetChannelID(
270      "harvard.com",
271      base::Time(),
272      base::Time(),
273      "e", "f");
274  // Wait for load & queued set tasks.
275  base::MessageLoop::current()->RunUntilIdle();
276
277  EXPECT_EQ(3, store.GetChannelIDCount());
278  int delete_finished = 0;
279  store.DeleteAll(base::Bind(&CallCounter, &delete_finished));
280  ASSERT_EQ(1, delete_finished);
281  EXPECT_EQ(0, store.GetChannelIDCount());
282}
283
284TEST(DefaultChannelIDStoreTest, TestAsyncGetAndDeleteAll) {
285  scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
286  persistent_store->AddChannelID(ChannelIDStore::ChannelID(
287      "verisign.com",
288      base::Time(),
289      base::Time(),
290      "a", "b"));
291  persistent_store->AddChannelID(ChannelIDStore::ChannelID(
292      "google.com",
293      base::Time(),
294      base::Time(),
295      "c", "d"));
296
297  ChannelIDStore::ChannelIDList pre_channel_ids;
298  ChannelIDStore::ChannelIDList post_channel_ids;
299  int delete_finished = 0;
300  DefaultChannelIDStore store(persistent_store.get());
301
302  store.GetAllChannelIDs(base::Bind(GetAllCallback, &pre_channel_ids));
303  store.DeleteAll(base::Bind(&CallCounter, &delete_finished));
304  store.GetAllChannelIDs(base::Bind(GetAllCallback, &post_channel_ids));
305  // Tasks have not run yet.
306  EXPECT_EQ(0u, pre_channel_ids.size());
307  // Wait for load & queued tasks.
308  base::MessageLoop::current()->RunUntilIdle();
309  EXPECT_EQ(0, store.GetChannelIDCount());
310  EXPECT_EQ(2u, pre_channel_ids.size());
311  EXPECT_EQ(0u, post_channel_ids.size());
312}
313
314TEST(DefaultChannelIDStoreTest, TestDelete) {
315  scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
316  DefaultChannelIDStore store(persistent_store.get());
317
318  base::Time expiration_time;
319  std::string private_key, cert;
320  EXPECT_EQ(0, store.GetChannelIDCount());
321  store.SetChannelID(
322      "verisign.com",
323      base::Time(),
324      base::Time(),
325      "a", "b");
326  // Wait for load & queued set task.
327  base::MessageLoop::current()->RunUntilIdle();
328
329  store.SetChannelID(
330      "google.com",
331      base::Time(),
332      base::Time(),
333      "c", "d");
334
335  EXPECT_EQ(2, store.GetChannelIDCount());
336  int delete_finished = 0;
337  store.DeleteChannelID("verisign.com",
338                              base::Bind(&CallCounter, &delete_finished));
339  ASSERT_EQ(1, delete_finished);
340  EXPECT_EQ(1, store.GetChannelIDCount());
341  EXPECT_EQ(ERR_FILE_NOT_FOUND,
342            store.GetChannelID("verisign.com",
343                               &expiration_time,
344                               &private_key,
345                               &cert,
346                               base::Bind(&GetChannelIDCallbackNotCalled)));
347  EXPECT_EQ(OK,
348            store.GetChannelID("google.com",
349                               &expiration_time,
350                               &private_key,
351                               &cert,
352                               base::Bind(&GetChannelIDCallbackNotCalled)));
353  int delete2_finished = 0;
354  store.DeleteChannelID("google.com",
355                        base::Bind(&CallCounter, &delete2_finished));
356  ASSERT_EQ(1, delete2_finished);
357  EXPECT_EQ(0, store.GetChannelIDCount());
358  EXPECT_EQ(ERR_FILE_NOT_FOUND,
359            store.GetChannelID("google.com",
360                               &expiration_time,
361                               &private_key,
362                               &cert,
363                               base::Bind(&GetChannelIDCallbackNotCalled)));
364}
365
366TEST(DefaultChannelIDStoreTest, TestAsyncDelete) {
367  scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
368  persistent_store->AddChannelID(ChannelIDStore::ChannelID(
369      "a.com",
370      base::Time::FromInternalValue(1),
371      base::Time::FromInternalValue(2),
372      "a", "b"));
373  persistent_store->AddChannelID(ChannelIDStore::ChannelID(
374      "b.com",
375      base::Time::FromInternalValue(3),
376      base::Time::FromInternalValue(4),
377      "c", "d"));
378  DefaultChannelIDStore store(persistent_store.get());
379  int delete_finished = 0;
380  store.DeleteChannelID("a.com",
381                        base::Bind(&CallCounter, &delete_finished));
382
383  AsyncGetChannelIDHelper a_helper;
384  AsyncGetChannelIDHelper b_helper;
385  base::Time expiration_time;
386  std::string private_key;
387  std::string cert = "not set";
388  EXPECT_EQ(0, store.GetChannelIDCount());
389  EXPECT_EQ(ERR_IO_PENDING,
390      store.GetChannelID(
391          "a.com", &expiration_time, &private_key, &cert,
392          base::Bind(&AsyncGetChannelIDHelper::Callback,
393                     base::Unretained(&a_helper))));
394  EXPECT_EQ(ERR_IO_PENDING,
395      store.GetChannelID(
396          "b.com", &expiration_time, &private_key, &cert,
397          base::Bind(&AsyncGetChannelIDHelper::Callback,
398                     base::Unretained(&b_helper))));
399
400  EXPECT_EQ(0, delete_finished);
401  EXPECT_FALSE(a_helper.called_);
402  EXPECT_FALSE(b_helper.called_);
403  // Wait for load & queued tasks.
404  base::MessageLoop::current()->RunUntilIdle();
405  EXPECT_EQ(1, delete_finished);
406  EXPECT_EQ(1, store.GetChannelIDCount());
407  EXPECT_EQ("not set", cert);
408  EXPECT_TRUE(a_helper.called_);
409  EXPECT_EQ(ERR_FILE_NOT_FOUND, a_helper.err_);
410  EXPECT_EQ("a.com", a_helper.server_identifier_);
411  EXPECT_EQ(0, a_helper.expiration_time_.ToInternalValue());
412  EXPECT_EQ("", a_helper.private_key_);
413  EXPECT_EQ("", a_helper.cert_);
414  EXPECT_TRUE(b_helper.called_);
415  EXPECT_EQ(OK, b_helper.err_);
416  EXPECT_EQ("b.com", b_helper.server_identifier_);
417  EXPECT_EQ(4, b_helper.expiration_time_.ToInternalValue());
418  EXPECT_EQ("c", b_helper.private_key_);
419  EXPECT_EQ("d", b_helper.cert_);
420}
421
422TEST(DefaultChannelIDStoreTest, TestGetAll) {
423  scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
424  DefaultChannelIDStore store(persistent_store.get());
425
426  EXPECT_EQ(0, store.GetChannelIDCount());
427  store.SetChannelID(
428      "verisign.com",
429      base::Time(),
430      base::Time(),
431      "a", "b");
432  store.SetChannelID(
433      "google.com",
434      base::Time(),
435      base::Time(),
436      "c", "d");
437  store.SetChannelID(
438      "harvard.com",
439      base::Time(),
440      base::Time(),
441      "e", "f");
442  store.SetChannelID(
443      "mit.com",
444      base::Time(),
445      base::Time(),
446      "g", "h");
447  // Wait for load & queued set tasks.
448  base::MessageLoop::current()->RunUntilIdle();
449
450  EXPECT_EQ(4, store.GetChannelIDCount());
451  ChannelIDStore::ChannelIDList channel_ids;
452  store.GetAllChannelIDs(base::Bind(GetAllCallback, &channel_ids));
453  EXPECT_EQ(4u, channel_ids.size());
454}
455
456TEST(DefaultChannelIDStoreTest, TestInitializeFrom) {
457  scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
458  DefaultChannelIDStore store(persistent_store.get());
459
460  store.SetChannelID(
461      "preexisting.com",
462      base::Time(),
463      base::Time(),
464      "a", "b");
465  store.SetChannelID(
466      "both.com",
467      base::Time(),
468      base::Time(),
469      "c", "d");
470  // Wait for load & queued set tasks.
471  base::MessageLoop::current()->RunUntilIdle();
472  EXPECT_EQ(2, store.GetChannelIDCount());
473
474  ChannelIDStore::ChannelIDList source_channel_ids;
475  source_channel_ids.push_back(ChannelIDStore::ChannelID(
476      "both.com",
477      base::Time(),
478      base::Time(),
479      // Key differs from above to test that existing entries are overwritten.
480      "e", "f"));
481  source_channel_ids.push_back(ChannelIDStore::ChannelID(
482      "copied.com",
483      base::Time(),
484      base::Time(),
485      "g", "h"));
486  store.InitializeFrom(source_channel_ids);
487  EXPECT_EQ(3, store.GetChannelIDCount());
488
489  ChannelIDStore::ChannelIDList channel_ids;
490  store.GetAllChannelIDs(base::Bind(GetAllCallback, &channel_ids));
491  ASSERT_EQ(3u, channel_ids.size());
492
493  ChannelIDStore::ChannelIDList::iterator channel_id = channel_ids.begin();
494  EXPECT_EQ("both.com", channel_id->server_identifier());
495  EXPECT_EQ("e", channel_id->private_key());
496
497  ++channel_id;
498  EXPECT_EQ("copied.com", channel_id->server_identifier());
499  EXPECT_EQ("g", channel_id->private_key());
500
501  ++channel_id;
502  EXPECT_EQ("preexisting.com", channel_id->server_identifier());
503  EXPECT_EQ("a", channel_id->private_key());
504}
505
506TEST(DefaultChannelIDStoreTest, TestAsyncInitializeFrom) {
507  scoped_refptr<MockPersistentStore> persistent_store(new MockPersistentStore);
508  persistent_store->AddChannelID(ChannelIDStore::ChannelID(
509      "preexisting.com",
510      base::Time(),
511      base::Time(),
512      "a", "b"));
513  persistent_store->AddChannelID(ChannelIDStore::ChannelID(
514      "both.com",
515      base::Time(),
516      base::Time(),
517      "c", "d"));
518
519  DefaultChannelIDStore store(persistent_store.get());
520  ChannelIDStore::ChannelIDList source_channel_ids;
521  source_channel_ids.push_back(ChannelIDStore::ChannelID(
522      "both.com",
523      base::Time(),
524      base::Time(),
525      // Key differs from above to test that existing entries are overwritten.
526      "e", "f"));
527  source_channel_ids.push_back(ChannelIDStore::ChannelID(
528      "copied.com",
529      base::Time(),
530      base::Time(),
531      "g", "h"));
532  store.InitializeFrom(source_channel_ids);
533  EXPECT_EQ(0, store.GetChannelIDCount());
534  // Wait for load & queued tasks.
535  base::MessageLoop::current()->RunUntilIdle();
536  EXPECT_EQ(3, store.GetChannelIDCount());
537
538  ChannelIDStore::ChannelIDList channel_ids;
539  store.GetAllChannelIDs(base::Bind(GetAllCallback, &channel_ids));
540  ASSERT_EQ(3u, channel_ids.size());
541
542  ChannelIDStore::ChannelIDList::iterator channel_id = channel_ids.begin();
543  EXPECT_EQ("both.com", channel_id->server_identifier());
544  EXPECT_EQ("e", channel_id->private_key());
545
546  ++channel_id;
547  EXPECT_EQ("copied.com", channel_id->server_identifier());
548  EXPECT_EQ("g", channel_id->private_key());
549
550  ++channel_id;
551  EXPECT_EQ("preexisting.com", channel_id->server_identifier());
552  EXPECT_EQ("a", channel_id->private_key());
553}
554
555}  // namespace net
556