1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/ssl/channel_id_service.h"
6
7#include <algorithm>
8#include <limits>
9
10#include "base/bind.h"
11#include "base/bind_helpers.h"
12#include "base/callback_helpers.h"
13#include "base/compiler_specific.h"
14#include "base/location.h"
15#include "base/logging.h"
16#include "base/memory/ref_counted.h"
17#include "base/memory/scoped_ptr.h"
18#include "base/message_loop/message_loop_proxy.h"
19#include "base/metrics/histogram.h"
20#include "base/rand_util.h"
21#include "base/stl_util.h"
22#include "base/task_runner.h"
23#include "crypto/ec_private_key.h"
24#include "net/base/net_errors.h"
25#include "net/base/registry_controlled_domains/registry_controlled_domain.h"
26#include "net/cert/x509_certificate.h"
27#include "net/cert/x509_util.h"
28#include "url/gurl.h"
29
30#if defined(USE_NSS)
31#include <private/pprthred.h>  // PR_DetachThread
32#endif
33
34namespace net {
35
36namespace {
37
38const int kValidityPeriodInDays = 365;
39// When we check the system time, we add this many days to the end of the check
40// so the result will still hold even after chrome has been running for a
41// while.
42const int kSystemTimeValidityBufferInDays = 90;
43
44// Used by the GetDomainBoundCertResult histogram to record the final
45// outcome of each GetChannelID or GetOrCreateChannelID call.
46// Do not re-use values.
47enum GetChannelIDResult {
48  // Synchronously found and returned an existing domain bound cert.
49  SYNC_SUCCESS = 0,
50  // Retrieved or generated and returned a domain bound cert asynchronously.
51  ASYNC_SUCCESS = 1,
52  // Retrieval/generation request was cancelled before the cert generation
53  // completed.
54  ASYNC_CANCELLED = 2,
55  // Cert generation failed.
56  ASYNC_FAILURE_KEYGEN = 3,
57  ASYNC_FAILURE_CREATE_CERT = 4,
58  ASYNC_FAILURE_EXPORT_KEY = 5,
59  ASYNC_FAILURE_UNKNOWN = 6,
60  // GetChannelID or GetOrCreateChannelID was called with
61  // invalid arguments.
62  INVALID_ARGUMENT = 7,
63  // We don't support any of the cert types the server requested.
64  UNSUPPORTED_TYPE = 8,
65  // Server asked for a different type of certs while we were generating one.
66  TYPE_MISMATCH = 9,
67  // Couldn't start a worker to generate a cert.
68  WORKER_FAILURE = 10,
69  GET_CHANNEL_ID_RESULT_MAX
70};
71
72void RecordGetChannelIDResult(GetChannelIDResult result) {
73  UMA_HISTOGRAM_ENUMERATION("DomainBoundCerts.GetDomainBoundCertResult", result,
74                            GET_CHANNEL_ID_RESULT_MAX);
75}
76
77void RecordGetChannelIDTime(base::TimeDelta request_time) {
78  UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.GetCertTime",
79                             request_time,
80                             base::TimeDelta::FromMilliseconds(1),
81                             base::TimeDelta::FromMinutes(5),
82                             50);
83}
84
85// On success, returns a ChannelID object and sets |*error| to OK.
86// Otherwise, returns NULL, and |*error| will be set to a net error code.
87// |serial_number| is passed in because base::RandInt cannot be called from an
88// unjoined thread, due to relying on a non-leaked LazyInstance
89scoped_ptr<ChannelIDStore::ChannelID> GenerateChannelID(
90    const std::string& server_identifier,
91    uint32 serial_number,
92    int* error) {
93  scoped_ptr<ChannelIDStore::ChannelID> result;
94
95  base::TimeTicks start = base::TimeTicks::Now();
96  base::Time not_valid_before = base::Time::Now();
97  base::Time not_valid_after =
98      not_valid_before + base::TimeDelta::FromDays(kValidityPeriodInDays);
99  std::string der_cert;
100  std::vector<uint8> private_key_info;
101  scoped_ptr<crypto::ECPrivateKey> key;
102  if (!x509_util::CreateKeyAndChannelIDEC(server_identifier,
103                                          serial_number,
104                                          not_valid_before,
105                                          not_valid_after,
106                                          &key,
107                                          &der_cert)) {
108    DLOG(ERROR) << "Unable to create x509 cert for client";
109    *error = ERR_ORIGIN_BOUND_CERT_GENERATION_FAILED;
110    return result.Pass();
111  }
112
113  if (!key->ExportEncryptedPrivateKey(ChannelIDService::kEPKIPassword,
114                                      1, &private_key_info)) {
115    DLOG(ERROR) << "Unable to export private key";
116    *error = ERR_PRIVATE_KEY_EXPORT_FAILED;
117    return result.Pass();
118  }
119
120  // TODO(rkn): Perhaps ExportPrivateKey should be changed to output a
121  // std::string* to prevent this copying.
122  std::string key_out(private_key_info.begin(), private_key_info.end());
123
124  result.reset(new ChannelIDStore::ChannelID(
125      server_identifier,
126      not_valid_before,
127      not_valid_after,
128      key_out,
129      der_cert));
130  UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.GenerateCertTime",
131                             base::TimeTicks::Now() - start,
132                             base::TimeDelta::FromMilliseconds(1),
133                             base::TimeDelta::FromMinutes(5),
134                             50);
135  *error = OK;
136  return result.Pass();
137}
138
139}  // namespace
140
141// Represents the output and result callback of a request.
142class ChannelIDServiceRequest {
143 public:
144  ChannelIDServiceRequest(base::TimeTicks request_start,
145                          const CompletionCallback& callback,
146                          std::string* private_key,
147                          std::string* cert)
148      : request_start_(request_start),
149        callback_(callback),
150        private_key_(private_key),
151        cert_(cert) {
152  }
153
154  // Ensures that the result callback will never be made.
155  void Cancel() {
156    RecordGetChannelIDResult(ASYNC_CANCELLED);
157    callback_.Reset();
158    private_key_ = NULL;
159    cert_ = NULL;
160  }
161
162  // Copies the contents of |private_key| and |cert| to the caller's output
163  // arguments and calls the callback.
164  void Post(int error,
165            const std::string& private_key,
166            const std::string& cert) {
167    switch (error) {
168      case OK: {
169        base::TimeDelta request_time = base::TimeTicks::Now() - request_start_;
170        UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.GetCertTimeAsync",
171                                   request_time,
172                                   base::TimeDelta::FromMilliseconds(1),
173                                   base::TimeDelta::FromMinutes(5),
174                                   50);
175        RecordGetChannelIDTime(request_time);
176        RecordGetChannelIDResult(ASYNC_SUCCESS);
177        break;
178      }
179      case ERR_KEY_GENERATION_FAILED:
180        RecordGetChannelIDResult(ASYNC_FAILURE_KEYGEN);
181        break;
182      case ERR_ORIGIN_BOUND_CERT_GENERATION_FAILED:
183        RecordGetChannelIDResult(ASYNC_FAILURE_CREATE_CERT);
184        break;
185      case ERR_PRIVATE_KEY_EXPORT_FAILED:
186        RecordGetChannelIDResult(ASYNC_FAILURE_EXPORT_KEY);
187        break;
188      case ERR_INSUFFICIENT_RESOURCES:
189        RecordGetChannelIDResult(WORKER_FAILURE);
190        break;
191      default:
192        RecordGetChannelIDResult(ASYNC_FAILURE_UNKNOWN);
193        break;
194    }
195    if (!callback_.is_null()) {
196      *private_key_ = private_key;
197      *cert_ = cert;
198      callback_.Run(error);
199    }
200    delete this;
201  }
202
203  bool canceled() const { return callback_.is_null(); }
204
205 private:
206  base::TimeTicks request_start_;
207  CompletionCallback callback_;
208  std::string* private_key_;
209  std::string* cert_;
210};
211
212// ChannelIDServiceWorker runs on a worker thread and takes care of the
213// blocking process of performing key generation. Will take care of deleting
214// itself once Start() is called.
215class ChannelIDServiceWorker {
216 public:
217  typedef base::Callback<void(
218      const std::string&,
219      int,
220      scoped_ptr<ChannelIDStore::ChannelID>)> WorkerDoneCallback;
221
222  ChannelIDServiceWorker(
223      const std::string& server_identifier,
224      const WorkerDoneCallback& callback)
225      : server_identifier_(server_identifier),
226        serial_number_(base::RandInt(0, std::numeric_limits<int>::max())),
227        origin_loop_(base::MessageLoopProxy::current()),
228        callback_(callback) {
229  }
230
231  // Starts the worker on |task_runner|. If the worker fails to start, such as
232  // if the task runner is shutting down, then it will take care of deleting
233  // itself.
234  bool Start(const scoped_refptr<base::TaskRunner>& task_runner) {
235    DCHECK(origin_loop_->RunsTasksOnCurrentThread());
236
237    return task_runner->PostTask(
238        FROM_HERE,
239        base::Bind(&ChannelIDServiceWorker::Run, base::Owned(this)));
240  }
241
242 private:
243  void Run() {
244    // Runs on a worker thread.
245    int error = ERR_FAILED;
246    scoped_ptr<ChannelIDStore::ChannelID> cert =
247        GenerateChannelID(server_identifier_, serial_number_, &error);
248    DVLOG(1) << "GenerateCert " << server_identifier_ << " returned " << error;
249#if defined(USE_NSS)
250    // Detach the thread from NSPR.
251    // Calling NSS functions attaches the thread to NSPR, which stores
252    // the NSPR thread ID in thread-specific data.
253    // The threads in our thread pool terminate after we have called
254    // PR_Cleanup. Unless we detach them from NSPR, net_unittests gets
255    // segfaults on shutdown when the threads' thread-specific data
256    // destructors run.
257    PR_DetachThread();
258#endif
259    origin_loop_->PostTask(FROM_HERE,
260                           base::Bind(callback_, server_identifier_, error,
261                                      base::Passed(&cert)));
262  }
263
264  const std::string server_identifier_;
265  // Note that serial_number_ must be initialized on a non-worker thread
266  // (see documentation for GenerateCert).
267  uint32 serial_number_;
268  scoped_refptr<base::SequencedTaskRunner> origin_loop_;
269  WorkerDoneCallback callback_;
270
271  DISALLOW_COPY_AND_ASSIGN(ChannelIDServiceWorker);
272};
273
274// A ChannelIDServiceJob is a one-to-one counterpart of an
275// ChannelIDServiceWorker. It lives only on the ChannelIDService's
276// origin message loop.
277class ChannelIDServiceJob {
278 public:
279  ChannelIDServiceJob(bool create_if_missing)
280      : create_if_missing_(create_if_missing) {
281  }
282
283  ~ChannelIDServiceJob() {
284    if (!requests_.empty())
285      DeleteAllCanceled();
286  }
287
288  void AddRequest(ChannelIDServiceRequest* request,
289                  bool create_if_missing = false) {
290    create_if_missing_ |= create_if_missing;
291    requests_.push_back(request);
292  }
293
294  void HandleResult(int error,
295                    const std::string& private_key,
296                    const std::string& cert) {
297    PostAll(error, private_key, cert);
298  }
299
300  bool CreateIfMissing() const { return create_if_missing_; }
301
302 private:
303  void PostAll(int error,
304               const std::string& private_key,
305               const std::string& cert) {
306    std::vector<ChannelIDServiceRequest*> requests;
307    requests_.swap(requests);
308
309    for (std::vector<ChannelIDServiceRequest*>::iterator
310         i = requests.begin(); i != requests.end(); i++) {
311      (*i)->Post(error, private_key, cert);
312      // Post() causes the ChannelIDServiceRequest to delete itself.
313    }
314  }
315
316  void DeleteAllCanceled() {
317    for (std::vector<ChannelIDServiceRequest*>::iterator
318         i = requests_.begin(); i != requests_.end(); i++) {
319      if ((*i)->canceled()) {
320        delete *i;
321      } else {
322        LOG(DFATAL) << "ChannelIDServiceRequest leaked!";
323      }
324    }
325  }
326
327  std::vector<ChannelIDServiceRequest*> requests_;
328  bool create_if_missing_;
329};
330
331// static
332const char ChannelIDService::kEPKIPassword[] = "";
333
334ChannelIDService::RequestHandle::RequestHandle()
335    : service_(NULL),
336      request_(NULL) {}
337
338ChannelIDService::RequestHandle::~RequestHandle() {
339  Cancel();
340}
341
342void ChannelIDService::RequestHandle::Cancel() {
343  if (request_) {
344    service_->CancelRequest(request_);
345    request_ = NULL;
346    callback_.Reset();
347  }
348}
349
350void ChannelIDService::RequestHandle::RequestStarted(
351    ChannelIDService* service,
352    ChannelIDServiceRequest* request,
353    const CompletionCallback& callback) {
354  DCHECK(request_ == NULL);
355  service_ = service;
356  request_ = request;
357  callback_ = callback;
358}
359
360void ChannelIDService::RequestHandle::OnRequestComplete(int result) {
361  request_ = NULL;
362  // Running the callback might delete |this|, so we can't touch any of our
363  // members afterwards. Reset callback_ first.
364  base::ResetAndReturn(&callback_).Run(result);
365}
366
367ChannelIDService::ChannelIDService(
368    ChannelIDStore* channel_id_store,
369    const scoped_refptr<base::TaskRunner>& task_runner)
370    : channel_id_store_(channel_id_store),
371      task_runner_(task_runner),
372      requests_(0),
373      cert_store_hits_(0),
374      inflight_joins_(0),
375      workers_created_(0),
376      weak_ptr_factory_(this) {
377  base::Time start = base::Time::Now();
378  base::Time end = start + base::TimeDelta::FromDays(
379      kValidityPeriodInDays + kSystemTimeValidityBufferInDays);
380  is_system_time_valid_ = x509_util::IsSupportedValidityRange(start, end);
381}
382
383ChannelIDService::~ChannelIDService() {
384  STLDeleteValues(&inflight_);
385}
386
387//static
388std::string ChannelIDService::GetDomainForHost(const std::string& host) {
389  std::string domain =
390      registry_controlled_domains::GetDomainAndRegistry(
391          host, registry_controlled_domains::INCLUDE_PRIVATE_REGISTRIES);
392  if (domain.empty())
393    return host;
394  return domain;
395}
396
397int ChannelIDService::GetOrCreateChannelID(
398    const std::string& host,
399    std::string* private_key,
400    std::string* cert,
401    const CompletionCallback& callback,
402    RequestHandle* out_req) {
403  DVLOG(1) << __FUNCTION__ << " " << host;
404  DCHECK(CalledOnValidThread());
405  base::TimeTicks request_start = base::TimeTicks::Now();
406
407  if (callback.is_null() || !private_key || !cert || host.empty()) {
408    RecordGetChannelIDResult(INVALID_ARGUMENT);
409    return ERR_INVALID_ARGUMENT;
410  }
411
412  std::string domain = GetDomainForHost(host);
413  if (domain.empty()) {
414    RecordGetChannelIDResult(INVALID_ARGUMENT);
415    return ERR_INVALID_ARGUMENT;
416  }
417
418  requests_++;
419
420  // See if a request for the same domain is currently in flight.
421  bool create_if_missing = true;
422  if (JoinToInFlightRequest(request_start, domain, private_key, cert,
423                            create_if_missing, callback, out_req)) {
424    return ERR_IO_PENDING;
425  }
426
427  int err = LookupChannelID(request_start, domain, private_key, cert,
428                                  create_if_missing, callback, out_req);
429  if (err == ERR_FILE_NOT_FOUND) {
430    // Sync lookup did not find a valid cert.  Start generating a new one.
431    workers_created_++;
432    ChannelIDServiceWorker* worker = new ChannelIDServiceWorker(
433        domain,
434        base::Bind(&ChannelIDService::GeneratedChannelID,
435                   weak_ptr_factory_.GetWeakPtr()));
436    if (!worker->Start(task_runner_)) {
437      // TODO(rkn): Log to the NetLog.
438      LOG(ERROR) << "ChannelIDServiceWorker couldn't be started.";
439      RecordGetChannelIDResult(WORKER_FAILURE);
440      return ERR_INSUFFICIENT_RESOURCES;
441    }
442    // We are waiting for cert generation.  Create a job & request to track it.
443    ChannelIDServiceJob* job = new ChannelIDServiceJob(create_if_missing);
444    inflight_[domain] = job;
445
446    ChannelIDServiceRequest* request = new ChannelIDServiceRequest(
447        request_start,
448        base::Bind(&RequestHandle::OnRequestComplete,
449                   base::Unretained(out_req)),
450        private_key,
451        cert);
452    job->AddRequest(request);
453    out_req->RequestStarted(this, request, callback);
454    return ERR_IO_PENDING;
455  }
456
457  return err;
458}
459
460int ChannelIDService::GetChannelID(
461    const std::string& host,
462    std::string* private_key,
463    std::string* cert,
464    const CompletionCallback& callback,
465    RequestHandle* out_req) {
466  DVLOG(1) << __FUNCTION__ << " " << host;
467  DCHECK(CalledOnValidThread());
468  base::TimeTicks request_start = base::TimeTicks::Now();
469
470  if (callback.is_null() || !private_key || !cert || host.empty()) {
471    RecordGetChannelIDResult(INVALID_ARGUMENT);
472    return ERR_INVALID_ARGUMENT;
473  }
474
475  std::string domain = GetDomainForHost(host);
476  if (domain.empty()) {
477    RecordGetChannelIDResult(INVALID_ARGUMENT);
478    return ERR_INVALID_ARGUMENT;
479  }
480
481  requests_++;
482
483  // See if a request for the same domain currently in flight.
484  bool create_if_missing = false;
485  if (JoinToInFlightRequest(request_start, domain, private_key, cert,
486                            create_if_missing, callback, out_req)) {
487    return ERR_IO_PENDING;
488  }
489
490  int err = LookupChannelID(request_start, domain, private_key, cert,
491                            create_if_missing, callback, out_req);
492  return err;
493}
494
495void ChannelIDService::GotChannelID(
496    int err,
497    const std::string& server_identifier,
498    base::Time expiration_time,
499    const std::string& key,
500    const std::string& cert) {
501  DCHECK(CalledOnValidThread());
502
503  std::map<std::string, ChannelIDServiceJob*>::iterator j;
504  j = inflight_.find(server_identifier);
505  if (j == inflight_.end()) {
506    NOTREACHED();
507    return;
508  }
509
510  if (err == OK) {
511    // Async DB lookup found a valid cert.
512    DVLOG(1) << "Cert store had valid cert for " << server_identifier;
513    cert_store_hits_++;
514    // ChannelIDServiceRequest::Post will do the histograms and stuff.
515    HandleResult(OK, server_identifier, key, cert);
516    return;
517  }
518  // Async lookup failed or the certificate was missing. Return the error
519  // directly, unless the certificate was missing and a request asked to create
520  // one.
521  if (err != ERR_FILE_NOT_FOUND || !j->second->CreateIfMissing()) {
522    HandleResult(err, server_identifier, key, cert);
523    return;
524  }
525  // At least one request asked to create a cert => start generating a new one.
526  workers_created_++;
527  ChannelIDServiceWorker* worker = new ChannelIDServiceWorker(
528      server_identifier,
529      base::Bind(&ChannelIDService::GeneratedChannelID,
530                 weak_ptr_factory_.GetWeakPtr()));
531  if (!worker->Start(task_runner_)) {
532    // TODO(rkn): Log to the NetLog.
533    LOG(ERROR) << "ChannelIDServiceWorker couldn't be started.";
534    HandleResult(ERR_INSUFFICIENT_RESOURCES,
535                 server_identifier,
536                 std::string(),
537                 std::string());
538  }
539}
540
541ChannelIDStore* ChannelIDService::GetChannelIDStore() {
542  return channel_id_store_.get();
543}
544
545void ChannelIDService::CancelRequest(ChannelIDServiceRequest* req) {
546  DCHECK(CalledOnValidThread());
547  req->Cancel();
548}
549
550void ChannelIDService::GeneratedChannelID(
551    const std::string& server_identifier,
552    int error,
553    scoped_ptr<ChannelIDStore::ChannelID> cert) {
554  DCHECK(CalledOnValidThread());
555
556  if (error == OK) {
557    // TODO(mattm): we should just Pass() the cert object to
558    // SetChannelID().
559    channel_id_store_->SetChannelID(
560        cert->server_identifier(),
561        cert->creation_time(),
562        cert->expiration_time(),
563        cert->private_key(),
564        cert->cert());
565
566    HandleResult(error, server_identifier, cert->private_key(), cert->cert());
567  } else {
568    HandleResult(error, server_identifier, std::string(), std::string());
569  }
570}
571
572void ChannelIDService::HandleResult(
573    int error,
574    const std::string& server_identifier,
575    const std::string& private_key,
576    const std::string& cert) {
577  DCHECK(CalledOnValidThread());
578
579  std::map<std::string, ChannelIDServiceJob*>::iterator j;
580  j = inflight_.find(server_identifier);
581  if (j == inflight_.end()) {
582    NOTREACHED();
583    return;
584  }
585  ChannelIDServiceJob* job = j->second;
586  inflight_.erase(j);
587
588  job->HandleResult(error, private_key, cert);
589  delete job;
590}
591
592bool ChannelIDService::JoinToInFlightRequest(
593    const base::TimeTicks& request_start,
594    const std::string& domain,
595    std::string* private_key,
596    std::string* cert,
597    bool create_if_missing,
598    const CompletionCallback& callback,
599    RequestHandle* out_req) {
600  ChannelIDServiceJob* job = NULL;
601  std::map<std::string, ChannelIDServiceJob*>::const_iterator j =
602      inflight_.find(domain);
603  if (j != inflight_.end()) {
604    // A request for the same domain is in flight already. We'll attach our
605    // callback, but we'll also mark it as requiring a cert if one's mising.
606    job = j->second;
607    inflight_joins_++;
608
609    ChannelIDServiceRequest* request = new ChannelIDServiceRequest(
610        request_start,
611        base::Bind(&RequestHandle::OnRequestComplete,
612                   base::Unretained(out_req)),
613        private_key,
614        cert);
615    job->AddRequest(request, create_if_missing);
616    out_req->RequestStarted(this, request, callback);
617    return true;
618  }
619  return false;
620}
621
622int ChannelIDService::LookupChannelID(
623    const base::TimeTicks& request_start,
624    const std::string& domain,
625    std::string* private_key,
626    std::string* cert,
627    bool create_if_missing,
628    const CompletionCallback& callback,
629    RequestHandle* out_req) {
630  // Check if a domain bound cert already exists for this domain. Note that
631  // |expiration_time| is ignored, and expired certs are considered valid.
632  base::Time expiration_time;
633  int err = channel_id_store_->GetChannelID(
634      domain,
635      &expiration_time  /* ignored */,
636      private_key,
637      cert,
638      base::Bind(&ChannelIDService::GotChannelID,
639                 weak_ptr_factory_.GetWeakPtr()));
640
641  if (err == OK) {
642    // Sync lookup found a valid cert.
643    DVLOG(1) << "Cert store had valid cert for " << domain;
644    cert_store_hits_++;
645    RecordGetChannelIDResult(SYNC_SUCCESS);
646    base::TimeDelta request_time = base::TimeTicks::Now() - request_start;
647    UMA_HISTOGRAM_TIMES("DomainBoundCerts.GetCertTimeSync", request_time);
648    RecordGetChannelIDTime(request_time);
649    return OK;
650  }
651
652  if (err == ERR_IO_PENDING) {
653    // We are waiting for async DB lookup.  Create a job & request to track it.
654    ChannelIDServiceJob* job = new ChannelIDServiceJob(create_if_missing);
655    inflight_[domain] = job;
656
657    ChannelIDServiceRequest* request = new ChannelIDServiceRequest(
658        request_start,
659        base::Bind(&RequestHandle::OnRequestComplete,
660                   base::Unretained(out_req)),
661        private_key,
662        cert);
663    job->AddRequest(request);
664    out_req->RequestStarted(this, request, callback);
665    return ERR_IO_PENDING;
666  }
667
668  return err;
669}
670
671int ChannelIDService::cert_count() {
672  return channel_id_store_->GetChannelIDCount();
673}
674
675}  // namespace net
676