1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/ssl/default_server_bound_cert_store.h"
6
7#include "base/bind.h"
8#include "base/message_loop/message_loop.h"
9#include "base/metrics/histogram.h"
10#include "net/base/net_errors.h"
11
12namespace net {
13
14// --------------------------------------------------------------------------
15// Task
16class DefaultServerBoundCertStore::Task {
17 public:
18  virtual ~Task();
19
20  // Runs the task and invokes the client callback on the thread that
21  // originally constructed the task.
22  virtual void Run(DefaultServerBoundCertStore* store) = 0;
23
24 protected:
25  void InvokeCallback(base::Closure callback) const;
26};
27
28DefaultServerBoundCertStore::Task::~Task() {
29}
30
31void DefaultServerBoundCertStore::Task::InvokeCallback(
32    base::Closure callback) const {
33  if (!callback.is_null())
34    callback.Run();
35}
36
37// --------------------------------------------------------------------------
38// GetServerBoundCertTask
39class DefaultServerBoundCertStore::GetServerBoundCertTask
40    : public DefaultServerBoundCertStore::Task {
41 public:
42  GetServerBoundCertTask(const std::string& server_identifier,
43                         const GetCertCallback& callback);
44  virtual ~GetServerBoundCertTask();
45  virtual void Run(DefaultServerBoundCertStore* store) OVERRIDE;
46
47 private:
48  std::string server_identifier_;
49  GetCertCallback callback_;
50};
51
52DefaultServerBoundCertStore::GetServerBoundCertTask::GetServerBoundCertTask(
53    const std::string& server_identifier,
54    const GetCertCallback& callback)
55    : server_identifier_(server_identifier),
56      callback_(callback) {
57}
58
59DefaultServerBoundCertStore::GetServerBoundCertTask::~GetServerBoundCertTask() {
60}
61
62void DefaultServerBoundCertStore::GetServerBoundCertTask::Run(
63    DefaultServerBoundCertStore* store) {
64  base::Time expiration_time;
65  std::string private_key_result;
66  std::string cert_result;
67  int err = store->GetServerBoundCert(
68      server_identifier_, &expiration_time, &private_key_result,
69      &cert_result, GetCertCallback());
70  DCHECK(err != ERR_IO_PENDING);
71
72  InvokeCallback(base::Bind(callback_, err, server_identifier_,
73                            expiration_time, private_key_result, cert_result));
74}
75
76// --------------------------------------------------------------------------
77// SetServerBoundCertTask
78class DefaultServerBoundCertStore::SetServerBoundCertTask
79    : public DefaultServerBoundCertStore::Task {
80 public:
81  SetServerBoundCertTask(const std::string& server_identifier,
82                         base::Time creation_time,
83                         base::Time expiration_time,
84                         const std::string& private_key,
85                         const std::string& cert);
86  virtual ~SetServerBoundCertTask();
87  virtual void Run(DefaultServerBoundCertStore* store) OVERRIDE;
88
89 private:
90  std::string server_identifier_;
91  base::Time creation_time_;
92  base::Time expiration_time_;
93  std::string private_key_;
94  std::string cert_;
95};
96
97DefaultServerBoundCertStore::SetServerBoundCertTask::SetServerBoundCertTask(
98    const std::string& server_identifier,
99    base::Time creation_time,
100    base::Time expiration_time,
101    const std::string& private_key,
102    const std::string& cert)
103    : server_identifier_(server_identifier),
104      creation_time_(creation_time),
105      expiration_time_(expiration_time),
106      private_key_(private_key),
107      cert_(cert) {
108}
109
110DefaultServerBoundCertStore::SetServerBoundCertTask::~SetServerBoundCertTask() {
111}
112
113void DefaultServerBoundCertStore::SetServerBoundCertTask::Run(
114    DefaultServerBoundCertStore* store) {
115  store->SyncSetServerBoundCert(server_identifier_, creation_time_,
116                                expiration_time_, private_key_, cert_);
117}
118
119// --------------------------------------------------------------------------
120// DeleteServerBoundCertTask
121class DefaultServerBoundCertStore::DeleteServerBoundCertTask
122    : public DefaultServerBoundCertStore::Task {
123 public:
124  DeleteServerBoundCertTask(const std::string& server_identifier,
125                            const base::Closure& callback);
126  virtual ~DeleteServerBoundCertTask();
127  virtual void Run(DefaultServerBoundCertStore* store) OVERRIDE;
128
129 private:
130  std::string server_identifier_;
131  base::Closure callback_;
132};
133
134DefaultServerBoundCertStore::DeleteServerBoundCertTask::
135    DeleteServerBoundCertTask(
136        const std::string& server_identifier,
137        const base::Closure& callback)
138        : server_identifier_(server_identifier),
139          callback_(callback) {
140}
141
142DefaultServerBoundCertStore::DeleteServerBoundCertTask::
143    ~DeleteServerBoundCertTask() {
144}
145
146void DefaultServerBoundCertStore::DeleteServerBoundCertTask::Run(
147    DefaultServerBoundCertStore* store) {
148  store->SyncDeleteServerBoundCert(server_identifier_);
149
150  InvokeCallback(callback_);
151}
152
153// --------------------------------------------------------------------------
154// DeleteAllCreatedBetweenTask
155class DefaultServerBoundCertStore::DeleteAllCreatedBetweenTask
156    : public DefaultServerBoundCertStore::Task {
157 public:
158  DeleteAllCreatedBetweenTask(base::Time delete_begin,
159                              base::Time delete_end,
160                              const base::Closure& callback);
161  virtual ~DeleteAllCreatedBetweenTask();
162  virtual void Run(DefaultServerBoundCertStore* store) OVERRIDE;
163
164 private:
165  base::Time delete_begin_;
166  base::Time delete_end_;
167  base::Closure callback_;
168};
169
170DefaultServerBoundCertStore::DeleteAllCreatedBetweenTask::
171    DeleteAllCreatedBetweenTask(
172        base::Time delete_begin,
173        base::Time delete_end,
174        const base::Closure& callback)
175        : delete_begin_(delete_begin),
176          delete_end_(delete_end),
177          callback_(callback) {
178}
179
180DefaultServerBoundCertStore::DeleteAllCreatedBetweenTask::
181    ~DeleteAllCreatedBetweenTask() {
182}
183
184void DefaultServerBoundCertStore::DeleteAllCreatedBetweenTask::Run(
185    DefaultServerBoundCertStore* store) {
186  store->SyncDeleteAllCreatedBetween(delete_begin_, delete_end_);
187
188  InvokeCallback(callback_);
189}
190
191// --------------------------------------------------------------------------
192// GetAllServerBoundCertsTask
193class DefaultServerBoundCertStore::GetAllServerBoundCertsTask
194    : public DefaultServerBoundCertStore::Task {
195 public:
196  explicit GetAllServerBoundCertsTask(const GetCertListCallback& callback);
197  virtual ~GetAllServerBoundCertsTask();
198  virtual void Run(DefaultServerBoundCertStore* store) OVERRIDE;
199
200 private:
201  std::string server_identifier_;
202  GetCertListCallback callback_;
203};
204
205DefaultServerBoundCertStore::GetAllServerBoundCertsTask::
206    GetAllServerBoundCertsTask(const GetCertListCallback& callback)
207        : callback_(callback) {
208}
209
210DefaultServerBoundCertStore::GetAllServerBoundCertsTask::
211    ~GetAllServerBoundCertsTask() {
212}
213
214void DefaultServerBoundCertStore::GetAllServerBoundCertsTask::Run(
215    DefaultServerBoundCertStore* store) {
216  ServerBoundCertList cert_list;
217  store->SyncGetAllServerBoundCerts(&cert_list);
218
219  InvokeCallback(base::Bind(callback_, cert_list));
220}
221
222// --------------------------------------------------------------------------
223// DefaultServerBoundCertStore
224
225// static
226const size_t DefaultServerBoundCertStore::kMaxCerts = 3300;
227
228DefaultServerBoundCertStore::DefaultServerBoundCertStore(
229    PersistentStore* store)
230    : initialized_(false),
231      loaded_(false),
232      store_(store),
233      weak_ptr_factory_(this) {}
234
235int DefaultServerBoundCertStore::GetServerBoundCert(
236    const std::string& server_identifier,
237    base::Time* expiration_time,
238    std::string* private_key_result,
239    std::string* cert_result,
240    const GetCertCallback& callback) {
241  DCHECK(CalledOnValidThread());
242  InitIfNecessary();
243
244  if (!loaded_) {
245    EnqueueTask(scoped_ptr<Task>(
246        new GetServerBoundCertTask(server_identifier, callback)));
247    return ERR_IO_PENDING;
248  }
249
250  ServerBoundCertMap::iterator it = server_bound_certs_.find(server_identifier);
251
252  if (it == server_bound_certs_.end())
253    return ERR_FILE_NOT_FOUND;
254
255  ServerBoundCert* cert = it->second;
256  *expiration_time = cert->expiration_time();
257  *private_key_result = cert->private_key();
258  *cert_result = cert->cert();
259
260  return OK;
261}
262
263void DefaultServerBoundCertStore::SetServerBoundCert(
264    const std::string& server_identifier,
265    base::Time creation_time,
266    base::Time expiration_time,
267    const std::string& private_key,
268    const std::string& cert) {
269  RunOrEnqueueTask(scoped_ptr<Task>(new SetServerBoundCertTask(
270      server_identifier, creation_time, expiration_time, private_key,
271      cert)));
272}
273
274void DefaultServerBoundCertStore::DeleteServerBoundCert(
275    const std::string& server_identifier,
276    const base::Closure& callback) {
277  RunOrEnqueueTask(scoped_ptr<Task>(
278      new DeleteServerBoundCertTask(server_identifier, callback)));
279}
280
281void DefaultServerBoundCertStore::DeleteAllCreatedBetween(
282    base::Time delete_begin,
283    base::Time delete_end,
284    const base::Closure& callback) {
285  RunOrEnqueueTask(scoped_ptr<Task>(
286      new DeleteAllCreatedBetweenTask(delete_begin, delete_end, callback)));
287}
288
289void DefaultServerBoundCertStore::DeleteAll(
290    const base::Closure& callback) {
291  DeleteAllCreatedBetween(base::Time(), base::Time(), callback);
292}
293
294void DefaultServerBoundCertStore::GetAllServerBoundCerts(
295    const GetCertListCallback& callback) {
296  RunOrEnqueueTask(scoped_ptr<Task>(new GetAllServerBoundCertsTask(callback)));
297}
298
299int DefaultServerBoundCertStore::GetCertCount() {
300  DCHECK(CalledOnValidThread());
301
302  return server_bound_certs_.size();
303}
304
305void DefaultServerBoundCertStore::SetForceKeepSessionState() {
306  DCHECK(CalledOnValidThread());
307  InitIfNecessary();
308
309  if (store_.get())
310    store_->SetForceKeepSessionState();
311}
312
313DefaultServerBoundCertStore::~DefaultServerBoundCertStore() {
314  DeleteAllInMemory();
315}
316
317void DefaultServerBoundCertStore::DeleteAllInMemory() {
318  DCHECK(CalledOnValidThread());
319
320  for (ServerBoundCertMap::iterator it = server_bound_certs_.begin();
321       it != server_bound_certs_.end(); ++it) {
322    delete it->second;
323  }
324  server_bound_certs_.clear();
325}
326
327void DefaultServerBoundCertStore::InitStore() {
328  DCHECK(CalledOnValidThread());
329  DCHECK(store_.get()) << "Store must exist to initialize";
330  DCHECK(!loaded_);
331
332  store_->Load(base::Bind(&DefaultServerBoundCertStore::OnLoaded,
333                          weak_ptr_factory_.GetWeakPtr()));
334}
335
336void DefaultServerBoundCertStore::OnLoaded(
337    scoped_ptr<ScopedVector<ServerBoundCert> > certs) {
338  DCHECK(CalledOnValidThread());
339
340  for (std::vector<ServerBoundCert*>::const_iterator it = certs->begin();
341       it != certs->end(); ++it) {
342    DCHECK(server_bound_certs_.find((*it)->server_identifier()) ==
343           server_bound_certs_.end());
344    server_bound_certs_[(*it)->server_identifier()] = *it;
345  }
346  certs->weak_clear();
347
348  loaded_ = true;
349
350  base::TimeDelta wait_time;
351  if (!waiting_tasks_.empty())
352    wait_time = base::TimeTicks::Now() - waiting_tasks_start_time_;
353  DVLOG(1) << "Task delay " << wait_time.InMilliseconds();
354  UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.TaskMaxWaitTime",
355                             wait_time,
356                             base::TimeDelta::FromMilliseconds(1),
357                             base::TimeDelta::FromMinutes(1),
358                             50);
359  UMA_HISTOGRAM_COUNTS_100("DomainBoundCerts.TaskWaitCount",
360                           waiting_tasks_.size());
361
362
363  for (ScopedVector<Task>::iterator i = waiting_tasks_.begin();
364       i != waiting_tasks_.end(); ++i)
365    (*i)->Run(this);
366  waiting_tasks_.clear();
367}
368
369void DefaultServerBoundCertStore::SyncSetServerBoundCert(
370    const std::string& server_identifier,
371    base::Time creation_time,
372    base::Time expiration_time,
373    const std::string& private_key,
374    const std::string& cert) {
375  DCHECK(CalledOnValidThread());
376  DCHECK(loaded_);
377
378  InternalDeleteServerBoundCert(server_identifier);
379  InternalInsertServerBoundCert(
380      server_identifier,
381      new ServerBoundCert(
382          server_identifier, creation_time, expiration_time, private_key,
383          cert));
384}
385
386void DefaultServerBoundCertStore::SyncDeleteServerBoundCert(
387    const std::string& server_identifier) {
388  DCHECK(CalledOnValidThread());
389  DCHECK(loaded_);
390  InternalDeleteServerBoundCert(server_identifier);
391}
392
393void DefaultServerBoundCertStore::SyncDeleteAllCreatedBetween(
394    base::Time delete_begin,
395    base::Time delete_end) {
396  DCHECK(CalledOnValidThread());
397  DCHECK(loaded_);
398  for (ServerBoundCertMap::iterator it = server_bound_certs_.begin();
399       it != server_bound_certs_.end();) {
400    ServerBoundCertMap::iterator cur = it;
401    ++it;
402    ServerBoundCert* cert = cur->second;
403    if ((delete_begin.is_null() || cert->creation_time() >= delete_begin) &&
404        (delete_end.is_null() || cert->creation_time() < delete_end)) {
405      if (store_.get())
406        store_->DeleteServerBoundCert(*cert);
407      delete cert;
408      server_bound_certs_.erase(cur);
409    }
410  }
411}
412
413void DefaultServerBoundCertStore::SyncGetAllServerBoundCerts(
414    ServerBoundCertList* cert_list) {
415  DCHECK(CalledOnValidThread());
416  DCHECK(loaded_);
417  for (ServerBoundCertMap::iterator it = server_bound_certs_.begin();
418       it != server_bound_certs_.end(); ++it)
419    cert_list->push_back(*it->second);
420}
421
422void DefaultServerBoundCertStore::EnqueueTask(scoped_ptr<Task> task) {
423  DCHECK(CalledOnValidThread());
424  DCHECK(!loaded_);
425  if (waiting_tasks_.empty())
426    waiting_tasks_start_time_ = base::TimeTicks::Now();
427  waiting_tasks_.push_back(task.release());
428}
429
430void DefaultServerBoundCertStore::RunOrEnqueueTask(scoped_ptr<Task> task) {
431  DCHECK(CalledOnValidThread());
432  InitIfNecessary();
433
434  if (!loaded_) {
435    EnqueueTask(task.Pass());
436    return;
437  }
438
439  task->Run(this);
440}
441
442void DefaultServerBoundCertStore::InternalDeleteServerBoundCert(
443    const std::string& server_identifier) {
444  DCHECK(CalledOnValidThread());
445  DCHECK(loaded_);
446
447  ServerBoundCertMap::iterator it = server_bound_certs_.find(server_identifier);
448  if (it == server_bound_certs_.end())
449    return;  // There is nothing to delete.
450
451  ServerBoundCert* cert = it->second;
452  if (store_.get())
453    store_->DeleteServerBoundCert(*cert);
454  server_bound_certs_.erase(it);
455  delete cert;
456}
457
458void DefaultServerBoundCertStore::InternalInsertServerBoundCert(
459    const std::string& server_identifier,
460    ServerBoundCert* cert) {
461  DCHECK(CalledOnValidThread());
462  DCHECK(loaded_);
463
464  if (store_.get())
465    store_->AddServerBoundCert(*cert);
466  server_bound_certs_[server_identifier] = cert;
467}
468
469DefaultServerBoundCertStore::PersistentStore::PersistentStore() {}
470
471DefaultServerBoundCertStore::PersistentStore::~PersistentStore() {}
472
473}  // namespace net
474