1# Copyright 2007 Google, Inc. All Rights Reserved.
2# Licensed to PSF under a Contributor Agreement.
3
4"""Unit tests for abc.py."""
5
6import unittest
7from test import support
8
9import abc
10from inspect import isabstract
11
12
13class TestLegacyAPI(unittest.TestCase):
14
15    def test_abstractproperty_basics(self):
16        @abc.abstractproperty
17        def foo(self): pass
18        self.assertTrue(foo.__isabstractmethod__)
19        def bar(self): pass
20        self.assertFalse(hasattr(bar, "__isabstractmethod__"))
21
22        class C(metaclass=abc.ABCMeta):
23            @abc.abstractproperty
24            def foo(self): return 3
25        self.assertRaises(TypeError, C)
26        class D(C):
27            @property
28            def foo(self): return super().foo
29        self.assertEqual(D().foo, 3)
30        self.assertFalse(getattr(D.foo, "__isabstractmethod__", False))
31
32    def test_abstractclassmethod_basics(self):
33        @abc.abstractclassmethod
34        def foo(cls): pass
35        self.assertTrue(foo.__isabstractmethod__)
36        @classmethod
37        def bar(cls): pass
38        self.assertFalse(getattr(bar, "__isabstractmethod__", False))
39
40        class C(metaclass=abc.ABCMeta):
41            @abc.abstractclassmethod
42            def foo(cls): return cls.__name__
43        self.assertRaises(TypeError, C)
44        class D(C):
45            @classmethod
46            def foo(cls): return super().foo()
47        self.assertEqual(D.foo(), 'D')
48        self.assertEqual(D().foo(), 'D')
49
50    def test_abstractstaticmethod_basics(self):
51        @abc.abstractstaticmethod
52        def foo(): pass
53        self.assertTrue(foo.__isabstractmethod__)
54        @staticmethod
55        def bar(): pass
56        self.assertFalse(getattr(bar, "__isabstractmethod__", False))
57
58        class C(metaclass=abc.ABCMeta):
59            @abc.abstractstaticmethod
60            def foo(): return 3
61        self.assertRaises(TypeError, C)
62        class D(C):
63            @staticmethod
64            def foo(): return 4
65        self.assertEqual(D.foo(), 4)
66        self.assertEqual(D().foo(), 4)
67
68
69class TestABC(unittest.TestCase):
70
71    def test_ABC_helper(self):
72        # create an ABC using the helper class and perform basic checks
73        class C(abc.ABC):
74            @classmethod
75            @abc.abstractmethod
76            def foo(cls): return cls.__name__
77        self.assertEqual(type(C), abc.ABCMeta)
78        self.assertRaises(TypeError, C)
79        class D(C):
80            @classmethod
81            def foo(cls): return super().foo()
82        self.assertEqual(D.foo(), 'D')
83
84    def test_abstractmethod_basics(self):
85        @abc.abstractmethod
86        def foo(self): pass
87        self.assertTrue(foo.__isabstractmethod__)
88        def bar(self): pass
89        self.assertFalse(hasattr(bar, "__isabstractmethod__"))
90
91    def test_abstractproperty_basics(self):
92        @property
93        @abc.abstractmethod
94        def foo(self): pass
95        self.assertTrue(foo.__isabstractmethod__)
96        def bar(self): pass
97        self.assertFalse(getattr(bar, "__isabstractmethod__", False))
98
99        class C(metaclass=abc.ABCMeta):
100            @property
101            @abc.abstractmethod
102            def foo(self): return 3
103        self.assertRaises(TypeError, C)
104        class D(C):
105            @C.foo.getter
106            def foo(self): return super().foo
107        self.assertEqual(D().foo, 3)
108
109    def test_abstractclassmethod_basics(self):
110        @classmethod
111        @abc.abstractmethod
112        def foo(cls): pass
113        self.assertTrue(foo.__isabstractmethod__)
114        @classmethod
115        def bar(cls): pass
116        self.assertFalse(getattr(bar, "__isabstractmethod__", False))
117
118        class C(metaclass=abc.ABCMeta):
119            @classmethod
120            @abc.abstractmethod
121            def foo(cls): return cls.__name__
122        self.assertRaises(TypeError, C)
123        class D(C):
124            @classmethod
125            def foo(cls): return super().foo()
126        self.assertEqual(D.foo(), 'D')
127        self.assertEqual(D().foo(), 'D')
128
129    def test_abstractstaticmethod_basics(self):
130        @staticmethod
131        @abc.abstractmethod
132        def foo(): pass
133        self.assertTrue(foo.__isabstractmethod__)
134        @staticmethod
135        def bar(): pass
136        self.assertFalse(getattr(bar, "__isabstractmethod__", False))
137
138        class C(metaclass=abc.ABCMeta):
139            @staticmethod
140            @abc.abstractmethod
141            def foo(): return 3
142        self.assertRaises(TypeError, C)
143        class D(C):
144            @staticmethod
145            def foo(): return 4
146        self.assertEqual(D.foo(), 4)
147        self.assertEqual(D().foo(), 4)
148
149    def test_abstractmethod_integration(self):
150        for abstractthing in [abc.abstractmethod, abc.abstractproperty,
151                              abc.abstractclassmethod,
152                              abc.abstractstaticmethod]:
153            class C(metaclass=abc.ABCMeta):
154                @abstractthing
155                def foo(self): pass  # abstract
156                def bar(self): pass  # concrete
157            self.assertEqual(C.__abstractmethods__, {"foo"})
158            self.assertRaises(TypeError, C)  # because foo is abstract
159            self.assertTrue(isabstract(C))
160            class D(C):
161                def bar(self): pass  # concrete override of concrete
162            self.assertEqual(D.__abstractmethods__, {"foo"})
163            self.assertRaises(TypeError, D)  # because foo is still abstract
164            self.assertTrue(isabstract(D))
165            class E(D):
166                def foo(self): pass
167            self.assertEqual(E.__abstractmethods__, set())
168            E()  # now foo is concrete, too
169            self.assertFalse(isabstract(E))
170            class F(E):
171                @abstractthing
172                def bar(self): pass  # abstract override of concrete
173            self.assertEqual(F.__abstractmethods__, {"bar"})
174            self.assertRaises(TypeError, F)  # because bar is abstract now
175            self.assertTrue(isabstract(F))
176
177    def test_descriptors_with_abstractmethod(self):
178        class C(metaclass=abc.ABCMeta):
179            @property
180            @abc.abstractmethod
181            def foo(self): return 3
182            @foo.setter
183            @abc.abstractmethod
184            def foo(self, val): pass
185        self.assertRaises(TypeError, C)
186        class D(C):
187            @C.foo.getter
188            def foo(self): return super().foo
189        self.assertRaises(TypeError, D)
190        class E(D):
191            @D.foo.setter
192            def foo(self, val): pass
193        self.assertEqual(E().foo, 3)
194        # check that the property's __isabstractmethod__ descriptor does the
195        # right thing when presented with a value that fails truth testing:
196        class NotBool(object):
197            def __bool__(self):
198                raise ValueError()
199            __len__ = __bool__
200        with self.assertRaises(ValueError):
201            class F(C):
202                def bar(self):
203                    pass
204                bar.__isabstractmethod__ = NotBool()
205                foo = property(bar)
206
207
208    def test_customdescriptors_with_abstractmethod(self):
209        class Descriptor:
210            def __init__(self, fget, fset=None):
211                self._fget = fget
212                self._fset = fset
213            def getter(self, callable):
214                return Descriptor(callable, self._fget)
215            def setter(self, callable):
216                return Descriptor(self._fget, callable)
217            @property
218            def __isabstractmethod__(self):
219                return (getattr(self._fget, '__isabstractmethod__', False)
220                        or getattr(self._fset, '__isabstractmethod__', False))
221        class C(metaclass=abc.ABCMeta):
222            @Descriptor
223            @abc.abstractmethod
224            def foo(self): return 3
225            @foo.setter
226            @abc.abstractmethod
227            def foo(self, val): pass
228        self.assertRaises(TypeError, C)
229        class D(C):
230            @C.foo.getter
231            def foo(self): return super().foo
232        self.assertRaises(TypeError, D)
233        class E(D):
234            @D.foo.setter
235            def foo(self, val): pass
236        self.assertFalse(E.foo.__isabstractmethod__)
237
238    def test_metaclass_abc(self):
239        # Metaclasses can be ABCs, too.
240        class A(metaclass=abc.ABCMeta):
241            @abc.abstractmethod
242            def x(self):
243                pass
244        self.assertEqual(A.__abstractmethods__, {"x"})
245        class meta(type, A):
246            def x(self):
247                return 1
248        class C(metaclass=meta):
249            pass
250
251    def test_registration_basics(self):
252        class A(metaclass=abc.ABCMeta):
253            pass
254        class B(object):
255            pass
256        b = B()
257        self.assertFalse(issubclass(B, A))
258        self.assertFalse(issubclass(B, (A,)))
259        self.assertNotIsInstance(b, A)
260        self.assertNotIsInstance(b, (A,))
261        B1 = A.register(B)
262        self.assertTrue(issubclass(B, A))
263        self.assertTrue(issubclass(B, (A,)))
264        self.assertIsInstance(b, A)
265        self.assertIsInstance(b, (A,))
266        self.assertIs(B1, B)
267        class C(B):
268            pass
269        c = C()
270        self.assertTrue(issubclass(C, A))
271        self.assertTrue(issubclass(C, (A,)))
272        self.assertIsInstance(c, A)
273        self.assertIsInstance(c, (A,))
274
275    def test_register_as_class_deco(self):
276        class A(metaclass=abc.ABCMeta):
277            pass
278        @A.register
279        class B(object):
280            pass
281        b = B()
282        self.assertTrue(issubclass(B, A))
283        self.assertTrue(issubclass(B, (A,)))
284        self.assertIsInstance(b, A)
285        self.assertIsInstance(b, (A,))
286        @A.register
287        class C(B):
288            pass
289        c = C()
290        self.assertTrue(issubclass(C, A))
291        self.assertTrue(issubclass(C, (A,)))
292        self.assertIsInstance(c, A)
293        self.assertIsInstance(c, (A,))
294        self.assertIs(C, A.register(C))
295
296    def test_isinstance_invalidation(self):
297        class A(metaclass=abc.ABCMeta):
298            pass
299        class B:
300            pass
301        b = B()
302        self.assertFalse(isinstance(b, A))
303        self.assertFalse(isinstance(b, (A,)))
304        token_old = abc.get_cache_token()
305        A.register(B)
306        token_new = abc.get_cache_token()
307        self.assertNotEqual(token_old, token_new)
308        self.assertTrue(isinstance(b, A))
309        self.assertTrue(isinstance(b, (A,)))
310
311    def test_registration_builtins(self):
312        class A(metaclass=abc.ABCMeta):
313            pass
314        A.register(int)
315        self.assertIsInstance(42, A)
316        self.assertIsInstance(42, (A,))
317        self.assertTrue(issubclass(int, A))
318        self.assertTrue(issubclass(int, (A,)))
319        class B(A):
320            pass
321        B.register(str)
322        class C(str): pass
323        self.assertIsInstance("", A)
324        self.assertIsInstance("", (A,))
325        self.assertTrue(issubclass(str, A))
326        self.assertTrue(issubclass(str, (A,)))
327        self.assertTrue(issubclass(C, A))
328        self.assertTrue(issubclass(C, (A,)))
329
330    def test_registration_edge_cases(self):
331        class A(metaclass=abc.ABCMeta):
332            pass
333        A.register(A)  # should pass silently
334        class A1(A):
335            pass
336        self.assertRaises(RuntimeError, A1.register, A)  # cycles not allowed
337        class B(object):
338            pass
339        A1.register(B)  # ok
340        A1.register(B)  # should pass silently
341        class C(A):
342            pass
343        A.register(C)  # should pass silently
344        self.assertRaises(RuntimeError, C.register, A)  # cycles not allowed
345        C.register(B)  # ok
346
347    def test_register_non_class(self):
348        class A(metaclass=abc.ABCMeta):
349            pass
350        self.assertRaisesRegex(TypeError, "Can only register classes",
351                               A.register, 4)
352
353    def test_registration_transitiveness(self):
354        class A(metaclass=abc.ABCMeta):
355            pass
356        self.assertTrue(issubclass(A, A))
357        self.assertTrue(issubclass(A, (A,)))
358        class B(metaclass=abc.ABCMeta):
359            pass
360        self.assertFalse(issubclass(A, B))
361        self.assertFalse(issubclass(A, (B,)))
362        self.assertFalse(issubclass(B, A))
363        self.assertFalse(issubclass(B, (A,)))
364        class C(metaclass=abc.ABCMeta):
365            pass
366        A.register(B)
367        class B1(B):
368            pass
369        self.assertTrue(issubclass(B1, A))
370        self.assertTrue(issubclass(B1, (A,)))
371        class C1(C):
372            pass
373        B1.register(C1)
374        self.assertFalse(issubclass(C, B))
375        self.assertFalse(issubclass(C, (B,)))
376        self.assertFalse(issubclass(C, B1))
377        self.assertFalse(issubclass(C, (B1,)))
378        self.assertTrue(issubclass(C1, A))
379        self.assertTrue(issubclass(C1, (A,)))
380        self.assertTrue(issubclass(C1, B))
381        self.assertTrue(issubclass(C1, (B,)))
382        self.assertTrue(issubclass(C1, B1))
383        self.assertTrue(issubclass(C1, (B1,)))
384        C1.register(int)
385        class MyInt(int):
386            pass
387        self.assertTrue(issubclass(MyInt, A))
388        self.assertTrue(issubclass(MyInt, (A,)))
389        self.assertIsInstance(42, A)
390        self.assertIsInstance(42, (A,))
391
392    def test_all_new_methods_are_called(self):
393        class A(metaclass=abc.ABCMeta):
394            pass
395        class B(object):
396            counter = 0
397            def __new__(cls):
398                B.counter += 1
399                return super().__new__(cls)
400        class C(A, B):
401            pass
402        self.assertEqual(B.counter, 0)
403        C()
404        self.assertEqual(B.counter, 1)
405
406
407if __name__ == "__main__":
408    unittest.main()
409