1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "base/win/iat_patch_function.h"
6
7#include "base/logging.h"
8#include "base/win/pe_image.h"
9
10namespace base {
11namespace win {
12
13namespace {
14
15struct InterceptFunctionInformation {
16  bool finished_operation;
17  const char* imported_from_module;
18  const char* function_name;
19  void* new_function;
20  void** old_function;
21  IMAGE_THUNK_DATA** iat_thunk;
22  DWORD return_code;
23};
24
25void* GetIATFunction(IMAGE_THUNK_DATA* iat_thunk) {
26  if (NULL == iat_thunk) {
27    NOTREACHED();
28    return NULL;
29  }
30
31  // Works around the 64 bit portability warning:
32  // The Function member inside IMAGE_THUNK_DATA is really a pointer
33  // to the IAT function. IMAGE_THUNK_DATA correctly maps to IMAGE_THUNK_DATA32
34  // or IMAGE_THUNK_DATA64 for correct pointer size.
35  union FunctionThunk {
36    IMAGE_THUNK_DATA thunk;
37    void* pointer;
38  } iat_function;
39
40  iat_function.thunk = *iat_thunk;
41  return iat_function.pointer;
42}
43
44bool InterceptEnumCallback(const base::win::PEImage& image, const char* module,
45                           DWORD ordinal, const char* name, DWORD hint,
46                           IMAGE_THUNK_DATA* iat, void* cookie) {
47  InterceptFunctionInformation* intercept_information =
48    reinterpret_cast<InterceptFunctionInformation*>(cookie);
49
50  if (NULL == intercept_information) {
51    NOTREACHED();
52    return false;
53  }
54
55  DCHECK(module);
56
57  if ((0 == lstrcmpiA(module, intercept_information->imported_from_module)) &&
58     (NULL != name) &&
59     (0 == lstrcmpiA(name, intercept_information->function_name))) {
60    // Save the old pointer.
61    if (NULL != intercept_information->old_function) {
62      *(intercept_information->old_function) = GetIATFunction(iat);
63    }
64
65    if (NULL != intercept_information->iat_thunk) {
66      *(intercept_information->iat_thunk) = iat;
67    }
68
69    // portability check
70    COMPILE_ASSERT(sizeof(iat->u1.Function) ==
71      sizeof(intercept_information->new_function), unknown_IAT_thunk_format);
72
73    // Patch the function.
74    intercept_information->return_code =
75      ModifyCode(&(iat->u1.Function),
76                 &(intercept_information->new_function),
77                 sizeof(intercept_information->new_function));
78
79    // Terminate further enumeration.
80    intercept_information->finished_operation = true;
81    return false;
82  }
83
84  return true;
85}
86
87// Helper to intercept a function in an import table of a specific
88// module.
89//
90// Arguments:
91// module_handle          Module to be intercepted
92// imported_from_module   Module that exports the symbol
93// function_name          Name of the API to be intercepted
94// new_function           Interceptor function
95// old_function           Receives the original function pointer
96// iat_thunk              Receives pointer to IAT_THUNK_DATA
97//                        for the API from the import table.
98//
99// Returns: Returns NO_ERROR on success or Windows error code
100//          as defined in winerror.h
101DWORD InterceptImportedFunction(HMODULE module_handle,
102                                const char* imported_from_module,
103                                const char* function_name, void* new_function,
104                                void** old_function,
105                                IMAGE_THUNK_DATA** iat_thunk) {
106  if ((NULL == module_handle) || (NULL == imported_from_module) ||
107     (NULL == function_name) || (NULL == new_function)) {
108    NOTREACHED();
109    return ERROR_INVALID_PARAMETER;
110  }
111
112  base::win::PEImage target_image(module_handle);
113  if (!target_image.VerifyMagic()) {
114    NOTREACHED();
115    return ERROR_INVALID_PARAMETER;
116  }
117
118  InterceptFunctionInformation intercept_information = {
119    false,
120    imported_from_module,
121    function_name,
122    new_function,
123    old_function,
124    iat_thunk,
125    ERROR_GEN_FAILURE};
126
127  // First go through the IAT. If we don't find the import we are looking
128  // for in IAT, search delay import table.
129  target_image.EnumAllImports(InterceptEnumCallback, &intercept_information);
130  if (!intercept_information.finished_operation) {
131    target_image.EnumAllDelayImports(InterceptEnumCallback,
132                                     &intercept_information);
133  }
134
135  return intercept_information.return_code;
136}
137
138// Restore intercepted IAT entry with the original function.
139//
140// Arguments:
141// intercept_function     Interceptor function
142// original_function      Receives the original function pointer
143//
144// Returns: Returns NO_ERROR on success or Windows error code
145//          as defined in winerror.h
146DWORD RestoreImportedFunction(void* intercept_function,
147                              void* original_function,
148                              IMAGE_THUNK_DATA* iat_thunk) {
149  if ((NULL == intercept_function) || (NULL == original_function) ||
150      (NULL == iat_thunk)) {
151    NOTREACHED();
152    return ERROR_INVALID_PARAMETER;
153  }
154
155  if (GetIATFunction(iat_thunk) != intercept_function) {
156    // Check if someone else has intercepted on top of us.
157    // We cannot unpatch in this case, just raise a red flag.
158    NOTREACHED();
159    return ERROR_INVALID_FUNCTION;
160  }
161
162  return ModifyCode(&(iat_thunk->u1.Function),
163                    &original_function,
164                    sizeof(original_function));
165}
166
167}  // namespace
168
169// Change the page protection (of code pages) to writable and copy
170// the data at the specified location
171//
172// Arguments:
173// old_code               Target location to copy
174// new_code               Source
175// length                 Number of bytes to copy
176//
177// Returns: Windows error code (winerror.h). NO_ERROR if successful
178DWORD ModifyCode(void* old_code, void* new_code, int length) {
179  if ((NULL == old_code) || (NULL == new_code) || (0 == length)) {
180    NOTREACHED();
181    return ERROR_INVALID_PARAMETER;
182  }
183
184  // Change the page protection so that we can write.
185  MEMORY_BASIC_INFORMATION memory_info;
186  DWORD error = NO_ERROR;
187  DWORD old_page_protection = 0;
188
189  if (!VirtualQuery(old_code, &memory_info, sizeof(memory_info))) {
190    error = GetLastError();
191    return error;
192  }
193
194  DWORD is_executable = (PAGE_EXECUTE | PAGE_EXECUTE_READ |
195                        PAGE_EXECUTE_READWRITE | PAGE_EXECUTE_WRITECOPY) &
196                        memory_info.Protect;
197
198  if (VirtualProtect(old_code,
199                     length,
200                     is_executable ? PAGE_EXECUTE_READWRITE :
201                                     PAGE_READWRITE,
202                     &old_page_protection)) {
203
204    // Write the data.
205    CopyMemory(old_code, new_code, length);
206
207    // Restore the old page protection.
208    error = ERROR_SUCCESS;
209    VirtualProtect(old_code,
210                  length,
211                  old_page_protection,
212                  &old_page_protection);
213  } else {
214    error = GetLastError();
215  }
216
217  return error;
218}
219
220IATPatchFunction::IATPatchFunction()
221    : module_handle_(NULL),
222      original_function_(NULL),
223      iat_thunk_(NULL),
224      intercept_function_(NULL) {
225}
226
227IATPatchFunction::~IATPatchFunction() {
228  if (NULL != intercept_function_) {
229    DWORD error = Unpatch();
230    DCHECK_EQ(static_cast<DWORD>(NO_ERROR), error);
231  }
232}
233
234DWORD IATPatchFunction::Patch(const wchar_t* module,
235                              const char* imported_from_module,
236                              const char* function_name,
237                              void* new_function) {
238  DCHECK_EQ(static_cast<void*>(NULL), original_function_);
239  DCHECK_EQ(static_cast<IMAGE_THUNK_DATA*>(NULL), iat_thunk_);
240  DCHECK_EQ(static_cast<void*>(NULL), intercept_function_);
241
242  HMODULE module_handle = LoadLibraryW(module);
243
244  if (module_handle == NULL) {
245    NOTREACHED();
246    return GetLastError();
247  }
248
249  DWORD error = InterceptImportedFunction(module_handle,
250                                          imported_from_module,
251                                          function_name,
252                                          new_function,
253                                          &original_function_,
254                                          &iat_thunk_);
255
256  if (NO_ERROR == error) {
257    DCHECK_NE(original_function_, intercept_function_);
258    module_handle_ = module_handle;
259    intercept_function_ = new_function;
260  } else {
261    FreeLibrary(module_handle);
262  }
263
264  return error;
265}
266
267DWORD IATPatchFunction::Unpatch() {
268  DWORD error = RestoreImportedFunction(intercept_function_,
269                                        original_function_,
270                                        iat_thunk_);
271  DCHECK_EQ(static_cast<DWORD>(NO_ERROR), error);
272
273  // Hands off the intercept if we fail to unpatch.
274  // If IATPatchFunction::Unpatch fails during RestoreImportedFunction
275  // it means that we cannot safely unpatch the import address table
276  // patch. In this case its better to be hands off the intercept as
277  // trying to unpatch again in the destructor of IATPatchFunction is
278  // not going to be any safer
279  if (module_handle_)
280    FreeLibrary(module_handle_);
281  module_handle_ = NULL;
282  intercept_function_ = NULL;
283  original_function_ = NULL;
284  iat_thunk_ = NULL;
285
286  return error;
287}
288
289void* IATPatchFunction::original_function() const {
290  DCHECK(is_patched());
291  return original_function_;
292}
293
294}  // namespace win
295}  // namespace base
296