1import unittest
2from test import test_support
3
4def funcattrs(**kwds):
5    def decorate(func):
6        func.__dict__.update(kwds)
7        return func
8    return decorate
9
10class MiscDecorators (object):
11    @staticmethod
12    def author(name):
13        def decorate(func):
14            func.__dict__['author'] = name
15            return func
16        return decorate
17
18# -----------------------------------------------
19
20class DbcheckError (Exception):
21    def __init__(self, exprstr, func, args, kwds):
22        # A real version of this would set attributes here
23        Exception.__init__(self, "dbcheck %r failed (func=%s args=%s kwds=%s)" %
24                           (exprstr, func, args, kwds))
25
26
27def dbcheck(exprstr, globals=None, locals=None):
28    "Decorator to implement debugging assertions"
29    def decorate(func):
30        expr = compile(exprstr, "dbcheck-%s" % func.func_name, "eval")
31        def check(*args, **kwds):
32            if not eval(expr, globals, locals):
33                raise DbcheckError(exprstr, func, args, kwds)
34            return func(*args, **kwds)
35        return check
36    return decorate
37
38# -----------------------------------------------
39
40def countcalls(counts):
41    "Decorator to count calls to a function"
42    def decorate(func):
43        func_name = func.func_name
44        counts[func_name] = 0
45        def call(*args, **kwds):
46            counts[func_name] += 1
47            return func(*args, **kwds)
48        call.func_name = func_name
49        return call
50    return decorate
51
52# -----------------------------------------------
53
54def memoize(func):
55    saved = {}
56    def call(*args):
57        try:
58            return saved[args]
59        except KeyError:
60            res = func(*args)
61            saved[args] = res
62            return res
63        except TypeError:
64            # Unhashable argument
65            return func(*args)
66    call.func_name = func.func_name
67    return call
68
69# -----------------------------------------------
70
71class TestDecorators(unittest.TestCase):
72
73    def test_single(self):
74        class C(object):
75            @staticmethod
76            def foo(): return 42
77        self.assertEqual(C.foo(), 42)
78        self.assertEqual(C().foo(), 42)
79
80    def test_staticmethod_function(self):
81        @staticmethod
82        def notamethod(x):
83            return x
84        self.assertRaises(TypeError, notamethod, 1)
85
86    def test_dotted(self):
87        decorators = MiscDecorators()
88        @decorators.author('Cleese')
89        def foo(): return 42
90        self.assertEqual(foo(), 42)
91        self.assertEqual(foo.author, 'Cleese')
92
93    def test_argforms(self):
94        # A few tests of argument passing, as we use restricted form
95        # of expressions for decorators.
96
97        def noteargs(*args, **kwds):
98            def decorate(func):
99                setattr(func, 'dbval', (args, kwds))
100                return func
101            return decorate
102
103        args = ( 'Now', 'is', 'the', 'time' )
104        kwds = dict(one=1, two=2)
105        @noteargs(*args, **kwds)
106        def f1(): return 42
107        self.assertEqual(f1(), 42)
108        self.assertEqual(f1.dbval, (args, kwds))
109
110        @noteargs('terry', 'gilliam', eric='idle', john='cleese')
111        def f2(): return 84
112        self.assertEqual(f2(), 84)
113        self.assertEqual(f2.dbval, (('terry', 'gilliam'),
114                                     dict(eric='idle', john='cleese')))
115
116        @noteargs(1, 2,)
117        def f3(): pass
118        self.assertEqual(f3.dbval, ((1, 2), {}))
119
120    def test_dbcheck(self):
121        @dbcheck('args[1] is not None')
122        def f(a, b):
123            return a + b
124        self.assertEqual(f(1, 2), 3)
125        self.assertRaises(DbcheckError, f, 1, None)
126
127    def test_memoize(self):
128        counts = {}
129
130        @memoize
131        @countcalls(counts)
132        def double(x):
133            return x * 2
134        self.assertEqual(double.func_name, 'double')
135
136        self.assertEqual(counts, dict(double=0))
137
138        # Only the first call with a given argument bumps the call count:
139        #
140        self.assertEqual(double(2), 4)
141        self.assertEqual(counts['double'], 1)
142        self.assertEqual(double(2), 4)
143        self.assertEqual(counts['double'], 1)
144        self.assertEqual(double(3), 6)
145        self.assertEqual(counts['double'], 2)
146
147        # Unhashable arguments do not get memoized:
148        #
149        self.assertEqual(double([10]), [10, 10])
150        self.assertEqual(counts['double'], 3)
151        self.assertEqual(double([10]), [10, 10])
152        self.assertEqual(counts['double'], 4)
153
154    def test_errors(self):
155        # Test syntax restrictions - these are all compile-time errors:
156        #
157        for expr in [ "1+2", "x[3]", "(1, 2)" ]:
158            # Sanity check: is expr is a valid expression by itself?
159            compile(expr, "testexpr", "exec")
160
161            codestr = "@%s\ndef f(): pass" % expr
162            self.assertRaises(SyntaxError, compile, codestr, "test", "exec")
163
164        # You can't put multiple decorators on a single line:
165        #
166        self.assertRaises(SyntaxError, compile,
167                          "@f1 @f2\ndef f(): pass", "test", "exec")
168
169        # Test runtime errors
170
171        def unimp(func):
172            raise NotImplementedError
173        context = dict(nullval=None, unimp=unimp)
174
175        for expr, exc in [ ("undef", NameError),
176                           ("nullval", TypeError),
177                           ("nullval.attr", AttributeError),
178                           ("unimp", NotImplementedError)]:
179            codestr = "@%s\ndef f(): pass\nassert f() is None" % expr
180            code = compile(codestr, "test", "exec")
181            self.assertRaises(exc, eval, code, context)
182
183    def test_double(self):
184        class C(object):
185            @funcattrs(abc=1, xyz="haha")
186            @funcattrs(booh=42)
187            def foo(self): return 42
188        self.assertEqual(C().foo(), 42)
189        self.assertEqual(C.foo.abc, 1)
190        self.assertEqual(C.foo.xyz, "haha")
191        self.assertEqual(C.foo.booh, 42)
192
193    def test_order(self):
194        # Test that decorators are applied in the proper order to the function
195        # they are decorating.
196        def callnum(num):
197            """Decorator factory that returns a decorator that replaces the
198            passed-in function with one that returns the value of 'num'"""
199            def deco(func):
200                return lambda: num
201            return deco
202        @callnum(2)
203        @callnum(1)
204        def foo(): return 42
205        self.assertEqual(foo(), 2,
206                            "Application order of decorators is incorrect")
207
208    def test_eval_order(self):
209        # Evaluating a decorated function involves four steps for each
210        # decorator-maker (the function that returns a decorator):
211        #
212        #    1: Evaluate the decorator-maker name
213        #    2: Evaluate the decorator-maker arguments (if any)
214        #    3: Call the decorator-maker to make a decorator
215        #    4: Call the decorator
216        #
217        # When there are multiple decorators, these steps should be
218        # performed in the above order for each decorator, but we should
219        # iterate through the decorators in the reverse of the order they
220        # appear in the source.
221
222        actions = []
223
224        def make_decorator(tag):
225            actions.append('makedec' + tag)
226            def decorate(func):
227                actions.append('calldec' + tag)
228                return func
229            return decorate
230
231        class NameLookupTracer (object):
232            def __init__(self, index):
233                self.index = index
234
235            def __getattr__(self, fname):
236                if fname == 'make_decorator':
237                    opname, res = ('evalname', make_decorator)
238                elif fname == 'arg':
239                    opname, res = ('evalargs', str(self.index))
240                else:
241                    assert False, "Unknown attrname %s" % fname
242                actions.append('%s%d' % (opname, self.index))
243                return res
244
245        c1, c2, c3 = map(NameLookupTracer, [ 1, 2, 3 ])
246
247        expected_actions = [ 'evalname1', 'evalargs1', 'makedec1',
248                             'evalname2', 'evalargs2', 'makedec2',
249                             'evalname3', 'evalargs3', 'makedec3',
250                             'calldec3', 'calldec2', 'calldec1' ]
251
252        actions = []
253        @c1.make_decorator(c1.arg)
254        @c2.make_decorator(c2.arg)
255        @c3.make_decorator(c3.arg)
256        def foo(): return 42
257        self.assertEqual(foo(), 42)
258
259        self.assertEqual(actions, expected_actions)
260
261        # Test the equivalence claim in chapter 7 of the reference manual.
262        #
263        actions = []
264        def bar(): return 42
265        bar = c1.make_decorator(c1.arg)(c2.make_decorator(c2.arg)(c3.make_decorator(c3.arg)(bar)))
266        self.assertEqual(bar(), 42)
267        self.assertEqual(actions, expected_actions)
268
269class TestClassDecorators(unittest.TestCase):
270
271    def test_simple(self):
272        def plain(x):
273            x.extra = 'Hello'
274            return x
275        @plain
276        class C(object): pass
277        self.assertEqual(C.extra, 'Hello')
278
279    def test_double(self):
280        def ten(x):
281            x.extra = 10
282            return x
283        def add_five(x):
284            x.extra += 5
285            return x
286
287        @add_five
288        @ten
289        class C(object): pass
290        self.assertEqual(C.extra, 15)
291
292    def test_order(self):
293        def applied_first(x):
294            x.extra = 'first'
295            return x
296        def applied_second(x):
297            x.extra = 'second'
298            return x
299        @applied_second
300        @applied_first
301        class C(object): pass
302        self.assertEqual(C.extra, 'second')
303
304def test_main():
305    test_support.run_unittest(TestDecorators)
306    test_support.run_unittest(TestClassDecorators)
307
308if __name__=="__main__":
309    test_main()
310