1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// Most of this code is copied from:
6//   src/chrome/browser/policy/asynchronous_policy_loader.{h,cc}
7
8#include "remoting/host/policy_hack/policy_watcher.h"
9
10#include "base/bind.h"
11#include "base/compiler_specific.h"
12#include "base/location.h"
13#include "base/memory/weak_ptr.h"
14#include "base/single_thread_task_runner.h"
15#include "base/synchronization/waitable_event.h"
16#include "base/time/time.h"
17#include "base/values.h"
18#include "remoting/host/dns_blackhole_checker.h"
19
20#if !defined(NDEBUG)
21#include "base/json/json_reader.h"
22#endif
23
24namespace remoting {
25namespace policy_hack {
26
27namespace {
28
29// The time interval for rechecking policy. This is our fallback in case the
30// delegate never reports a change to the ReloadObserver.
31const int kFallbackReloadDelayMinutes = 15;
32
33// Copies all policy values from one dictionary to another, using values from
34// |default| if they are not set in |from|, or values from |bad_type_values| if
35// the value in |from| has the wrong type.
36scoped_ptr<base::DictionaryValue> CopyGoodValuesAndAddDefaults(
37    const base::DictionaryValue* from,
38    const base::DictionaryValue* default_values,
39    const base::DictionaryValue* bad_type_values) {
40  scoped_ptr<base::DictionaryValue> to(default_values->DeepCopy());
41  for (base::DictionaryValue::Iterator i(*default_values);
42       !i.IsAtEnd(); i.Advance()) {
43
44    const base::Value* value = NULL;
45
46    // If the policy isn't in |from|, use the default.
47    if (!from->Get(i.key(), &value)) {
48      continue;
49    }
50
51    // If the policy is the wrong type, use the value from |bad_type_values|.
52    if (!value->IsType(i.value().GetType())) {
53      CHECK(bad_type_values->Get(i.key(), &value));
54    }
55
56    to->Set(i.key(), value->DeepCopy());
57  }
58
59#if !defined(NDEBUG)
60  // Replace values with those specified in DebugOverridePolicies, if present.
61  std::string policy_overrides;
62  if (from->GetString(PolicyWatcher::kHostDebugOverridePoliciesName,
63                      &policy_overrides)) {
64    scoped_ptr<base::Value> value(base::JSONReader::Read(policy_overrides));
65    const base::DictionaryValue* override_values;
66    if (value && value->GetAsDictionary(&override_values)) {
67      to->MergeDictionary(override_values);
68    }
69  }
70#endif  // defined(NDEBUG)
71
72  return to.Pass();
73}
74
75}  // namespace
76
77const char PolicyWatcher::kNatPolicyName[] =
78    "RemoteAccessHostFirewallTraversal";
79
80const char PolicyWatcher::kHostRequireTwoFactorPolicyName[] =
81    "RemoteAccessHostRequireTwoFactor";
82
83const char PolicyWatcher::kHostDomainPolicyName[] =
84    "RemoteAccessHostDomain";
85
86const char PolicyWatcher::kHostMatchUsernamePolicyName[] =
87    "RemoteAccessHostMatchUsername";
88
89const char PolicyWatcher::kHostTalkGadgetPrefixPolicyName[] =
90    "RemoteAccessHostTalkGadgetPrefix";
91
92const char PolicyWatcher::kHostRequireCurtainPolicyName[] =
93    "RemoteAccessHostRequireCurtain";
94
95const char PolicyWatcher::kHostTokenUrlPolicyName[] =
96    "RemoteAccessHostTokenUrl";
97
98const char PolicyWatcher::kHostTokenValidationUrlPolicyName[] =
99    "RemoteAccessHostTokenValidationUrl";
100
101const char PolicyWatcher::kHostTokenValidationCertIssuerPolicyName[] =
102    "RemoteAccessHostTokenValidationCertificateIssuer";
103
104const char PolicyWatcher::kHostAllowClientPairing[] =
105    "RemoteAccessHostAllowClientPairing";
106
107const char PolicyWatcher::kHostAllowGnubbyAuthPolicyName[] =
108    "RemoteAccessHostAllowGnubbyAuth";
109
110const char PolicyWatcher::kRelayPolicyName[] =
111    "RemoteAccessHostAllowRelayedConnection";
112
113const char PolicyWatcher::kUdpPortRangePolicyName[] =
114    "RemoteAccessHostUdpPortRange";
115
116const char PolicyWatcher::kHostDebugOverridePoliciesName[] =
117    "RemoteAccessHostDebugOverridePolicies";
118
119PolicyWatcher::PolicyWatcher(
120    scoped_refptr<base::SingleThreadTaskRunner> task_runner)
121    : task_runner_(task_runner),
122      old_policies_(new base::DictionaryValue()),
123      default_values_(new base::DictionaryValue()),
124      weak_factory_(this) {
125  // Initialize the default values for each policy.
126  default_values_->SetBoolean(kNatPolicyName, true);
127  default_values_->SetBoolean(kHostRequireTwoFactorPolicyName, false);
128  default_values_->SetBoolean(kHostRequireCurtainPolicyName, false);
129  default_values_->SetBoolean(kHostMatchUsernamePolicyName, false);
130  default_values_->SetString(kHostDomainPolicyName, std::string());
131  default_values_->SetString(kHostTalkGadgetPrefixPolicyName,
132                               kDefaultHostTalkGadgetPrefix);
133  default_values_->SetString(kHostTokenUrlPolicyName, std::string());
134  default_values_->SetString(kHostTokenValidationUrlPolicyName, std::string());
135  default_values_->SetString(kHostTokenValidationCertIssuerPolicyName,
136                             std::string());
137  default_values_->SetBoolean(kHostAllowClientPairing, true);
138  default_values_->SetBoolean(kHostAllowGnubbyAuthPolicyName, true);
139  default_values_->SetBoolean(kRelayPolicyName, true);
140  default_values_->SetString(kUdpPortRangePolicyName, "");
141#if !defined(NDEBUG)
142  default_values_->SetString(kHostDebugOverridePoliciesName, std::string());
143#endif
144
145  // Initialize the fall-back values to use for unreadable policies.
146  // For most policies these match the defaults.
147  bad_type_values_.reset(default_values_->DeepCopy());
148  bad_type_values_->SetBoolean(kNatPolicyName, false);
149  bad_type_values_->SetBoolean(kRelayPolicyName, false);
150}
151
152PolicyWatcher::~PolicyWatcher() {
153}
154
155void PolicyWatcher::StartWatching(const PolicyCallback& policy_callback) {
156  if (!OnPolicyWatcherThread()) {
157    task_runner_->PostTask(FROM_HERE,
158                           base::Bind(&PolicyWatcher::StartWatching,
159                                      base::Unretained(this),
160                                      policy_callback));
161    return;
162  }
163
164  policy_callback_ = policy_callback;
165  StartWatchingInternal();
166}
167
168void PolicyWatcher::StopWatching(base::WaitableEvent* done) {
169  if (!OnPolicyWatcherThread()) {
170    task_runner_->PostTask(FROM_HERE,
171                           base::Bind(&PolicyWatcher::StopWatching,
172                                      base::Unretained(this), done));
173    return;
174  }
175
176  StopWatchingInternal();
177  weak_factory_.InvalidateWeakPtrs();
178  policy_callback_.Reset();
179
180  done->Signal();
181}
182
183void PolicyWatcher::ScheduleFallbackReloadTask() {
184  DCHECK(OnPolicyWatcherThread());
185  ScheduleReloadTask(
186      base::TimeDelta::FromMinutes(kFallbackReloadDelayMinutes));
187}
188
189void PolicyWatcher::ScheduleReloadTask(const base::TimeDelta& delay) {
190  DCHECK(OnPolicyWatcherThread());
191  task_runner_->PostDelayedTask(
192      FROM_HERE,
193      base::Bind(&PolicyWatcher::Reload, weak_factory_.GetWeakPtr()),
194      delay);
195}
196
197const base::DictionaryValue& PolicyWatcher::Defaults() const {
198  return *default_values_;
199}
200
201bool PolicyWatcher::OnPolicyWatcherThread() const {
202  return task_runner_->BelongsToCurrentThread();
203}
204
205void PolicyWatcher::UpdatePolicies(
206    const base::DictionaryValue* new_policies_raw) {
207  DCHECK(OnPolicyWatcherThread());
208
209  // Use default values for any missing policies.
210  scoped_ptr<base::DictionaryValue> new_policies =
211      CopyGoodValuesAndAddDefaults(
212          new_policies_raw, default_values_.get(), bad_type_values_.get());
213
214  // Find the changed policies.
215  scoped_ptr<base::DictionaryValue> changed_policies(
216      new base::DictionaryValue());
217  base::DictionaryValue::Iterator iter(*new_policies);
218  while (!iter.IsAtEnd()) {
219    base::Value* old_policy;
220    if (!(old_policies_->Get(iter.key(), &old_policy) &&
221          old_policy->Equals(&iter.value()))) {
222      changed_policies->Set(iter.key(), iter.value().DeepCopy());
223    }
224    iter.Advance();
225  }
226
227  // Save the new policies.
228  old_policies_.swap(new_policies);
229
230  // Notify our client of the changed policies.
231  if (!changed_policies->empty()) {
232    policy_callback_.Run(changed_policies.Pass());
233  }
234}
235
236}  // namespace policy_hack
237}  // namespace remoting
238