1# mock.py
2# Test tools for mocking and patching.
3# Copyright (C) 2007-2009 Michael Foord
4# E-mail: fuzzyman AT voidspace DOT org DOT uk
5
6# mock 0.6.0
7# http://www.voidspace.org.uk/python/mock/
8
9# Released subject to the BSD License
10# Please see http://www.voidspace.org.uk/python/license.shtml
11
12# 2009-11-25: Licence downloaded from above URL.
13# BEGIN DOWNLOADED LICENSE
14#
15# Copyright (c) 2003-2009, Michael Foord
16# All rights reserved.
17# E-mail : fuzzyman AT voidspace DOT org DOT uk
18#
19# Redistribution and use in source and binary forms, with or without
20# modification, are permitted provided that the following conditions are
21# met:
22#
23#
24#     * Redistributions of source code must retain the above copyright
25#       notice, this list of conditions and the following disclaimer.
26#
27#     * Redistributions in binary form must reproduce the above
28#       copyright notice, this list of conditions and the following
29#       disclaimer in the documentation and/or other materials provided
30#       with the distribution.
31#
32#     * Neither the name of Michael Foord nor the name of Voidspace
33#       may be used to endorse or promote products derived from this
34#       software without specific prior written permission.
35#
36# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
37# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
38# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
39# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
40# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
41# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
42# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
43# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
44# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
45# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
46# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
47#
48# END DOWNLOADED LICENSE
49
50# Scripts maintained at http://www.voidspace.org.uk/python/index.shtml
51# Comments, suggestions and bug reports welcome.
52
53
54__all__ = (
55    'Mock',
56    'patch',
57    'patch_object',
58    'sentinel',
59    'DEFAULT'
60)
61
62__version__ = '0.6.0'
63
64class SentinelObject(object):
65    def __init__(self, name):
66        self.name = name
67
68    def __repr__(self):
69        return '<SentinelObject "%s">' % self.name
70
71
72class Sentinel(object):
73    def __init__(self):
74        self._sentinels = {}
75
76    def __getattr__(self, name):
77        return self._sentinels.setdefault(name, SentinelObject(name))
78
79
80sentinel = Sentinel()
81
82DEFAULT = sentinel.DEFAULT
83
84class OldStyleClass:
85    pass
86ClassType = type(OldStyleClass)
87
88def _is_magic(name):
89    return '__%s__' % name[2:-2] == name
90
91def _copy(value):
92    if type(value) in (dict, list, tuple, set):
93        return type(value)(value)
94    return value
95
96
97class Mock(object):
98
99    def __init__(self, spec=None, side_effect=None, return_value=DEFAULT,
100                 name=None, parent=None, wraps=None):
101        self._parent = parent
102        self._name = name
103        if spec is not None and not isinstance(spec, list):
104            spec = [member for member in dir(spec) if not _is_magic(member)]
105
106        self._methods = spec
107        self._children = {}
108        self._return_value = return_value
109        self.side_effect = side_effect
110        self._wraps = wraps
111
112        self.reset_mock()
113
114
115    def reset_mock(self):
116        self.called = False
117        self.call_args = None
118        self.call_count = 0
119        self.call_args_list = []
120        self.method_calls = []
121        for child in self._children.itervalues():
122            child.reset_mock()
123        if isinstance(self._return_value, Mock):
124            self._return_value.reset_mock()
125
126
127    def __get_return_value(self):
128        if self._return_value is DEFAULT:
129            self._return_value = Mock()
130        return self._return_value
131
132    def __set_return_value(self, value):
133        self._return_value = value
134
135    return_value = property(__get_return_value, __set_return_value)
136
137
138    def __call__(self, *args, **kwargs):
139        self.called = True
140        self.call_count += 1
141        self.call_args = (args, kwargs)
142        self.call_args_list.append((args, kwargs))
143
144        parent = self._parent
145        name = self._name
146        while parent is not None:
147            parent.method_calls.append((name, args, kwargs))
148            if parent._parent is None:
149                break
150            name = parent._name + '.' + name
151            parent = parent._parent
152
153        ret_val = DEFAULT
154        if self.side_effect is not None:
155            if (isinstance(self.side_effect, Exception) or
156                isinstance(self.side_effect, (type, ClassType)) and
157                issubclass(self.side_effect, Exception)):
158                raise self.side_effect
159
160            ret_val = self.side_effect(*args, **kwargs)
161            if ret_val is DEFAULT:
162                ret_val = self.return_value
163
164        if self._wraps is not None and self._return_value is DEFAULT:
165            return self._wraps(*args, **kwargs)
166        if ret_val is DEFAULT:
167            ret_val = self.return_value
168        return ret_val
169
170
171    def __getattr__(self, name):
172        if self._methods is not None:
173            if name not in self._methods:
174                raise AttributeError("Mock object has no attribute '%s'" % name)
175        elif _is_magic(name):
176            raise AttributeError(name)
177
178        if name not in self._children:
179            wraps = None
180            if self._wraps is not None:
181                wraps = getattr(self._wraps, name)
182            self._children[name] = Mock(parent=self, name=name, wraps=wraps)
183
184        return self._children[name]
185
186
187    def assert_called_with(self, *args, **kwargs):
188        assert self.call_args == (args, kwargs), 'Expected: %s\nCalled with: %s' % ((args, kwargs), self.call_args)
189
190
191def _dot_lookup(thing, comp, import_path):
192    try:
193        return getattr(thing, comp)
194    except AttributeError:
195        __import__(import_path)
196        return getattr(thing, comp)
197
198
199def _importer(target):
200    components = target.split('.')
201    import_path = components.pop(0)
202    thing = __import__(import_path)
203
204    for comp in components:
205        import_path += ".%s" % comp
206        thing = _dot_lookup(thing, comp, import_path)
207    return thing
208
209
210class _patch(object):
211    def __init__(self, target, attribute, new, spec, create):
212        self.target = target
213        self.attribute = attribute
214        self.new = new
215        self.spec = spec
216        self.create = create
217        self.has_local = False
218
219
220    def __call__(self, func):
221        if hasattr(func, 'patchings'):
222            func.patchings.append(self)
223            return func
224
225        def patched(*args, **keywargs):
226            # don't use a with here (backwards compatability with 2.5)
227            extra_args = []
228            for patching in patched.patchings:
229                arg = patching.__enter__()
230                if patching.new is DEFAULT:
231                    extra_args.append(arg)
232            args += tuple(extra_args)
233            try:
234                return func(*args, **keywargs)
235            finally:
236                for patching in getattr(patched, 'patchings', []):
237                    patching.__exit__()
238
239        patched.patchings = [self]
240        patched.__name__ = func.__name__
241        patched.compat_co_firstlineno = getattr(func, "compat_co_firstlineno",
242                                                func.func_code.co_firstlineno)
243        return patched
244
245
246    def get_original(self):
247        target = self.target
248        name = self.attribute
249        create = self.create
250
251        original = DEFAULT
252        if _has_local_attr(target, name):
253            try:
254                original = target.__dict__[name]
255            except AttributeError:
256                # for instances of classes with slots, they have no __dict__
257                original = getattr(target, name)
258        elif not create and not hasattr(target, name):
259            raise AttributeError("%s does not have the attribute %r" % (target, name))
260        return original
261
262
263    def __enter__(self):
264        new, spec, = self.new, self.spec
265        original = self.get_original()
266        if new is DEFAULT:
267            # XXXX what if original is DEFAULT - shouldn't use it as a spec
268            inherit = False
269            if spec == True:
270                # set spec to the object we are replacing
271                spec = original
272                if isinstance(spec, (type, ClassType)):
273                    inherit = True
274            new = Mock(spec=spec)
275            if inherit:
276                new.return_value = Mock(spec=spec)
277        self.temp_original = original
278        setattr(self.target, self.attribute, new)
279        return new
280
281
282    def __exit__(self, *_):
283        if self.temp_original is not DEFAULT:
284            setattr(self.target, self.attribute, self.temp_original)
285        else:
286            delattr(self.target, self.attribute)
287        del self.temp_original
288
289
290def patch_object(target, attribute, new=DEFAULT, spec=None, create=False):
291    return _patch(target, attribute, new, spec, create)
292
293
294def patch(target, new=DEFAULT, spec=None, create=False):
295    try:
296        target, attribute = target.rsplit('.', 1)
297    except (TypeError, ValueError):
298        raise TypeError("Need a valid target to patch. You supplied: %r" % (target,))
299    target = _importer(target)
300    return _patch(target, attribute, new, spec, create)
301
302
303
304def _has_local_attr(obj, name):
305    try:
306        return name in vars(obj)
307    except TypeError:
308        # objects without a __dict__
309        return hasattr(obj, name)
310