1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/host/pairing_registry_delegate_win.h"
6
7#include "base/json/json_string_value_serializer.h"
8#include "base/logging.h"
9#include "base/strings/utf_string_conversions.h"
10#include "base/values.h"
11#include "base/win/registry.h"
12
13namespace remoting {
14
15namespace {
16
17// Duplicates a registry key handle (returned by RegCreateXxx/RegOpenXxx).
18// The returned handle cannot be inherited and has the same permissions as
19// the source one.
20bool DuplicateKeyHandle(HKEY source, base::win::RegKey* dest) {
21  HANDLE handle;
22  if (!DuplicateHandle(GetCurrentProcess(),
23                       source,
24                       GetCurrentProcess(),
25                       &handle,
26                       0,
27                       FALSE,
28                       DUPLICATE_SAME_ACCESS)) {
29    PLOG(ERROR) << "Failed to duplicate a registry key handle";
30    return false;
31  }
32
33  dest->Set(reinterpret_cast<HKEY>(handle));
34  return true;
35}
36
37// Reads value |value_name| from |key| as a JSON string and returns it as
38// |base::Value|.
39scoped_ptr<base::DictionaryValue> ReadValue(const base::win::RegKey& key,
40                                            const wchar_t* value_name) {
41  // presubmit: allow wstring
42  std::wstring value_json;
43  LONG result = key.ReadValue(value_name, &value_json);
44  if (result != ERROR_SUCCESS) {
45    SetLastError(result);
46    PLOG(ERROR) << "Cannot read value '" << value_name << "'";
47    return scoped_ptr<base::DictionaryValue>();
48  }
49
50  // Parse the value.
51  std::string value_json_utf8 = base::WideToUTF8(value_json);
52  JSONStringValueSerializer serializer(&value_json_utf8);
53  int error_code;
54  std::string error_message;
55  scoped_ptr<base::Value> value(serializer.Deserialize(&error_code,
56                                                       &error_message));
57  if (!value) {
58    LOG(ERROR) << "Failed to parse '" << value_name << "': " << error_message
59               << " (" << error_code << ").";
60    return scoped_ptr<base::DictionaryValue>();
61  }
62
63  if (value->GetType() != base::Value::TYPE_DICTIONARY) {
64    LOG(ERROR) << "Failed to parse '" << value_name << "': not a dictionary.";
65    return scoped_ptr<base::DictionaryValue>();
66  }
67
68  return scoped_ptr<base::DictionaryValue>(
69      static_cast<base::DictionaryValue*>(value.release()));
70}
71
72// Serializes |value| into a JSON string and writes it as value |value_name|
73// under |key|.
74bool WriteValue(base::win::RegKey& key,
75                const wchar_t* value_name,
76                scoped_ptr<base::DictionaryValue> value) {
77  std::string value_json_utf8;
78  JSONStringValueSerializer serializer(&value_json_utf8);
79  if (!serializer.Serialize(*value)) {
80    LOG(ERROR) << "Failed to serialize '" << value_name << "'";
81    return false;
82  }
83
84  // presubmit: allow wstring
85  std::wstring value_json = base::UTF8ToWide(value_json_utf8);
86  LONG result = key.WriteValue(value_name, value_json.c_str());
87  if (result != ERROR_SUCCESS) {
88    SetLastError(result);
89    PLOG(ERROR) << "Cannot write value '" << value_name << "'";
90    return false;
91  }
92
93  return true;
94}
95
96}  // namespace
97
98using protocol::PairingRegistry;
99
100PairingRegistryDelegateWin::PairingRegistryDelegateWin() {
101}
102
103PairingRegistryDelegateWin::~PairingRegistryDelegateWin() {
104}
105
106bool PairingRegistryDelegateWin::SetRootKeys(HKEY privileged,
107                                             HKEY unprivileged) {
108  DCHECK(!privileged_.Valid());
109  DCHECK(!unprivileged_.Valid());
110  DCHECK(unprivileged);
111
112  if (!DuplicateKeyHandle(unprivileged, &unprivileged_))
113    return false;
114
115  if (privileged) {
116    if (!DuplicateKeyHandle(privileged, &privileged_))
117      return false;
118  }
119
120  return true;
121}
122
123scoped_ptr<base::ListValue> PairingRegistryDelegateWin::LoadAll() {
124  scoped_ptr<base::ListValue> pairings(new base::ListValue());
125
126  // Enumerate and parse all values under the unprivileged key.
127  DWORD count = unprivileged_.GetValueCount();
128  for (DWORD index = 0; index < count; ++index) {
129    // presubmit: allow wstring
130    std::wstring value_name;
131    LONG result = unprivileged_.GetValueNameAt(index, &value_name);
132    if (result != ERROR_SUCCESS) {
133      SetLastError(result);
134      PLOG(ERROR) << "Cannot get the name of value " << index;
135      continue;
136    }
137
138    PairingRegistry::Pairing pairing = Load(base::WideToUTF8(value_name));
139    if (pairing.is_valid())
140      pairings->Append(pairing.ToValue().release());
141  }
142
143  return pairings.Pass();
144}
145
146bool PairingRegistryDelegateWin::DeleteAll() {
147  if (!privileged_.Valid()) {
148    LOG(ERROR) << "Cannot delete pairings: the delegate is read-only.";
149    return false;
150  }
151
152  // Enumerate and delete the values in the privileged and unprivileged keys
153  // separately in case they get out of sync.
154  bool success = true;
155  DWORD count = unprivileged_.GetValueCount();
156  while (count > 0) {
157    // presubmit: allow wstring
158    std::wstring value_name;
159    LONG result = unprivileged_.GetValueNameAt(0, &value_name);
160    if (result == ERROR_SUCCESS)
161      result = unprivileged_.DeleteValue(value_name.c_str());
162
163    success = success && (result == ERROR_SUCCESS);
164    count = unprivileged_.GetValueCount();
165  }
166
167  count = privileged_.GetValueCount();
168  while (count > 0) {
169    // presubmit: allow wstring
170    std::wstring value_name;
171    LONG result = privileged_.GetValueNameAt(0, &value_name);
172    if (result == ERROR_SUCCESS)
173      result = privileged_.DeleteValue(value_name.c_str());
174
175    success = success && (result == ERROR_SUCCESS);
176    count = privileged_.GetValueCount();
177  }
178
179  return success;
180}
181
182PairingRegistry::Pairing PairingRegistryDelegateWin::Load(
183    const std::string& client_id) {
184  // presubmit: allow wstring
185  std::wstring value_name = base::UTF8ToWide(client_id);
186
187  // Read unprivileged fields first.
188  scoped_ptr<base::DictionaryValue> pairing = ReadValue(unprivileged_,
189                                                        value_name.c_str());
190  if (!pairing)
191    return PairingRegistry::Pairing();
192
193  // Read the shared secret.
194  if (privileged_.Valid()) {
195    scoped_ptr<base::DictionaryValue> secret = ReadValue(privileged_,
196                                                         value_name.c_str());
197    if (!secret)
198      return PairingRegistry::Pairing();
199
200    // Merge the two dictionaries.
201    pairing->MergeDictionary(secret.get());
202  }
203
204  return PairingRegistry::Pairing::CreateFromValue(*pairing);
205}
206
207bool PairingRegistryDelegateWin::Save(const PairingRegistry::Pairing& pairing) {
208  if (!privileged_.Valid()) {
209    LOG(ERROR) << "Cannot save pairing entry '" << pairing.client_id()
210                << "': the delegate is read-only.";
211    return false;
212  }
213
214  // Convert pairing to JSON.
215  scoped_ptr<base::DictionaryValue> pairing_json = pairing.ToValue();
216
217  // Extract the shared secret to a separate dictionary.
218  scoped_ptr<base::Value> secret_key;
219  CHECK(pairing_json->Remove(PairingRegistry::kSharedSecretKey, &secret_key));
220  scoped_ptr<base::DictionaryValue> secret_json(new base::DictionaryValue());
221  secret_json->Set(PairingRegistry::kSharedSecretKey, secret_key.release());
222
223  // presubmit: allow wstring
224  std::wstring value_name = base::UTF8ToWide(pairing.client_id());
225
226  // Write pairing to the registry.
227  if (!WriteValue(privileged_, value_name.c_str(), secret_json.Pass()) ||
228      !WriteValue(unprivileged_, value_name.c_str(), pairing_json.Pass())) {
229    return false;
230  }
231
232  return true;
233}
234
235bool PairingRegistryDelegateWin::Delete(const std::string& client_id) {
236  if (!privileged_.Valid()) {
237    LOG(ERROR) << "Cannot delete pairing entry '" << client_id
238                << "': the delegate is read-only.";
239    return false;
240  }
241
242  // presubmit: allow wstring
243  std::wstring value_name = base::UTF8ToWide(client_id);
244  LONG result = privileged_.DeleteValue(value_name.c_str());
245  if (result != ERROR_SUCCESS &&
246      result != ERROR_FILE_NOT_FOUND &&
247      result != ERROR_PATH_NOT_FOUND) {
248    SetLastError(result);
249    PLOG(ERROR) << "Cannot delete pairing entry '" << client_id << "'";
250    return false;
251  }
252
253  result = unprivileged_.DeleteValue(value_name.c_str());
254  if (result != ERROR_SUCCESS &&
255      result != ERROR_FILE_NOT_FOUND &&
256      result != ERROR_PATH_NOT_FOUND) {
257    SetLastError(result);
258    PLOG(ERROR) << "Cannot delete pairing entry '" << client_id << "'";
259    return false;
260  }
261
262  return true;
263}
264
265scoped_ptr<PairingRegistry::Delegate> CreatePairingRegistryDelegate() {
266  return scoped_ptr<PairingRegistry::Delegate>(
267      new PairingRegistryDelegateWin());
268}
269
270}  // namespace remoting
271