1/*
2 * libjingle
3 * Copyright 2003-2008, Google Inc.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 *  1. Redistributions of source code must retain the above copyright notice,
9 *     this list of conditions and the following disclaimer.
10 *  2. Redistributions in binary form must reproduce the above copyright notice,
11 *     this list of conditions and the following disclaimer in the documentation
12 *     and/or other materials provided with the distribution.
13 *  3. The name of the author may not be used to endorse or promote products
14 *     derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28// Registry configuration wrapers class implementation
29//
30// Change made by S. Ganesh - ganesh@google.com:
31//   Use SHQueryValueEx instead of RegQueryValueEx throughout.
32//   A call to the SHLWAPI function is essentially a call to the standard
33//   function but with post-processing:
34//   * to fix REG_SZ or REG_EXPAND_SZ data that is not properly null-terminated;
35//   * to expand REG_EXPAND_SZ data.
36
37#include "talk/base/win32regkey.h"
38
39#include <shlwapi.h>
40
41#include "talk/base/common.h"
42#include "talk/base/logging.h"
43#include "talk/base/scoped_ptr.h"
44
45namespace talk_base {
46
47RegKey::RegKey() {
48  h_key_ = NULL;
49}
50
51RegKey::~RegKey() {
52  Close();
53}
54
55HRESULT RegKey::Create(HKEY parent_key, const wchar_t* key_name) {
56  return Create(parent_key,
57                key_name,
58                REG_NONE,
59                REG_OPTION_NON_VOLATILE,
60                KEY_ALL_ACCESS,
61                NULL,
62                NULL);
63}
64
65HRESULT RegKey::Open(HKEY parent_key, const wchar_t* key_name) {
66  return Open(parent_key, key_name, KEY_ALL_ACCESS);
67}
68
69bool RegKey::HasValue(const TCHAR* value_name) const {
70  return (ERROR_SUCCESS == ::RegQueryValueEx(h_key_, value_name, NULL,
71                                             NULL, NULL, NULL));
72}
73
74HRESULT RegKey::SetValue(const wchar_t* full_key_name,
75                         const wchar_t* value_name,
76                         DWORD value) {
77  ASSERT(full_key_name != NULL);
78
79  return SetValueStaticHelper(full_key_name, value_name, REG_DWORD, &value);
80}
81
82HRESULT RegKey::SetValue(const wchar_t* full_key_name,
83                         const wchar_t* value_name,
84                         DWORD64 value) {
85  ASSERT(full_key_name != NULL);
86
87  return SetValueStaticHelper(full_key_name, value_name, REG_QWORD, &value);
88}
89
90HRESULT RegKey::SetValue(const wchar_t* full_key_name,
91                         const wchar_t* value_name,
92                         float value) {
93  ASSERT(full_key_name != NULL);
94
95  return SetValueStaticHelper(full_key_name, value_name,
96                              REG_BINARY, &value, sizeof(value));
97}
98
99HRESULT RegKey::SetValue(const wchar_t* full_key_name,
100                         const wchar_t* value_name,
101                         double value) {
102  ASSERT(full_key_name != NULL);
103
104  return SetValueStaticHelper(full_key_name, value_name,
105                              REG_BINARY, &value, sizeof(value));
106}
107
108HRESULT RegKey::SetValue(const wchar_t* full_key_name,
109                         const wchar_t* value_name,
110                         const TCHAR* value) {
111  ASSERT(full_key_name != NULL);
112  ASSERT(value != NULL);
113
114  return SetValueStaticHelper(full_key_name, value_name,
115                              REG_SZ, const_cast<wchar_t*>(value));
116}
117
118HRESULT RegKey::SetValue(const wchar_t* full_key_name,
119                         const wchar_t* value_name,
120                         const uint8* value,
121                         DWORD byte_count) {
122  ASSERT(full_key_name != NULL);
123
124  return SetValueStaticHelper(full_key_name, value_name, REG_BINARY,
125                              const_cast<uint8*>(value), byte_count);
126}
127
128HRESULT RegKey::SetValueMultiSZ(const wchar_t* full_key_name,
129                                const wchar_t* value_name,
130                                const uint8* value,
131                                DWORD byte_count) {
132  ASSERT(full_key_name != NULL);
133
134  return SetValueStaticHelper(full_key_name, value_name, REG_MULTI_SZ,
135                              const_cast<uint8*>(value), byte_count);
136}
137
138HRESULT RegKey::GetValue(const wchar_t* full_key_name,
139                         const wchar_t* value_name,
140                         DWORD* value) {
141  ASSERT(full_key_name != NULL);
142  ASSERT(value != NULL);
143
144  return GetValueStaticHelper(full_key_name, value_name, REG_DWORD, value);
145}
146
147HRESULT RegKey::GetValue(const wchar_t* full_key_name,
148                         const wchar_t* value_name,
149                         DWORD64* value) {
150  ASSERT(full_key_name != NULL);
151  ASSERT(value != NULL);
152
153  return GetValueStaticHelper(full_key_name, value_name, REG_QWORD, value);
154}
155
156HRESULT RegKey::GetValue(const wchar_t* full_key_name,
157                         const wchar_t* value_name,
158                         float* value) {
159  ASSERT(value != NULL);
160  ASSERT(full_key_name != NULL);
161
162  DWORD byte_count = 0;
163  scoped_array<byte> buffer;
164  HRESULT hr = GetValueStaticHelper(full_key_name, value_name,
165                                    REG_BINARY, buffer.accept(), &byte_count);
166  if (SUCCEEDED(hr)) {
167    ASSERT(byte_count == sizeof(*value));
168    if (byte_count == sizeof(*value)) {
169      *value = *reinterpret_cast<float*>(buffer.get());
170    }
171  }
172  return hr;
173}
174
175HRESULT RegKey::GetValue(const wchar_t* full_key_name,
176                         const wchar_t* value_name,
177                         double* value) {
178  ASSERT(value != NULL);
179  ASSERT(full_key_name != NULL);
180
181  DWORD byte_count = 0;
182  scoped_array<byte> buffer;
183  HRESULT hr = GetValueStaticHelper(full_key_name, value_name,
184                                    REG_BINARY, buffer.accept(), &byte_count);
185  if (SUCCEEDED(hr)) {
186    ASSERT(byte_count == sizeof(*value));
187    if (byte_count == sizeof(*value)) {
188      *value = *reinterpret_cast<double*>(buffer.get());
189    }
190  }
191  return hr;
192}
193
194HRESULT RegKey::GetValue(const wchar_t* full_key_name,
195                         const wchar_t* value_name,
196                         wchar_t** value) {
197  ASSERT(full_key_name != NULL);
198  ASSERT(value != NULL);
199
200  return GetValueStaticHelper(full_key_name, value_name, REG_SZ, value);
201}
202
203HRESULT RegKey::GetValue(const wchar_t* full_key_name,
204                         const wchar_t* value_name,
205                         std::wstring* value) {
206  ASSERT(full_key_name != NULL);
207  ASSERT(value != NULL);
208
209  scoped_array<wchar_t> buffer;
210  HRESULT hr = RegKey::GetValue(full_key_name, value_name, buffer.accept());
211  if (SUCCEEDED(hr)) {
212    value->assign(buffer.get());
213  }
214  return hr;
215}
216
217HRESULT RegKey::GetValue(const wchar_t* full_key_name,
218                         const wchar_t* value_name,
219                         std::vector<std::wstring>* value) {
220  ASSERT(full_key_name != NULL);
221  ASSERT(value != NULL);
222
223  return GetValueStaticHelper(full_key_name, value_name, REG_MULTI_SZ, value);
224}
225
226HRESULT RegKey::GetValue(const wchar_t* full_key_name,
227                         const wchar_t* value_name,
228                         uint8** value,
229                         DWORD* byte_count) {
230  ASSERT(full_key_name != NULL);
231  ASSERT(value != NULL);
232  ASSERT(byte_count != NULL);
233
234  return GetValueStaticHelper(full_key_name, value_name,
235                              REG_BINARY, value, byte_count);
236}
237
238HRESULT RegKey::DeleteSubKey(const wchar_t* key_name) {
239  ASSERT(key_name != NULL);
240  ASSERT(h_key_ != NULL);
241
242  LONG res = ::RegDeleteKey(h_key_, key_name);
243  HRESULT hr = HRESULT_FROM_WIN32(res);
244  if (hr == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) ||
245      hr == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) {
246    hr = S_FALSE;
247  }
248  return hr;
249}
250
251HRESULT RegKey::DeleteValue(const wchar_t* value_name) {
252  ASSERT(h_key_ != NULL);
253
254  LONG res = ::RegDeleteValue(h_key_, value_name);
255  HRESULT hr = HRESULT_FROM_WIN32(res);
256  if (hr == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) ||
257      hr == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) {
258    hr = S_FALSE;
259  }
260  return hr;
261}
262
263HRESULT RegKey::Close() {
264  HRESULT hr = S_OK;
265  if (h_key_ != NULL) {
266    LONG res = ::RegCloseKey(h_key_);
267    hr = HRESULT_FROM_WIN32(res);
268    h_key_ = NULL;
269  }
270  return hr;
271}
272
273HRESULT RegKey::Create(HKEY parent_key,
274                       const wchar_t* key_name,
275                       wchar_t* lpszClass,
276                       DWORD options,
277                       REGSAM sam_desired,
278                       LPSECURITY_ATTRIBUTES lpSecAttr,
279                       LPDWORD lpdwDisposition) {
280  ASSERT(key_name != NULL);
281  ASSERT(parent_key != NULL);
282
283  DWORD dw = 0;
284  HKEY h_key = NULL;
285  LONG res = ::RegCreateKeyEx(parent_key, key_name, 0, lpszClass, options,
286                              sam_desired, lpSecAttr, &h_key, &dw);
287  HRESULT hr = HRESULT_FROM_WIN32(res);
288
289  if (lpdwDisposition) {
290    *lpdwDisposition = dw;
291  }
292
293  // we have to close the currently opened key
294  // before replacing it with the new one
295  if (hr == S_OK) {
296    hr = Close();
297    ASSERT(hr == S_OK);
298    h_key_ = h_key;
299  }
300  return hr;
301}
302
303HRESULT RegKey::Open(HKEY parent_key,
304                     const wchar_t* key_name,
305                     REGSAM sam_desired) {
306  ASSERT(key_name != NULL);
307  ASSERT(parent_key != NULL);
308
309  HKEY h_key = NULL;
310  LONG res = ::RegOpenKeyEx(parent_key, key_name, 0, sam_desired, &h_key);
311  HRESULT hr = HRESULT_FROM_WIN32(res);
312
313  // we have to close the currently opened key
314  // before replacing it with the new one
315  if (hr == S_OK) {
316    // close the currently opened key if any
317    hr = Close();
318    ASSERT(hr == S_OK);
319    h_key_ = h_key;
320  }
321  return hr;
322}
323
324// save the key and all of its subkeys and values to a file
325HRESULT RegKey::Save(const wchar_t* full_key_name, const wchar_t* file_name) {
326  ASSERT(full_key_name != NULL);
327  ASSERT(file_name != NULL);
328
329  std::wstring key_name(full_key_name);
330  HKEY h_key = GetRootKeyInfo(&key_name);
331  if (!h_key) {
332    return E_FAIL;
333  }
334
335  RegKey key;
336  HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
337  if (FAILED(hr)) {
338    return hr;
339  }
340
341  AdjustCurrentProcessPrivilege(SE_BACKUP_NAME, true);
342  LONG res = ::RegSaveKey(key.h_key_, file_name, NULL);
343  AdjustCurrentProcessPrivilege(SE_BACKUP_NAME, false);
344
345  return HRESULT_FROM_WIN32(res);
346}
347
348// restore the key and all of its subkeys and values which are saved into a file
349HRESULT RegKey::Restore(const wchar_t* full_key_name,
350                        const wchar_t* file_name) {
351  ASSERT(full_key_name != NULL);
352  ASSERT(file_name != NULL);
353
354  std::wstring key_name(full_key_name);
355  HKEY h_key = GetRootKeyInfo(&key_name);
356  if (!h_key) {
357    return E_FAIL;
358  }
359
360  RegKey key;
361  HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_WRITE);
362  if (FAILED(hr)) {
363    return hr;
364  }
365
366  AdjustCurrentProcessPrivilege(SE_RESTORE_NAME, true);
367  LONG res = ::RegRestoreKey(key.h_key_, file_name, REG_FORCE_RESTORE);
368  AdjustCurrentProcessPrivilege(SE_RESTORE_NAME, false);
369
370  return HRESULT_FROM_WIN32(res);
371}
372
373// check if the current key has the specified subkey
374bool RegKey::HasSubkey(const wchar_t* key_name) const {
375  ASSERT(key_name != NULL);
376
377  RegKey key;
378  HRESULT hr = key.Open(h_key_, key_name, KEY_READ);
379  key.Close();
380  return hr == S_OK;
381}
382
383// static flush key
384HRESULT RegKey::FlushKey(const wchar_t* full_key_name) {
385  ASSERT(full_key_name != NULL);
386
387  HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
388  // get the root HKEY
389  std::wstring key_name(full_key_name);
390  HKEY h_key = GetRootKeyInfo(&key_name);
391
392  if (h_key != NULL) {
393    LONG res = ::RegFlushKey(h_key);
394    hr = HRESULT_FROM_WIN32(res);
395  }
396  return hr;
397}
398
399// static SET helper
400HRESULT RegKey::SetValueStaticHelper(const wchar_t* full_key_name,
401                                     const wchar_t* value_name,
402                                     DWORD type,
403                                     LPVOID value,
404                                     DWORD byte_count) {
405  ASSERT(full_key_name != NULL);
406
407  HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
408  // get the root HKEY
409  std::wstring key_name(full_key_name);
410  HKEY h_key = GetRootKeyInfo(&key_name);
411
412  if (h_key != NULL) {
413    RegKey key;
414    hr = key.Create(h_key, key_name.c_str());
415    if (hr == S_OK) {
416      switch (type) {
417        case REG_DWORD:
418          hr = key.SetValue(value_name, *(static_cast<DWORD*>(value)));
419          break;
420        case REG_QWORD:
421          hr = key.SetValue(value_name, *(static_cast<DWORD64*>(value)));
422          break;
423        case REG_SZ:
424          hr = key.SetValue(value_name, static_cast<const wchar_t*>(value));
425          break;
426        case REG_BINARY:
427          hr = key.SetValue(value_name, static_cast<const uint8*>(value),
428                            byte_count);
429          break;
430        case REG_MULTI_SZ:
431          hr = key.SetValue(value_name, static_cast<const uint8*>(value),
432                            byte_count, type);
433          break;
434        default:
435          ASSERT(false);
436          hr = HRESULT_FROM_WIN32(ERROR_DATATYPE_MISMATCH);
437          break;
438      }
439      // close the key after writing
440      HRESULT temp_hr = key.Close();
441      if (hr == S_OK) {
442        hr = temp_hr;
443      }
444    }
445  }
446  return hr;
447}
448
449// static GET helper
450HRESULT RegKey::GetValueStaticHelper(const wchar_t* full_key_name,
451                                     const wchar_t* value_name,
452                                     DWORD type,
453                                     LPVOID value,
454                                     DWORD* byte_count) {
455  ASSERT(full_key_name != NULL);
456
457  HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
458  // get the root HKEY
459  std::wstring key_name(full_key_name);
460  HKEY h_key = GetRootKeyInfo(&key_name);
461
462  if (h_key != NULL) {
463    RegKey key;
464    hr = key.Open(h_key, key_name.c_str(), KEY_READ);
465    if (hr == S_OK) {
466      switch (type) {
467        case REG_DWORD:
468          hr = key.GetValue(value_name, reinterpret_cast<DWORD*>(value));
469          break;
470        case REG_QWORD:
471          hr = key.GetValue(value_name, reinterpret_cast<DWORD64*>(value));
472          break;
473        case REG_SZ:
474          hr = key.GetValue(value_name, reinterpret_cast<wchar_t**>(value));
475          break;
476        case REG_MULTI_SZ:
477          hr = key.GetValue(value_name, reinterpret_cast<
478                                            std::vector<std::wstring>*>(value));
479          break;
480        case REG_BINARY:
481          hr = key.GetValue(value_name, reinterpret_cast<uint8**>(value),
482                            byte_count);
483          break;
484        default:
485          ASSERT(false);
486          hr = HRESULT_FROM_WIN32(ERROR_DATATYPE_MISMATCH);
487          break;
488      }
489      // close the key after writing
490      HRESULT temp_hr = key.Close();
491      if (hr == S_OK) {
492        hr = temp_hr;
493      }
494    }
495  }
496  return hr;
497}
498
499// GET helper
500HRESULT RegKey::GetValueHelper(const wchar_t* value_name,
501                               DWORD* type,
502                               uint8** value,
503                               DWORD* byte_count) const {
504  ASSERT(byte_count != NULL);
505  ASSERT(value != NULL);
506  ASSERT(type != NULL);
507
508  // init return buffer
509  *value = NULL;
510
511  // get the size of the return data buffer
512  LONG res = ::SHQueryValueEx(h_key_, value_name, NULL, type, NULL, byte_count);
513  HRESULT hr = HRESULT_FROM_WIN32(res);
514
515  if (hr == S_OK) {
516    // if the value length is 0, nothing to do
517    if (*byte_count != 0) {
518      // allocate the buffer
519      *value = new byte[*byte_count];
520      ASSERT(*value != NULL);
521
522      // make the call again to get the data
523      res = ::SHQueryValueEx(h_key_, value_name, NULL,
524                             type, *value, byte_count);
525      hr = HRESULT_FROM_WIN32(res);
526      ASSERT(hr == S_OK);
527    }
528  }
529  return hr;
530}
531
532// Int32 Get
533HRESULT RegKey::GetValue(const wchar_t* value_name, DWORD* value) const {
534  ASSERT(value != NULL);
535
536  DWORD type = 0;
537  DWORD byte_count = sizeof(DWORD);
538  LONG res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
539                              value, &byte_count);
540  HRESULT hr = HRESULT_FROM_WIN32(res);
541  ASSERT((hr != S_OK) || (type == REG_DWORD));
542  ASSERT((hr != S_OK) || (byte_count == sizeof(DWORD)));
543  return hr;
544}
545
546// Int64 Get
547HRESULT RegKey::GetValue(const wchar_t* value_name, DWORD64* value) const {
548  ASSERT(value != NULL);
549
550  DWORD type = 0;
551  DWORD byte_count = sizeof(DWORD64);
552  LONG res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
553                              value, &byte_count);
554  HRESULT hr = HRESULT_FROM_WIN32(res);
555  ASSERT((hr != S_OK) || (type == REG_QWORD));
556  ASSERT((hr != S_OK) || (byte_count == sizeof(DWORD64)));
557  return hr;
558}
559
560// String Get
561HRESULT RegKey::GetValue(const wchar_t* value_name, wchar_t** value) const {
562  ASSERT(value != NULL);
563
564  DWORD byte_count = 0;
565  DWORD type = 0;
566
567  // first get the size of the string buffer
568  LONG res = ::SHQueryValueEx(h_key_, value_name, NULL,
569                              &type, NULL, &byte_count);
570  HRESULT hr = HRESULT_FROM_WIN32(res);
571
572  if (hr == S_OK) {
573    // allocate room for the string and a terminating \0
574    *value = new wchar_t[(byte_count / sizeof(wchar_t)) + 1];
575
576    if ((*value) != NULL) {
577      if (byte_count != 0) {
578        // make the call again
579        res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
580                               *value, &byte_count);
581        hr = HRESULT_FROM_WIN32(res);
582      } else {
583        (*value)[0] = L'\0';
584      }
585
586      ASSERT((hr != S_OK) || (type == REG_SZ) ||
587             (type == REG_MULTI_SZ) || (type == REG_EXPAND_SZ));
588    } else {
589      hr = E_OUTOFMEMORY;
590    }
591  }
592
593  return hr;
594}
595
596// get a string value
597HRESULT RegKey::GetValue(const wchar_t* value_name, std::wstring* value) const {
598  ASSERT(value != NULL);
599
600  DWORD byte_count = 0;
601  DWORD type = 0;
602
603  // first get the size of the string buffer
604  LONG res = ::SHQueryValueEx(h_key_, value_name, NULL,
605                              &type, NULL, &byte_count);
606  HRESULT hr = HRESULT_FROM_WIN32(res);
607
608  if (hr == S_OK) {
609    if (byte_count != 0) {
610      // Allocate some memory and make the call again
611      value->resize(byte_count / sizeof(wchar_t) + 1);
612      res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
613                             &value->at(0), &byte_count);
614      hr = HRESULT_FROM_WIN32(res);
615      value->resize(wcslen(value->data()));
616    } else {
617      value->clear();
618    }
619
620    ASSERT((hr != S_OK) || (type == REG_SZ) ||
621           (type == REG_MULTI_SZ) || (type == REG_EXPAND_SZ));
622  }
623
624  return hr;
625}
626
627// convert REG_MULTI_SZ bytes to string array
628HRESULT RegKey::MultiSZBytesToStringArray(const uint8* buffer,
629                                          DWORD byte_count,
630                                          std::vector<std::wstring>* value) {
631  ASSERT(buffer != NULL);
632  ASSERT(value != NULL);
633
634  const wchar_t* data = reinterpret_cast<const wchar_t*>(buffer);
635  DWORD data_len = byte_count / sizeof(wchar_t);
636  value->clear();
637  if (data_len > 1) {
638    // must be terminated by two null characters
639    if (data[data_len - 1] != 0 || data[data_len - 2] != 0) {
640      return E_INVALIDARG;
641    }
642
643    // put null-terminated strings into arrays
644    while (*data) {
645      std::wstring str(data);
646      value->push_back(str);
647      data += str.length() + 1;
648    }
649  }
650  return S_OK;
651}
652
653// get a std::vector<std::wstring> value from REG_MULTI_SZ type
654HRESULT RegKey::GetValue(const wchar_t* value_name,
655                         std::vector<std::wstring>* value) const {
656  ASSERT(value != NULL);
657
658  DWORD byte_count = 0;
659  DWORD type = 0;
660  uint8* buffer = 0;
661
662  // first get the size of the buffer
663  HRESULT hr = GetValueHelper(value_name, &type, &buffer, &byte_count);
664  ASSERT((hr != S_OK) || (type == REG_MULTI_SZ));
665
666  if (SUCCEEDED(hr)) {
667    hr = MultiSZBytesToStringArray(buffer, byte_count, value);
668  }
669
670  return hr;
671}
672
673// Binary data Get
674HRESULT RegKey::GetValue(const wchar_t* value_name,
675                         uint8** value,
676                         DWORD* byte_count) const {
677  ASSERT(byte_count != NULL);
678  ASSERT(value != NULL);
679
680  DWORD type = 0;
681  HRESULT hr = GetValueHelper(value_name, &type, value, byte_count);
682  ASSERT((hr != S_OK) || (type == REG_MULTI_SZ) || (type == REG_BINARY));
683  return hr;
684}
685
686// Raw data get
687HRESULT RegKey::GetValue(const wchar_t* value_name,
688                         uint8** value,
689                         DWORD* byte_count,
690                         DWORD*type) const {
691  ASSERT(type != NULL);
692  ASSERT(byte_count != NULL);
693  ASSERT(value != NULL);
694
695  return GetValueHelper(value_name, type, value, byte_count);
696}
697
698// Int32 set
699HRESULT RegKey::SetValue(const wchar_t* value_name, DWORD value) const {
700  ASSERT(h_key_ != NULL);
701
702  LONG res = ::RegSetValueEx(h_key_, value_name, NULL, REG_DWORD,
703                             reinterpret_cast<const uint8*>(&value),
704                             sizeof(DWORD));
705  return HRESULT_FROM_WIN32(res);
706}
707
708// Int64 set
709HRESULT RegKey::SetValue(const wchar_t* value_name, DWORD64 value) const {
710  ASSERT(h_key_ != NULL);
711
712  LONG res = ::RegSetValueEx(h_key_, value_name, NULL, REG_QWORD,
713                             reinterpret_cast<const uint8*>(&value),
714                             sizeof(DWORD64));
715  return HRESULT_FROM_WIN32(res);
716}
717
718// String set
719HRESULT RegKey::SetValue(const wchar_t* value_name,
720                         const wchar_t* value) const {
721  ASSERT(value != NULL);
722  ASSERT(h_key_ != NULL);
723
724  LONG res = ::RegSetValueEx(h_key_, value_name, NULL, REG_SZ,
725                             reinterpret_cast<const uint8*>(value),
726                             (lstrlen(value) + 1) * sizeof(wchar_t));
727  return HRESULT_FROM_WIN32(res);
728}
729
730// Binary data set
731HRESULT RegKey::SetValue(const wchar_t* value_name,
732                         const uint8* value,
733                         DWORD byte_count) const {
734  ASSERT(h_key_ != NULL);
735
736  // special case - if 'value' is NULL make sure byte_count is zero
737  if (value == NULL) {
738    byte_count = 0;
739  }
740
741  LONG res = ::RegSetValueEx(h_key_, value_name, NULL,
742                             REG_BINARY, value, byte_count);
743  return HRESULT_FROM_WIN32(res);
744}
745
746// Raw data set
747HRESULT RegKey::SetValue(const wchar_t* value_name,
748                         const uint8* value,
749                         DWORD byte_count,
750                         DWORD type) const {
751  ASSERT(value != NULL);
752  ASSERT(h_key_ != NULL);
753
754  LONG res = ::RegSetValueEx(h_key_, value_name, NULL, type, value, byte_count);
755  return HRESULT_FROM_WIN32(res);
756}
757
758bool RegKey::HasKey(const wchar_t* full_key_name) {
759  ASSERT(full_key_name != NULL);
760
761  // get the root HKEY
762  std::wstring key_name(full_key_name);
763  HKEY h_key = GetRootKeyInfo(&key_name);
764
765  if (h_key != NULL) {
766    RegKey key;
767    HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
768    key.Close();
769    return S_OK == hr;
770  }
771  return false;
772}
773
774// static version of HasValue
775bool RegKey::HasValue(const wchar_t* full_key_name, const wchar_t* value_name) {
776  ASSERT(full_key_name != NULL);
777
778  bool has_value = false;
779  // get the root HKEY
780  std::wstring key_name(full_key_name);
781  HKEY h_key = GetRootKeyInfo(&key_name);
782
783  if (h_key != NULL) {
784    RegKey key;
785    if (key.Open(h_key, key_name.c_str(), KEY_READ) == S_OK) {
786      has_value = key.HasValue(value_name);
787      key.Close();
788    }
789  }
790  return has_value;
791}
792
793HRESULT RegKey::GetValueType(const wchar_t* full_key_name,
794                             const wchar_t* value_name,
795                             DWORD* value_type) {
796  ASSERT(full_key_name != NULL);
797  ASSERT(value_type != NULL);
798
799  *value_type = REG_NONE;
800
801  std::wstring key_name(full_key_name);
802  HKEY h_key = GetRootKeyInfo(&key_name);
803
804  RegKey key;
805  HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
806  if (SUCCEEDED(hr)) {
807    LONG res = ::SHQueryValueEx(key.h_key_, value_name, NULL, value_type,
808                                NULL, NULL);
809    if (res != ERROR_SUCCESS) {
810      hr = HRESULT_FROM_WIN32(res);
811    }
812  }
813
814  return hr;
815}
816
817HRESULT RegKey::DeleteKey(const wchar_t* full_key_name) {
818  ASSERT(full_key_name != NULL);
819
820  return DeleteKey(full_key_name, true);
821}
822
823HRESULT RegKey::DeleteKey(const wchar_t* full_key_name, bool recursively) {
824  ASSERT(full_key_name != NULL);
825
826  // need to open the parent key first
827  // get the root HKEY
828  std::wstring key_name(full_key_name);
829  HKEY h_key = GetRootKeyInfo(&key_name);
830
831  // get the parent key
832  std::wstring parent_key(GetParentKeyInfo(&key_name));
833
834  RegKey key;
835  HRESULT hr = key.Open(h_key, parent_key.c_str());
836
837  if (hr == S_OK) {
838    hr = recursively ? key.RecurseDeleteSubKey(key_name.c_str())
839                     : key.DeleteSubKey(key_name.c_str());
840  } else if (hr == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) ||
841             hr == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) {
842    hr = S_FALSE;
843  }
844
845  key.Close();
846  return hr;
847}
848
849HRESULT RegKey::DeleteValue(const wchar_t* full_key_name,
850                            const wchar_t* value_name) {
851  ASSERT(full_key_name != NULL);
852
853  HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
854  // get the root HKEY
855  std::wstring key_name(full_key_name);
856  HKEY h_key = GetRootKeyInfo(&key_name);
857
858  if (h_key != NULL) {
859    RegKey key;
860    hr = key.Open(h_key, key_name.c_str());
861    if (hr == S_OK) {
862      hr = key.DeleteValue(value_name);
863      key.Close();
864    }
865  }
866  return hr;
867}
868
869HRESULT RegKey::RecurseDeleteSubKey(const wchar_t* key_name) {
870  ASSERT(key_name != NULL);
871
872  RegKey key;
873  HRESULT hr = key.Open(h_key_, key_name);
874
875  if (hr == S_OK) {
876    // enumerate all subkeys of this key and recursivelly delete them
877    FILETIME time = {0};
878    wchar_t key_name_buf[kMaxKeyNameChars] = {0};
879    DWORD key_name_buf_size = kMaxKeyNameChars;
880    while (hr == S_OK &&
881        ::RegEnumKeyEx(key.h_key_, 0, key_name_buf, &key_name_buf_size,
882                       NULL, NULL, NULL,  &time) == ERROR_SUCCESS) {
883      hr = key.RecurseDeleteSubKey(key_name_buf);
884
885      // restore the buffer size
886      key_name_buf_size = kMaxKeyNameChars;
887    }
888    // close the top key
889    key.Close();
890  }
891
892  if (hr == S_OK) {
893    // the key has no more children keys
894    // delete the key and all of its values
895    hr = DeleteSubKey(key_name);
896  }
897
898  return hr;
899}
900
901HKEY RegKey::GetRootKeyInfo(std::wstring* full_key_name) {
902  ASSERT(full_key_name != NULL);
903
904  HKEY h_key = NULL;
905  // get the root HKEY
906  size_t index = full_key_name->find(L'\\');
907  std::wstring root_key;
908
909  if (index == -1) {
910    root_key = *full_key_name;
911    *full_key_name = L"";
912  } else {
913    root_key = full_key_name->substr(0, index);
914    *full_key_name = full_key_name->substr(index + 1,
915                                           full_key_name->length() - index - 1);
916  }
917
918  for (std::wstring::iterator iter = root_key.begin();
919       iter != root_key.end(); ++iter) {
920    *iter = toupper(*iter);
921  }
922
923  if (!root_key.compare(L"HKLM") ||
924      !root_key.compare(L"HKEY_LOCAL_MACHINE")) {
925    h_key = HKEY_LOCAL_MACHINE;
926  } else if (!root_key.compare(L"HKCU") ||
927             !root_key.compare(L"HKEY_CURRENT_USER")) {
928    h_key = HKEY_CURRENT_USER;
929  } else if (!root_key.compare(L"HKU") ||
930             !root_key.compare(L"HKEY_USERS")) {
931    h_key = HKEY_USERS;
932  } else if (!root_key.compare(L"HKCR") ||
933             !root_key.compare(L"HKEY_CLASSES_ROOT")) {
934    h_key = HKEY_CLASSES_ROOT;
935  }
936
937  return h_key;
938}
939
940
941// Returns true if this key name is 'safe' for deletion
942// (doesn't specify a key root)
943bool RegKey::SafeKeyNameForDeletion(const wchar_t* key_name) {
944  ASSERT(key_name != NULL);
945  std::wstring key(key_name);
946
947  HKEY root_key = GetRootKeyInfo(&key);
948
949  if (!root_key) {
950    key = key_name;
951  }
952  if (key.empty()) {
953    return false;
954  }
955  bool found_subkey = false, backslash_found = false;
956  for (size_t i = 0 ; i < key.length() ; ++i) {
957    if (key[i] == L'\\') {
958      backslash_found = true;
959    } else if (backslash_found) {
960      found_subkey = true;
961      break;
962    }
963  }
964  return (root_key == HKEY_USERS) ? found_subkey : true;
965}
966
967std::wstring RegKey::GetParentKeyInfo(std::wstring* key_name) {
968  ASSERT(key_name != NULL);
969
970  // get the parent key
971  size_t index = key_name->rfind(L'\\');
972  std::wstring parent_key;
973  if (index == -1) {
974    parent_key = L"";
975  } else {
976    parent_key = key_name->substr(0, index);
977    *key_name = key_name->substr(index + 1, key_name->length() - index - 1);
978  }
979
980  return parent_key;
981}
982
983// get the number of values for this key
984uint32 RegKey::GetValueCount() {
985  DWORD num_values = 0;
986
987  LONG res = ::RegQueryInfoKey(
988        h_key_,                  // key handle
989        NULL,                    // buffer for class name
990        NULL,                    // size of class string
991        NULL,                    // reserved
992        NULL,                    // number of subkeys
993        NULL,                    // longest subkey size
994        NULL,                    // longest class string
995        &num_values,             // number of values for this key
996        NULL,                    // longest value name
997        NULL,                    // longest value data
998        NULL,                    // security descriptor
999        NULL);                   // last write time
1000
1001  ASSERT(res == ERROR_SUCCESS);
1002  return num_values;
1003}
1004
1005// Enumerators for the value_names for this key
1006
1007// Called to get the value name for the given value name index
1008// Use GetValueCount() to get the total value_name count for this key
1009// Returns failure if no key at the specified index
1010HRESULT RegKey::GetValueNameAt(int index, std::wstring* value_name,
1011                               DWORD* type) {
1012  ASSERT(value_name != NULL);
1013
1014  LONG res = ERROR_SUCCESS;
1015  wchar_t value_name_buf[kMaxValueNameChars] = {0};
1016  DWORD value_name_buf_size = kMaxValueNameChars;
1017  res = ::RegEnumValue(h_key_, index, value_name_buf, &value_name_buf_size,
1018                       NULL, type, NULL, NULL);
1019
1020  if (res == ERROR_SUCCESS) {
1021    value_name->assign(value_name_buf);
1022  }
1023
1024  return HRESULT_FROM_WIN32(res);
1025}
1026
1027uint32 RegKey::GetSubkeyCount() {
1028  // number of values for key
1029  DWORD num_subkeys = 0;
1030
1031  LONG res = ::RegQueryInfoKey(
1032    h_key_,                  // key handle
1033    NULL,                    // buffer for class name
1034    NULL,                    // size of class string
1035    NULL,                    // reserved
1036    &num_subkeys,            // number of subkeys
1037    NULL,                    // longest subkey size
1038    NULL,                    // longest class string
1039    NULL,                    // number of values for this key
1040    NULL,                    // longest value name
1041    NULL,                    // longest value data
1042    NULL,                    // security descriptor
1043    NULL);                   // last write time
1044
1045  ASSERT(res == ERROR_SUCCESS);
1046  return num_subkeys;
1047}
1048
1049HRESULT RegKey::GetSubkeyNameAt(int index, std::wstring* key_name) {
1050  ASSERT(key_name != NULL);
1051
1052  LONG res = ERROR_SUCCESS;
1053  wchar_t key_name_buf[kMaxKeyNameChars] = {0};
1054  DWORD key_name_buf_size = kMaxKeyNameChars;
1055
1056  res = ::RegEnumKeyEx(h_key_, index, key_name_buf, &key_name_buf_size,
1057                       NULL, NULL, NULL, NULL);
1058
1059  if (res == ERROR_SUCCESS) {
1060    key_name->assign(key_name_buf);
1061  }
1062
1063  return HRESULT_FROM_WIN32(res);
1064}
1065
1066// Is the key empty: having no sub-keys and values
1067bool RegKey::IsKeyEmpty(const wchar_t* full_key_name) {
1068  ASSERT(full_key_name != NULL);
1069
1070  bool is_empty = true;
1071
1072  // Get the root HKEY
1073  std::wstring key_name(full_key_name);
1074  HKEY h_key = GetRootKeyInfo(&key_name);
1075
1076  // Open the key to check
1077  if (h_key != NULL) {
1078    RegKey key;
1079    HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
1080    if (SUCCEEDED(hr)) {
1081      is_empty = key.GetSubkeyCount() == 0 && key.GetValueCount() == 0;
1082      key.Close();
1083    }
1084  }
1085
1086  return is_empty;
1087}
1088
1089bool AdjustCurrentProcessPrivilege(const TCHAR* privilege, bool to_enable) {
1090  ASSERT(privilege != NULL);
1091
1092  bool ret = false;
1093  HANDLE token;
1094  if (::OpenProcessToken(::GetCurrentProcess(),
1095                         TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, &token)) {
1096    LUID luid;
1097    memset(&luid, 0, sizeof(luid));
1098    if (::LookupPrivilegeValue(NULL, privilege, &luid)) {
1099      TOKEN_PRIVILEGES privs;
1100      privs.PrivilegeCount = 1;
1101      privs.Privileges[0].Luid = luid;
1102      privs.Privileges[0].Attributes = to_enable ? SE_PRIVILEGE_ENABLED : 0;
1103      if (::AdjustTokenPrivileges(token, FALSE, &privs, 0, NULL, 0)) {
1104        ret = true;
1105      } else {
1106        LOG_GLE(LS_ERROR) << "AdjustTokenPrivileges failed";
1107      }
1108    } else {
1109      LOG_GLE(LS_ERROR) << "LookupPrivilegeValue failed";
1110    }
1111    CloseHandle(token);
1112  } else {
1113    LOG_GLE(LS_ERROR) << "OpenProcessToken(GetCurrentProcess) failed";
1114  }
1115
1116  return ret;
1117}
1118
1119}  // namespace talk_base
1120