1# Copyright 2008 Google Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15# This is a fork of the pymox library intended to work with Python 3.
16# The file was modified by quermit@gmail.com and dawid.fatyga@gmail.com
17
18"""Mox, an object-mocking framework for Python.
19
20Mox works in the record-replay-verify paradigm.  When you first create
21a mock object, it is in record mode.  You then programmatically set
22the expected behavior of the mock object (what methods are to be
23called on it, with what parameters, what they should return, and in
24what order).
25
26Once you have set up the expected mock behavior, you put it in replay
27mode.  Now the mock responds to method calls just as you told it to.
28If an unexpected method (or an expected method with unexpected
29parameters) is called, then an exception will be raised.
30
31Once you are done interacting with the mock, you need to verify that
32all the expected interactions occured.  (Maybe your code exited
33prematurely without calling some cleanup method!)  The verify phase
34ensures that every expected method was called; otherwise, an exception
35will be raised.
36
37WARNING! Mock objects created by Mox are not thread-safe.  If you are
38call a mock in multiple threads, it should be guarded by a mutex.
39
40TODO(stevepm): Add the option to make mocks thread-safe!
41
42Suggested usage / workflow:
43
44    # Create Mox factory
45    my_mox = Mox()
46
47    # Create a mock data access object
48    mock_dao = my_mox.CreateMock(DAOClass)
49
50    # Set up expected behavior
51    mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
52    mock_dao.DeletePerson(person)
53
54    # Put mocks in replay mode
55    my_mox.ReplayAll()
56
57    # Inject mock object and run test
58    controller.SetDao(mock_dao)
59    controller.DeletePersonById('1')
60
61    # Verify all methods were called as expected
62    my_mox.VerifyAll()
63"""
64
65import collections
66import difflib
67import inspect
68import re
69import types
70import unittest
71
72from mox3 import stubout
73
74
75class Error(AssertionError):
76    """Base exception for this module."""
77
78    pass
79
80
81class ExpectedMethodCallsError(Error):
82    """Raised when an expected method wasn't called.
83
84    This can occur if Verify() is called before all expected methods have been
85    called.
86    """
87
88    def __init__(self, expected_methods):
89        """Init exception.
90
91        Args:
92            # expected_methods: A sequence of MockMethod objects that should
93            #                   have been called.
94            expected_methods: [MockMethod]
95
96        Raises:
97            ValueError: if expected_methods contains no methods.
98        """
99
100        if not expected_methods:
101            raise ValueError("There must be at least one expected method")
102        Error.__init__(self)
103        self._expected_methods = expected_methods
104
105    def __str__(self):
106        calls = "\n".join(["%3d.  %s" % (i, m)
107                          for i, m in enumerate(self._expected_methods)])
108        return "Verify: Expected methods never called:\n%s" % (calls,)
109
110
111class UnexpectedMethodCallError(Error):
112    """Raised when an unexpected method is called.
113
114    This can occur if a method is called with incorrect parameters, or out of
115    the specified order.
116    """
117
118    def __init__(self, unexpected_method, expected):
119        """Init exception.
120
121        Args:
122            # unexpected_method: MockMethod that was called but was not at the
123            #     head of the expected_method queue.
124            # expected: MockMethod or UnorderedGroup the method should have
125            #     been in.
126            unexpected_method: MockMethod
127            expected: MockMethod or UnorderedGroup
128        """
129
130        Error.__init__(self)
131        if expected is None:
132            self._str = "Unexpected method call %s" % (unexpected_method,)
133        else:
134            differ = difflib.Differ()
135            diff = differ.compare(str(unexpected_method).splitlines(True),
136                                  str(expected).splitlines(True))
137            self._str = ("Unexpected method call."
138                         "  unexpected:-  expected:+\n%s"
139                         % ("\n".join(line.rstrip() for line in diff),))
140
141    def __str__(self):
142        return self._str
143
144
145class UnknownMethodCallError(Error):
146    """Raised if an unknown method is requested of the mock object."""
147
148    def __init__(self, unknown_method_name):
149        """Init exception.
150
151        Args:
152            # unknown_method_name: Method call that is not part of the mocked
153            #     class's public interface.
154            unknown_method_name: str
155        """
156
157        Error.__init__(self)
158        self._unknown_method_name = unknown_method_name
159
160    def __str__(self):
161        return ("Method called is not a member of the object: %s" %
162                self._unknown_method_name)
163
164
165class PrivateAttributeError(Error):
166    """Raised if a MockObject is passed a private additional attribute name."""
167
168    def __init__(self, attr):
169        Error.__init__(self)
170        self._attr = attr
171
172    def __str__(self):
173        return ("Attribute '%s' is private and should not be available"
174                "in a mock object." % self._attr)
175
176
177class ExpectedMockCreationError(Error):
178    """Raised if mocks should have been created by StubOutClassWithMocks."""
179
180    def __init__(self, expected_mocks):
181        """Init exception.
182
183        Args:
184            # expected_mocks: A sequence of MockObjects that should have been
185            #     created
186
187        Raises:
188            ValueError: if expected_mocks contains no methods.
189        """
190
191        if not expected_mocks:
192            raise ValueError("There must be at least one expected method")
193        Error.__init__(self)
194        self._expected_mocks = expected_mocks
195
196    def __str__(self):
197        mocks = "\n".join(["%3d.  %s" % (i, m)
198                          for i, m in enumerate(self._expected_mocks)])
199        return "Verify: Expected mocks never created:\n%s" % (mocks,)
200
201
202class UnexpectedMockCreationError(Error):
203    """Raised if too many mocks were created by StubOutClassWithMocks."""
204
205    def __init__(self, instance, *params, **named_params):
206        """Init exception.
207
208        Args:
209            # instance: the type of obejct that was created
210            # params: parameters given during instantiation
211            # named_params: named parameters given during instantiation
212        """
213
214        Error.__init__(self)
215        self._instance = instance
216        self._params = params
217        self._named_params = named_params
218
219    def __str__(self):
220        args = ", ".join(["%s" % v for i, v in enumerate(self._params)])
221        error = "Unexpected mock creation: %s(%s" % (self._instance, args)
222
223        if self._named_params:
224            error += ", " + ", ".join(["%s=%s" % (k, v) for k, v in
225                                      self._named_params.items()])
226
227        error += ")"
228        return error
229
230
231class Mox(object):
232    """Mox: a factory for creating mock objects."""
233
234    # A list of types that should be stubbed out with MockObjects (as
235    # opposed to MockAnythings).
236    _USE_MOCK_OBJECT = [types.FunctionType, types.ModuleType, types.MethodType]
237
238    def __init__(self):
239        """Initialize a new Mox."""
240
241        self._mock_objects = []
242        self.stubs = stubout.StubOutForTesting()
243
244    def CreateMock(self, class_to_mock, attrs=None, bounded_to=None):
245        """Create a new mock object.
246
247        Args:
248            # class_to_mock: the class to be mocked
249            class_to_mock: class
250            attrs: dict of attribute names to values that will be
251                   set on the mock object. Only public attributes may be set.
252            bounded_to: optionally, when class_to_mock is not a class,
253                        it points to a real class object, to which
254                        attribute is bound
255
256        Returns:
257            MockObject that can be used as the class_to_mock would be.
258        """
259        if attrs is None:
260            attrs = {}
261        new_mock = MockObject(class_to_mock, attrs=attrs,
262                              class_to_bind=bounded_to)
263        self._mock_objects.append(new_mock)
264        return new_mock
265
266    def CreateMockAnything(self, description=None):
267        """Create a mock that will accept any method calls.
268
269        This does not enforce an interface.
270
271        Args:
272        description: str. Optionally, a descriptive name for the mock object
273        being created, for debugging output purposes.
274        """
275        new_mock = MockAnything(description=description)
276        self._mock_objects.append(new_mock)
277        return new_mock
278
279    def ReplayAll(self):
280        """Set all mock objects to replay mode."""
281
282        for mock_obj in self._mock_objects:
283            mock_obj._Replay()
284
285    def VerifyAll(self):
286        """Call verify on all mock objects created."""
287
288        for mock_obj in self._mock_objects:
289            mock_obj._Verify()
290
291    def ResetAll(self):
292        """Call reset on all mock objects.    This does not unset stubs."""
293
294        for mock_obj in self._mock_objects:
295            mock_obj._Reset()
296
297    def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
298        """Replace a method, attribute, etc. with a Mock.
299
300        This will replace a class or module with a MockObject, and everything
301        else (method, function, etc) with a MockAnything. This can be
302        overridden to always use a MockAnything by setting use_mock_anything
303        to True.
304
305        Args:
306            obj: A Python object (class, module, instance, callable).
307            attr_name: str. The name of the attribute to replace with a mock.
308            use_mock_anything: bool. True if a MockAnything should be used
309                               regardless of the type of attribute.
310        """
311
312        if inspect.isclass(obj):
313            class_to_bind = obj
314        else:
315            class_to_bind = None
316
317        attr_to_replace = getattr(obj, attr_name)
318        attr_type = type(attr_to_replace)
319
320        if attr_type == MockAnything or attr_type == MockObject:
321            raise TypeError('Cannot mock a MockAnything! Did you remember to '
322                            'call UnsetStubs in your previous test?')
323
324        type_check = (
325            attr_type in self._USE_MOCK_OBJECT or
326            inspect.isclass(attr_to_replace) or
327            isinstance(attr_to_replace, object))
328        if type_check and not use_mock_anything:
329            stub = self.CreateMock(attr_to_replace, bounded_to=class_to_bind)
330        else:
331            stub = self.CreateMockAnything(
332                description='Stub for %s' % attr_to_replace)
333            stub.__name__ = attr_name
334
335        self.stubs.Set(obj, attr_name, stub)
336
337    def StubOutClassWithMocks(self, obj, attr_name):
338        """Replace a class with a "mock factory" that will create mock objects.
339
340        This is useful if the code-under-test directly instantiates
341        dependencies.    Previously some boilder plate was necessary to
342        create a mock that would act as a factory.    Using
343        StubOutClassWithMocks, once you've stubbed out the class you may
344        use the stubbed class as you would any other mock created by mox:
345        during the record phase, new mock instances will be created, and
346        during replay, the recorded mocks will be returned.
347
348        In replay mode
349
350        # Example using StubOutWithMock (the old, clunky way):
351
352        mock1 = mox.CreateMock(my_import.FooClass)
353        mock2 = mox.CreateMock(my_import.FooClass)
354        foo_factory = mox.StubOutWithMock(my_import, 'FooClass',
355                                          use_mock_anything=True)
356        foo_factory(1, 2).AndReturn(mock1)
357        foo_factory(9, 10).AndReturn(mock2)
358        mox.ReplayAll()
359
360        my_import.FooClass(1, 2)     # Returns mock1 again.
361        my_import.FooClass(9, 10)    # Returns mock2 again.
362        mox.VerifyAll()
363
364        # Example using StubOutClassWithMocks:
365
366        mox.StubOutClassWithMocks(my_import, 'FooClass')
367        mock1 = my_import.FooClass(1, 2)     # Returns a new mock of FooClass
368        mock2 = my_import.FooClass(9, 10)    # Returns another mock instance
369        mox.ReplayAll()
370
371        my_import.FooClass(1, 2)     # Returns mock1 again.
372        my_import.FooClass(9, 10)    # Returns mock2 again.
373        mox.VerifyAll()
374        """
375        attr_to_replace = getattr(obj, attr_name)
376        attr_type = type(attr_to_replace)
377
378        if attr_type == MockAnything or attr_type == MockObject:
379            raise TypeError('Cannot mock a MockAnything! Did you remember to '
380                            'call UnsetStubs in your previous test?')
381
382        if not inspect.isclass(attr_to_replace):
383            raise TypeError('Given attr is not a Class. Use StubOutWithMock.')
384
385        factory = _MockObjectFactory(attr_to_replace, self)
386        self._mock_objects.append(factory)
387        self.stubs.Set(obj, attr_name, factory)
388
389    def UnsetStubs(self):
390        """Restore stubs to their original state."""
391
392        self.stubs.UnsetAll()
393
394
395def Replay(*args):
396    """Put mocks into Replay mode.
397
398    Args:
399        # args is any number of mocks to put into replay mode.
400    """
401
402    for mock in args:
403        mock._Replay()
404
405
406def Verify(*args):
407    """Verify mocks.
408
409    Args:
410        # args is any number of mocks to be verified.
411    """
412
413    for mock in args:
414        mock._Verify()
415
416
417def Reset(*args):
418    """Reset mocks.
419
420    Args:
421        # args is any number of mocks to be reset.
422    """
423
424    for mock in args:
425        mock._Reset()
426
427
428class MockAnything(object):
429    """A mock that can be used to mock anything.
430
431    This is helpful for mocking classes that do not provide a public interface.
432    """
433
434    def __init__(self, description=None):
435        """Initialize a new MockAnything.
436
437        Args:
438            description: str. Optionally, a descriptive name for the mock
439                         object being created, for debugging output purposes.
440        """
441        self._description = description
442        self._Reset()
443
444    def __repr__(self):
445        if self._description:
446            return '<MockAnything instance of %s>' % self._description
447        else:
448            return '<MockAnything instance>'
449
450    def __getattr__(self, method_name):
451        """Intercept method calls on this object.
452
453         A new MockMethod is returned that is aware of the MockAnything's
454         state (record or replay).    The call will be recorded or replayed
455         by the MockMethod's __call__.
456
457        Args:
458            # method name: the name of the method being called.
459            method_name: str
460
461        Returns:
462            A new MockMethod aware of MockAnything's state (record or replay).
463        """
464        if method_name == '__dir__':
465                return self.__class__.__dir__.__get__(self, self.__class__)
466
467        return self._CreateMockMethod(method_name)
468
469    def __str__(self):
470        return self._CreateMockMethod('__str__')()
471
472    def __call__(self, *args, **kwargs):
473        return self._CreateMockMethod('__call__')(*args, **kwargs)
474
475    def __getitem__(self, i):
476        return self._CreateMockMethod('__getitem__')(i)
477
478    def _CreateMockMethod(self, method_name, method_to_mock=None,
479                          class_to_bind=object):
480        """Create a new mock method call and return it.
481
482        Args:
483            # method_name: the name of the method being called.
484            # method_to_mock: The actual method being mocked, used for
485            #                 introspection.
486            # class_to_bind: Class to which method is bounded
487            #                (object by default)
488            method_name: str
489            method_to_mock: a method object
490
491        Returns:
492            A new MockMethod aware of MockAnything's state (record or replay).
493        """
494
495        return MockMethod(method_name, self._expected_calls_queue,
496                          self._replay_mode, method_to_mock=method_to_mock,
497                          description=self._description,
498                          class_to_bind=class_to_bind)
499
500    def __nonzero__(self):
501        """Return 1 for nonzero so the mock can be used as a conditional."""
502
503        return 1
504
505    def __bool__(self):
506        """Return True for nonzero so the mock can be used as a conditional."""
507        return True
508
509    def __eq__(self, rhs):
510        """Provide custom logic to compare objects."""
511
512        return (isinstance(rhs, MockAnything) and
513                self._replay_mode == rhs._replay_mode and
514                self._expected_calls_queue == rhs._expected_calls_queue)
515
516    def __ne__(self, rhs):
517        """Provide custom logic to compare objects."""
518
519        return not self == rhs
520
521    def _Replay(self):
522        """Start replaying expected method calls."""
523
524        self._replay_mode = True
525
526    def _Verify(self):
527        """Verify that all of the expected calls have been made.
528
529        Raises:
530            ExpectedMethodCallsError: if there are still more method calls in
531                                      the expected queue.
532        """
533
534        # If the list of expected calls is not empty, raise an exception
535        if self._expected_calls_queue:
536            # The last MultipleTimesGroup is not popped from the queue.
537            if (len(self._expected_calls_queue) == 1 and
538                    isinstance(self._expected_calls_queue[0],
539                               MultipleTimesGroup) and
540                    self._expected_calls_queue[0].IsSatisfied()):
541                pass
542            else:
543                raise ExpectedMethodCallsError(self._expected_calls_queue)
544
545    def _Reset(self):
546        """Reset the state of this mock to record mode with an empty queue."""
547
548        # Maintain a list of method calls we are expecting
549        self._expected_calls_queue = collections.deque()
550
551        # Make sure we are in setup mode, not replay mode
552        self._replay_mode = False
553
554
555class MockObject(MockAnything):
556    """Mock object that simulates the public/protected interface of a class."""
557
558    def __init__(self, class_to_mock, attrs=None, class_to_bind=None):
559        """Initialize a mock object.
560
561        Determines the methods and properties of the class and stores them.
562
563        Args:
564            # class_to_mock: class to be mocked
565            class_to_mock: class
566            attrs: dict of attribute names to values that will be set on the
567                   mock object. Only public attributes may be set.
568            class_to_bind: optionally, when class_to_mock is not a class at
569                           all, it points to a real class
570
571        Raises:
572            PrivateAttributeError: if a supplied attribute is not public.
573            ValueError: if an attribute would mask an existing method.
574        """
575        if attrs is None:
576            attrs = {}
577
578        # Used to hack around the mixin/inheritance of MockAnything, which
579        # is not a proper object (it can be anything. :-)
580        MockAnything.__dict__['__init__'](self)
581
582        # Get a list of all the public and special methods we should mock.
583        self._known_methods = set()
584        self._known_vars = set()
585        self._class_to_mock = class_to_mock
586
587        if inspect.isclass(class_to_mock):
588            self._class_to_bind = self._class_to_mock
589        else:
590            self._class_to_bind = class_to_bind
591
592        try:
593            if inspect.isclass(self._class_to_mock):
594                self._description = class_to_mock.__name__
595            else:
596                self._description = type(class_to_mock).__name__
597        except Exception:
598            pass
599
600        for method in dir(class_to_mock):
601            attr = getattr(class_to_mock, method)
602            if callable(attr):
603                self._known_methods.add(method)
604            elif not (type(attr) is property):
605                # treating properties as class vars makes little sense.
606                self._known_vars.add(method)
607
608        # Set additional attributes at instantiation time; this is quicker
609        # than manually setting attributes that are normally created in
610        # __init__.
611        for attr, value in attrs.items():
612            if attr.startswith("_"):
613                raise PrivateAttributeError(attr)
614            elif attr in self._known_methods:
615                raise ValueError("'%s' is a method of '%s' objects." % (attr,
616                                 class_to_mock))
617            else:
618                setattr(self, attr, value)
619
620    def _CreateMockMethod(self, *args, **kwargs):
621        """Overridden to provide self._class_to_mock to class_to_bind."""
622        kwargs.setdefault("class_to_bind", self._class_to_bind)
623        return super(MockObject, self)._CreateMockMethod(*args, **kwargs)
624
625    def __getattr__(self, name):
626        """Intercept attribute request on this object.
627
628        If the attribute is a public class variable, it will be returned and
629        not recorded as a call.
630
631        If the attribute is not a variable, it is handled like a method
632        call. The method name is checked against the set of mockable
633        methods, and a new MockMethod is returned that is aware of the
634        MockObject's state (record or replay).    The call will be recorded
635        or replayed by the MockMethod's __call__.
636
637        Args:
638            # name: the name of the attribute being requested.
639            name: str
640
641        Returns:
642            Either a class variable or a new MockMethod that is aware of the
643            state of the mock (record or replay).
644
645        Raises:
646            UnknownMethodCallError if the MockObject does not mock the
647            requested method.
648        """
649
650        if name in self._known_vars:
651            return getattr(self._class_to_mock, name)
652
653        if name in self._known_methods:
654            return self._CreateMockMethod(
655                name,
656                method_to_mock=getattr(self._class_to_mock, name))
657
658        raise UnknownMethodCallError(name)
659
660    def __eq__(self, rhs):
661        """Provide custom logic to compare objects."""
662
663        return (isinstance(rhs, MockObject) and
664                self._class_to_mock == rhs._class_to_mock and
665                self._replay_mode == rhs._replay_mode and
666                self._expected_calls_queue == rhs._expected_calls_queue)
667
668    def __setitem__(self, key, value):
669        """Custom logic for mocking classes that support item assignment.
670
671        Args:
672            key: Key to set the value for.
673            value: Value to set.
674
675        Returns:
676            Expected return value in replay mode. A MockMethod object for the
677            __setitem__ method that has already been called if not in replay
678            mode.
679
680        Raises:
681            TypeError if the underlying class does not support item assignment.
682            UnexpectedMethodCallError if the object does not expect the call to
683                __setitem__.
684
685        """
686        # Verify the class supports item assignment.
687        if '__setitem__' not in dir(self._class_to_mock):
688            raise TypeError('object does not support item assignment')
689
690        # If we are in replay mode then simply call the mock __setitem__ method
691        if self._replay_mode:
692            return MockMethod('__setitem__', self._expected_calls_queue,
693                              self._replay_mode)(key, value)
694
695        # Otherwise, create a mock method __setitem__.
696        return self._CreateMockMethod('__setitem__')(key, value)
697
698    def __getitem__(self, key):
699        """Provide custom logic for mocking classes that are subscriptable.
700
701        Args:
702            key: Key to return the value for.
703
704        Returns:
705            Expected return value in replay mode. A MockMethod object for the
706            __getitem__ method that has already been called if not in replay
707            mode.
708
709        Raises:
710            TypeError if the underlying class is not subscriptable.
711            UnexpectedMethodCallError if the object does not expect the call to
712                __getitem__.
713
714        """
715        # Verify the class supports item assignment.
716        if '__getitem__' not in dir(self._class_to_mock):
717            raise TypeError('unsubscriptable object')
718
719        # If we are in replay mode then simply call the mock __getitem__ method
720        if self._replay_mode:
721            return MockMethod('__getitem__', self._expected_calls_queue,
722                              self._replay_mode)(key)
723
724        # Otherwise, create a mock method __getitem__.
725        return self._CreateMockMethod('__getitem__')(key)
726
727    def __iter__(self):
728        """Provide custom logic for mocking classes that are iterable.
729
730        Returns:
731            Expected return value in replay mode. A MockMethod object for the
732            __iter__ method that has already been called if not in replay mode.
733
734        Raises:
735            TypeError if the underlying class is not iterable.
736            UnexpectedMethodCallError if the object does not expect the call to
737                __iter__.
738
739        """
740        methods = dir(self._class_to_mock)
741
742        # Verify the class supports iteration.
743        if '__iter__' not in methods:
744            # If it doesn't have iter method and we are in replay method,
745            # then try to iterate using subscripts.
746            if '__getitem__' not in methods or not self._replay_mode:
747                raise TypeError('not iterable object')
748            else:
749                results = []
750                index = 0
751                try:
752                    while True:
753                        results.append(self[index])
754                        index += 1
755                except IndexError:
756                    return iter(results)
757
758        # If we are in replay mode then simply call the mock __iter__ method.
759        if self._replay_mode:
760            return MockMethod('__iter__', self._expected_calls_queue,
761                              self._replay_mode)()
762
763        # Otherwise, create a mock method __iter__.
764        return self._CreateMockMethod('__iter__')()
765
766    def __contains__(self, key):
767        """Provide custom logic for mocking classes that contain items.
768
769        Args:
770            key: Key to look in container for.
771
772        Returns:
773            Expected return value in replay mode. A MockMethod object for the
774            __contains__ method that has already been called if not in replay
775            mode.
776
777        Raises:
778            TypeError if the underlying class does not implement __contains__
779            UnexpectedMethodCaller if the object does not expect the call to
780            __contains__.
781
782        """
783        contains = self._class_to_mock.__dict__.get('__contains__', None)
784
785        if contains is None:
786            raise TypeError('unsubscriptable object')
787
788        if self._replay_mode:
789            return MockMethod('__contains__', self._expected_calls_queue,
790                              self._replay_mode)(key)
791
792        return self._CreateMockMethod('__contains__')(key)
793
794    def __call__(self, *params, **named_params):
795        """Provide custom logic for mocking classes that are callable."""
796
797        # Verify the class we are mocking is callable.
798        is_callable = hasattr(self._class_to_mock, '__call__')
799        if not is_callable:
800            raise TypeError('Not callable')
801
802        # Because the call is happening directly on this object instead of
803        # a method, the call on the mock method is made right here
804
805        # If we are mocking a Function, then use the function, and not the
806        # __call__ method
807        method = None
808        if type(self._class_to_mock) in (types.FunctionType, types.MethodType):
809            method = self._class_to_mock
810        else:
811            method = getattr(self._class_to_mock, '__call__')
812        mock_method = self._CreateMockMethod('__call__', method_to_mock=method)
813
814        return mock_method(*params, **named_params)
815
816    @property
817    def __name__(self):
818        """Return the name that is being mocked."""
819        return self._description
820
821    # TODO(dejw): this property stopped to work after I introduced changes with
822    #     binding classes. Fortunately I found a solution in the form of
823    #     __getattribute__ method below, but this issue should be investigated
824    @property
825    def __class__(self):
826        return self._class_to_mock
827
828    def __dir__(self):
829        """Return only attributes of a class to mock."""
830        return dir(self._class_to_mock)
831
832    def __getattribute__(self, name):
833        """Return _class_to_mock on __class__ attribute."""
834        if name == "__class__":
835            return super(MockObject, self).__getattribute__("_class_to_mock")
836
837        return super(MockObject, self).__getattribute__(name)
838
839
840class _MockObjectFactory(MockObject):
841    """A MockObjectFactory creates mocks and verifies __init__ params.
842
843    A MockObjectFactory removes the boiler plate code that was previously
844    necessary to stub out direction instantiation of a class.
845
846    The MockObjectFactory creates new MockObjects when called and verifies the
847    __init__ params are correct when in record mode.    When replaying,
848    existing mocks are returned, and the __init__ params are verified.
849
850    See StubOutWithMock vs StubOutClassWithMocks for more detail.
851    """
852
853    def __init__(self, class_to_mock, mox_instance):
854        MockObject.__init__(self, class_to_mock)
855        self._mox = mox_instance
856        self._instance_queue = collections.deque()
857
858    def __call__(self, *params, **named_params):
859        """Instantiate and record that a new mock has been created."""
860
861        method = getattr(self._class_to_mock, '__init__')
862        mock_method = self._CreateMockMethod('__init__', method_to_mock=method)
863        # Note: calling mock_method() is deferred in order to catch the
864        # empty instance_queue first.
865
866        if self._replay_mode:
867            if not self._instance_queue:
868                raise UnexpectedMockCreationError(self._class_to_mock, *params,
869                                                  **named_params)
870
871            mock_method(*params, **named_params)
872
873            return self._instance_queue.pop()
874        else:
875            mock_method(*params, **named_params)
876
877            instance = self._mox.CreateMock(self._class_to_mock)
878            self._instance_queue.appendleft(instance)
879            return instance
880
881    def _Verify(self):
882        """Verify that all mocks have been created."""
883        if self._instance_queue:
884            raise ExpectedMockCreationError(self._instance_queue)
885        super(_MockObjectFactory, self)._Verify()
886
887
888class MethodSignatureChecker(object):
889    """Ensures that methods are called correctly."""
890
891    _NEEDED, _DEFAULT, _GIVEN = range(3)
892
893    def __init__(self, method, class_to_bind=None):
894        """Creates a checker.
895
896        Args:
897            # method: A method to check.
898            # class_to_bind: optionally, a class used to type check first
899            #                method parameter, only used with unbound methods
900            method: function
901            class_to_bind: type or None
902
903        Raises:
904            ValueError: method could not be inspected, so checks aren't
905                        possible. Some methods and functions like built-ins
906                        can't be inspected.
907        """
908        try:
909            self._args, varargs, varkw, defaults = inspect.getargspec(method)
910        except TypeError:
911            raise ValueError('Could not get argument specification for %r'
912                             % (method,))
913        if inspect.ismethod(method) or class_to_bind:
914            self._args = self._args[1:]    # Skip 'self'.
915        self._method = method
916        self._instance = None    # May contain the instance this is bound to.
917        self._instance = getattr(method, "__self__", None)
918
919        # _bounded_to determines whether the method is bound or not
920        if self._instance:
921            self._bounded_to = self._instance.__class__
922        else:
923            self._bounded_to = class_to_bind or getattr(method, "im_class",
924                                                        None)
925
926        self._has_varargs = varargs is not None
927        self._has_varkw = varkw is not None
928        if defaults is None:
929            self._required_args = self._args
930            self._default_args = []
931        else:
932            self._required_args = self._args[:-len(defaults)]
933            self._default_args = self._args[-len(defaults):]
934
935    def _RecordArgumentGiven(self, arg_name, arg_status):
936        """Mark an argument as being given.
937
938        Args:
939            # arg_name: The name of the argument to mark in arg_status.
940            # arg_status: Maps argument names to one of
941            #             _NEEDED, _DEFAULT, _GIVEN.
942            arg_name: string
943            arg_status: dict
944
945        Raises:
946            AttributeError: arg_name is already marked as _GIVEN.
947        """
948        if arg_status.get(arg_name, None) == MethodSignatureChecker._GIVEN:
949            raise AttributeError('%s provided more than once' % (arg_name,))
950        arg_status[arg_name] = MethodSignatureChecker._GIVEN
951
952    def Check(self, params, named_params):
953        """Ensures that the parameters used while recording a call are valid.
954
955        Args:
956            # params: A list of positional parameters.
957            # named_params: A dict of named parameters.
958            params: list
959            named_params: dict
960
961        Raises:
962            AttributeError: the given parameters don't work with the given
963                            method.
964        """
965        arg_status = dict((a, MethodSignatureChecker._NEEDED)
966                          for a in self._required_args)
967        for arg in self._default_args:
968            arg_status[arg] = MethodSignatureChecker._DEFAULT
969
970        # WARNING: Suspect hack ahead.
971        #
972        # Check to see if this is an unbound method, where the instance
973        # should be bound as the first argument.    We try to determine if
974        # the first argument (param[0]) is an instance of the class, or it
975        # is equivalent to the class (used to account for Comparators).
976        #
977        # NOTE: If a Func() comparator is used, and the signature is not
978        # correct, this will cause extra executions of the function.
979        if inspect.ismethod(self._method) or self._bounded_to:
980            # The extra param accounts for the bound instance.
981            if len(params) > len(self._required_args):
982                expected = self._bounded_to
983
984                # Check if the param is an instance of the expected class,
985                # or check equality (useful for checking Comparators).
986
987                # This is a hack to work around the fact that the first
988                # parameter can be a Comparator, and the comparison may raise
989                # an exception during this comparison, which is OK.
990                try:
991                    param_equality = (params[0] == expected)
992                except Exception:
993                    param_equality = False
994
995                if isinstance(params[0], expected) or param_equality:
996                    params = params[1:]
997                # If the IsA() comparator is being used, we need to check the
998                # inverse of the usual case - that the given instance is a
999                # subclass of the expected class. For example, the code under
1000                # test does late binding to a subclass.
1001                elif (isinstance(params[0], IsA) and
1002                      params[0]._IsSubClass(expected)):
1003                    params = params[1:]
1004
1005        # Check that each positional param is valid.
1006        for i in range(len(params)):
1007            try:
1008                arg_name = self._args[i]
1009            except IndexError:
1010                if not self._has_varargs:
1011                    raise AttributeError(
1012                        '%s does not take %d or more positional '
1013                        'arguments' % (self._method.__name__, i))
1014            else:
1015                self._RecordArgumentGiven(arg_name, arg_status)
1016
1017        # Check each keyword argument.
1018        for arg_name in named_params:
1019            if arg_name not in arg_status and not self._has_varkw:
1020                raise AttributeError('%s is not expecting keyword argument %s'
1021                                     % (self._method.__name__, arg_name))
1022            self._RecordArgumentGiven(arg_name, arg_status)
1023
1024        # Ensure all the required arguments have been given.
1025        still_needed = [k for k, v in arg_status.items()
1026                        if v == MethodSignatureChecker._NEEDED]
1027        if still_needed:
1028            raise AttributeError('No values given for arguments: %s'
1029                                 % (' '.join(sorted(still_needed))))
1030
1031
1032class MockMethod(object):
1033    """Callable mock method.
1034
1035    A MockMethod should act exactly like the method it mocks, accepting
1036    parameters and returning a value, or throwing an exception (as specified).
1037    When this method is called, it can optionally verify whether the called
1038    method (name and signature) matches the expected method.
1039    """
1040
1041    def __init__(self, method_name, call_queue, replay_mode,
1042                 method_to_mock=None, description=None, class_to_bind=None):
1043        """Construct a new mock method.
1044
1045        Args:
1046            # method_name: the name of the method
1047            # call_queue: deque of calls, verify this call against the head,
1048            #             or add this call to the queue.
1049            # replay_mode: False if we are recording, True if we are verifying
1050            #              calls against the call queue.
1051            # method_to_mock: The actual method being mocked, used for
1052            #                 introspection.
1053            # description: optionally, a descriptive name for this method.
1054            #              Typically this is equal to the descriptive name of
1055            #              the method's class.
1056            # class_to_bind: optionally, a class that is used for unbound
1057            #                methods (or functions in Python3) to which method
1058            #                is bound, in order not to loose binding
1059            #                information. If given, it will be used for
1060            #                checking the type of first method parameter
1061            method_name: str
1062            call_queue: list or deque
1063            replay_mode: bool
1064            method_to_mock: a method object
1065            description: str or None
1066            class_to_bind: type or None
1067        """
1068
1069        self._name = method_name
1070        self.__name__ = method_name
1071        self._call_queue = call_queue
1072        if not isinstance(call_queue, collections.deque):
1073            self._call_queue = collections.deque(self._call_queue)
1074        self._replay_mode = replay_mode
1075        self._description = description
1076
1077        self._params = None
1078        self._named_params = None
1079        self._return_value = None
1080        self._exception = None
1081        self._side_effects = None
1082
1083        try:
1084            self._checker = MethodSignatureChecker(method_to_mock,
1085                                                   class_to_bind=class_to_bind)
1086        except ValueError:
1087            self._checker = None
1088
1089    def __call__(self, *params, **named_params):
1090        """Log parameters and return the specified return value.
1091
1092        If the Mock(Anything/Object) associated with this call is in record
1093        mode, this MockMethod will be pushed onto the expected call queue.
1094        If the mock is in replay mode, this will pop a MockMethod off the
1095        top of the queue and verify this call is equal to the expected call.
1096
1097        Raises:
1098            UnexpectedMethodCall if this call is supposed to match an expected
1099                method call and it does not.
1100        """
1101
1102        self._params = params
1103        self._named_params = named_params
1104
1105        if not self._replay_mode:
1106            if self._checker is not None:
1107                self._checker.Check(params, named_params)
1108            self._call_queue.append(self)
1109            return self
1110
1111        expected_method = self._VerifyMethodCall()
1112
1113        if expected_method._side_effects:
1114            result = expected_method._side_effects(*params, **named_params)
1115            if expected_method._return_value is None:
1116                expected_method._return_value = result
1117
1118        if expected_method._exception:
1119            raise expected_method._exception
1120
1121        return expected_method._return_value
1122
1123    def __getattr__(self, name):
1124        """Raise an AttributeError with a helpful message."""
1125
1126        raise AttributeError(
1127            'MockMethod has no attribute "%s". '
1128            'Did you remember to put your mocks in replay mode?' % name)
1129
1130    def __iter__(self):
1131        """Raise a TypeError with a helpful message."""
1132        raise TypeError(
1133            'MockMethod cannot be iterated. '
1134            'Did you remember to put your mocks in replay mode?')
1135
1136    def next(self):
1137        """Raise a TypeError with a helpful message."""
1138        raise TypeError(
1139            'MockMethod cannot be iterated. '
1140            'Did you remember to put your mocks in replay mode?')
1141
1142    def __next__(self):
1143        """Raise a TypeError with a helpful message."""
1144        raise TypeError(
1145            'MockMethod cannot be iterated. '
1146            'Did you remember to put your mocks in replay mode?')
1147
1148    def _PopNextMethod(self):
1149        """Pop the next method from our call queue."""
1150        try:
1151            return self._call_queue.popleft()
1152        except IndexError:
1153            raise UnexpectedMethodCallError(self, None)
1154
1155    def _VerifyMethodCall(self):
1156        """Verify the called method is expected.
1157
1158        This can be an ordered method, or part of an unordered set.
1159
1160        Returns:
1161            The expected mock method.
1162
1163        Raises:
1164            UnexpectedMethodCall if the method called was not expected.
1165        """
1166
1167        expected = self._PopNextMethod()
1168
1169        # Loop here, because we might have a MethodGroup followed by another
1170        # group.
1171        while isinstance(expected, MethodGroup):
1172            expected, method = expected.MethodCalled(self)
1173            if method is not None:
1174                return method
1175
1176        # This is a mock method, so just check equality.
1177        if expected != self:
1178            raise UnexpectedMethodCallError(self, expected)
1179
1180        return expected
1181
1182    def __str__(self):
1183        params = ', '.join(
1184            [repr(p) for p in self._params or []] +
1185            ['%s=%r' % x for x in sorted((self._named_params or {}).items())])
1186        full_desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
1187        if self._description:
1188            full_desc = "%s.%s" % (self._description, full_desc)
1189        return full_desc
1190
1191    def __hash__(self):
1192        return id(self)
1193
1194    def __eq__(self, rhs):
1195        """Test whether this MockMethod is equivalent to another MockMethod.
1196
1197        Args:
1198            # rhs: the right hand side of the test
1199            rhs: MockMethod
1200        """
1201
1202        return (isinstance(rhs, MockMethod) and
1203                self._name == rhs._name and
1204                self._params == rhs._params and
1205                self._named_params == rhs._named_params)
1206
1207    def __ne__(self, rhs):
1208        """Test if this MockMethod is not equivalent to another MockMethod.
1209
1210        Args:
1211            # rhs: the right hand side of the test
1212            rhs: MockMethod
1213        """
1214
1215        return not self == rhs
1216
1217    def GetPossibleGroup(self):
1218        """Returns a possible group from the end of the call queue.
1219
1220        Return None if no other methods are on the stack.
1221        """
1222
1223        # Remove this method from the tail of the queue so we can add it
1224        # to a group.
1225        this_method = self._call_queue.pop()
1226        assert this_method == self
1227
1228        # Determine if the tail of the queue is a group, or just a regular
1229        # ordered mock method.
1230        group = None
1231        try:
1232            group = self._call_queue[-1]
1233        except IndexError:
1234            pass
1235
1236        return group
1237
1238    def _CheckAndCreateNewGroup(self, group_name, group_class):
1239        """Checks if the last method (a possible group) is an instance of our
1240        group_class. Adds the current method to this group or creates a
1241        new one.
1242
1243        Args:
1244
1245            group_name: the name of the group.
1246            group_class: the class used to create instance of this new group
1247        """
1248        group = self.GetPossibleGroup()
1249
1250        # If this is a group, and it is the correct group, add the method.
1251        if isinstance(group, group_class) and group.group_name() == group_name:
1252            group.AddMethod(self)
1253            return self
1254
1255        # Create a new group and add the method.
1256        new_group = group_class(group_name)
1257        new_group.AddMethod(self)
1258        self._call_queue.append(new_group)
1259        return self
1260
1261    def InAnyOrder(self, group_name="default"):
1262        """Move this method into a group of unordered calls.
1263
1264        A group of unordered calls must be defined together, and must be
1265        executed in full before the next expected method can be called.
1266        There can be multiple groups that are expected serially, if they are
1267        given different group names. The same group name can be reused if there
1268        is a standard method call, or a group with a different name, spliced
1269        between usages.
1270
1271        Args:
1272            group_name: the name of the unordered group.
1273
1274        Returns:
1275            self
1276        """
1277        return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
1278
1279    def MultipleTimes(self, group_name="default"):
1280        """Move method into group of calls which may be called multiple times.
1281
1282        A group of repeating calls must be defined together, and must be
1283        executed in full before the next expected method can be called.
1284
1285        Args:
1286            group_name: the name of the unordered group.
1287
1288        Returns:
1289            self
1290        """
1291        return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
1292
1293    def AndReturn(self, return_value):
1294        """Set the value to return when this method is called.
1295
1296        Args:
1297            # return_value can be anything.
1298        """
1299
1300        self._return_value = return_value
1301        return return_value
1302
1303    def AndRaise(self, exception):
1304        """Set the exception to raise when this method is called.
1305
1306        Args:
1307            # exception: the exception to raise when this method is called.
1308            exception: Exception
1309        """
1310
1311        self._exception = exception
1312
1313    def WithSideEffects(self, side_effects):
1314        """Set the side effects that are simulated when this method is called.
1315
1316        Args:
1317            side_effects: A callable which modifies the parameters or other
1318                          relevant state which a given test case depends on.
1319
1320        Returns:
1321            Self for chaining with AndReturn and AndRaise.
1322        """
1323        self._side_effects = side_effects
1324        return self
1325
1326
1327class Comparator:
1328    """Base class for all Mox comparators.
1329
1330    A Comparator can be used as a parameter to a mocked method when the exact
1331    value is not known.    For example, the code you are testing might build up
1332    a long SQL string that is passed to your mock DAO. You're only interested
1333    that the IN clause contains the proper primary keys, so you can set your
1334    mock up as follows:
1335
1336    mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
1337
1338    Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
1339
1340    A Comparator may replace one or more parameters, for example:
1341    # return at most 10 rows
1342    mock_dao.RunQuery(StrContains('SELECT'), 10)
1343
1344    or
1345
1346    # Return some non-deterministic number of rows
1347    mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
1348    """
1349
1350    def equals(self, rhs):
1351        """Special equals method that all comparators must implement.
1352
1353        Args:
1354            rhs: any python object
1355        """
1356
1357        raise NotImplementedError('method must be implemented by a subclass.')
1358
1359    def __eq__(self, rhs):
1360        return self.equals(rhs)
1361
1362    def __ne__(self, rhs):
1363        return not self.equals(rhs)
1364
1365
1366class Is(Comparator):
1367    """Comparison class used to check identity, instead of equality."""
1368
1369    def __init__(self, obj):
1370        self._obj = obj
1371
1372    def equals(self, rhs):
1373        return rhs is self._obj
1374
1375    def __repr__(self):
1376        return "<is %r (%s)>" % (self._obj, id(self._obj))
1377
1378
1379class IsA(Comparator):
1380    """This class wraps a basic Python type or class.    It is used to verify
1381    that a parameter is of the given type or class.
1382
1383    Example:
1384    mock_dao.Connect(IsA(DbConnectInfo))
1385    """
1386
1387    def __init__(self, class_name):
1388        """Initialize IsA
1389
1390        Args:
1391            class_name: basic python type or a class
1392        """
1393
1394        self._class_name = class_name
1395
1396    def equals(self, rhs):
1397        """Check to see if the RHS is an instance of class_name.
1398
1399        Args:
1400            # rhs: the right hand side of the test
1401            rhs: object
1402
1403        Returns:
1404            bool
1405        """
1406
1407        try:
1408            return isinstance(rhs, self._class_name)
1409        except TypeError:
1410            # Check raw types if there was a type error.    This is helpful for
1411            # things like cStringIO.StringIO.
1412            return type(rhs) == type(self._class_name)
1413
1414    def _IsSubClass(self, clazz):
1415        """Check to see if the IsA comparators class is a subclass of clazz.
1416
1417        Args:
1418            # clazz: a class object
1419
1420        Returns:
1421            bool
1422        """
1423
1424        try:
1425            return issubclass(self._class_name, clazz)
1426        except TypeError:
1427            # Check raw types if there was a type error.    This is helpful for
1428            # things like cStringIO.StringIO.
1429            return type(clazz) == type(self._class_name)
1430
1431    def __repr__(self):
1432        return 'mox.IsA(%s) ' % str(self._class_name)
1433
1434
1435class IsAlmost(Comparator):
1436    """Comparison class used to check whether a parameter is nearly equal
1437    to a given value.    Generally useful for floating point numbers.
1438
1439    Example mock_dao.SetTimeout((IsAlmost(3.9)))
1440    """
1441
1442    def __init__(self, float_value, places=7):
1443        """Initialize IsAlmost.
1444
1445        Args:
1446            float_value: The value for making the comparison.
1447            places: The number of decimal places to round to.
1448        """
1449
1450        self._float_value = float_value
1451        self._places = places
1452
1453    def equals(self, rhs):
1454        """Check to see if RHS is almost equal to float_value
1455
1456        Args:
1457            rhs: the value to compare to float_value
1458
1459        Returns:
1460            bool
1461        """
1462
1463        try:
1464            return round(rhs - self._float_value, self._places) == 0
1465        except Exception:
1466            # Probably because either float_value or rhs is not a number.
1467            return False
1468
1469    def __repr__(self):
1470        return str(self._float_value)
1471
1472
1473class StrContains(Comparator):
1474    """Comparison class used to check whether a substring exists in a
1475    string parameter.    This can be useful in mocking a database with SQL
1476    passed in as a string parameter, for example.
1477
1478    Example:
1479    mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
1480    """
1481
1482    def __init__(self, search_string):
1483        """Initialize.
1484
1485        Args:
1486            # search_string: the string you are searching for
1487            search_string: str
1488        """
1489
1490        self._search_string = search_string
1491
1492    def equals(self, rhs):
1493        """Check to see if the search_string is contained in the rhs string.
1494
1495        Args:
1496            # rhs: the right hand side of the test
1497            rhs: object
1498
1499        Returns:
1500            bool
1501        """
1502
1503        try:
1504            return rhs.find(self._search_string) > -1
1505        except Exception:
1506            return False
1507
1508    def __repr__(self):
1509        return '<str containing \'%s\'>' % self._search_string
1510
1511
1512class Regex(Comparator):
1513    """Checks if a string matches a regular expression.
1514
1515    This uses a given regular expression to determine equality.
1516    """
1517
1518    def __init__(self, pattern, flags=0):
1519        """Initialize.
1520
1521        Args:
1522            # pattern is the regular expression to search for
1523            pattern: str
1524            # flags passed to re.compile function as the second argument
1525            flags: int
1526        """
1527        self.flags = flags
1528        self.regex = re.compile(pattern, flags=flags)
1529
1530    def equals(self, rhs):
1531        """Check to see if rhs matches regular expression pattern.
1532
1533        Returns:
1534            bool
1535        """
1536
1537        try:
1538            return self.regex.search(rhs) is not None
1539        except Exception:
1540            return False
1541
1542    def __repr__(self):
1543        s = '<regular expression \'%s\'' % self.regex.pattern
1544        if self.flags:
1545            s += ', flags=%d' % self.flags
1546        s += '>'
1547        return s
1548
1549
1550class In(Comparator):
1551    """Checks whether an item (or key) is in a list (or dict) parameter.
1552
1553    Example:
1554    mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
1555    """
1556
1557    def __init__(self, key):
1558        """Initialize.
1559
1560        Args:
1561            # key is any thing that could be in a list or a key in a dict
1562        """
1563
1564        self._key = key
1565
1566    def equals(self, rhs):
1567        """Check to see whether key is in rhs.
1568
1569        Args:
1570            rhs: dict
1571
1572        Returns:
1573            bool
1574        """
1575
1576        try:
1577            return self._key in rhs
1578        except Exception:
1579            return False
1580
1581    def __repr__(self):
1582        return '<sequence or map containing \'%s\'>' % str(self._key)
1583
1584
1585class Not(Comparator):
1586    """Checks whether a predicates is False.
1587
1588    Example:
1589        mock_dao.UpdateUsers(Not(ContainsKeyValue('stevepm',
1590                                                  stevepm_user_info)))
1591    """
1592
1593    def __init__(self, predicate):
1594        """Initialize.
1595
1596        Args:
1597            # predicate: a Comparator instance.
1598        """
1599
1600        assert isinstance(predicate, Comparator), ("predicate %r must be a"
1601                                                   " Comparator." % predicate)
1602        self._predicate = predicate
1603
1604    def equals(self, rhs):
1605        """Check to see whether the predicate is False.
1606
1607        Args:
1608            rhs: A value that will be given in argument of the predicate.
1609
1610        Returns:
1611            bool
1612        """
1613
1614        try:
1615            return not self._predicate.equals(rhs)
1616        except Exception:
1617            return False
1618
1619    def __repr__(self):
1620        return '<not \'%s\'>' % self._predicate
1621
1622
1623class ContainsKeyValue(Comparator):
1624    """Checks whether a key/value pair is in a dict parameter.
1625
1626    Example:
1627    mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
1628    """
1629
1630    def __init__(self, key, value):
1631        """Initialize.
1632
1633        Args:
1634            # key: a key in a dict
1635            # value: the corresponding value
1636        """
1637
1638        self._key = key
1639        self._value = value
1640
1641    def equals(self, rhs):
1642        """Check whether the given key/value pair is in the rhs dict.
1643
1644        Returns:
1645            bool
1646        """
1647
1648        try:
1649            return rhs[self._key] == self._value
1650        except Exception:
1651            return False
1652
1653    def __repr__(self):
1654        return '<map containing the entry \'%s: %s\'>' % (str(self._key),
1655                                                          str(self._value))
1656
1657
1658class ContainsAttributeValue(Comparator):
1659    """Checks whether passed parameter contains attributes with a given value.
1660
1661    Example:
1662    mock_dao.UpdateSomething(ContainsAttribute('stevepm', stevepm_user_info))
1663    """
1664
1665    def __init__(self, key, value):
1666        """Initialize.
1667
1668        Args:
1669            # key: an attribute name of an object
1670            # value: the corresponding value
1671        """
1672
1673        self._key = key
1674        self._value = value
1675
1676    def equals(self, rhs):
1677        """Check if the given attribute has a matching value in the rhs object.
1678
1679        Returns:
1680            bool
1681        """
1682
1683        try:
1684            return getattr(rhs, self._key) == self._value
1685        except Exception:
1686            return False
1687
1688
1689class SameElementsAs(Comparator):
1690    """Checks whether sequences contain the same elements (ignoring order).
1691
1692    Example:
1693    mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
1694    """
1695
1696    def __init__(self, expected_seq):
1697        """Initialize.
1698
1699        Args:
1700            expected_seq: a sequence
1701        """
1702        # Store in case expected_seq is an iterator.
1703        self._expected_list = list(expected_seq)
1704
1705    def equals(self, actual_seq):
1706        """Check to see whether actual_seq has same elements as expected_seq.
1707
1708        Args:
1709            actual_seq: sequence
1710
1711        Returns:
1712            bool
1713        """
1714        try:
1715            # Store in case actual_seq is an iterator. We potentially iterate
1716            # twice: once to make the dict, once in the list fallback.
1717            actual_list = list(actual_seq)
1718        except TypeError:
1719            # actual_seq cannot be read as a sequence.
1720            #
1721            # This happens because Mox uses __eq__ both to check object
1722            # equality (in MethodSignatureChecker) and to invoke Comparators.
1723            return False
1724
1725        try:
1726            return set(self._expected_list) == set(actual_list)
1727        except TypeError:
1728            # Fall back to slower list-compare if any of the objects
1729            # are unhashable.
1730            if len(self._expected_list) != len(actual_list):
1731                return False
1732            for el in actual_list:
1733                if el not in self._expected_list:
1734                    return False
1735        return True
1736
1737    def __repr__(self):
1738        return '<sequence with same elements as \'%s\'>' % self._expected_list
1739
1740
1741class And(Comparator):
1742    """Evaluates one or more Comparators on RHS, returns an AND of the results.
1743    """
1744
1745    def __init__(self, *args):
1746        """Initialize.
1747
1748        Args:
1749            *args: One or more Comparator
1750        """
1751
1752        self._comparators = args
1753
1754    def equals(self, rhs):
1755        """Checks whether all Comparators are equal to rhs.
1756
1757        Args:
1758            # rhs: can be anything
1759
1760        Returns:
1761            bool
1762        """
1763
1764        for comparator in self._comparators:
1765            if not comparator.equals(rhs):
1766                return False
1767
1768        return True
1769
1770    def __repr__(self):
1771        return '<AND %s>' % str(self._comparators)
1772
1773
1774class Or(Comparator):
1775    """Evaluates one or more Comparators on RHS; returns OR of the results."""
1776
1777    def __init__(self, *args):
1778        """Initialize.
1779
1780        Args:
1781            *args: One or more Mox comparators
1782        """
1783
1784        self._comparators = args
1785
1786    def equals(self, rhs):
1787        """Checks whether any Comparator is equal to rhs.
1788
1789        Args:
1790            # rhs: can be anything
1791
1792        Returns:
1793            bool
1794        """
1795
1796        for comparator in self._comparators:
1797            if comparator.equals(rhs):
1798                return True
1799
1800        return False
1801
1802    def __repr__(self):
1803        return '<OR %s>' % str(self._comparators)
1804
1805
1806class Func(Comparator):
1807    """Call a function that should verify the parameter passed in is correct.
1808
1809    You may need the ability to perform more advanced operations on the
1810    parameter in order to validate it. You can use this to have a callable
1811    validate any parameter. The callable should return either True or False.
1812
1813
1814    Example:
1815
1816    def myParamValidator(param):
1817        # Advanced logic here
1818        return True
1819
1820    mock_dao.DoSomething(Func(myParamValidator), true)
1821    """
1822
1823    def __init__(self, func):
1824        """Initialize.
1825
1826        Args:
1827            func: callable that takes one parameter and returns a bool
1828        """
1829
1830        self._func = func
1831
1832    def equals(self, rhs):
1833        """Test whether rhs passes the function test.
1834
1835        rhs is passed into func.
1836
1837        Args:
1838            rhs: any python object
1839
1840        Returns:
1841            the result of func(rhs)
1842        """
1843
1844        return self._func(rhs)
1845
1846    def __repr__(self):
1847        return str(self._func)
1848
1849
1850class IgnoreArg(Comparator):
1851    """Ignore an argument.
1852
1853    This can be used when we don't care about an argument of a method call.
1854
1855    Example:
1856    # Check if CastMagic is called with 3 as first arg and
1857    # 'disappear' as third.
1858    mymock.CastMagic(3, IgnoreArg(), 'disappear')
1859    """
1860
1861    def equals(self, unused_rhs):
1862        """Ignores arguments and returns True.
1863
1864        Args:
1865            unused_rhs: any python object
1866
1867        Returns:
1868            always returns True
1869        """
1870
1871        return True
1872
1873    def __repr__(self):
1874        return '<IgnoreArg>'
1875
1876
1877class Value(Comparator):
1878    """Compares argument against a remembered value.
1879
1880    To be used in conjunction with Remember comparator.    See Remember()
1881    for example.
1882    """
1883
1884    def __init__(self):
1885        self._value = None
1886        self._has_value = False
1887
1888    def store_value(self, rhs):
1889        self._value = rhs
1890        self._has_value = True
1891
1892    def equals(self, rhs):
1893        if not self._has_value:
1894            return False
1895        else:
1896            return rhs == self._value
1897
1898    def __repr__(self):
1899        if self._has_value:
1900            return "<Value %r>" % self._value
1901        else:
1902            return "<Value>"
1903
1904
1905class Remember(Comparator):
1906    """Remembers the argument to a value store.
1907
1908    To be used in conjunction with Value comparator.
1909
1910    Example:
1911    # Remember the argument for one method call.
1912    users_list = Value()
1913    mock_dao.ProcessUsers(Remember(users_list))
1914
1915    # Check argument against remembered value.
1916    mock_dao.ReportUsers(users_list)
1917    """
1918
1919    def __init__(self, value_store):
1920        if not isinstance(value_store, Value):
1921            raise TypeError(
1922                "value_store is not an instance of the Value class")
1923        self._value_store = value_store
1924
1925    def equals(self, rhs):
1926        self._value_store.store_value(rhs)
1927        return True
1928
1929    def __repr__(self):
1930        return "<Remember %d>" % id(self._value_store)
1931
1932
1933class MethodGroup(object):
1934    """Base class containing common behaviour for MethodGroups."""
1935
1936    def __init__(self, group_name):
1937        self._group_name = group_name
1938
1939    def group_name(self):
1940        return self._group_name
1941
1942    def __str__(self):
1943        return '<%s "%s">' % (self.__class__.__name__, self._group_name)
1944
1945    def AddMethod(self, mock_method):
1946        raise NotImplementedError
1947
1948    def MethodCalled(self, mock_method):
1949        raise NotImplementedError
1950
1951    def IsSatisfied(self):
1952        raise NotImplementedError
1953
1954
1955class UnorderedGroup(MethodGroup):
1956    """UnorderedGroup holds a set of method calls that may occur in any order.
1957
1958    This construct is helpful for non-deterministic events, such as iterating
1959    over the keys of a dict.
1960    """
1961
1962    def __init__(self, group_name):
1963        super(UnorderedGroup, self).__init__(group_name)
1964        self._methods = []
1965
1966    def __str__(self):
1967        return '%s "%s" pending calls:\n%s' % (
1968            self.__class__.__name__,
1969            self._group_name,
1970            "\n".join(str(method) for method in self._methods))
1971
1972    def AddMethod(self, mock_method):
1973        """Add a method to this group.
1974
1975        Args:
1976            mock_method: A mock method to be added to this group.
1977        """
1978
1979        self._methods.append(mock_method)
1980
1981    def MethodCalled(self, mock_method):
1982        """Remove a method call from the group.
1983
1984        If the method is not in the set, an UnexpectedMethodCallError will be
1985        raised.
1986
1987        Args:
1988            mock_method: a mock method that should be equal to a method in the
1989                         group.
1990
1991        Returns:
1992            The mock method from the group
1993
1994        Raises:
1995            UnexpectedMethodCallError if the mock_method was not in the group.
1996        """
1997
1998        # Check to see if this method exists, and if so, remove it from the set
1999        # and return it.
2000        for method in self._methods:
2001            if method == mock_method:
2002                # Remove the called mock_method instead of the method in the
2003                # group. The called method will match any comparators when
2004                # equality is checked during removal. The method in the group
2005                # could pass a comparator to another comparator during the
2006                # equality check.
2007                self._methods.remove(mock_method)
2008
2009                # If group is not empty, put it back at the head of the queue.
2010                if not self.IsSatisfied():
2011                    mock_method._call_queue.appendleft(self)
2012
2013                return self, method
2014
2015        raise UnexpectedMethodCallError(mock_method, self)
2016
2017    def IsSatisfied(self):
2018        """Return True if there are not any methods in this group."""
2019
2020        return len(self._methods) == 0
2021
2022
2023class MultipleTimesGroup(MethodGroup):
2024    """MultipleTimesGroup holds methods that may be called any number of times.
2025
2026    Note: Each method must be called at least once.
2027
2028    This is helpful, if you don't know or care how many times a method is
2029    called.
2030    """
2031
2032    def __init__(self, group_name):
2033        super(MultipleTimesGroup, self).__init__(group_name)
2034        self._methods = set()
2035        self._methods_left = set()
2036
2037    def AddMethod(self, mock_method):
2038        """Add a method to this group.
2039
2040        Args:
2041            mock_method: A mock method to be added to this group.
2042        """
2043
2044        self._methods.add(mock_method)
2045        self._methods_left.add(mock_method)
2046
2047    def MethodCalled(self, mock_method):
2048        """Remove a method call from the group.
2049
2050        If the method is not in the set, an UnexpectedMethodCallError will be
2051        raised.
2052
2053        Args:
2054            mock_method: a mock method that should be equal to a method in the
2055                         group.
2056
2057        Returns:
2058            The mock method from the group
2059
2060        Raises:
2061            UnexpectedMethodCallError if the mock_method was not in the group.
2062        """
2063
2064        # Check to see if this method exists, and if so add it to the set of
2065        # called methods.
2066        for method in self._methods:
2067            if method == mock_method:
2068                self._methods_left.discard(method)
2069                # Always put this group back on top of the queue,
2070                # because we don't know when we are done.
2071                mock_method._call_queue.appendleft(self)
2072                return self, method
2073
2074        if self.IsSatisfied():
2075            next_method = mock_method._PopNextMethod()
2076            return next_method, None
2077        else:
2078            raise UnexpectedMethodCallError(mock_method, self)
2079
2080    def IsSatisfied(self):
2081        """Return True if all methods in group are called at least once."""
2082        return len(self._methods_left) == 0
2083
2084
2085class MoxMetaTestBase(type):
2086    """Metaclass to add mox cleanup and verification to every test.
2087
2088    As the mox unit testing class is being constructed (MoxTestBase or a
2089    subclass), this metaclass will modify all test functions to call the
2090    CleanUpMox method of the test class after they finish. This means that
2091    unstubbing and verifying will happen for every test with no additional
2092    code, and any failures will result in test failures as opposed to errors.
2093    """
2094
2095    def __init__(cls, name, bases, d):
2096        type.__init__(cls, name, bases, d)
2097
2098        # also get all the attributes from the base classes to account
2099        # for a case when test class is not the immediate child of MoxTestBase
2100        for base in bases:
2101            for attr_name in dir(base):
2102                if attr_name not in d:
2103                    d[attr_name] = getattr(base, attr_name)
2104
2105        for func_name, func in d.items():
2106            if func_name.startswith('test') and callable(func):
2107
2108                setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
2109
2110    @staticmethod
2111    def CleanUpTest(cls, func):
2112        """Adds Mox cleanup code to any MoxTestBase method.
2113
2114        Always unsets stubs after a test. Will verify all mocks for tests that
2115        otherwise pass.
2116
2117        Args:
2118            cls: MoxTestBase or subclass; the class whose method we are
2119                                          altering.
2120            func: method; the method of the MoxTestBase test class we wish to
2121                          alter.
2122
2123        Returns:
2124            The modified method.
2125        """
2126        def new_method(self, *args, **kwargs):
2127            mox_obj = getattr(self, 'mox', None)
2128            stubout_obj = getattr(self, 'stubs', None)
2129            cleanup_mox = False
2130            cleanup_stubout = False
2131            if mox_obj and isinstance(mox_obj, Mox):
2132                cleanup_mox = True
2133            if stubout_obj and isinstance(stubout_obj,
2134                                          stubout.StubOutForTesting):
2135                cleanup_stubout = True
2136            try:
2137                func(self, *args, **kwargs)
2138            finally:
2139                if cleanup_mox:
2140                    mox_obj.UnsetStubs()
2141                if cleanup_stubout:
2142                    stubout_obj.UnsetAll()
2143                    stubout_obj.SmartUnsetAll()
2144            if cleanup_mox:
2145                mox_obj.VerifyAll()
2146        new_method.__name__ = func.__name__
2147        new_method.__doc__ = func.__doc__
2148        new_method.__module__ = func.__module__
2149        return new_method
2150
2151
2152_MoxTestBase = MoxMetaTestBase('_MoxTestBase', (unittest.TestCase, ), {})
2153
2154
2155class MoxTestBase(_MoxTestBase):
2156    """Convenience test class to make stubbing easier.
2157
2158    Sets up a "mox" attribute which is an instance of Mox (any mox tests will
2159    want this), and a "stubs" attribute that is an instance of
2160    StubOutForTesting (needed at times). Also automatically unsets any stubs
2161    and verifies that all mock methods have been called at the end of each
2162    test, eliminating boilerplate code.
2163    """
2164
2165    def setUp(self):
2166        super(MoxTestBase, self).setUp()
2167        self.mox = Mox()
2168        self.stubs = stubout.StubOutForTesting()
2169