1// Copyright (c) 2011 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome_frame/vtable_patch_manager.h"
6
7#include <unknwn.h>
8
9#include "base/bind.h"
10#include "base/bind_helpers.h"
11#include "base/message_loop/message_loop.h"
12#include "base/threading/thread.h"
13#include "base/win/scoped_handle.h"
14#include "gmock/gmock.h"
15#include "gtest/gtest.h"
16
17namespace {
18// GMock names we use.
19using testing::_;
20using testing::Return;
21
22class MockClassFactory : public IClassFactory {
23 public:
24  MOCK_METHOD2_WITH_CALLTYPE(__stdcall, QueryInterface,
25      HRESULT(REFIID riid, void **object));
26  MOCK_METHOD0_WITH_CALLTYPE(__stdcall, AddRef, ULONG());
27  MOCK_METHOD0_WITH_CALLTYPE(__stdcall, Release, ULONG());
28  MOCK_METHOD3_WITH_CALLTYPE(__stdcall, CreateInstance,
29        HRESULT (IUnknown *outer, REFIID riid, void **object));
30  MOCK_METHOD1_WITH_CALLTYPE(__stdcall, LockServer, HRESULT(BOOL lock));
31};
32
33// Retrieve the vtable for an interface.
34void* GetVtable(IUnknown* unk) {
35  return *reinterpret_cast<void**>(unk);
36}
37
38// Forward decl.
39extern vtable_patch::MethodPatchInfo IClassFactory_PatchInfo[];
40
41class VtablePatchManagerTest: public testing::Test {
42 public:
43  VtablePatchManagerTest() {
44    EXPECT_TRUE(current_ == NULL);
45    current_ = this;
46  }
47
48  ~VtablePatchManagerTest() {
49    EXPECT_TRUE(current_ == this);
50    current_ = NULL;
51  }
52
53  virtual void SetUp() {
54    // Make a backup of the test vtable and it's page protection settings.
55    void* vtable = GetVtable(&factory_);
56    MEMORY_BASIC_INFORMATION info;
57    ASSERT_TRUE(::VirtualQuery(vtable, &info, sizeof(info)));
58    vtable_protection_ = info.Protect;
59    memcpy(vtable_backup_, vtable, sizeof(vtable_backup_));
60  }
61
62  virtual void TearDown() {
63    // Unpatch to make sure we've restored state for subsequent test.
64    UnpatchInterfaceMethods(IClassFactory_PatchInfo);
65
66    // Restore the test vtable and its page protection settings.
67    void* vtable = GetVtable(&factory_);
68    DWORD old_protect = 0;
69    EXPECT_TRUE(::VirtualProtect(vtable, sizeof(vtable_backup_),
70        PAGE_EXECUTE_WRITECOPY, &old_protect));
71    memcpy(vtable, vtable_backup_, sizeof(vtable_backup_));
72    EXPECT_TRUE(::VirtualProtect(vtable, sizeof(vtable_backup_),
73        vtable_protection_, &old_protect));
74  }
75
76  typedef HRESULT (__stdcall* LockServerFun)(IClassFactory* self, BOOL lock);
77  MOCK_METHOD3(LockServerPatch,
78      HRESULT(LockServerFun old_fun, IClassFactory* self, BOOL lock));
79
80  static HRESULT STDMETHODCALLTYPE LockServerPatchCallback(
81      LockServerFun fun, IClassFactory* self, BOOL lock) {
82    EXPECT_TRUE(current_ != NULL);
83    if (current_ != NULL)
84      return current_->LockServerPatch(fun, self, lock);
85    else
86      return E_UNEXPECTED;
87  }
88
89 protected:
90  // Number of functions in the IClassFactory vtable.
91  static const size_t kFunctionCount = 5;
92
93  // Backup of the factory_ vtable as we found it at Setup.
94  PROC vtable_backup_[kFunctionCount];
95  // VirtualProtect flags on the factory_ vtable as we found it at Setup.
96  DWORD vtable_protection_;
97
98  // The mock factory class we patch.
99  MockClassFactory factory_;
100
101  // Current test running for routing the patch callback function.
102  static VtablePatchManagerTest* current_;
103};
104
105VtablePatchManagerTest* VtablePatchManagerTest::current_ = NULL;
106
107BEGIN_VTABLE_PATCHES(IClassFactory)
108  VTABLE_PATCH_ENTRY(4, &VtablePatchManagerTest::LockServerPatchCallback)
109END_VTABLE_PATCHES();
110
111}  // namespace
112
113TEST_F(VtablePatchManagerTest, ReplacePointer) {
114  void* const kFunctionOriginal = reinterpret_cast<void*>(0xCAFEBABE);
115  void* const kFunctionFoo = reinterpret_cast<void*>(0xF0F0F0F0);
116  void* const kFunctionBar = reinterpret_cast<void*>(0xBABABABA);
117
118  using vtable_patch::internal::ReplaceFunctionPointer;
119  // Replacing a non-writable location should fail, but not crash.
120  EXPECT_FALSE(ReplaceFunctionPointer(NULL, kFunctionBar, kFunctionFoo));
121
122  void* foo_entry = kFunctionOriginal;
123  // Replacing with the wrong original function should
124  // fail and not change the entry.
125  EXPECT_FALSE(ReplaceFunctionPointer(&foo_entry, kFunctionBar, kFunctionFoo));
126  EXPECT_EQ(foo_entry, kFunctionOriginal);
127
128  // Replacing with the correct original should succeed.
129  EXPECT_TRUE(ReplaceFunctionPointer(&foo_entry,
130                                     kFunctionBar,
131                                     kFunctionOriginal));
132  EXPECT_EQ(foo_entry, kFunctionBar);
133}
134
135TEST_F(VtablePatchManagerTest, PatchInterfaceMethods) {
136  // Unpatched.
137  EXPECT_CALL(factory_, LockServer(TRUE))
138      .WillOnce(Return(E_FAIL));
139  EXPECT_EQ(E_FAIL, factory_.LockServer(TRUE));
140
141  EXPECT_HRESULT_SUCCEEDED(
142      PatchInterfaceMethods(&factory_, IClassFactory_PatchInfo));
143
144  EXPECT_NE(0, memcmp(GetVtable(&factory_),
145                      vtable_backup_,
146                      sizeof(vtable_backup_)));
147
148  // This should not be called while the patch is in effect.
149  EXPECT_CALL(factory_, LockServer(_))
150      .Times(0);
151
152  EXPECT_CALL(*this, LockServerPatch(testing::_, &factory_, TRUE))
153      .WillOnce(testing::Return(S_FALSE));
154
155  EXPECT_EQ(S_FALSE, factory_.LockServer(TRUE));
156}
157
158TEST_F(VtablePatchManagerTest, UnpatchInterfaceMethods) {
159  // Patch it.
160  EXPECT_HRESULT_SUCCEEDED(
161      PatchInterfaceMethods(&factory_, IClassFactory_PatchInfo));
162
163  EXPECT_NE(0, memcmp(GetVtable(&factory_),
164                      vtable_backup_,
165                      sizeof(vtable_backup_)));
166
167  // This should not be called while the patch is in effect.
168  EXPECT_CALL(factory_, LockServer(testing::_))
169      .Times(0);
170
171  EXPECT_CALL(*this, LockServerPatch(testing::_, &factory_, TRUE))
172      .WillOnce(testing::Return(S_FALSE));
173
174  EXPECT_EQ(S_FALSE, factory_.LockServer(TRUE));
175
176  // Now unpatch.
177  EXPECT_HRESULT_SUCCEEDED(
178      UnpatchInterfaceMethods(IClassFactory_PatchInfo));
179
180  // And check that the call comes through correctly.
181  EXPECT_CALL(factory_, LockServer(FALSE))
182      .WillOnce(testing::Return(E_FAIL));
183  EXPECT_EQ(E_FAIL, factory_.LockServer(FALSE));
184}
185
186TEST_F(VtablePatchManagerTest, DoublePatch) {
187  // Patch it.
188  EXPECT_HRESULT_SUCCEEDED(
189      PatchInterfaceMethods(&factory_, IClassFactory_PatchInfo));
190
191  // Capture the VTable after patching.
192  PROC vtable[kFunctionCount];
193  memcpy(vtable, GetVtable(&factory_), sizeof(vtable));
194
195  // Patch it again, this should be idempotent.
196  EXPECT_HRESULT_SUCCEEDED(
197      PatchInterfaceMethods(&factory_, IClassFactory_PatchInfo));
198
199  // Should not have changed the VTable on second call.
200  EXPECT_EQ(0, memcmp(vtable, GetVtable(&factory_), sizeof(vtable)));
201}
202
203namespace vtable_patch {
204// Expose internal implementation detail, purely for testing.
205extern base::Lock patch_lock_;
206
207}  // namespace vtable_patch
208
209TEST_F(VtablePatchManagerTest, ThreadSafePatching) {
210  // It's difficult to test for threadsafe patching, but as a close proxy,
211  // test for no patching happening from a background thread while the patch
212  // lock is held.
213  base::Thread background("Background Test Thread");
214
215  EXPECT_TRUE(background.Start());
216  base::win::ScopedHandle event(::CreateEvent(NULL, TRUE, FALSE, NULL));
217
218  // Grab the patch lock.
219  vtable_patch::patch_lock_.Acquire();
220
221  // Instruct the background thread to patch factory_.
222  background.message_loop()->PostTask(
223      FROM_HERE,
224      base::Bind(base::IgnoreResult(&vtable_patch::PatchInterfaceMethods),
225                 &factory_, &IClassFactory_PatchInfo[0]));
226
227  // And subsequently to signal the event. Neither of these actions should
228  // occur until we've released the patch lock.
229  background.message_loop()->PostTask(
230      FROM_HERE, base::Bind(base::IgnoreResult(::SetEvent), event.Get()));
231
232  // Wait for a little while, to give the background thread time to process.
233  // We expect this wait to time out, as the background thread should end up
234  // blocking on the patch lock.
235  EXPECT_EQ(WAIT_TIMEOUT, ::WaitForSingleObject(event.Get(), 50));
236
237  // Verify that patching did not take place yet.
238  EXPECT_CALL(factory_, LockServer(TRUE))
239      .WillOnce(Return(S_FALSE));
240  EXPECT_EQ(S_FALSE, factory_.LockServer(TRUE));
241
242  // Release the lock and wait on the event again to ensure
243  // the patching has taken place now.
244  vtable_patch::patch_lock_.Release();
245  EXPECT_EQ(WAIT_OBJECT_0, ::WaitForSingleObject(event.Get(), INFINITE));
246
247  // We should not get called here anymore.
248  EXPECT_CALL(factory_, LockServer(TRUE))
249      .Times(0);
250
251  // But should be diverted here.
252  EXPECT_CALL(*this, LockServerPatch(_, &factory_, TRUE))
253      .WillOnce(Return(S_FALSE));
254  EXPECT_EQ(S_FALSE, factory_.LockServer(TRUE));
255
256  // Same deal for unpatching.
257  ::ResetEvent(event.Get());
258
259  // Grab the patch lock.
260  vtable_patch::patch_lock_.Acquire();
261
262  // Instruct the background thread to unpatch.
263  background.message_loop()->PostTask(
264      FROM_HERE,
265      base::Bind(base::IgnoreResult(&vtable_patch::UnpatchInterfaceMethods),
266                 &IClassFactory_PatchInfo[0]));
267
268  // And subsequently to signal the event. Neither of these actions should
269  // occur until we've released the patch lock.
270  background.message_loop()->PostTask(
271      FROM_HERE, base::Bind(base::IgnoreResult(::SetEvent), event.Get()));
272
273  // Wait for a little while, to give the background thread time to process.
274  // We expect this wait to time out, as the background thread should end up
275  // blocking on the patch lock.
276  EXPECT_EQ(WAIT_TIMEOUT, ::WaitForSingleObject(event.Get(), 50));
277
278  // We should still be patched.
279  EXPECT_CALL(factory_, LockServer(TRUE))
280      .Times(0);
281  EXPECT_CALL(*this, LockServerPatch(_, &factory_, TRUE))
282      .WillOnce(Return(S_FALSE));
283  EXPECT_EQ(S_FALSE, factory_.LockServer(TRUE));
284
285  // Release the patch lock and wait on the event.
286  vtable_patch::patch_lock_.Release();
287  EXPECT_EQ(WAIT_OBJECT_0, ::WaitForSingleObject(event.Get(), INFINITE));
288
289  // Verify that unpatching took place.
290  EXPECT_CALL(factory_, LockServer(TRUE))
291      .WillOnce(Return(S_FALSE));
292  EXPECT_EQ(S_FALSE, factory_.LockServer(TRUE));
293}
294