1# Copyright 2012 The Chromium Authors. All rights reserved. 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4"""A very very simple mock object harness.""" 5 6DONT_CARE = '' 7 8class MockFunctionCall(object): 9 def __init__(self, name): 10 self.name = name 11 self.args = tuple() 12 self.return_value = None 13 self.when_called_handlers = [] 14 15 def WithArgs(self, *args): 16 self.args = args 17 return self 18 19 def WillReturn(self, value): 20 self.return_value = value 21 return self 22 23 def WhenCalled(self, handler): 24 self.when_called_handlers.append(handler) 25 26 def VerifyEquals(self, got): 27 if self.name != got.name: 28 raise Exception('Self %s, got %s' % (repr(self), repr(got))) 29 if len(self.args) != len(got.args): 30 raise Exception('Self %s, got %s' % (repr(self), repr(got))) 31 for i in range(len(self.args)): 32 self_a = self.args[i] 33 got_a = got.args[i] 34 if self_a == DONT_CARE: 35 continue 36 if self_a != got_a: 37 raise Exception('Self %s, got %s' % (repr(self), repr(got))) 38 39 def __repr__(self): 40 def arg_to_text(a): 41 if a == DONT_CARE: 42 return '_' 43 return repr(a) 44 args_text = ', '.join([arg_to_text(a) for a in self.args]) 45 if self.return_value in (None, DONT_CARE): 46 return '%s(%s)' % (self.name, args_text) 47 return '%s(%s)->%s' % (self.name, args_text, repr(self.return_value)) 48 49class MockTrace(object): 50 def __init__(self): 51 self.expected_calls = [] 52 self.next_call_index = 0 53 54class MockObject(object): 55 def __init__(self, parent_mock=None): 56 if parent_mock: 57 self._trace = parent_mock._trace # pylint: disable=protected-access 58 else: 59 self._trace = MockTrace() 60 61 def __setattr__(self, name, value): 62 if (not hasattr(self, '_trace') or 63 hasattr(value, 'is_hook')): 64 object.__setattr__(self, name, value) 65 return 66 assert isinstance(value, MockObject) 67 object.__setattr__(self, name, value) 68 69 def SetAttribute(self, name, value): 70 setattr(self, name, value) 71 72 def ExpectCall(self, func_name, *args): 73 assert self._trace.next_call_index == 0 74 if not hasattr(self, func_name): 75 self._install_hook(func_name) 76 77 call = MockFunctionCall(func_name) 78 self._trace.expected_calls.append(call) 79 call.WithArgs(*args) 80 return call 81 82 def _install_hook(self, func_name): 83 def handler(*args, **_): 84 got_call = MockFunctionCall( 85 func_name).WithArgs(*args).WillReturn(DONT_CARE) 86 if self._trace.next_call_index >= len(self._trace.expected_calls): 87 raise Exception( 88 'Call to %s was not expected, at end of programmed trace.' % 89 repr(got_call)) 90 expected_call = self._trace.expected_calls[ 91 self._trace.next_call_index] 92 expected_call.VerifyEquals(got_call) 93 self._trace.next_call_index += 1 94 for h in expected_call.when_called_handlers: 95 h(*args) 96 return expected_call.return_value 97 handler.is_hook = True 98 setattr(self, func_name, handler) 99