sqlite_channel_id_store.cc revision 1320f92c476a1ad9d19dba2a48c72b75566198e9
1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/extras/sqlite/sqlite_channel_id_store.h"
6
7#include <set>
8
9#include "base/basictypes.h"
10#include "base/bind.h"
11#include "base/files/file_path.h"
12#include "base/files/file_util.h"
13#include "base/location.h"
14#include "base/logging.h"
15#include "base/memory/scoped_ptr.h"
16#include "base/memory/scoped_vector.h"
17#include "base/metrics/histogram.h"
18#include "base/sequenced_task_runner.h"
19#include "base/strings/string_util.h"
20#include "net/cert/x509_certificate.h"
21#include "net/cookies/cookie_util.h"
22#include "net/ssl/ssl_client_cert_type.h"
23#include "sql/error_delegate_util.h"
24#include "sql/meta_table.h"
25#include "sql/statement.h"
26#include "sql/transaction.h"
27#include "url/gurl.h"
28
29namespace {
30
31// Version number of the database.
32const int kCurrentVersionNumber = 4;
33const int kCompatibleVersionNumber = 1;
34
35// Initializes the certs table, returning true on success.
36bool InitTable(sql::Connection* db) {
37  // The table is named "origin_bound_certs" for backwards compatability before
38  // we renamed this class to SQLiteChannelIDStore.  Likewise, the primary
39  // key is "origin", but now can be other things like a plain domain.
40  if (!db->DoesTableExist("origin_bound_certs")) {
41    if (!db->Execute(
42            "CREATE TABLE origin_bound_certs ("
43            "origin TEXT NOT NULL UNIQUE PRIMARY KEY,"
44            "private_key BLOB NOT NULL,"
45            "cert BLOB NOT NULL,"
46            "cert_type INTEGER,"
47            "expiration_time INTEGER,"
48            "creation_time INTEGER)")) {
49      return false;
50    }
51  }
52
53  return true;
54}
55
56}  // namespace
57
58namespace net {
59
60// This class is designed to be shared between any calling threads and the
61// background task runner. It batches operations and commits them on a timer.
62class SQLiteChannelIDStore::Backend
63    : public base::RefCountedThreadSafe<SQLiteChannelIDStore::Backend> {
64 public:
65  Backend(
66      const base::FilePath& path,
67      const scoped_refptr<base::SequencedTaskRunner>& background_task_runner)
68      : path_(path),
69        num_pending_(0),
70        force_keep_session_state_(false),
71        background_task_runner_(background_task_runner),
72        corruption_detected_(false) {}
73
74  // Creates or loads the SQLite database.
75  void Load(const LoadedCallback& loaded_callback);
76
77  // Batch a channel ID addition.
78  void AddChannelID(const DefaultChannelIDStore::ChannelID& channel_id);
79
80  // Batch a channel ID deletion.
81  void DeleteChannelID(const DefaultChannelIDStore::ChannelID& channel_id);
82
83  // Post background delete of all channel ids for |server_identifiers|.
84  void DeleteAllInList(const std::list<std::string>& server_identifiers);
85
86  // Commit any pending operations and close the database.  This must be called
87  // before the object is destructed.
88  void Close();
89
90  void SetForceKeepSessionState();
91
92 private:
93  friend class base::RefCountedThreadSafe<SQLiteChannelIDStore::Backend>;
94
95  // You should call Close() before destructing this object.
96  virtual ~Backend() {
97    DCHECK(!db_.get()) << "Close should have already been called.";
98    DCHECK_EQ(0u, num_pending_);
99    DCHECK(pending_.empty());
100  }
101
102  void LoadInBackground(
103      ScopedVector<DefaultChannelIDStore::ChannelID>* channel_ids);
104
105  // Database upgrade statements.
106  bool EnsureDatabaseVersion();
107
108  class PendingOperation {
109   public:
110    enum OperationType { CHANNEL_ID_ADD, CHANNEL_ID_DELETE };
111
112    PendingOperation(OperationType op,
113                     const DefaultChannelIDStore::ChannelID& channel_id)
114        : op_(op), channel_id_(channel_id) {}
115
116    OperationType op() const { return op_; }
117    const DefaultChannelIDStore::ChannelID& channel_id() const {
118      return channel_id_;
119    }
120
121   private:
122    OperationType op_;
123    DefaultChannelIDStore::ChannelID channel_id_;
124  };
125
126 private:
127  // Batch a channel id operation (add or delete).
128  void BatchOperation(PendingOperation::OperationType op,
129                      const DefaultChannelIDStore::ChannelID& channel_id);
130  // Commit our pending operations to the database.
131  void Commit();
132  // Close() executed on the background task runner.
133  void InternalBackgroundClose();
134
135  void BackgroundDeleteAllInList(
136      const std::list<std::string>& server_identifiers);
137
138  void DatabaseErrorCallback(int error, sql::Statement* stmt);
139  void KillDatabase();
140
141  const base::FilePath path_;
142  scoped_ptr<sql::Connection> db_;
143  sql::MetaTable meta_table_;
144
145  typedef std::list<PendingOperation*> PendingOperationsList;
146  PendingOperationsList pending_;
147  PendingOperationsList::size_type num_pending_;
148  // True if the persistent store should skip clear on exit rules.
149  bool force_keep_session_state_;
150  // Guard |pending_|, |num_pending_| and |force_keep_session_state_|.
151  base::Lock lock_;
152
153  scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
154
155  // Indicates if the kill-database callback has been scheduled.
156  bool corruption_detected_;
157
158  DISALLOW_COPY_AND_ASSIGN(Backend);
159};
160
161void SQLiteChannelIDStore::Backend::Load(
162    const LoadedCallback& loaded_callback) {
163  // This function should be called only once per instance.
164  DCHECK(!db_.get());
165  scoped_ptr<ScopedVector<DefaultChannelIDStore::ChannelID> > channel_ids(
166      new ScopedVector<DefaultChannelIDStore::ChannelID>());
167  ScopedVector<DefaultChannelIDStore::ChannelID>* channel_ids_ptr =
168      channel_ids.get();
169
170  background_task_runner_->PostTaskAndReply(
171      FROM_HERE,
172      base::Bind(&Backend::LoadInBackground, this, channel_ids_ptr),
173      base::Bind(loaded_callback, base::Passed(&channel_ids)));
174}
175
176void SQLiteChannelIDStore::Backend::LoadInBackground(
177    ScopedVector<DefaultChannelIDStore::ChannelID>* channel_ids) {
178  DCHECK(background_task_runner_->RunsTasksOnCurrentThread());
179
180  // This method should be called only once per instance.
181  DCHECK(!db_.get());
182
183  base::TimeTicks start = base::TimeTicks::Now();
184
185  // Ensure the parent directory for storing certs is created before reading
186  // from it.
187  const base::FilePath dir = path_.DirName();
188  if (!base::PathExists(dir) && !base::CreateDirectory(dir))
189    return;
190
191  int64 db_size = 0;
192  if (base::GetFileSize(path_, &db_size))
193    UMA_HISTOGRAM_COUNTS("DomainBoundCerts.DBSizeInKB", db_size / 1024);
194
195  db_.reset(new sql::Connection);
196  db_->set_histogram_tag("DomainBoundCerts");
197
198  // Unretained to avoid a ref loop with db_.
199  db_->set_error_callback(
200      base::Bind(&SQLiteChannelIDStore::Backend::DatabaseErrorCallback,
201                 base::Unretained(this)));
202
203  if (!db_->Open(path_)) {
204    NOTREACHED() << "Unable to open cert DB.";
205    if (corruption_detected_)
206      KillDatabase();
207    db_.reset();
208    return;
209  }
210
211  if (!EnsureDatabaseVersion() || !InitTable(db_.get())) {
212    NOTREACHED() << "Unable to open cert DB.";
213    if (corruption_detected_)
214      KillDatabase();
215    meta_table_.Reset();
216    db_.reset();
217    return;
218  }
219
220  db_->Preload();
221
222  // Slurp all the certs into the out-vector.
223  sql::Statement smt(db_->GetUniqueStatement(
224      "SELECT origin, private_key, cert, cert_type, expiration_time, "
225      "creation_time FROM origin_bound_certs"));
226  if (!smt.is_valid()) {
227    if (corruption_detected_)
228      KillDatabase();
229    meta_table_.Reset();
230    db_.reset();
231    return;
232  }
233
234  while (smt.Step()) {
235    SSLClientCertType type = static_cast<SSLClientCertType>(smt.ColumnInt(3));
236    if (type != CLIENT_CERT_ECDSA_SIGN)
237      continue;
238    std::string private_key_from_db, cert_from_db;
239    smt.ColumnBlobAsString(1, &private_key_from_db);
240    smt.ColumnBlobAsString(2, &cert_from_db);
241    scoped_ptr<DefaultChannelIDStore::ChannelID> channel_id(
242        new DefaultChannelIDStore::ChannelID(
243            smt.ColumnString(0),  // origin
244            base::Time::FromInternalValue(smt.ColumnInt64(5)),
245            base::Time::FromInternalValue(smt.ColumnInt64(4)),
246            private_key_from_db,
247            cert_from_db));
248    channel_ids->push_back(channel_id.release());
249  }
250
251  UMA_HISTOGRAM_COUNTS_10000(
252      "DomainBoundCerts.DBLoadedCount",
253      static_cast<base::HistogramBase::Sample>(channel_ids->size()));
254  base::TimeDelta load_time = base::TimeTicks::Now() - start;
255  UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.DBLoadTime",
256                             load_time,
257                             base::TimeDelta::FromMilliseconds(1),
258                             base::TimeDelta::FromMinutes(1),
259                             50);
260  DVLOG(1) << "loaded " << channel_ids->size() << " in "
261           << load_time.InMilliseconds() << " ms";
262}
263
264bool SQLiteChannelIDStore::Backend::EnsureDatabaseVersion() {
265  // Version check.
266  if (!meta_table_.Init(
267          db_.get(), kCurrentVersionNumber, kCompatibleVersionNumber)) {
268    return false;
269  }
270
271  if (meta_table_.GetCompatibleVersionNumber() > kCurrentVersionNumber) {
272    LOG(WARNING) << "Server bound cert database is too new.";
273    return false;
274  }
275
276  int cur_version = meta_table_.GetVersionNumber();
277  if (cur_version == 1) {
278    sql::Transaction transaction(db_.get());
279    if (!transaction.Begin())
280      return false;
281    if (!db_->Execute(
282            "ALTER TABLE origin_bound_certs ADD COLUMN cert_type "
283            "INTEGER")) {
284      LOG(WARNING) << "Unable to update server bound cert database to "
285                   << "version 2.";
286      return false;
287    }
288    // All certs in version 1 database are rsa_sign, which are unsupported.
289    // Just discard them all.
290    if (!db_->Execute("DELETE from origin_bound_certs")) {
291      LOG(WARNING) << "Unable to update server bound cert database to "
292                   << "version 2.";
293      return false;
294    }
295    ++cur_version;
296    meta_table_.SetVersionNumber(cur_version);
297    meta_table_.SetCompatibleVersionNumber(
298        std::min(cur_version, kCompatibleVersionNumber));
299    transaction.Commit();
300  }
301
302  if (cur_version <= 3) {
303    sql::Transaction transaction(db_.get());
304    if (!transaction.Begin())
305      return false;
306
307    if (cur_version == 2) {
308      if (!db_->Execute(
309              "ALTER TABLE origin_bound_certs ADD COLUMN "
310              "expiration_time INTEGER")) {
311        LOG(WARNING) << "Unable to update server bound cert database to "
312                     << "version 4.";
313        return false;
314      }
315    }
316
317    if (!db_->Execute(
318            "ALTER TABLE origin_bound_certs ADD COLUMN "
319            "creation_time INTEGER")) {
320      LOG(WARNING) << "Unable to update server bound cert database to "
321                   << "version 4.";
322      return false;
323    }
324
325    sql::Statement statement(
326        db_->GetUniqueStatement("SELECT origin, cert FROM origin_bound_certs"));
327    sql::Statement update_expires_statement(db_->GetUniqueStatement(
328        "UPDATE origin_bound_certs SET expiration_time = ? WHERE origin = ?"));
329    sql::Statement update_creation_statement(db_->GetUniqueStatement(
330        "UPDATE origin_bound_certs SET creation_time = ? WHERE origin = ?"));
331    if (!statement.is_valid() || !update_expires_statement.is_valid() ||
332        !update_creation_statement.is_valid()) {
333      LOG(WARNING) << "Unable to update server bound cert database to "
334                   << "version 4.";
335      return false;
336    }
337
338    while (statement.Step()) {
339      std::string origin = statement.ColumnString(0);
340      std::string cert_from_db;
341      statement.ColumnBlobAsString(1, &cert_from_db);
342      // Parse the cert and extract the real value and then update the DB.
343      scoped_refptr<X509Certificate> cert(X509Certificate::CreateFromBytes(
344          cert_from_db.data(), static_cast<int>(cert_from_db.size())));
345      if (cert.get()) {
346        if (cur_version == 2) {
347          update_expires_statement.Reset(true);
348          update_expires_statement.BindInt64(
349              0, cert->valid_expiry().ToInternalValue());
350          update_expires_statement.BindString(1, origin);
351          if (!update_expires_statement.Run()) {
352            LOG(WARNING) << "Unable to update server bound cert database to "
353                         << "version 4.";
354            return false;
355          }
356        }
357
358        update_creation_statement.Reset(true);
359        update_creation_statement.BindInt64(
360            0, cert->valid_start().ToInternalValue());
361        update_creation_statement.BindString(1, origin);
362        if (!update_creation_statement.Run()) {
363          LOG(WARNING) << "Unable to update server bound cert database to "
364                       << "version 4.";
365          return false;
366        }
367      } else {
368        // If there's a cert we can't parse, just leave it.  It'll get replaced
369        // with a new one if we ever try to use it.
370        LOG(WARNING) << "Error parsing cert for database upgrade for origin "
371                     << statement.ColumnString(0);
372      }
373    }
374
375    cur_version = 4;
376    meta_table_.SetVersionNumber(cur_version);
377    meta_table_.SetCompatibleVersionNumber(
378        std::min(cur_version, kCompatibleVersionNumber));
379    transaction.Commit();
380  }
381
382  // Put future migration cases here.
383
384  // When the version is too old, we just try to continue anyway, there should
385  // not be a released product that makes a database too old for us to handle.
386  LOG_IF(WARNING, cur_version < kCurrentVersionNumber)
387      << "Server bound cert database version " << cur_version
388      << " is too old to handle.";
389
390  return true;
391}
392
393void SQLiteChannelIDStore::Backend::DatabaseErrorCallback(
394    int error,
395    sql::Statement* stmt) {
396  DCHECK(background_task_runner_->RunsTasksOnCurrentThread());
397
398  if (!sql::IsErrorCatastrophic(error))
399    return;
400
401  // TODO(shess): Running KillDatabase() multiple times should be
402  // safe.
403  if (corruption_detected_)
404    return;
405
406  corruption_detected_ = true;
407
408  // TODO(shess): Consider just calling RazeAndClose() immediately.
409  // db_ may not be safe to reset at this point, but RazeAndClose()
410  // would cause the stack to unwind safely with errors.
411  background_task_runner_->PostTask(FROM_HERE,
412                                    base::Bind(&Backend::KillDatabase, this));
413}
414
415void SQLiteChannelIDStore::Backend::KillDatabase() {
416  DCHECK(background_task_runner_->RunsTasksOnCurrentThread());
417
418  if (db_) {
419    // This Backend will now be in-memory only. In a future run the database
420    // will be recreated. Hopefully things go better then!
421    bool success = db_->RazeAndClose();
422    UMA_HISTOGRAM_BOOLEAN("DomainBoundCerts.KillDatabaseResult", success);
423    meta_table_.Reset();
424    db_.reset();
425  }
426}
427
428void SQLiteChannelIDStore::Backend::AddChannelID(
429    const DefaultChannelIDStore::ChannelID& channel_id) {
430  BatchOperation(PendingOperation::CHANNEL_ID_ADD, channel_id);
431}
432
433void SQLiteChannelIDStore::Backend::DeleteChannelID(
434    const DefaultChannelIDStore::ChannelID& channel_id) {
435  BatchOperation(PendingOperation::CHANNEL_ID_DELETE, channel_id);
436}
437
438void SQLiteChannelIDStore::Backend::DeleteAllInList(
439    const std::list<std::string>& server_identifiers) {
440  if (server_identifiers.empty())
441    return;
442  // Perform deletion on background task runner.
443  background_task_runner_->PostTask(
444      FROM_HERE,
445      base::Bind(
446          &Backend::BackgroundDeleteAllInList, this, server_identifiers));
447}
448
449void SQLiteChannelIDStore::Backend::BatchOperation(
450    PendingOperation::OperationType op,
451    const DefaultChannelIDStore::ChannelID& channel_id) {
452  // Commit every 30 seconds.
453  static const int kCommitIntervalMs = 30 * 1000;
454  // Commit right away if we have more than 512 outstanding operations.
455  static const size_t kCommitAfterBatchSize = 512;
456
457  // We do a full copy of the cert here, and hopefully just here.
458  scoped_ptr<PendingOperation> po(new PendingOperation(op, channel_id));
459
460  PendingOperationsList::size_type num_pending;
461  {
462    base::AutoLock locked(lock_);
463    pending_.push_back(po.release());
464    num_pending = ++num_pending_;
465  }
466
467  if (num_pending == 1) {
468    // We've gotten our first entry for this batch, fire off the timer.
469    background_task_runner_->PostDelayedTask(
470        FROM_HERE,
471        base::Bind(&Backend::Commit, this),
472        base::TimeDelta::FromMilliseconds(kCommitIntervalMs));
473  } else if (num_pending == kCommitAfterBatchSize) {
474    // We've reached a big enough batch, fire off a commit now.
475    background_task_runner_->PostTask(FROM_HERE,
476                                      base::Bind(&Backend::Commit, this));
477  }
478}
479
480void SQLiteChannelIDStore::Backend::Commit() {
481  DCHECK(background_task_runner_->RunsTasksOnCurrentThread());
482
483  PendingOperationsList ops;
484  {
485    base::AutoLock locked(lock_);
486    pending_.swap(ops);
487    num_pending_ = 0;
488  }
489
490  // Maybe an old timer fired or we are already Close()'ed.
491  if (!db_.get() || ops.empty())
492    return;
493
494  sql::Statement add_statement(db_->GetCachedStatement(
495      SQL_FROM_HERE,
496      "INSERT INTO origin_bound_certs (origin, private_key, cert, cert_type, "
497      "expiration_time, creation_time) VALUES (?,?,?,?,?,?)"));
498  if (!add_statement.is_valid())
499    return;
500
501  sql::Statement del_statement(db_->GetCachedStatement(
502      SQL_FROM_HERE, "DELETE FROM origin_bound_certs WHERE origin=?"));
503  if (!del_statement.is_valid())
504    return;
505
506  sql::Transaction transaction(db_.get());
507  if (!transaction.Begin())
508    return;
509
510  for (PendingOperationsList::iterator it = ops.begin(); it != ops.end();
511       ++it) {
512    // Free the certs as we commit them to the database.
513    scoped_ptr<PendingOperation> po(*it);
514    switch (po->op()) {
515      case PendingOperation::CHANNEL_ID_ADD: {
516        add_statement.Reset(true);
517        add_statement.BindString(0, po->channel_id().server_identifier());
518        const std::string& private_key = po->channel_id().private_key();
519        add_statement.BindBlob(
520            1, private_key.data(), static_cast<int>(private_key.size()));
521        const std::string& cert = po->channel_id().cert();
522        add_statement.BindBlob(2, cert.data(), static_cast<int>(cert.size()));
523        add_statement.BindInt(3, CLIENT_CERT_ECDSA_SIGN);
524        add_statement.BindInt64(
525            4, po->channel_id().expiration_time().ToInternalValue());
526        add_statement.BindInt64(
527            5, po->channel_id().creation_time().ToInternalValue());
528        if (!add_statement.Run())
529          NOTREACHED() << "Could not add a server bound cert to the DB.";
530        break;
531      }
532      case PendingOperation::CHANNEL_ID_DELETE:
533        del_statement.Reset(true);
534        del_statement.BindString(0, po->channel_id().server_identifier());
535        if (!del_statement.Run())
536          NOTREACHED() << "Could not delete a server bound cert from the DB.";
537        break;
538
539      default:
540        NOTREACHED();
541        break;
542    }
543  }
544  transaction.Commit();
545}
546
547// Fire off a close message to the background task runner. We could still have a
548// pending commit timer that will be holding a reference on us, but if/when
549// this fires we will already have been cleaned up and it will be ignored.
550void SQLiteChannelIDStore::Backend::Close() {
551  // Must close the backend on the background task runner.
552  background_task_runner_->PostTask(
553      FROM_HERE, base::Bind(&Backend::InternalBackgroundClose, this));
554}
555
556void SQLiteChannelIDStore::Backend::InternalBackgroundClose() {
557  DCHECK(background_task_runner_->RunsTasksOnCurrentThread());
558  // Commit any pending operations
559  Commit();
560  db_.reset();
561}
562
563void SQLiteChannelIDStore::Backend::BackgroundDeleteAllInList(
564    const std::list<std::string>& server_identifiers) {
565  DCHECK(background_task_runner_->RunsTasksOnCurrentThread());
566
567  if (!db_.get())
568    return;
569
570  sql::Statement del_smt(db_->GetCachedStatement(
571      SQL_FROM_HERE, "DELETE FROM origin_bound_certs WHERE origin=?"));
572  if (!del_smt.is_valid()) {
573    LOG(WARNING) << "Unable to delete channel ids.";
574    return;
575  }
576
577  sql::Transaction transaction(db_.get());
578  if (!transaction.Begin()) {
579    LOG(WARNING) << "Unable to delete channel ids.";
580    return;
581  }
582
583  for (std::list<std::string>::const_iterator it = server_identifiers.begin();
584       it != server_identifiers.end();
585       ++it) {
586    del_smt.Reset(true);
587    del_smt.BindString(0, *it);
588    if (!del_smt.Run())
589      NOTREACHED() << "Could not delete a channel id from the DB.";
590  }
591
592  if (!transaction.Commit())
593    LOG(WARNING) << "Unable to delete channel ids.";
594}
595
596void SQLiteChannelIDStore::Backend::SetForceKeepSessionState() {
597  base::AutoLock locked(lock_);
598  force_keep_session_state_ = true;
599}
600
601SQLiteChannelIDStore::SQLiteChannelIDStore(
602    const base::FilePath& path,
603    const scoped_refptr<base::SequencedTaskRunner>& background_task_runner)
604    : backend_(new Backend(path, background_task_runner)) {
605}
606
607void SQLiteChannelIDStore::Load(const LoadedCallback& loaded_callback) {
608  backend_->Load(loaded_callback);
609}
610
611void SQLiteChannelIDStore::AddChannelID(
612    const DefaultChannelIDStore::ChannelID& channel_id) {
613  backend_->AddChannelID(channel_id);
614}
615
616void SQLiteChannelIDStore::DeleteChannelID(
617    const DefaultChannelIDStore::ChannelID& channel_id) {
618  backend_->DeleteChannelID(channel_id);
619}
620
621void SQLiteChannelIDStore::DeleteAllInList(
622    const std::list<std::string>& server_identifiers) {
623  backend_->DeleteAllInList(server_identifiers);
624}
625
626void SQLiteChannelIDStore::SetForceKeepSessionState() {
627  backend_->SetForceKeepSessionState();
628}
629
630SQLiteChannelIDStore::~SQLiteChannelIDStore() {
631  backend_->Close();
632  // We release our reference to the Backend, though it will probably still have
633  // a reference if the background task runner has not run Close() yet.
634}
635
636}  // namespace net
637