1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// Most of this code is copied from various classes in
6// src/chrome/browser/policy. In particular, look at
7//
8//   configuration_policy_provider_delegate_win.{h,cc}
9//   configuration_policy_loader_win.{h,cc}
10//
11// This is a reduction of the functionality in those classes.
12
13#include "remoting/host/policy_hack/policy_watcher.h"
14
15#include <userenv.h>
16
17#include "base/compiler_specific.h"
18#include "base/memory/scoped_ptr.h"
19#include "base/message_loop/message_loop_proxy.h"
20#include "base/strings/string16.h"
21#include "base/strings/utf_string_conversions.h"
22#include "base/synchronization/waitable_event.h"
23#include "base/values.h"
24#include "base/win/object_watcher.h"
25#include "base/win/registry.h"
26
27// userenv.dll is required for RegisterGPNotification().
28#pragma comment(lib, "userenv.lib")
29
30using base::win::RegKey;
31
32namespace remoting {
33namespace policy_hack {
34
35namespace {
36
37const wchar_t kRegistrySubKey[] = L"SOFTWARE\\Policies\\Google\\Chrome";
38
39}  // namespace
40
41class PolicyWatcherWin :
42  public PolicyWatcher,
43  public base::win::ObjectWatcher::Delegate {
44 public:
45  explicit PolicyWatcherWin(
46      scoped_refptr<base::SingleThreadTaskRunner> task_runner)
47      : PolicyWatcher(task_runner),
48        user_policy_changed_event_(false, false),
49        machine_policy_changed_event_(false, false),
50        user_policy_watcher_failed_(false),
51        machine_policy_watcher_failed_(false) {
52  }
53
54  virtual ~PolicyWatcherWin() {
55  }
56
57  virtual void StartWatchingInternal() OVERRIDE {
58    DCHECK(OnPolicyWatcherThread());
59
60    if (!RegisterGPNotification(user_policy_changed_event_.handle(), false)) {
61      PLOG(WARNING) << "Failed to register user group policy notification";
62      user_policy_watcher_failed_ = true;
63    }
64
65    if (!RegisterGPNotification(machine_policy_changed_event_.handle(), true)) {
66      PLOG(WARNING) << "Failed to register machine group policy notification.";
67      machine_policy_watcher_failed_ = true;
68    }
69
70    Reload();
71  }
72
73  virtual void StopWatchingInternal() OVERRIDE {
74    DCHECK(OnPolicyWatcherThread());
75
76    if (!UnregisterGPNotification(user_policy_changed_event_.handle())) {
77      PLOG(WARNING) << "Failed to unregister user group policy notification";
78    }
79
80    if (!UnregisterGPNotification(machine_policy_changed_event_.handle())) {
81      PLOG(WARNING) <<
82          "Failed to unregister machine group policy notification.";
83    }
84
85    user_policy_watcher_.StopWatching();
86    machine_policy_watcher_.StopWatching();
87  }
88
89 private:
90  // Updates the watchers and schedules the reload task if appropriate.
91  void SetupWatches() {
92    DCHECK(OnPolicyWatcherThread());
93
94    if (!user_policy_watcher_failed_ &&
95        !user_policy_watcher_.GetWatchedObject() &&
96        !user_policy_watcher_.StartWatching(
97            user_policy_changed_event_.handle(), this)) {
98      LOG(WARNING) << "Failed to start watch for user policy change event";
99      user_policy_watcher_failed_ = true;
100    }
101
102    if (!machine_policy_watcher_failed_ &&
103        !machine_policy_watcher_.GetWatchedObject() &&
104        !machine_policy_watcher_.StartWatching(
105            machine_policy_changed_event_.handle(), this)) {
106      LOG(WARNING) << "Failed to start watch for machine policy change event";
107      machine_policy_watcher_failed_ = true;
108     }
109
110    if (user_policy_watcher_failed_ || machine_policy_watcher_failed_) {
111      ScheduleFallbackReloadTask();
112    }
113  }
114
115  bool GetRegistryPolicyString(const std::string& value_name,
116                               std::string* result) const {
117    // presubmit: allow wstring
118    std::wstring value_name_wide = UTF8ToWide(value_name);
119    // presubmit: allow wstring
120    std::wstring value;
121    RegKey policy_key(HKEY_LOCAL_MACHINE, kRegistrySubKey, KEY_READ);
122    if (policy_key.ReadValue(value_name_wide.c_str(), &value) ==
123        ERROR_SUCCESS) {
124      *result = WideToUTF8(value);
125      return true;
126    }
127
128    if (policy_key.Open(HKEY_CURRENT_USER, kRegistrySubKey, KEY_READ) ==
129      ERROR_SUCCESS) {
130      if (policy_key.ReadValue(value_name_wide.c_str(), &value) ==
131          ERROR_SUCCESS) {
132        *result = WideToUTF8(value);
133        return true;
134      }
135    }
136    return false;
137  }
138
139  bool GetRegistryPolicyInteger(const std::string& value_name,
140                                uint32* result) const {
141    // presubmit: allow wstring
142    std::wstring value_name_wide = UTF8ToWide(value_name);
143    DWORD value = 0;
144    RegKey policy_key(HKEY_LOCAL_MACHINE, kRegistrySubKey, KEY_READ);
145    if (policy_key.ReadValueDW(value_name_wide.c_str(), &value) ==
146        ERROR_SUCCESS) {
147      *result = value;
148      return true;
149    }
150
151    if (policy_key.Open(HKEY_CURRENT_USER, kRegistrySubKey, KEY_READ) ==
152        ERROR_SUCCESS) {
153      if (policy_key.ReadValueDW(value_name_wide.c_str(), &value) ==
154          ERROR_SUCCESS) {
155        *result = value;
156        return true;
157      }
158    }
159    return false;
160  }
161
162  bool GetRegistryPolicyBoolean(const std::string& value_name,
163                                bool* result) const {
164    uint32 local_result = 0;
165    bool ret = GetRegistryPolicyInteger(value_name, &local_result);
166    if (ret)
167      *result = local_result != 0;
168    return ret;
169  }
170
171  scoped_ptr<base::DictionaryValue> Load() {
172    scoped_ptr<base::DictionaryValue> policy(new base::DictionaryValue());
173
174    for (base::DictionaryValue::Iterator i(Defaults());
175         !i.IsAtEnd(); i.Advance()) {
176      const std::string& policy_name = i.key();
177      if (i.value().GetType() == base::DictionaryValue::TYPE_BOOLEAN) {
178        bool bool_value;
179        if (GetRegistryPolicyBoolean(policy_name, &bool_value)) {
180          policy->SetBoolean(policy_name, bool_value);
181        }
182      }
183      if (i.value().GetType() == base::DictionaryValue::TYPE_STRING) {
184        std::string string_value;
185        if (GetRegistryPolicyString(policy_name, &string_value)) {
186          policy->SetString(policy_name, string_value);
187        }
188      }
189    }
190    return policy.Pass();
191  }
192
193  // Post a reload notification and update the watch machinery.
194  void Reload() {
195    DCHECK(OnPolicyWatcherThread());
196    SetupWatches();
197    scoped_ptr<DictionaryValue> new_policy(Load());
198    UpdatePolicies(new_policy.get());
199  }
200
201  // ObjectWatcher::Delegate overrides:
202  virtual void OnObjectSignaled(HANDLE object) {
203    DCHECK(OnPolicyWatcherThread());
204    DCHECK(object == user_policy_changed_event_.handle() ||
205           object == machine_policy_changed_event_.handle())
206        << "unexpected object signaled policy reload, obj = "
207        << std::showbase << std::hex << object;
208    Reload();
209  }
210
211  base::WaitableEvent user_policy_changed_event_;
212  base::WaitableEvent machine_policy_changed_event_;
213  base::win::ObjectWatcher user_policy_watcher_;
214  base::win::ObjectWatcher machine_policy_watcher_;
215  bool user_policy_watcher_failed_;
216  bool machine_policy_watcher_failed_;
217};
218
219PolicyWatcher* PolicyWatcher::Create(
220    scoped_refptr<base::SingleThreadTaskRunner> task_runner) {
221  return new PolicyWatcherWin(task_runner);
222}
223
224}  // namespace policy_hack
225}  // namespace remoting
226