1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "net/http/transport_security_persister.h"
6
7#include "base/base64.h"
8#include "base/bind.h"
9#include "base/files/file_path.h"
10#include "base/files/file_util.h"
11#include "base/json/json_reader.h"
12#include "base/json/json_writer.h"
13#include "base/message_loop/message_loop.h"
14#include "base/message_loop/message_loop_proxy.h"
15#include "base/sequenced_task_runner.h"
16#include "base/task_runner_util.h"
17#include "base/values.h"
18#include "crypto/sha2.h"
19#include "net/cert/x509_certificate.h"
20#include "net/http/transport_security_state.h"
21
22using net::HashValue;
23using net::HashValueTag;
24using net::HashValueVector;
25using net::TransportSecurityState;
26
27namespace {
28
29base::ListValue* SPKIHashesToListValue(const HashValueVector& hashes) {
30  base::ListValue* pins = new base::ListValue;
31  for (size_t i = 0; i != hashes.size(); i++)
32    pins->Append(new base::StringValue(hashes[i].ToString()));
33  return pins;
34}
35
36void SPKIHashesFromListValue(const base::ListValue& pins,
37                             HashValueVector* hashes) {
38  size_t num_pins = pins.GetSize();
39  for (size_t i = 0; i < num_pins; ++i) {
40    std::string type_and_base64;
41    HashValue fingerprint;
42    if (pins.GetString(i, &type_and_base64) &&
43        fingerprint.FromString(type_and_base64)) {
44      hashes->push_back(fingerprint);
45    }
46  }
47}
48
49// This function converts the binary hashes to a base64 string which we can
50// include in a JSON file.
51std::string HashedDomainToExternalString(const std::string& hashed) {
52  std::string out;
53  base::Base64Encode(hashed, &out);
54  return out;
55}
56
57// This inverts |HashedDomainToExternalString|, above. It turns an external
58// string (from a JSON file) into an internal (binary) string.
59std::string ExternalStringToHashedDomain(const std::string& external) {
60  std::string out;
61  if (!base::Base64Decode(external, &out) ||
62      out.size() != crypto::kSHA256Length) {
63    return std::string();
64  }
65
66  return out;
67}
68
69const char kIncludeSubdomains[] = "include_subdomains";
70const char kStsIncludeSubdomains[] = "sts_include_subdomains";
71const char kPkpIncludeSubdomains[] = "pkp_include_subdomains";
72const char kMode[] = "mode";
73const char kExpiry[] = "expiry";
74const char kDynamicSPKIHashesExpiry[] = "dynamic_spki_hashes_expiry";
75const char kDynamicSPKIHashes[] = "dynamic_spki_hashes";
76const char kForceHTTPS[] = "force-https";
77const char kStrict[] = "strict";
78const char kDefault[] = "default";
79const char kPinningOnly[] = "pinning-only";
80const char kCreated[] = "created";
81const char kStsObserved[] = "sts_observed";
82const char kPkpObserved[] = "pkp_observed";
83
84std::string LoadState(const base::FilePath& path) {
85  std::string result;
86  if (!base::ReadFileToString(path, &result)) {
87    return "";
88  }
89  return result;
90}
91
92}  // namespace
93
94
95namespace net {
96
97TransportSecurityPersister::TransportSecurityPersister(
98    TransportSecurityState* state,
99    const base::FilePath& profile_path,
100    const scoped_refptr<base::SequencedTaskRunner>& background_runner,
101    bool readonly)
102    : transport_security_state_(state),
103      writer_(profile_path.AppendASCII("TransportSecurity"), background_runner),
104      foreground_runner_(base::MessageLoop::current()->message_loop_proxy()),
105      background_runner_(background_runner),
106      readonly_(readonly),
107      weak_ptr_factory_(this) {
108  transport_security_state_->SetDelegate(this);
109
110  base::PostTaskAndReplyWithResult(
111      background_runner_.get(),
112      FROM_HERE,
113      base::Bind(&::LoadState, writer_.path()),
114      base::Bind(&TransportSecurityPersister::CompleteLoad,
115                 weak_ptr_factory_.GetWeakPtr()));
116}
117
118TransportSecurityPersister::~TransportSecurityPersister() {
119  DCHECK(foreground_runner_->RunsTasksOnCurrentThread());
120
121  if (writer_.HasPendingWrite())
122    writer_.DoScheduledWrite();
123
124  transport_security_state_->SetDelegate(NULL);
125}
126
127void TransportSecurityPersister::StateIsDirty(
128    TransportSecurityState* state) {
129  DCHECK(foreground_runner_->RunsTasksOnCurrentThread());
130  DCHECK_EQ(transport_security_state_, state);
131
132  if (!readonly_)
133    writer_.ScheduleWrite(this);
134}
135
136bool TransportSecurityPersister::SerializeData(std::string* output) {
137  DCHECK(foreground_runner_->RunsTasksOnCurrentThread());
138
139  base::DictionaryValue toplevel;
140  base::Time now = base::Time::Now();
141  TransportSecurityState::Iterator state(*transport_security_state_);
142  for (; state.HasNext(); state.Advance()) {
143    const std::string& hostname = state.hostname();
144    const TransportSecurityState::DomainState& domain_state =
145        state.domain_state();
146
147    base::DictionaryValue* serialized = new base::DictionaryValue;
148    serialized->SetBoolean(kStsIncludeSubdomains,
149                           domain_state.sts.include_subdomains);
150    serialized->SetBoolean(kPkpIncludeSubdomains,
151                           domain_state.pkp.include_subdomains);
152    serialized->SetDouble(kStsObserved,
153                          domain_state.sts.last_observed.ToDoubleT());
154    serialized->SetDouble(kPkpObserved,
155                          domain_state.pkp.last_observed.ToDoubleT());
156    serialized->SetDouble(kExpiry, domain_state.sts.expiry.ToDoubleT());
157    serialized->SetDouble(kDynamicSPKIHashesExpiry,
158                          domain_state.pkp.expiry.ToDoubleT());
159
160    switch (domain_state.sts.upgrade_mode) {
161      case TransportSecurityState::DomainState::MODE_FORCE_HTTPS:
162        serialized->SetString(kMode, kForceHTTPS);
163        break;
164      case TransportSecurityState::DomainState::MODE_DEFAULT:
165        serialized->SetString(kMode, kDefault);
166        break;
167      default:
168        NOTREACHED() << "DomainState with unknown mode";
169        delete serialized;
170        continue;
171    }
172
173    if (now < domain_state.pkp.expiry) {
174      serialized->Set(kDynamicSPKIHashes,
175                      SPKIHashesToListValue(domain_state.pkp.spki_hashes));
176    }
177
178    toplevel.Set(HashedDomainToExternalString(hostname), serialized);
179  }
180
181  base::JSONWriter::WriteWithOptions(&toplevel,
182                                     base::JSONWriter::OPTIONS_PRETTY_PRINT,
183                                     output);
184  return true;
185}
186
187bool TransportSecurityPersister::LoadEntries(const std::string& serialized,
188                                             bool* dirty) {
189  DCHECK(foreground_runner_->RunsTasksOnCurrentThread());
190
191  transport_security_state_->ClearDynamicData();
192  return Deserialize(serialized, dirty, transport_security_state_);
193}
194
195// static
196bool TransportSecurityPersister::Deserialize(const std::string& serialized,
197                                             bool* dirty,
198                                             TransportSecurityState* state) {
199  scoped_ptr<base::Value> value(base::JSONReader::Read(serialized));
200  base::DictionaryValue* dict_value = NULL;
201  if (!value.get() || !value->GetAsDictionary(&dict_value))
202    return false;
203
204  const base::Time current_time(base::Time::Now());
205  bool dirtied = false;
206
207  for (base::DictionaryValue::Iterator i(*dict_value);
208       !i.IsAtEnd(); i.Advance()) {
209    const base::DictionaryValue* parsed = NULL;
210    if (!i.value().GetAsDictionary(&parsed)) {
211      LOG(WARNING) << "Could not parse entry " << i.key() << "; skipping entry";
212      continue;
213    }
214
215    TransportSecurityState::DomainState domain_state;
216
217    // kIncludeSubdomains is a legacy synonym for kStsIncludeSubdomains and
218    // kPkpIncludeSubdomains. Parse at least one of these properties,
219    // preferably the new ones.
220    bool include_subdomains = false;
221    bool parsed_include_subdomains = parsed->GetBoolean(kIncludeSubdomains,
222                                                        &include_subdomains);
223    domain_state.sts.include_subdomains = include_subdomains;
224    domain_state.pkp.include_subdomains = include_subdomains;
225    if (parsed->GetBoolean(kStsIncludeSubdomains, &include_subdomains)) {
226      domain_state.sts.include_subdomains = include_subdomains;
227      parsed_include_subdomains = true;
228    }
229    if (parsed->GetBoolean(kPkpIncludeSubdomains, &include_subdomains)) {
230      domain_state.pkp.include_subdomains = include_subdomains;
231      parsed_include_subdomains = true;
232    }
233
234    std::string mode_string;
235    double expiry = 0;
236    if (!parsed_include_subdomains ||
237        !parsed->GetString(kMode, &mode_string) ||
238        !parsed->GetDouble(kExpiry, &expiry)) {
239      LOG(WARNING) << "Could not parse some elements of entry " << i.key()
240                   << "; skipping entry";
241      continue;
242    }
243
244    // Don't fail if this key is not present.
245    double dynamic_spki_hashes_expiry = 0;
246    parsed->GetDouble(kDynamicSPKIHashesExpiry,
247                      &dynamic_spki_hashes_expiry);
248
249    const base::ListValue* pins_list = NULL;
250    if (parsed->GetList(kDynamicSPKIHashes, &pins_list)) {
251      SPKIHashesFromListValue(*pins_list, &domain_state.pkp.spki_hashes);
252    }
253
254    if (mode_string == kForceHTTPS || mode_string == kStrict) {
255      domain_state.sts.upgrade_mode =
256          TransportSecurityState::DomainState::MODE_FORCE_HTTPS;
257    } else if (mode_string == kDefault || mode_string == kPinningOnly) {
258      domain_state.sts.upgrade_mode =
259          TransportSecurityState::DomainState::MODE_DEFAULT;
260    } else {
261      LOG(WARNING) << "Unknown TransportSecurityState mode string "
262                   << mode_string << " found for entry " << i.key()
263                   << "; skipping entry";
264      continue;
265    }
266
267    domain_state.sts.expiry = base::Time::FromDoubleT(expiry);
268    domain_state.pkp.expiry =
269        base::Time::FromDoubleT(dynamic_spki_hashes_expiry);
270
271    double sts_observed;
272    double pkp_observed;
273    if (parsed->GetDouble(kStsObserved, &sts_observed)) {
274      domain_state.sts.last_observed = base::Time::FromDoubleT(sts_observed);
275    } else if (parsed->GetDouble(kCreated, &sts_observed)) {
276      // kCreated is a legacy synonym for both kStsObserved and kPkpObserved.
277      domain_state.sts.last_observed = base::Time::FromDoubleT(sts_observed);
278    } else {
279      // We're migrating an old entry with no observation date. Make sure we
280      // write the new date back in a reasonable time frame.
281      dirtied = true;
282      domain_state.sts.last_observed = base::Time::Now();
283    }
284    if (parsed->GetDouble(kPkpObserved, &pkp_observed)) {
285      domain_state.pkp.last_observed = base::Time::FromDoubleT(pkp_observed);
286    } else if (parsed->GetDouble(kCreated, &pkp_observed)) {
287      domain_state.pkp.last_observed = base::Time::FromDoubleT(pkp_observed);
288    } else {
289      dirtied = true;
290      domain_state.pkp.last_observed = base::Time::Now();
291    }
292
293    if (domain_state.sts.expiry <= current_time &&
294        domain_state.pkp.expiry <= current_time) {
295      // Make sure we dirty the state if we drop an entry.
296      dirtied = true;
297      continue;
298    }
299
300    std::string hashed = ExternalStringToHashedDomain(i.key());
301    if (hashed.empty()) {
302      dirtied = true;
303      continue;
304    }
305
306    state->AddOrUpdateEnabledHosts(hashed, domain_state);
307  }
308
309  *dirty = dirtied;
310  return true;
311}
312
313void TransportSecurityPersister::CompleteLoad(const std::string& state) {
314  DCHECK(foreground_runner_->RunsTasksOnCurrentThread());
315
316  if (state.empty())
317    return;
318
319  bool dirty = false;
320  if (!LoadEntries(state, &dirty)) {
321    LOG(ERROR) << "Failed to deserialize state: " << state;
322    return;
323  }
324  if (dirty)
325    StateIsDirty(transport_security_state_);
326}
327
328}  // namespace net
329