1from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING,
2                            NAME_MAPPING, REVERSE_NAME_MAPPING)
3import builtins
4import pickle
5import io
6import collections
7import struct
8import sys
9
10import unittest
11from test import support
12
13from test.pickletester import AbstractUnpickleTests
14from test.pickletester import AbstractPickleTests
15from test.pickletester import AbstractPickleModuleTests
16from test.pickletester import AbstractPersistentPicklerTests
17from test.pickletester import AbstractIdentityPersistentPicklerTests
18from test.pickletester import AbstractPicklerUnpicklerObjectTests
19from test.pickletester import AbstractDispatchTableTests
20from test.pickletester import BigmemPickleTests
21
22try:
23    import _pickle
24    has_c_implementation = True
25except ImportError:
26    has_c_implementation = False
27
28
29class PickleTests(AbstractPickleModuleTests):
30    pass
31
32
33class PyUnpicklerTests(AbstractUnpickleTests):
34
35    unpickler = pickle._Unpickler
36    bad_stack_errors = (IndexError,)
37    truncated_errors = (pickle.UnpicklingError, EOFError,
38                        AttributeError, ValueError,
39                        struct.error, IndexError, ImportError)
40
41    def loads(self, buf, **kwds):
42        f = io.BytesIO(buf)
43        u = self.unpickler(f, **kwds)
44        return u.load()
45
46
47class PyPicklerTests(AbstractPickleTests):
48
49    pickler = pickle._Pickler
50    unpickler = pickle._Unpickler
51
52    def dumps(self, arg, proto=None):
53        f = io.BytesIO()
54        p = self.pickler(f, proto)
55        p.dump(arg)
56        f.seek(0)
57        return bytes(f.read())
58
59    def loads(self, buf, **kwds):
60        f = io.BytesIO(buf)
61        u = self.unpickler(f, **kwds)
62        return u.load()
63
64
65class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests,
66                          BigmemPickleTests):
67
68    pickler = pickle._Pickler
69    unpickler = pickle._Unpickler
70    bad_stack_errors = (pickle.UnpicklingError, IndexError)
71    truncated_errors = (pickle.UnpicklingError, EOFError,
72                        AttributeError, ValueError,
73                        struct.error, IndexError, ImportError)
74
75    def dumps(self, arg, protocol=None):
76        return pickle.dumps(arg, protocol)
77
78    def loads(self, buf, **kwds):
79        return pickle.loads(buf, **kwds)
80
81
82class PersistentPicklerUnpicklerMixin(object):
83
84    def dumps(self, arg, proto=None):
85        class PersPickler(self.pickler):
86            def persistent_id(subself, obj):
87                return self.persistent_id(obj)
88        f = io.BytesIO()
89        p = PersPickler(f, proto)
90        p.dump(arg)
91        return f.getvalue()
92
93    def loads(self, buf, **kwds):
94        class PersUnpickler(self.unpickler):
95            def persistent_load(subself, obj):
96                return self.persistent_load(obj)
97        f = io.BytesIO(buf)
98        u = PersUnpickler(f, **kwds)
99        return u.load()
100
101
102class PyPersPicklerTests(AbstractPersistentPicklerTests,
103                         PersistentPicklerUnpicklerMixin):
104
105    pickler = pickle._Pickler
106    unpickler = pickle._Unpickler
107
108
109class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
110                           PersistentPicklerUnpicklerMixin):
111
112    pickler = pickle._Pickler
113    unpickler = pickle._Unpickler
114
115
116class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
117
118    pickler_class = pickle._Pickler
119    unpickler_class = pickle._Unpickler
120
121
122class PyDispatchTableTests(AbstractDispatchTableTests):
123
124    pickler_class = pickle._Pickler
125
126    def get_dispatch_table(self):
127        return pickle.dispatch_table.copy()
128
129
130class PyChainDispatchTableTests(AbstractDispatchTableTests):
131
132    pickler_class = pickle._Pickler
133
134    def get_dispatch_table(self):
135        return collections.ChainMap({}, pickle.dispatch_table)
136
137
138if has_c_implementation:
139    class CUnpicklerTests(PyUnpicklerTests):
140        unpickler = _pickle.Unpickler
141        bad_stack_errors = (pickle.UnpicklingError,)
142        truncated_errors = (pickle.UnpicklingError,)
143
144    class CPicklerTests(PyPicklerTests):
145        pickler = _pickle.Pickler
146        unpickler = _pickle.Unpickler
147
148    class CPersPicklerTests(PyPersPicklerTests):
149        pickler = _pickle.Pickler
150        unpickler = _pickle.Unpickler
151
152    class CIdPersPicklerTests(PyIdPersPicklerTests):
153        pickler = _pickle.Pickler
154        unpickler = _pickle.Unpickler
155
156    class CDumpPickle_LoadPickle(PyPicklerTests):
157        pickler = _pickle.Pickler
158        unpickler = pickle._Unpickler
159
160    class DumpPickle_CLoadPickle(PyPicklerTests):
161        pickler = pickle._Pickler
162        unpickler = _pickle.Unpickler
163
164    class CPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
165        pickler_class = _pickle.Pickler
166        unpickler_class = _pickle.Unpickler
167
168        def test_issue18339(self):
169            unpickler = self.unpickler_class(io.BytesIO())
170            with self.assertRaises(TypeError):
171                unpickler.memo = object
172            # used to cause a segfault
173            with self.assertRaises(ValueError):
174                unpickler.memo = {-1: None}
175            unpickler.memo = {1: None}
176
177    class CDispatchTableTests(AbstractDispatchTableTests):
178        pickler_class = pickle.Pickler
179        def get_dispatch_table(self):
180            return pickle.dispatch_table.copy()
181
182    class CChainDispatchTableTests(AbstractDispatchTableTests):
183        pickler_class = pickle.Pickler
184        def get_dispatch_table(self):
185            return collections.ChainMap({}, pickle.dispatch_table)
186
187    @support.cpython_only
188    class SizeofTests(unittest.TestCase):
189        check_sizeof = support.check_sizeof
190
191        def test_pickler(self):
192            basesize = support.calcobjsize('5P2n3i2n3iP')
193            p = _pickle.Pickler(io.BytesIO())
194            self.assertEqual(object.__sizeof__(p), basesize)
195            MT_size = struct.calcsize('3nP0n')
196            ME_size = struct.calcsize('Pn0P')
197            check = self.check_sizeof
198            check(p, basesize +
199                MT_size + 8 * ME_size +  # Minimal memo table size.
200                sys.getsizeof(b'x'*4096))  # Minimal write buffer size.
201            for i in range(6):
202                p.dump(chr(i))
203            check(p, basesize +
204                MT_size + 32 * ME_size +  # Size of memo table required to
205                                          # save references to 6 objects.
206                0)  # Write buffer is cleared after every dump().
207
208        def test_unpickler(self):
209            basesize = support.calcobjsize('2Pn2P 2P2n2i5P 2P3n6P2n2i')
210            unpickler = _pickle.Unpickler
211            P = struct.calcsize('P')  # Size of memo table entry.
212            n = struct.calcsize('n')  # Size of mark table entry.
213            check = self.check_sizeof
214            for encoding in 'ASCII', 'UTF-16', 'latin-1':
215                for errors in 'strict', 'replace':
216                    u = unpickler(io.BytesIO(),
217                                  encoding=encoding, errors=errors)
218                    self.assertEqual(object.__sizeof__(u), basesize)
219                    check(u, basesize +
220                             32 * P +  # Minimal memo table size.
221                             len(encoding) + 1 + len(errors) + 1)
222
223            stdsize = basesize + len('ASCII') + 1 + len('strict') + 1
224            def check_unpickler(data, memo_size, marks_size):
225                dump = pickle.dumps(data)
226                u = unpickler(io.BytesIO(dump),
227                              encoding='ASCII', errors='strict')
228                u.load()
229                check(u, stdsize + memo_size * P + marks_size * n)
230
231            check_unpickler(0, 32, 0)
232            # 20 is minimal non-empty mark stack size.
233            check_unpickler([0] * 100, 32, 20)
234            # 128 is memo table size required to save references to 100 objects.
235            check_unpickler([chr(i) for i in range(100)], 128, 20)
236            def recurse(deep):
237                data = 0
238                for i in range(deep):
239                    data = [data, data]
240                return data
241            check_unpickler(recurse(0), 32, 0)
242            check_unpickler(recurse(1), 32, 20)
243            check_unpickler(recurse(20), 32, 58)
244            check_unpickler(recurse(50), 64, 58)
245            check_unpickler(recurse(100), 128, 134)
246
247            u = unpickler(io.BytesIO(pickle.dumps('a', 0)),
248                          encoding='ASCII', errors='strict')
249            u.load()
250            check(u, stdsize + 32 * P + 2 + 1)
251
252
253ALT_IMPORT_MAPPING = {
254    ('_elementtree', 'xml.etree.ElementTree'),
255    ('cPickle', 'pickle'),
256    ('StringIO', 'io'),
257    ('cStringIO', 'io'),
258}
259
260ALT_NAME_MAPPING = {
261    ('__builtin__', 'basestring', 'builtins', 'str'),
262    ('exceptions', 'StandardError', 'builtins', 'Exception'),
263    ('UserDict', 'UserDict', 'collections', 'UserDict'),
264    ('socket', '_socketobject', 'socket', 'SocketType'),
265}
266
267def mapping(module, name):
268    if (module, name) in NAME_MAPPING:
269        module, name = NAME_MAPPING[(module, name)]
270    elif module in IMPORT_MAPPING:
271        module = IMPORT_MAPPING[module]
272    return module, name
273
274def reverse_mapping(module, name):
275    if (module, name) in REVERSE_NAME_MAPPING:
276        module, name = REVERSE_NAME_MAPPING[(module, name)]
277    elif module in REVERSE_IMPORT_MAPPING:
278        module = REVERSE_IMPORT_MAPPING[module]
279    return module, name
280
281def getmodule(module):
282    try:
283        return sys.modules[module]
284    except KeyError:
285        try:
286            __import__(module)
287        except AttributeError as exc:
288            if support.verbose:
289                print("Can't import module %r: %s" % (module, exc))
290            raise ImportError
291        except ImportError as exc:
292            if support.verbose:
293                print(exc)
294            raise
295        return sys.modules[module]
296
297def getattribute(module, name):
298    obj = getmodule(module)
299    for n in name.split('.'):
300        obj = getattr(obj, n)
301    return obj
302
303def get_exceptions(mod):
304    for name in dir(mod):
305        attr = getattr(mod, name)
306        if isinstance(attr, type) and issubclass(attr, BaseException):
307            yield name, attr
308
309class CompatPickleTests(unittest.TestCase):
310    def test_import(self):
311        modules = set(IMPORT_MAPPING.values())
312        modules |= set(REVERSE_IMPORT_MAPPING)
313        modules |= {module for module, name in REVERSE_NAME_MAPPING}
314        modules |= {module for module, name in NAME_MAPPING.values()}
315        for module in modules:
316            try:
317                getmodule(module)
318            except ImportError:
319                pass
320
321    def test_import_mapping(self):
322        for module3, module2 in REVERSE_IMPORT_MAPPING.items():
323            with self.subTest((module3, module2)):
324                try:
325                    getmodule(module3)
326                except ImportError:
327                    pass
328                if module3[:1] != '_':
329                    self.assertIn(module2, IMPORT_MAPPING)
330                    self.assertEqual(IMPORT_MAPPING[module2], module3)
331
332    def test_name_mapping(self):
333        for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items():
334            with self.subTest(((module3, name3), (module2, name2))):
335                if (module2, name2) == ('exceptions', 'OSError'):
336                    attr = getattribute(module3, name3)
337                    self.assertTrue(issubclass(attr, OSError))
338                elif (module2, name2) == ('exceptions', 'ImportError'):
339                    attr = getattribute(module3, name3)
340                    self.assertTrue(issubclass(attr, ImportError))
341                else:
342                    module, name = mapping(module2, name2)
343                    if module3[:1] != '_':
344                        self.assertEqual((module, name), (module3, name3))
345                    try:
346                        attr = getattribute(module3, name3)
347                    except ImportError:
348                        pass
349                    else:
350                        self.assertEqual(getattribute(module, name), attr)
351
352    def test_reverse_import_mapping(self):
353        for module2, module3 in IMPORT_MAPPING.items():
354            with self.subTest((module2, module3)):
355                try:
356                    getmodule(module3)
357                except ImportError as exc:
358                    if support.verbose:
359                        print(exc)
360                if ((module2, module3) not in ALT_IMPORT_MAPPING and
361                    REVERSE_IMPORT_MAPPING.get(module3, None) != module2):
362                    for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items():
363                        if (module3, module2) == (m3, m2):
364                            break
365                    else:
366                        self.fail('No reverse mapping from %r to %r' %
367                                  (module3, module2))
368                module = REVERSE_IMPORT_MAPPING.get(module3, module3)
369                module = IMPORT_MAPPING.get(module, module)
370                self.assertEqual(module, module3)
371
372    def test_reverse_name_mapping(self):
373        for (module2, name2), (module3, name3) in NAME_MAPPING.items():
374            with self.subTest(((module2, name2), (module3, name3))):
375                try:
376                    attr = getattribute(module3, name3)
377                except ImportError:
378                    pass
379                module, name = reverse_mapping(module3, name3)
380                if (module2, name2, module3, name3) not in ALT_NAME_MAPPING:
381                    self.assertEqual((module, name), (module2, name2))
382                module, name = mapping(module, name)
383                self.assertEqual((module, name), (module3, name3))
384
385    def test_exceptions(self):
386        self.assertEqual(mapping('exceptions', 'StandardError'),
387                         ('builtins', 'Exception'))
388        self.assertEqual(mapping('exceptions', 'Exception'),
389                         ('builtins', 'Exception'))
390        self.assertEqual(reverse_mapping('builtins', 'Exception'),
391                         ('exceptions', 'Exception'))
392        self.assertEqual(mapping('exceptions', 'OSError'),
393                         ('builtins', 'OSError'))
394        self.assertEqual(reverse_mapping('builtins', 'OSError'),
395                         ('exceptions', 'OSError'))
396
397        for name, exc in get_exceptions(builtins):
398            with self.subTest(name):
399                if exc in (BlockingIOError,
400                           ResourceWarning,
401                           StopAsyncIteration,
402                           RecursionError):
403                    continue
404                if exc is not OSError and issubclass(exc, OSError):
405                    self.assertEqual(reverse_mapping('builtins', name),
406                                     ('exceptions', 'OSError'))
407                elif exc is not ImportError and issubclass(exc, ImportError):
408                    self.assertEqual(reverse_mapping('builtins', name),
409                                     ('exceptions', 'ImportError'))
410                    self.assertEqual(mapping('exceptions', name),
411                                     ('exceptions', name))
412                else:
413                    self.assertEqual(reverse_mapping('builtins', name),
414                                     ('exceptions', name))
415                    self.assertEqual(mapping('exceptions', name),
416                                     ('builtins', name))
417
418    def test_multiprocessing_exceptions(self):
419        module = support.import_module('multiprocessing.context')
420        for name, exc in get_exceptions(module):
421            with self.subTest(name):
422                self.assertEqual(reverse_mapping('multiprocessing.context', name),
423                                 ('multiprocessing', name))
424                self.assertEqual(mapping('multiprocessing', name),
425                                 ('multiprocessing.context', name))
426
427
428def test_main():
429    tests = [PickleTests, PyUnpicklerTests, PyPicklerTests,
430             PyPersPicklerTests, PyIdPersPicklerTests,
431             PyDispatchTableTests, PyChainDispatchTableTests,
432             CompatPickleTests]
433    if has_c_implementation:
434        tests.extend([CUnpicklerTests, CPicklerTests,
435                      CPersPicklerTests, CIdPersPicklerTests,
436                      CDumpPickle_LoadPickle, DumpPickle_CLoadPickle,
437                      PyPicklerUnpicklerObjectTests,
438                      CPicklerUnpicklerObjectTests,
439                      CDispatchTableTests, CChainDispatchTableTests,
440                      InMemoryPickleTests, SizeofTests])
441    support.run_unittest(*tests)
442    support.run_doctest(pickle)
443
444if __name__ == "__main__":
445    test_main()
446