1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/ssl/default_channel_id_store.h"
6
7#include "base/bind.h"
8#include "base/message_loop/message_loop.h"
9#include "base/metrics/histogram.h"
10#include "net/base/net_errors.h"
11
12namespace net {
13
14// --------------------------------------------------------------------------
15// Task
16class DefaultChannelIDStore::Task {
17 public:
18  virtual ~Task();
19
20  // Runs the task and invokes the client callback on the thread that
21  // originally constructed the task.
22  virtual void Run(DefaultChannelIDStore* store) = 0;
23
24 protected:
25  void InvokeCallback(base::Closure callback) const;
26};
27
28DefaultChannelIDStore::Task::~Task() {
29}
30
31void DefaultChannelIDStore::Task::InvokeCallback(
32    base::Closure callback) const {
33  if (!callback.is_null())
34    callback.Run();
35}
36
37// --------------------------------------------------------------------------
38// GetChannelIDTask
39class DefaultChannelIDStore::GetChannelIDTask
40    : public DefaultChannelIDStore::Task {
41 public:
42  GetChannelIDTask(const std::string& server_identifier,
43                   const GetChannelIDCallback& callback);
44  virtual ~GetChannelIDTask();
45  virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
46
47 private:
48  std::string server_identifier_;
49  GetChannelIDCallback callback_;
50};
51
52DefaultChannelIDStore::GetChannelIDTask::GetChannelIDTask(
53    const std::string& server_identifier,
54    const GetChannelIDCallback& callback)
55    : server_identifier_(server_identifier),
56      callback_(callback) {
57}
58
59DefaultChannelIDStore::GetChannelIDTask::~GetChannelIDTask() {
60}
61
62void DefaultChannelIDStore::GetChannelIDTask::Run(
63    DefaultChannelIDStore* store) {
64  base::Time expiration_time;
65  std::string private_key_result;
66  std::string cert_result;
67  int err = store->GetChannelID(
68      server_identifier_, &expiration_time, &private_key_result,
69      &cert_result, GetChannelIDCallback());
70  DCHECK(err != ERR_IO_PENDING);
71
72  InvokeCallback(base::Bind(callback_, err, server_identifier_,
73                            expiration_time, private_key_result, cert_result));
74}
75
76// --------------------------------------------------------------------------
77// SetChannelIDTask
78class DefaultChannelIDStore::SetChannelIDTask
79    : public DefaultChannelIDStore::Task {
80 public:
81  SetChannelIDTask(const std::string& server_identifier,
82                   base::Time creation_time,
83                   base::Time expiration_time,
84                   const std::string& private_key,
85                   const std::string& cert);
86  virtual ~SetChannelIDTask();
87  virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
88
89 private:
90  std::string server_identifier_;
91  base::Time creation_time_;
92  base::Time expiration_time_;
93  std::string private_key_;
94  std::string cert_;
95};
96
97DefaultChannelIDStore::SetChannelIDTask::SetChannelIDTask(
98    const std::string& server_identifier,
99    base::Time creation_time,
100    base::Time expiration_time,
101    const std::string& private_key,
102    const std::string& cert)
103    : server_identifier_(server_identifier),
104      creation_time_(creation_time),
105      expiration_time_(expiration_time),
106      private_key_(private_key),
107      cert_(cert) {
108}
109
110DefaultChannelIDStore::SetChannelIDTask::~SetChannelIDTask() {
111}
112
113void DefaultChannelIDStore::SetChannelIDTask::Run(
114    DefaultChannelIDStore* store) {
115  store->SyncSetChannelID(server_identifier_, creation_time_,
116                          expiration_time_, private_key_, cert_);
117}
118
119// --------------------------------------------------------------------------
120// DeleteChannelIDTask
121class DefaultChannelIDStore::DeleteChannelIDTask
122    : public DefaultChannelIDStore::Task {
123 public:
124  DeleteChannelIDTask(const std::string& server_identifier,
125                      const base::Closure& callback);
126  virtual ~DeleteChannelIDTask();
127  virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
128
129 private:
130  std::string server_identifier_;
131  base::Closure callback_;
132};
133
134DefaultChannelIDStore::DeleteChannelIDTask::
135    DeleteChannelIDTask(
136        const std::string& server_identifier,
137        const base::Closure& callback)
138        : server_identifier_(server_identifier),
139          callback_(callback) {
140}
141
142DefaultChannelIDStore::DeleteChannelIDTask::
143    ~DeleteChannelIDTask() {
144}
145
146void DefaultChannelIDStore::DeleteChannelIDTask::Run(
147    DefaultChannelIDStore* store) {
148  store->SyncDeleteChannelID(server_identifier_);
149
150  InvokeCallback(callback_);
151}
152
153// --------------------------------------------------------------------------
154// DeleteAllCreatedBetweenTask
155class DefaultChannelIDStore::DeleteAllCreatedBetweenTask
156    : public DefaultChannelIDStore::Task {
157 public:
158  DeleteAllCreatedBetweenTask(base::Time delete_begin,
159                              base::Time delete_end,
160                              const base::Closure& callback);
161  virtual ~DeleteAllCreatedBetweenTask();
162  virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
163
164 private:
165  base::Time delete_begin_;
166  base::Time delete_end_;
167  base::Closure callback_;
168};
169
170DefaultChannelIDStore::DeleteAllCreatedBetweenTask::
171    DeleteAllCreatedBetweenTask(
172        base::Time delete_begin,
173        base::Time delete_end,
174        const base::Closure& callback)
175        : delete_begin_(delete_begin),
176          delete_end_(delete_end),
177          callback_(callback) {
178}
179
180DefaultChannelIDStore::DeleteAllCreatedBetweenTask::
181    ~DeleteAllCreatedBetweenTask() {
182}
183
184void DefaultChannelIDStore::DeleteAllCreatedBetweenTask::Run(
185    DefaultChannelIDStore* store) {
186  store->SyncDeleteAllCreatedBetween(delete_begin_, delete_end_);
187
188  InvokeCallback(callback_);
189}
190
191// --------------------------------------------------------------------------
192// GetAllChannelIDsTask
193class DefaultChannelIDStore::GetAllChannelIDsTask
194    : public DefaultChannelIDStore::Task {
195 public:
196  explicit GetAllChannelIDsTask(const GetChannelIDListCallback& callback);
197  virtual ~GetAllChannelIDsTask();
198  virtual void Run(DefaultChannelIDStore* store) OVERRIDE;
199
200 private:
201  std::string server_identifier_;
202  GetChannelIDListCallback callback_;
203};
204
205DefaultChannelIDStore::GetAllChannelIDsTask::
206    GetAllChannelIDsTask(const GetChannelIDListCallback& callback)
207        : callback_(callback) {
208}
209
210DefaultChannelIDStore::GetAllChannelIDsTask::
211    ~GetAllChannelIDsTask() {
212}
213
214void DefaultChannelIDStore::GetAllChannelIDsTask::Run(
215    DefaultChannelIDStore* store) {
216  ChannelIDList cert_list;
217  store->SyncGetAllChannelIDs(&cert_list);
218
219  InvokeCallback(base::Bind(callback_, cert_list));
220}
221
222// --------------------------------------------------------------------------
223// DefaultChannelIDStore
224
225DefaultChannelIDStore::DefaultChannelIDStore(
226    PersistentStore* store)
227    : initialized_(false),
228      loaded_(false),
229      store_(store),
230      weak_ptr_factory_(this) {}
231
232int DefaultChannelIDStore::GetChannelID(
233    const std::string& server_identifier,
234    base::Time* expiration_time,
235    std::string* private_key_result,
236    std::string* cert_result,
237    const GetChannelIDCallback& callback) {
238  DCHECK(CalledOnValidThread());
239  InitIfNecessary();
240
241  if (!loaded_) {
242    EnqueueTask(scoped_ptr<Task>(
243        new GetChannelIDTask(server_identifier, callback)));
244    return ERR_IO_PENDING;
245  }
246
247  ChannelIDMap::iterator it = channel_ids_.find(server_identifier);
248
249  if (it == channel_ids_.end())
250    return ERR_FILE_NOT_FOUND;
251
252  ChannelID* channel_id = it->second;
253  *expiration_time = channel_id->expiration_time();
254  *private_key_result = channel_id->private_key();
255  *cert_result = channel_id->cert();
256
257  return OK;
258}
259
260void DefaultChannelIDStore::SetChannelID(
261    const std::string& server_identifier,
262    base::Time creation_time,
263    base::Time expiration_time,
264    const std::string& private_key,
265    const std::string& cert) {
266  RunOrEnqueueTask(scoped_ptr<Task>(new SetChannelIDTask(
267      server_identifier, creation_time, expiration_time, private_key,
268      cert)));
269}
270
271void DefaultChannelIDStore::DeleteChannelID(
272    const std::string& server_identifier,
273    const base::Closure& callback) {
274  RunOrEnqueueTask(scoped_ptr<Task>(
275      new DeleteChannelIDTask(server_identifier, callback)));
276}
277
278void DefaultChannelIDStore::DeleteAllCreatedBetween(
279    base::Time delete_begin,
280    base::Time delete_end,
281    const base::Closure& callback) {
282  RunOrEnqueueTask(scoped_ptr<Task>(
283      new DeleteAllCreatedBetweenTask(delete_begin, delete_end, callback)));
284}
285
286void DefaultChannelIDStore::DeleteAll(
287    const base::Closure& callback) {
288  DeleteAllCreatedBetween(base::Time(), base::Time(), callback);
289}
290
291void DefaultChannelIDStore::GetAllChannelIDs(
292    const GetChannelIDListCallback& callback) {
293  RunOrEnqueueTask(scoped_ptr<Task>(new GetAllChannelIDsTask(callback)));
294}
295
296int DefaultChannelIDStore::GetChannelIDCount() {
297  DCHECK(CalledOnValidThread());
298
299  return channel_ids_.size();
300}
301
302void DefaultChannelIDStore::SetForceKeepSessionState() {
303  DCHECK(CalledOnValidThread());
304  InitIfNecessary();
305
306  if (store_.get())
307    store_->SetForceKeepSessionState();
308}
309
310DefaultChannelIDStore::~DefaultChannelIDStore() {
311  DeleteAllInMemory();
312}
313
314void DefaultChannelIDStore::DeleteAllInMemory() {
315  DCHECK(CalledOnValidThread());
316
317  for (ChannelIDMap::iterator it = channel_ids_.begin();
318       it != channel_ids_.end(); ++it) {
319    delete it->second;
320  }
321  channel_ids_.clear();
322}
323
324void DefaultChannelIDStore::InitStore() {
325  DCHECK(CalledOnValidThread());
326  DCHECK(store_.get()) << "Store must exist to initialize";
327  DCHECK(!loaded_);
328
329  store_->Load(base::Bind(&DefaultChannelIDStore::OnLoaded,
330                          weak_ptr_factory_.GetWeakPtr()));
331}
332
333void DefaultChannelIDStore::OnLoaded(
334    scoped_ptr<ScopedVector<ChannelID> > channel_ids) {
335  DCHECK(CalledOnValidThread());
336
337  for (std::vector<ChannelID*>::const_iterator it = channel_ids->begin();
338       it != channel_ids->end(); ++it) {
339    DCHECK(channel_ids_.find((*it)->server_identifier()) ==
340           channel_ids_.end());
341    channel_ids_[(*it)->server_identifier()] = *it;
342  }
343  channel_ids->weak_clear();
344
345  loaded_ = true;
346
347  base::TimeDelta wait_time;
348  if (!waiting_tasks_.empty())
349    wait_time = base::TimeTicks::Now() - waiting_tasks_start_time_;
350  DVLOG(1) << "Task delay " << wait_time.InMilliseconds();
351  UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.TaskMaxWaitTime",
352                             wait_time,
353                             base::TimeDelta::FromMilliseconds(1),
354                             base::TimeDelta::FromMinutes(1),
355                             50);
356  UMA_HISTOGRAM_COUNTS_100("DomainBoundCerts.TaskWaitCount",
357                           waiting_tasks_.size());
358
359
360  for (ScopedVector<Task>::iterator i = waiting_tasks_.begin();
361       i != waiting_tasks_.end(); ++i)
362    (*i)->Run(this);
363  waiting_tasks_.clear();
364}
365
366void DefaultChannelIDStore::SyncSetChannelID(
367    const std::string& server_identifier,
368    base::Time creation_time,
369    base::Time expiration_time,
370    const std::string& private_key,
371    const std::string& cert) {
372  DCHECK(CalledOnValidThread());
373  DCHECK(loaded_);
374
375  InternalDeleteChannelID(server_identifier);
376  InternalInsertChannelID(
377      server_identifier,
378      new ChannelID(
379          server_identifier, creation_time, expiration_time, private_key,
380          cert));
381}
382
383void DefaultChannelIDStore::SyncDeleteChannelID(
384    const std::string& server_identifier) {
385  DCHECK(CalledOnValidThread());
386  DCHECK(loaded_);
387  InternalDeleteChannelID(server_identifier);
388}
389
390void DefaultChannelIDStore::SyncDeleteAllCreatedBetween(
391    base::Time delete_begin,
392    base::Time delete_end) {
393  DCHECK(CalledOnValidThread());
394  DCHECK(loaded_);
395  for (ChannelIDMap::iterator it = channel_ids_.begin();
396       it != channel_ids_.end();) {
397    ChannelIDMap::iterator cur = it;
398    ++it;
399    ChannelID* channel_id = cur->second;
400    if ((delete_begin.is_null() ||
401         channel_id->creation_time() >= delete_begin) &&
402        (delete_end.is_null() || channel_id->creation_time() < delete_end)) {
403      if (store_.get())
404        store_->DeleteChannelID(*channel_id);
405      delete channel_id;
406      channel_ids_.erase(cur);
407    }
408  }
409}
410
411void DefaultChannelIDStore::SyncGetAllChannelIDs(
412    ChannelIDList* channel_id_list) {
413  DCHECK(CalledOnValidThread());
414  DCHECK(loaded_);
415  for (ChannelIDMap::iterator it = channel_ids_.begin();
416       it != channel_ids_.end(); ++it)
417    channel_id_list->push_back(*it->second);
418}
419
420void DefaultChannelIDStore::EnqueueTask(scoped_ptr<Task> task) {
421  DCHECK(CalledOnValidThread());
422  DCHECK(!loaded_);
423  if (waiting_tasks_.empty())
424    waiting_tasks_start_time_ = base::TimeTicks::Now();
425  waiting_tasks_.push_back(task.release());
426}
427
428void DefaultChannelIDStore::RunOrEnqueueTask(scoped_ptr<Task> task) {
429  DCHECK(CalledOnValidThread());
430  InitIfNecessary();
431
432  if (!loaded_) {
433    EnqueueTask(task.Pass());
434    return;
435  }
436
437  task->Run(this);
438}
439
440void DefaultChannelIDStore::InternalDeleteChannelID(
441    const std::string& server_identifier) {
442  DCHECK(CalledOnValidThread());
443  DCHECK(loaded_);
444
445  ChannelIDMap::iterator it = channel_ids_.find(server_identifier);
446  if (it == channel_ids_.end())
447    return;  // There is nothing to delete.
448
449  ChannelID* channel_id = it->second;
450  if (store_.get())
451    store_->DeleteChannelID(*channel_id);
452  channel_ids_.erase(it);
453  delete channel_id;
454}
455
456void DefaultChannelIDStore::InternalInsertChannelID(
457    const std::string& server_identifier,
458    ChannelID* channel_id) {
459  DCHECK(CalledOnValidThread());
460  DCHECK(loaded_);
461
462  if (store_.get())
463    store_->AddChannelID(*channel_id);
464  channel_ids_[server_identifier] = channel_id;
465}
466
467DefaultChannelIDStore::PersistentStore::PersistentStore() {}
468
469DefaultChannelIDStore::PersistentStore::~PersistentStore() {}
470
471}  // namespace net
472