simple_mock.py revision cedac228d2dd51db4b79ea1e72c7f249408ee061
1# Copyright (c) 2012 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4"""A very very simple mock object harness."""
5
6DONT_CARE = ''
7
8class MockFunctionCall(object):
9  def __init__(self, name):
10    self.name = name
11    self.args = tuple()
12    self.return_value = None
13    self.when_called_handlers = []
14
15  def WithArgs(self, *args):
16    self.args = args
17    return self
18
19  def WillReturn(self, value):
20    self.return_value = value
21    return self
22
23  def WhenCalled(self, handler):
24    self.when_called_handlers.append(handler)
25
26  def VerifyEquals(self, got):
27    if self.name != got.name:
28      raise Exception('Self %s, got %s' % (repr(self), repr(got)))
29    if len(self.args) != len(got.args):
30      raise Exception('Self %s, got %s' % (repr(self), repr(got)))
31    for i in range(len(self.args)):
32      self_a = self.args[i]
33      got_a = got.args[i]
34      if self_a == DONT_CARE:
35        continue
36      if self_a != got_a:
37        raise Exception('Self %s, got %s' % (repr(self), repr(got)))
38
39  def __repr__(self):
40    def arg_to_text(a):
41      if a == DONT_CARE:
42        return '_'
43      return repr(a)
44    args_text = ', '.join([arg_to_text(a) for a in self.args])
45    if self.return_value in (None, DONT_CARE):
46      return '%s(%s)' % (self.name, args_text)
47    return '%s(%s)->%s' % (self.name, args_text, repr(self.return_value))
48
49class MockTrace(object):
50  def __init__(self):
51    self.expected_calls = []
52    self.next_call_index = 0
53
54class MockObject(object):
55  def __init__(self, parent_mock = None):
56    if parent_mock:
57      self._trace = parent_mock._trace # pylint: disable=W0212
58    else:
59      self._trace = MockTrace()
60
61  def __setattr__(self, name, value):
62    if (not hasattr(self, '_trace') or
63        hasattr(value, 'is_hook')):
64      object.__setattr__(self, name, value)
65      return
66    assert isinstance(value, MockObject)
67    object.__setattr__(self, name, value)
68
69  def SetAttribute(self, name, value):
70    setattr(self, name, value)
71
72  def ExpectCall(self, func_name, *args):
73    assert self._trace.next_call_index == 0
74    if not hasattr(self, func_name):
75      self._install_hook(func_name)
76
77    call = MockFunctionCall(func_name)
78    self._trace.expected_calls.append(call)
79    call.WithArgs(*args)
80    return call
81
82  def _install_hook(self, func_name):
83    def handler(*args, **_):
84      got_call = MockFunctionCall(
85        func_name).WithArgs(*args).WillReturn(DONT_CARE)
86      if self._trace.next_call_index >= len(self._trace.expected_calls):
87        raise Exception(
88          'Call to %s was not expected, at end of programmed trace.' %
89          repr(got_call))
90      expected_call = self._trace.expected_calls[
91        self._trace.next_call_index]
92      expected_call.VerifyEquals(got_call)
93      self._trace.next_call_index += 1
94      for h in expected_call.when_called_handlers:
95        h(*args)
96      return expected_call.return_value
97    handler.is_hook = True
98    setattr(self, func_name, handler)
99
100
101class MockTimer(object):
102  def __init__(self):
103    self._elapsed_time = 0
104
105  def Sleep(self, time):
106    self._elapsed_time += time
107
108  def GetTime(self):
109    return self._elapsed_time
110