1/*
2 *  Copyright 2003 The WebRTC Project Authors. All rights reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11// Registry configuration wrapers class implementation
12//
13// Change made by S. Ganesh - ganesh@google.com:
14//   Use SHQueryValueEx instead of RegQueryValueEx throughout.
15//   A call to the SHLWAPI function is essentially a call to the standard
16//   function but with post-processing:
17//   * to fix REG_SZ or REG_EXPAND_SZ data that is not properly null-terminated;
18//   * to expand REG_EXPAND_SZ data.
19
20#include "webrtc/base/win32regkey.h"
21
22#include <shlwapi.h>
23
24#include "webrtc/base/common.h"
25#include "webrtc/base/logging.h"
26#include "webrtc/base/scoped_ptr.h"
27
28namespace rtc {
29
30RegKey::RegKey() {
31  h_key_ = NULL;
32}
33
34RegKey::~RegKey() {
35  Close();
36}
37
38HRESULT RegKey::Create(HKEY parent_key, const wchar_t* key_name) {
39  return Create(parent_key,
40                key_name,
41                REG_NONE,
42                REG_OPTION_NON_VOLATILE,
43                KEY_ALL_ACCESS,
44                NULL,
45                NULL);
46}
47
48HRESULT RegKey::Open(HKEY parent_key, const wchar_t* key_name) {
49  return Open(parent_key, key_name, KEY_ALL_ACCESS);
50}
51
52bool RegKey::HasValue(const TCHAR* value_name) const {
53  return (ERROR_SUCCESS == ::RegQueryValueEx(h_key_, value_name, NULL,
54                                             NULL, NULL, NULL));
55}
56
57HRESULT RegKey::SetValue(const wchar_t* full_key_name,
58                         const wchar_t* value_name,
59                         DWORD value) {
60  ASSERT(full_key_name != NULL);
61
62  return SetValueStaticHelper(full_key_name, value_name, REG_DWORD, &value);
63}
64
65HRESULT RegKey::SetValue(const wchar_t* full_key_name,
66                         const wchar_t* value_name,
67                         DWORD64 value) {
68  ASSERT(full_key_name != NULL);
69
70  return SetValueStaticHelper(full_key_name, value_name, REG_QWORD, &value);
71}
72
73HRESULT RegKey::SetValue(const wchar_t* full_key_name,
74                         const wchar_t* value_name,
75                         float value) {
76  ASSERT(full_key_name != NULL);
77
78  return SetValueStaticHelper(full_key_name, value_name,
79                              REG_BINARY, &value, sizeof(value));
80}
81
82HRESULT RegKey::SetValue(const wchar_t* full_key_name,
83                         const wchar_t* value_name,
84                         double value) {
85  ASSERT(full_key_name != NULL);
86
87  return SetValueStaticHelper(full_key_name, value_name,
88                              REG_BINARY, &value, sizeof(value));
89}
90
91HRESULT RegKey::SetValue(const wchar_t* full_key_name,
92                         const wchar_t* value_name,
93                         const TCHAR* value) {
94  ASSERT(full_key_name != NULL);
95  ASSERT(value != NULL);
96
97  return SetValueStaticHelper(full_key_name, value_name,
98                              REG_SZ, const_cast<wchar_t*>(value));
99}
100
101HRESULT RegKey::SetValue(const wchar_t* full_key_name,
102                         const wchar_t* value_name,
103                         const uint8* value,
104                         DWORD byte_count) {
105  ASSERT(full_key_name != NULL);
106
107  return SetValueStaticHelper(full_key_name, value_name, REG_BINARY,
108                              const_cast<uint8*>(value), byte_count);
109}
110
111HRESULT RegKey::SetValueMultiSZ(const wchar_t* full_key_name,
112                                const wchar_t* value_name,
113                                const uint8* value,
114                                DWORD byte_count) {
115  ASSERT(full_key_name != NULL);
116
117  return SetValueStaticHelper(full_key_name, value_name, REG_MULTI_SZ,
118                              const_cast<uint8*>(value), byte_count);
119}
120
121HRESULT RegKey::GetValue(const wchar_t* full_key_name,
122                         const wchar_t* value_name,
123                         DWORD* value) {
124  ASSERT(full_key_name != NULL);
125  ASSERT(value != NULL);
126
127  return GetValueStaticHelper(full_key_name, value_name, REG_DWORD, value);
128}
129
130HRESULT RegKey::GetValue(const wchar_t* full_key_name,
131                         const wchar_t* value_name,
132                         DWORD64* value) {
133  ASSERT(full_key_name != NULL);
134  ASSERT(value != NULL);
135
136  return GetValueStaticHelper(full_key_name, value_name, REG_QWORD, value);
137}
138
139HRESULT RegKey::GetValue(const wchar_t* full_key_name,
140                         const wchar_t* value_name,
141                         float* value) {
142  ASSERT(value != NULL);
143  ASSERT(full_key_name != NULL);
144
145  DWORD byte_count = 0;
146  scoped_ptr<byte[]> buffer;
147  HRESULT hr = GetValueStaticHelper(full_key_name, value_name,
148                                    REG_BINARY, buffer.accept(), &byte_count);
149  if (SUCCEEDED(hr)) {
150    ASSERT(byte_count == sizeof(*value));
151    if (byte_count == sizeof(*value)) {
152      *value = *reinterpret_cast<float*>(buffer.get());
153    }
154  }
155  return hr;
156}
157
158HRESULT RegKey::GetValue(const wchar_t* full_key_name,
159                         const wchar_t* value_name,
160                         double* value) {
161  ASSERT(value != NULL);
162  ASSERT(full_key_name != NULL);
163
164  DWORD byte_count = 0;
165  scoped_ptr<byte[]> buffer;
166  HRESULT hr = GetValueStaticHelper(full_key_name, value_name,
167                                    REG_BINARY, buffer.accept(), &byte_count);
168  if (SUCCEEDED(hr)) {
169    ASSERT(byte_count == sizeof(*value));
170    if (byte_count == sizeof(*value)) {
171      *value = *reinterpret_cast<double*>(buffer.get());
172    }
173  }
174  return hr;
175}
176
177HRESULT RegKey::GetValue(const wchar_t* full_key_name,
178                         const wchar_t* value_name,
179                         wchar_t** value) {
180  ASSERT(full_key_name != NULL);
181  ASSERT(value != NULL);
182
183  return GetValueStaticHelper(full_key_name, value_name, REG_SZ, value);
184}
185
186HRESULT RegKey::GetValue(const wchar_t* full_key_name,
187                         const wchar_t* value_name,
188                         std::wstring* value) {
189  ASSERT(full_key_name != NULL);
190  ASSERT(value != NULL);
191
192  scoped_ptr<wchar_t[]> buffer;
193  HRESULT hr = RegKey::GetValue(full_key_name, value_name, buffer.accept());
194  if (SUCCEEDED(hr)) {
195    value->assign(buffer.get());
196  }
197  return hr;
198}
199
200HRESULT RegKey::GetValue(const wchar_t* full_key_name,
201                         const wchar_t* value_name,
202                         std::vector<std::wstring>* value) {
203  ASSERT(full_key_name != NULL);
204  ASSERT(value != NULL);
205
206  return GetValueStaticHelper(full_key_name, value_name, REG_MULTI_SZ, value);
207}
208
209HRESULT RegKey::GetValue(const wchar_t* full_key_name,
210                         const wchar_t* value_name,
211                         uint8** value,
212                         DWORD* byte_count) {
213  ASSERT(full_key_name != NULL);
214  ASSERT(value != NULL);
215  ASSERT(byte_count != NULL);
216
217  return GetValueStaticHelper(full_key_name, value_name,
218                              REG_BINARY, value, byte_count);
219}
220
221HRESULT RegKey::DeleteSubKey(const wchar_t* key_name) {
222  ASSERT(key_name != NULL);
223  ASSERT(h_key_ != NULL);
224
225  LONG res = ::RegDeleteKey(h_key_, key_name);
226  HRESULT hr = HRESULT_FROM_WIN32(res);
227  if (hr == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) ||
228      hr == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) {
229    hr = S_FALSE;
230  }
231  return hr;
232}
233
234HRESULT RegKey::DeleteValue(const wchar_t* value_name) {
235  ASSERT(h_key_ != NULL);
236
237  LONG res = ::RegDeleteValue(h_key_, value_name);
238  HRESULT hr = HRESULT_FROM_WIN32(res);
239  if (hr == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) ||
240      hr == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) {
241    hr = S_FALSE;
242  }
243  return hr;
244}
245
246HRESULT RegKey::Close() {
247  HRESULT hr = S_OK;
248  if (h_key_ != NULL) {
249    LONG res = ::RegCloseKey(h_key_);
250    hr = HRESULT_FROM_WIN32(res);
251    h_key_ = NULL;
252  }
253  return hr;
254}
255
256HRESULT RegKey::Create(HKEY parent_key,
257                       const wchar_t* key_name,
258                       wchar_t* lpszClass,
259                       DWORD options,
260                       REGSAM sam_desired,
261                       LPSECURITY_ATTRIBUTES lpSecAttr,
262                       LPDWORD lpdwDisposition) {
263  ASSERT(key_name != NULL);
264  ASSERT(parent_key != NULL);
265
266  DWORD dw = 0;
267  HKEY h_key = NULL;
268  LONG res = ::RegCreateKeyEx(parent_key, key_name, 0, lpszClass, options,
269                              sam_desired, lpSecAttr, &h_key, &dw);
270  HRESULT hr = HRESULT_FROM_WIN32(res);
271
272  if (lpdwDisposition) {
273    *lpdwDisposition = dw;
274  }
275
276  // we have to close the currently opened key
277  // before replacing it with the new one
278  if (hr == S_OK) {
279    hr = Close();
280    ASSERT(hr == S_OK);
281    h_key_ = h_key;
282  }
283  return hr;
284}
285
286HRESULT RegKey::Open(HKEY parent_key,
287                     const wchar_t* key_name,
288                     REGSAM sam_desired) {
289  ASSERT(key_name != NULL);
290  ASSERT(parent_key != NULL);
291
292  HKEY h_key = NULL;
293  LONG res = ::RegOpenKeyEx(parent_key, key_name, 0, sam_desired, &h_key);
294  HRESULT hr = HRESULT_FROM_WIN32(res);
295
296  // we have to close the currently opened key
297  // before replacing it with the new one
298  if (hr == S_OK) {
299    // close the currently opened key if any
300    hr = Close();
301    ASSERT(hr == S_OK);
302    h_key_ = h_key;
303  }
304  return hr;
305}
306
307// save the key and all of its subkeys and values to a file
308HRESULT RegKey::Save(const wchar_t* full_key_name, const wchar_t* file_name) {
309  ASSERT(full_key_name != NULL);
310  ASSERT(file_name != NULL);
311
312  std::wstring key_name(full_key_name);
313  HKEY h_key = GetRootKeyInfo(&key_name);
314  if (!h_key) {
315    return E_FAIL;
316  }
317
318  RegKey key;
319  HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
320  if (FAILED(hr)) {
321    return hr;
322  }
323
324  AdjustCurrentProcessPrivilege(SE_BACKUP_NAME, true);
325  LONG res = ::RegSaveKey(key.h_key_, file_name, NULL);
326  AdjustCurrentProcessPrivilege(SE_BACKUP_NAME, false);
327
328  return HRESULT_FROM_WIN32(res);
329}
330
331// restore the key and all of its subkeys and values which are saved into a file
332HRESULT RegKey::Restore(const wchar_t* full_key_name,
333                        const wchar_t* file_name) {
334  ASSERT(full_key_name != NULL);
335  ASSERT(file_name != NULL);
336
337  std::wstring key_name(full_key_name);
338  HKEY h_key = GetRootKeyInfo(&key_name);
339  if (!h_key) {
340    return E_FAIL;
341  }
342
343  RegKey key;
344  HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_WRITE);
345  if (FAILED(hr)) {
346    return hr;
347  }
348
349  AdjustCurrentProcessPrivilege(SE_RESTORE_NAME, true);
350  LONG res = ::RegRestoreKey(key.h_key_, file_name, REG_FORCE_RESTORE);
351  AdjustCurrentProcessPrivilege(SE_RESTORE_NAME, false);
352
353  return HRESULT_FROM_WIN32(res);
354}
355
356// check if the current key has the specified subkey
357bool RegKey::HasSubkey(const wchar_t* key_name) const {
358  ASSERT(key_name != NULL);
359
360  RegKey key;
361  HRESULT hr = key.Open(h_key_, key_name, KEY_READ);
362  key.Close();
363  return hr == S_OK;
364}
365
366// static flush key
367HRESULT RegKey::FlushKey(const wchar_t* full_key_name) {
368  ASSERT(full_key_name != NULL);
369
370  HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
371  // get the root HKEY
372  std::wstring key_name(full_key_name);
373  HKEY h_key = GetRootKeyInfo(&key_name);
374
375  if (h_key != NULL) {
376    LONG res = ::RegFlushKey(h_key);
377    hr = HRESULT_FROM_WIN32(res);
378  }
379  return hr;
380}
381
382// static SET helper
383HRESULT RegKey::SetValueStaticHelper(const wchar_t* full_key_name,
384                                     const wchar_t* value_name,
385                                     DWORD type,
386                                     LPVOID value,
387                                     DWORD byte_count) {
388  ASSERT(full_key_name != NULL);
389
390  HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
391  // get the root HKEY
392  std::wstring key_name(full_key_name);
393  HKEY h_key = GetRootKeyInfo(&key_name);
394
395  if (h_key != NULL) {
396    RegKey key;
397    hr = key.Create(h_key, key_name.c_str());
398    if (hr == S_OK) {
399      switch (type) {
400        case REG_DWORD:
401          hr = key.SetValue(value_name, *(static_cast<DWORD*>(value)));
402          break;
403        case REG_QWORD:
404          hr = key.SetValue(value_name, *(static_cast<DWORD64*>(value)));
405          break;
406        case REG_SZ:
407          hr = key.SetValue(value_name, static_cast<const wchar_t*>(value));
408          break;
409        case REG_BINARY:
410          hr = key.SetValue(value_name, static_cast<const uint8*>(value),
411                            byte_count);
412          break;
413        case REG_MULTI_SZ:
414          hr = key.SetValue(value_name, static_cast<const uint8*>(value),
415                            byte_count, type);
416          break;
417        default:
418          ASSERT(false);
419          hr = HRESULT_FROM_WIN32(ERROR_DATATYPE_MISMATCH);
420          break;
421      }
422      // close the key after writing
423      HRESULT temp_hr = key.Close();
424      if (hr == S_OK) {
425        hr = temp_hr;
426      }
427    }
428  }
429  return hr;
430}
431
432// static GET helper
433HRESULT RegKey::GetValueStaticHelper(const wchar_t* full_key_name,
434                                     const wchar_t* value_name,
435                                     DWORD type,
436                                     LPVOID value,
437                                     DWORD* byte_count) {
438  ASSERT(full_key_name != NULL);
439
440  HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
441  // get the root HKEY
442  std::wstring key_name(full_key_name);
443  HKEY h_key = GetRootKeyInfo(&key_name);
444
445  if (h_key != NULL) {
446    RegKey key;
447    hr = key.Open(h_key, key_name.c_str(), KEY_READ);
448    if (hr == S_OK) {
449      switch (type) {
450        case REG_DWORD:
451          hr = key.GetValue(value_name, reinterpret_cast<DWORD*>(value));
452          break;
453        case REG_QWORD:
454          hr = key.GetValue(value_name, reinterpret_cast<DWORD64*>(value));
455          break;
456        case REG_SZ:
457          hr = key.GetValue(value_name, reinterpret_cast<wchar_t**>(value));
458          break;
459        case REG_MULTI_SZ:
460          hr = key.GetValue(value_name, reinterpret_cast<
461                                            std::vector<std::wstring>*>(value));
462          break;
463        case REG_BINARY:
464          hr = key.GetValue(value_name, reinterpret_cast<uint8**>(value),
465                            byte_count);
466          break;
467        default:
468          ASSERT(false);
469          hr = HRESULT_FROM_WIN32(ERROR_DATATYPE_MISMATCH);
470          break;
471      }
472      // close the key after writing
473      HRESULT temp_hr = key.Close();
474      if (hr == S_OK) {
475        hr = temp_hr;
476      }
477    }
478  }
479  return hr;
480}
481
482// GET helper
483HRESULT RegKey::GetValueHelper(const wchar_t* value_name,
484                               DWORD* type,
485                               uint8** value,
486                               DWORD* byte_count) const {
487  ASSERT(byte_count != NULL);
488  ASSERT(value != NULL);
489  ASSERT(type != NULL);
490
491  // init return buffer
492  *value = NULL;
493
494  // get the size of the return data buffer
495  LONG res = ::SHQueryValueEx(h_key_, value_name, NULL, type, NULL, byte_count);
496  HRESULT hr = HRESULT_FROM_WIN32(res);
497
498  if (hr == S_OK) {
499    // if the value length is 0, nothing to do
500    if (*byte_count != 0) {
501      // allocate the buffer
502      *value = new byte[*byte_count];
503      ASSERT(*value != NULL);
504
505      // make the call again to get the data
506      res = ::SHQueryValueEx(h_key_, value_name, NULL,
507                             type, *value, byte_count);
508      hr = HRESULT_FROM_WIN32(res);
509      ASSERT(hr == S_OK);
510    }
511  }
512  return hr;
513}
514
515// Int32 Get
516HRESULT RegKey::GetValue(const wchar_t* value_name, DWORD* value) const {
517  ASSERT(value != NULL);
518
519  DWORD type = 0;
520  DWORD byte_count = sizeof(DWORD);
521  LONG res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
522                              value, &byte_count);
523  HRESULT hr = HRESULT_FROM_WIN32(res);
524  ASSERT((hr != S_OK) || (type == REG_DWORD));
525  ASSERT((hr != S_OK) || (byte_count == sizeof(DWORD)));
526  return hr;
527}
528
529// Int64 Get
530HRESULT RegKey::GetValue(const wchar_t* value_name, DWORD64* value) const {
531  ASSERT(value != NULL);
532
533  DWORD type = 0;
534  DWORD byte_count = sizeof(DWORD64);
535  LONG res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
536                              value, &byte_count);
537  HRESULT hr = HRESULT_FROM_WIN32(res);
538  ASSERT((hr != S_OK) || (type == REG_QWORD));
539  ASSERT((hr != S_OK) || (byte_count == sizeof(DWORD64)));
540  return hr;
541}
542
543// String Get
544HRESULT RegKey::GetValue(const wchar_t* value_name, wchar_t** value) const {
545  ASSERT(value != NULL);
546
547  DWORD byte_count = 0;
548  DWORD type = 0;
549
550  // first get the size of the string buffer
551  LONG res = ::SHQueryValueEx(h_key_, value_name, NULL,
552                              &type, NULL, &byte_count);
553  HRESULT hr = HRESULT_FROM_WIN32(res);
554
555  if (hr == S_OK) {
556    // allocate room for the string and a terminating \0
557    *value = new wchar_t[(byte_count / sizeof(wchar_t)) + 1];
558
559    if ((*value) != NULL) {
560      if (byte_count != 0) {
561        // make the call again
562        res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
563                               *value, &byte_count);
564        hr = HRESULT_FROM_WIN32(res);
565      } else {
566        (*value)[0] = L'\0';
567      }
568
569      ASSERT((hr != S_OK) || (type == REG_SZ) ||
570             (type == REG_MULTI_SZ) || (type == REG_EXPAND_SZ));
571    } else {
572      hr = E_OUTOFMEMORY;
573    }
574  }
575
576  return hr;
577}
578
579// get a string value
580HRESULT RegKey::GetValue(const wchar_t* value_name, std::wstring* value) const {
581  ASSERT(value != NULL);
582
583  DWORD byte_count = 0;
584  DWORD type = 0;
585
586  // first get the size of the string buffer
587  LONG res = ::SHQueryValueEx(h_key_, value_name, NULL,
588                              &type, NULL, &byte_count);
589  HRESULT hr = HRESULT_FROM_WIN32(res);
590
591  if (hr == S_OK) {
592    if (byte_count != 0) {
593      // Allocate some memory and make the call again
594      value->resize(byte_count / sizeof(wchar_t) + 1);
595      res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
596                             &value->at(0), &byte_count);
597      hr = HRESULT_FROM_WIN32(res);
598      value->resize(wcslen(value->data()));
599    } else {
600      value->clear();
601    }
602
603    ASSERT((hr != S_OK) || (type == REG_SZ) ||
604           (type == REG_MULTI_SZ) || (type == REG_EXPAND_SZ));
605  }
606
607  return hr;
608}
609
610// convert REG_MULTI_SZ bytes to string array
611HRESULT RegKey::MultiSZBytesToStringArray(const uint8* buffer,
612                                          DWORD byte_count,
613                                          std::vector<std::wstring>* value) {
614  ASSERT(buffer != NULL);
615  ASSERT(value != NULL);
616
617  const wchar_t* data = reinterpret_cast<const wchar_t*>(buffer);
618  DWORD data_len = byte_count / sizeof(wchar_t);
619  value->clear();
620  if (data_len > 1) {
621    // must be terminated by two null characters
622    if (data[data_len - 1] != 0 || data[data_len - 2] != 0) {
623      return E_INVALIDARG;
624    }
625
626    // put null-terminated strings into arrays
627    while (*data) {
628      std::wstring str(data);
629      value->push_back(str);
630      data += str.length() + 1;
631    }
632  }
633  return S_OK;
634}
635
636// get a std::vector<std::wstring> value from REG_MULTI_SZ type
637HRESULT RegKey::GetValue(const wchar_t* value_name,
638                         std::vector<std::wstring>* value) const {
639  ASSERT(value != NULL);
640
641  DWORD byte_count = 0;
642  DWORD type = 0;
643  uint8* buffer = 0;
644
645  // first get the size of the buffer
646  HRESULT hr = GetValueHelper(value_name, &type, &buffer, &byte_count);
647  ASSERT((hr != S_OK) || (type == REG_MULTI_SZ));
648
649  if (SUCCEEDED(hr)) {
650    hr = MultiSZBytesToStringArray(buffer, byte_count, value);
651  }
652
653  return hr;
654}
655
656// Binary data Get
657HRESULT RegKey::GetValue(const wchar_t* value_name,
658                         uint8** value,
659                         DWORD* byte_count) const {
660  ASSERT(byte_count != NULL);
661  ASSERT(value != NULL);
662
663  DWORD type = 0;
664  HRESULT hr = GetValueHelper(value_name, &type, value, byte_count);
665  ASSERT((hr != S_OK) || (type == REG_MULTI_SZ) || (type == REG_BINARY));
666  return hr;
667}
668
669// Raw data get
670HRESULT RegKey::GetValue(const wchar_t* value_name,
671                         uint8** value,
672                         DWORD* byte_count,
673                         DWORD*type) const {
674  ASSERT(type != NULL);
675  ASSERT(byte_count != NULL);
676  ASSERT(value != NULL);
677
678  return GetValueHelper(value_name, type, value, byte_count);
679}
680
681// Int32 set
682HRESULT RegKey::SetValue(const wchar_t* value_name, DWORD value) const {
683  ASSERT(h_key_ != NULL);
684
685  LONG res = ::RegSetValueEx(h_key_, value_name, NULL, REG_DWORD,
686                             reinterpret_cast<const uint8*>(&value),
687                             sizeof(DWORD));
688  return HRESULT_FROM_WIN32(res);
689}
690
691// Int64 set
692HRESULT RegKey::SetValue(const wchar_t* value_name, DWORD64 value) const {
693  ASSERT(h_key_ != NULL);
694
695  LONG res = ::RegSetValueEx(h_key_, value_name, NULL, REG_QWORD,
696                             reinterpret_cast<const uint8*>(&value),
697                             sizeof(DWORD64));
698  return HRESULT_FROM_WIN32(res);
699}
700
701// String set
702HRESULT RegKey::SetValue(const wchar_t* value_name,
703                         const wchar_t* value) const {
704  ASSERT(value != NULL);
705  ASSERT(h_key_ != NULL);
706
707  LONG res = ::RegSetValueEx(h_key_, value_name, NULL, REG_SZ,
708                             reinterpret_cast<const uint8*>(value),
709                             (lstrlen(value) + 1) * sizeof(wchar_t));
710  return HRESULT_FROM_WIN32(res);
711}
712
713// Binary data set
714HRESULT RegKey::SetValue(const wchar_t* value_name,
715                         const uint8* value,
716                         DWORD byte_count) const {
717  ASSERT(h_key_ != NULL);
718
719  // special case - if 'value' is NULL make sure byte_count is zero
720  if (value == NULL) {
721    byte_count = 0;
722  }
723
724  LONG res = ::RegSetValueEx(h_key_, value_name, NULL,
725                             REG_BINARY, value, byte_count);
726  return HRESULT_FROM_WIN32(res);
727}
728
729// Raw data set
730HRESULT RegKey::SetValue(const wchar_t* value_name,
731                         const uint8* value,
732                         DWORD byte_count,
733                         DWORD type) const {
734  ASSERT(value != NULL);
735  ASSERT(h_key_ != NULL);
736
737  LONG res = ::RegSetValueEx(h_key_, value_name, NULL, type, value, byte_count);
738  return HRESULT_FROM_WIN32(res);
739}
740
741bool RegKey::HasKey(const wchar_t* full_key_name) {
742  ASSERT(full_key_name != NULL);
743
744  // get the root HKEY
745  std::wstring key_name(full_key_name);
746  HKEY h_key = GetRootKeyInfo(&key_name);
747
748  if (h_key != NULL) {
749    RegKey key;
750    HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
751    key.Close();
752    return S_OK == hr;
753  }
754  return false;
755}
756
757// static version of HasValue
758bool RegKey::HasValue(const wchar_t* full_key_name, const wchar_t* value_name) {
759  ASSERT(full_key_name != NULL);
760
761  bool has_value = false;
762  // get the root HKEY
763  std::wstring key_name(full_key_name);
764  HKEY h_key = GetRootKeyInfo(&key_name);
765
766  if (h_key != NULL) {
767    RegKey key;
768    if (key.Open(h_key, key_name.c_str(), KEY_READ) == S_OK) {
769      has_value = key.HasValue(value_name);
770      key.Close();
771    }
772  }
773  return has_value;
774}
775
776HRESULT RegKey::GetValueType(const wchar_t* full_key_name,
777                             const wchar_t* value_name,
778                             DWORD* value_type) {
779  ASSERT(full_key_name != NULL);
780  ASSERT(value_type != NULL);
781
782  *value_type = REG_NONE;
783
784  std::wstring key_name(full_key_name);
785  HKEY h_key = GetRootKeyInfo(&key_name);
786
787  RegKey key;
788  HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
789  if (SUCCEEDED(hr)) {
790    LONG res = ::SHQueryValueEx(key.h_key_, value_name, NULL, value_type,
791                                NULL, NULL);
792    if (res != ERROR_SUCCESS) {
793      hr = HRESULT_FROM_WIN32(res);
794    }
795  }
796
797  return hr;
798}
799
800HRESULT RegKey::DeleteKey(const wchar_t* full_key_name) {
801  ASSERT(full_key_name != NULL);
802
803  return DeleteKey(full_key_name, true);
804}
805
806HRESULT RegKey::DeleteKey(const wchar_t* full_key_name, bool recursively) {
807  ASSERT(full_key_name != NULL);
808
809  // need to open the parent key first
810  // get the root HKEY
811  std::wstring key_name(full_key_name);
812  HKEY h_key = GetRootKeyInfo(&key_name);
813
814  // get the parent key
815  std::wstring parent_key(GetParentKeyInfo(&key_name));
816
817  RegKey key;
818  HRESULT hr = key.Open(h_key, parent_key.c_str());
819
820  if (hr == S_OK) {
821    hr = recursively ? key.RecurseDeleteSubKey(key_name.c_str())
822                     : key.DeleteSubKey(key_name.c_str());
823  } else if (hr == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) ||
824             hr == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) {
825    hr = S_FALSE;
826  }
827
828  key.Close();
829  return hr;
830}
831
832HRESULT RegKey::DeleteValue(const wchar_t* full_key_name,
833                            const wchar_t* value_name) {
834  ASSERT(full_key_name != NULL);
835
836  HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
837  // get the root HKEY
838  std::wstring key_name(full_key_name);
839  HKEY h_key = GetRootKeyInfo(&key_name);
840
841  if (h_key != NULL) {
842    RegKey key;
843    hr = key.Open(h_key, key_name.c_str());
844    if (hr == S_OK) {
845      hr = key.DeleteValue(value_name);
846      key.Close();
847    }
848  }
849  return hr;
850}
851
852HRESULT RegKey::RecurseDeleteSubKey(const wchar_t* key_name) {
853  ASSERT(key_name != NULL);
854
855  RegKey key;
856  HRESULT hr = key.Open(h_key_, key_name);
857
858  if (hr == S_OK) {
859    // enumerate all subkeys of this key and recursivelly delete them
860    FILETIME time = {0};
861    wchar_t key_name_buf[kMaxKeyNameChars] = {0};
862    DWORD key_name_buf_size = kMaxKeyNameChars;
863    while (hr == S_OK &&
864        ::RegEnumKeyEx(key.h_key_, 0, key_name_buf, &key_name_buf_size,
865                       NULL, NULL, NULL,  &time) == ERROR_SUCCESS) {
866      hr = key.RecurseDeleteSubKey(key_name_buf);
867
868      // restore the buffer size
869      key_name_buf_size = kMaxKeyNameChars;
870    }
871    // close the top key
872    key.Close();
873  }
874
875  if (hr == S_OK) {
876    // the key has no more children keys
877    // delete the key and all of its values
878    hr = DeleteSubKey(key_name);
879  }
880
881  return hr;
882}
883
884HKEY RegKey::GetRootKeyInfo(std::wstring* full_key_name) {
885  ASSERT(full_key_name != NULL);
886
887  HKEY h_key = NULL;
888  // get the root HKEY
889  size_t index = full_key_name->find(L'\\');
890  std::wstring root_key;
891
892  if (index == -1) {
893    root_key = *full_key_name;
894    *full_key_name = L"";
895  } else {
896    root_key = full_key_name->substr(0, index);
897    *full_key_name = full_key_name->substr(index + 1,
898                                           full_key_name->length() - index - 1);
899  }
900
901  for (std::wstring::iterator iter = root_key.begin();
902       iter != root_key.end(); ++iter) {
903    *iter = toupper(*iter);
904  }
905
906  if (!root_key.compare(L"HKLM") ||
907      !root_key.compare(L"HKEY_LOCAL_MACHINE")) {
908    h_key = HKEY_LOCAL_MACHINE;
909  } else if (!root_key.compare(L"HKCU") ||
910             !root_key.compare(L"HKEY_CURRENT_USER")) {
911    h_key = HKEY_CURRENT_USER;
912  } else if (!root_key.compare(L"HKU") ||
913             !root_key.compare(L"HKEY_USERS")) {
914    h_key = HKEY_USERS;
915  } else if (!root_key.compare(L"HKCR") ||
916             !root_key.compare(L"HKEY_CLASSES_ROOT")) {
917    h_key = HKEY_CLASSES_ROOT;
918  }
919
920  return h_key;
921}
922
923
924// Returns true if this key name is 'safe' for deletion
925// (doesn't specify a key root)
926bool RegKey::SafeKeyNameForDeletion(const wchar_t* key_name) {
927  ASSERT(key_name != NULL);
928  std::wstring key(key_name);
929
930  HKEY root_key = GetRootKeyInfo(&key);
931
932  if (!root_key) {
933    key = key_name;
934  }
935  if (key.empty()) {
936    return false;
937  }
938  bool found_subkey = false, backslash_found = false;
939  for (size_t i = 0 ; i < key.length() ; ++i) {
940    if (key[i] == L'\\') {
941      backslash_found = true;
942    } else if (backslash_found) {
943      found_subkey = true;
944      break;
945    }
946  }
947  return (root_key == HKEY_USERS) ? found_subkey : true;
948}
949
950std::wstring RegKey::GetParentKeyInfo(std::wstring* key_name) {
951  ASSERT(key_name != NULL);
952
953  // get the parent key
954  size_t index = key_name->rfind(L'\\');
955  std::wstring parent_key;
956  if (index == -1) {
957    parent_key = L"";
958  } else {
959    parent_key = key_name->substr(0, index);
960    *key_name = key_name->substr(index + 1, key_name->length() - index - 1);
961  }
962
963  return parent_key;
964}
965
966// get the number of values for this key
967uint32 RegKey::GetValueCount() {
968  DWORD num_values = 0;
969
970  if (ERROR_SUCCESS != ::RegQueryInfoKey(
971        h_key_,  // key handle
972        NULL,  // buffer for class name
973        NULL,  // size of class string
974        NULL,  // reserved
975        NULL,  // number of subkeys
976        NULL,  // longest subkey size
977        NULL,  // longest class string
978        &num_values,  // number of values for this key
979        NULL,  // longest value name
980        NULL,  // longest value data
981        NULL,  // security descriptor
982        NULL)) {  // last write time
983    ASSERT(false);
984  }
985  return num_values;
986}
987
988// Enumerators for the value_names for this key
989
990// Called to get the value name for the given value name index
991// Use GetValueCount() to get the total value_name count for this key
992// Returns failure if no key at the specified index
993HRESULT RegKey::GetValueNameAt(int index, std::wstring* value_name,
994                               DWORD* type) {
995  ASSERT(value_name != NULL);
996
997  LONG res = ERROR_SUCCESS;
998  wchar_t value_name_buf[kMaxValueNameChars] = {0};
999  DWORD value_name_buf_size = kMaxValueNameChars;
1000  res = ::RegEnumValue(h_key_, index, value_name_buf, &value_name_buf_size,
1001                       NULL, type, NULL, NULL);
1002
1003  if (res == ERROR_SUCCESS) {
1004    value_name->assign(value_name_buf);
1005  }
1006
1007  return HRESULT_FROM_WIN32(res);
1008}
1009
1010uint32 RegKey::GetSubkeyCount() {
1011  // number of values for key
1012  DWORD num_subkeys = 0;
1013
1014  if (ERROR_SUCCESS != ::RegQueryInfoKey(
1015          h_key_,  // key handle
1016          NULL,  // buffer for class name
1017          NULL,  // size of class string
1018          NULL,  // reserved
1019          &num_subkeys,  // number of subkeys
1020          NULL,  // longest subkey size
1021          NULL,  // longest class string
1022          NULL,  // number of values for this key
1023          NULL,  // longest value name
1024          NULL,  // longest value data
1025          NULL,  // security descriptor
1026          NULL)) { // last write time
1027    ASSERT(false);
1028  }
1029  return num_subkeys;
1030}
1031
1032HRESULT RegKey::GetSubkeyNameAt(int index, std::wstring* key_name) {
1033  ASSERT(key_name != NULL);
1034
1035  LONG res = ERROR_SUCCESS;
1036  wchar_t key_name_buf[kMaxKeyNameChars] = {0};
1037  DWORD key_name_buf_size = kMaxKeyNameChars;
1038
1039  res = ::RegEnumKeyEx(h_key_, index, key_name_buf, &key_name_buf_size,
1040                       NULL, NULL, NULL, NULL);
1041
1042  if (res == ERROR_SUCCESS) {
1043    key_name->assign(key_name_buf);
1044  }
1045
1046  return HRESULT_FROM_WIN32(res);
1047}
1048
1049// Is the key empty: having no sub-keys and values
1050bool RegKey::IsKeyEmpty(const wchar_t* full_key_name) {
1051  ASSERT(full_key_name != NULL);
1052
1053  bool is_empty = true;
1054
1055  // Get the root HKEY
1056  std::wstring key_name(full_key_name);
1057  HKEY h_key = GetRootKeyInfo(&key_name);
1058
1059  // Open the key to check
1060  if (h_key != NULL) {
1061    RegKey key;
1062    HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
1063    if (SUCCEEDED(hr)) {
1064      is_empty = key.GetSubkeyCount() == 0 && key.GetValueCount() == 0;
1065      key.Close();
1066    }
1067  }
1068
1069  return is_empty;
1070}
1071
1072bool AdjustCurrentProcessPrivilege(const TCHAR* privilege, bool to_enable) {
1073  ASSERT(privilege != NULL);
1074
1075  bool ret = false;
1076  HANDLE token;
1077  if (::OpenProcessToken(::GetCurrentProcess(),
1078                         TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, &token)) {
1079    LUID luid;
1080    memset(&luid, 0, sizeof(luid));
1081    if (::LookupPrivilegeValue(NULL, privilege, &luid)) {
1082      TOKEN_PRIVILEGES privs;
1083      privs.PrivilegeCount = 1;
1084      privs.Privileges[0].Luid = luid;
1085      privs.Privileges[0].Attributes = to_enable ? SE_PRIVILEGE_ENABLED : 0;
1086      if (::AdjustTokenPrivileges(token, FALSE, &privs, 0, NULL, 0)) {
1087        ret = true;
1088      } else {
1089        LOG_GLE(LS_ERROR) << "AdjustTokenPrivileges failed";
1090      }
1091    } else {
1092      LOG_GLE(LS_ERROR) << "LookupPrivilegeValue failed";
1093    }
1094    CloseHandle(token);
1095  } else {
1096    LOG_GLE(LS_ERROR) << "OpenProcessToken(GetCurrentProcess) failed";
1097  }
1098
1099  return ret;
1100}
1101
1102}  // namespace rtc
1103