1import copy_reg
2import unittest
3
4from test import test_support
5from test.pickletester import ExtensionSaver
6
7class C:
8    pass
9
10
11class WithoutSlots(object):
12    pass
13
14class WithWeakref(object):
15    __slots__ = ('__weakref__',)
16
17class WithPrivate(object):
18    __slots__ = ('__spam',)
19
20class WithSingleString(object):
21    __slots__ = 'spam'
22
23class WithInherited(WithSingleString):
24    __slots__ = ('eggs',)
25
26
27class CopyRegTestCase(unittest.TestCase):
28
29    def test_class(self):
30        self.assertRaises(TypeError, copy_reg.pickle,
31                          C, None, None)
32
33    def test_noncallable_reduce(self):
34        self.assertRaises(TypeError, copy_reg.pickle,
35                          type(1), "not a callable")
36
37    def test_noncallable_constructor(self):
38        self.assertRaises(TypeError, copy_reg.pickle,
39                          type(1), int, "not a callable")
40
41    def test_bool(self):
42        import copy
43        self.assertEqual(True, copy.copy(True))
44
45    def test_extension_registry(self):
46        mod, func, code = 'junk1 ', ' junk2', 0xabcd
47        e = ExtensionSaver(code)
48        try:
49            # Shouldn't be in registry now.
50            self.assertRaises(ValueError, copy_reg.remove_extension,
51                              mod, func, code)
52            copy_reg.add_extension(mod, func, code)
53            # Should be in the registry.
54            self.assertTrue(copy_reg._extension_registry[mod, func] == code)
55            self.assertTrue(copy_reg._inverted_registry[code] == (mod, func))
56            # Shouldn't be in the cache.
57            self.assertNotIn(code, copy_reg._extension_cache)
58            # Redundant registration should be OK.
59            copy_reg.add_extension(mod, func, code)  # shouldn't blow up
60            # Conflicting code.
61            self.assertRaises(ValueError, copy_reg.add_extension,
62                              mod, func, code + 1)
63            self.assertRaises(ValueError, copy_reg.remove_extension,
64                              mod, func, code + 1)
65            # Conflicting module name.
66            self.assertRaises(ValueError, copy_reg.add_extension,
67                              mod[1:], func, code )
68            self.assertRaises(ValueError, copy_reg.remove_extension,
69                              mod[1:], func, code )
70            # Conflicting function name.
71            self.assertRaises(ValueError, copy_reg.add_extension,
72                              mod, func[1:], code)
73            self.assertRaises(ValueError, copy_reg.remove_extension,
74                              mod, func[1:], code)
75            # Can't remove one that isn't registered at all.
76            if code + 1 not in copy_reg._inverted_registry:
77                self.assertRaises(ValueError, copy_reg.remove_extension,
78                                  mod[1:], func[1:], code + 1)
79
80        finally:
81            e.restore()
82
83        # Shouldn't be there anymore.
84        self.assertNotIn((mod, func), copy_reg._extension_registry)
85        # The code *may* be in copy_reg._extension_registry, though, if
86        # we happened to pick on a registered code.  So don't check for
87        # that.
88
89        # Check valid codes at the limits.
90        for code in 1, 0x7fffffff:
91            e = ExtensionSaver(code)
92            try:
93                copy_reg.add_extension(mod, func, code)
94                copy_reg.remove_extension(mod, func, code)
95            finally:
96                e.restore()
97
98        # Ensure invalid codes blow up.
99        for code in -1, 0, 0x80000000L:
100            self.assertRaises(ValueError, copy_reg.add_extension,
101                              mod, func, code)
102
103    def test_slotnames(self):
104        self.assertEqual(copy_reg._slotnames(WithoutSlots), [])
105        self.assertEqual(copy_reg._slotnames(WithWeakref), [])
106        expected = ['_WithPrivate__spam']
107        self.assertEqual(copy_reg._slotnames(WithPrivate), expected)
108        self.assertEqual(copy_reg._slotnames(WithSingleString), ['spam'])
109        expected = ['eggs', 'spam']
110        expected.sort()
111        result = copy_reg._slotnames(WithInherited)
112        result.sort()
113        self.assertEqual(result, expected)
114
115
116def test_main():
117    test_support.run_unittest(CopyRegTestCase)
118
119
120if __name__ == "__main__":
121    test_main()
122