1from _compat_pickle import (IMPORT_MAPPING, REVERSE_IMPORT_MAPPING, 2 NAME_MAPPING, REVERSE_NAME_MAPPING) 3import builtins 4import pickle 5import io 6import collections 7import struct 8import sys 9 10import unittest 11from test import support 12 13from test.pickletester import AbstractUnpickleTests 14from test.pickletester import AbstractPickleTests 15from test.pickletester import AbstractPickleModuleTests 16from test.pickletester import AbstractPersistentPicklerTests 17from test.pickletester import AbstractIdentityPersistentPicklerTests 18from test.pickletester import AbstractPicklerUnpicklerObjectTests 19from test.pickletester import AbstractDispatchTableTests 20from test.pickletester import BigmemPickleTests 21 22try: 23 import _pickle 24 has_c_implementation = True 25except ImportError: 26 has_c_implementation = False 27 28 29class PickleTests(AbstractPickleModuleTests): 30 pass 31 32 33class PyUnpicklerTests(AbstractUnpickleTests): 34 35 unpickler = pickle._Unpickler 36 bad_stack_errors = (IndexError,) 37 truncated_errors = (pickle.UnpicklingError, EOFError, 38 AttributeError, ValueError, 39 struct.error, IndexError, ImportError) 40 41 def loads(self, buf, **kwds): 42 f = io.BytesIO(buf) 43 u = self.unpickler(f, **kwds) 44 return u.load() 45 46 47class PyPicklerTests(AbstractPickleTests): 48 49 pickler = pickle._Pickler 50 unpickler = pickle._Unpickler 51 52 def dumps(self, arg, proto=None): 53 f = io.BytesIO() 54 p = self.pickler(f, proto) 55 p.dump(arg) 56 f.seek(0) 57 return bytes(f.read()) 58 59 def loads(self, buf, **kwds): 60 f = io.BytesIO(buf) 61 u = self.unpickler(f, **kwds) 62 return u.load() 63 64 65class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests, 66 BigmemPickleTests): 67 68 pickler = pickle._Pickler 69 unpickler = pickle._Unpickler 70 bad_stack_errors = (pickle.UnpicklingError, IndexError) 71 truncated_errors = (pickle.UnpicklingError, EOFError, 72 AttributeError, ValueError, 73 struct.error, IndexError, ImportError) 74 75 def dumps(self, arg, protocol=None): 76 return pickle.dumps(arg, protocol) 77 78 def loads(self, buf, **kwds): 79 return pickle.loads(buf, **kwds) 80 81 82class PersistentPicklerUnpicklerMixin(object): 83 84 def dumps(self, arg, proto=None): 85 class PersPickler(self.pickler): 86 def persistent_id(subself, obj): 87 return self.persistent_id(obj) 88 f = io.BytesIO() 89 p = PersPickler(f, proto) 90 p.dump(arg) 91 return f.getvalue() 92 93 def loads(self, buf, **kwds): 94 class PersUnpickler(self.unpickler): 95 def persistent_load(subself, obj): 96 return self.persistent_load(obj) 97 f = io.BytesIO(buf) 98 u = PersUnpickler(f, **kwds) 99 return u.load() 100 101 102class PyPersPicklerTests(AbstractPersistentPicklerTests, 103 PersistentPicklerUnpicklerMixin): 104 105 pickler = pickle._Pickler 106 unpickler = pickle._Unpickler 107 108 109class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests, 110 PersistentPicklerUnpicklerMixin): 111 112 pickler = pickle._Pickler 113 unpickler = pickle._Unpickler 114 115 116class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests): 117 118 pickler_class = pickle._Pickler 119 unpickler_class = pickle._Unpickler 120 121 122class PyDispatchTableTests(AbstractDispatchTableTests): 123 124 pickler_class = pickle._Pickler 125 126 def get_dispatch_table(self): 127 return pickle.dispatch_table.copy() 128 129 130class PyChainDispatchTableTests(AbstractDispatchTableTests): 131 132 pickler_class = pickle._Pickler 133 134 def get_dispatch_table(self): 135 return collections.ChainMap({}, pickle.dispatch_table) 136 137 138if has_c_implementation: 139 class CUnpicklerTests(PyUnpicklerTests): 140 unpickler = _pickle.Unpickler 141 bad_stack_errors = (pickle.UnpicklingError,) 142 truncated_errors = (pickle.UnpicklingError,) 143 144 class CPicklerTests(PyPicklerTests): 145 pickler = _pickle.Pickler 146 unpickler = _pickle.Unpickler 147 148 class CPersPicklerTests(PyPersPicklerTests): 149 pickler = _pickle.Pickler 150 unpickler = _pickle.Unpickler 151 152 class CIdPersPicklerTests(PyIdPersPicklerTests): 153 pickler = _pickle.Pickler 154 unpickler = _pickle.Unpickler 155 156 class CDumpPickle_LoadPickle(PyPicklerTests): 157 pickler = _pickle.Pickler 158 unpickler = pickle._Unpickler 159 160 class DumpPickle_CLoadPickle(PyPicklerTests): 161 pickler = pickle._Pickler 162 unpickler = _pickle.Unpickler 163 164 class CPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests): 165 pickler_class = _pickle.Pickler 166 unpickler_class = _pickle.Unpickler 167 168 def test_issue18339(self): 169 unpickler = self.unpickler_class(io.BytesIO()) 170 with self.assertRaises(TypeError): 171 unpickler.memo = object 172 # used to cause a segfault 173 with self.assertRaises(ValueError): 174 unpickler.memo = {-1: None} 175 unpickler.memo = {1: None} 176 177 class CDispatchTableTests(AbstractDispatchTableTests): 178 pickler_class = pickle.Pickler 179 def get_dispatch_table(self): 180 return pickle.dispatch_table.copy() 181 182 class CChainDispatchTableTests(AbstractDispatchTableTests): 183 pickler_class = pickle.Pickler 184 def get_dispatch_table(self): 185 return collections.ChainMap({}, pickle.dispatch_table) 186 187 @support.cpython_only 188 class SizeofTests(unittest.TestCase): 189 check_sizeof = support.check_sizeof 190 191 def test_pickler(self): 192 basesize = support.calcobjsize('5P2n3i2n3iP') 193 p = _pickle.Pickler(io.BytesIO()) 194 self.assertEqual(object.__sizeof__(p), basesize) 195 MT_size = struct.calcsize('3nP0n') 196 ME_size = struct.calcsize('Pn0P') 197 check = self.check_sizeof 198 check(p, basesize + 199 MT_size + 8 * ME_size + # Minimal memo table size. 200 sys.getsizeof(b'x'*4096)) # Minimal write buffer size. 201 for i in range(6): 202 p.dump(chr(i)) 203 check(p, basesize + 204 MT_size + 32 * ME_size + # Size of memo table required to 205 # save references to 6 objects. 206 0) # Write buffer is cleared after every dump(). 207 208 def test_unpickler(self): 209 basesize = support.calcobjsize('2Pn2P 2P2n2i5P 2P3n6P2n2i') 210 unpickler = _pickle.Unpickler 211 P = struct.calcsize('P') # Size of memo table entry. 212 n = struct.calcsize('n') # Size of mark table entry. 213 check = self.check_sizeof 214 for encoding in 'ASCII', 'UTF-16', 'latin-1': 215 for errors in 'strict', 'replace': 216 u = unpickler(io.BytesIO(), 217 encoding=encoding, errors=errors) 218 self.assertEqual(object.__sizeof__(u), basesize) 219 check(u, basesize + 220 32 * P + # Minimal memo table size. 221 len(encoding) + 1 + len(errors) + 1) 222 223 stdsize = basesize + len('ASCII') + 1 + len('strict') + 1 224 def check_unpickler(data, memo_size, marks_size): 225 dump = pickle.dumps(data) 226 u = unpickler(io.BytesIO(dump), 227 encoding='ASCII', errors='strict') 228 u.load() 229 check(u, stdsize + memo_size * P + marks_size * n) 230 231 check_unpickler(0, 32, 0) 232 # 20 is minimal non-empty mark stack size. 233 check_unpickler([0] * 100, 32, 20) 234 # 128 is memo table size required to save references to 100 objects. 235 check_unpickler([chr(i) for i in range(100)], 128, 20) 236 def recurse(deep): 237 data = 0 238 for i in range(deep): 239 data = [data, data] 240 return data 241 check_unpickler(recurse(0), 32, 0) 242 check_unpickler(recurse(1), 32, 20) 243 check_unpickler(recurse(20), 32, 58) 244 check_unpickler(recurse(50), 64, 58) 245 check_unpickler(recurse(100), 128, 134) 246 247 u = unpickler(io.BytesIO(pickle.dumps('a', 0)), 248 encoding='ASCII', errors='strict') 249 u.load() 250 check(u, stdsize + 32 * P + 2 + 1) 251 252 253ALT_IMPORT_MAPPING = { 254 ('_elementtree', 'xml.etree.ElementTree'), 255 ('cPickle', 'pickle'), 256 ('StringIO', 'io'), 257 ('cStringIO', 'io'), 258} 259 260ALT_NAME_MAPPING = { 261 ('__builtin__', 'basestring', 'builtins', 'str'), 262 ('exceptions', 'StandardError', 'builtins', 'Exception'), 263 ('UserDict', 'UserDict', 'collections', 'UserDict'), 264 ('socket', '_socketobject', 'socket', 'SocketType'), 265} 266 267def mapping(module, name): 268 if (module, name) in NAME_MAPPING: 269 module, name = NAME_MAPPING[(module, name)] 270 elif module in IMPORT_MAPPING: 271 module = IMPORT_MAPPING[module] 272 return module, name 273 274def reverse_mapping(module, name): 275 if (module, name) in REVERSE_NAME_MAPPING: 276 module, name = REVERSE_NAME_MAPPING[(module, name)] 277 elif module in REVERSE_IMPORT_MAPPING: 278 module = REVERSE_IMPORT_MAPPING[module] 279 return module, name 280 281def getmodule(module): 282 try: 283 return sys.modules[module] 284 except KeyError: 285 try: 286 __import__(module) 287 except AttributeError as exc: 288 if support.verbose: 289 print("Can't import module %r: %s" % (module, exc)) 290 raise ImportError 291 except ImportError as exc: 292 if support.verbose: 293 print(exc) 294 raise 295 return sys.modules[module] 296 297def getattribute(module, name): 298 obj = getmodule(module) 299 for n in name.split('.'): 300 obj = getattr(obj, n) 301 return obj 302 303def get_exceptions(mod): 304 for name in dir(mod): 305 attr = getattr(mod, name) 306 if isinstance(attr, type) and issubclass(attr, BaseException): 307 yield name, attr 308 309class CompatPickleTests(unittest.TestCase): 310 def test_import(self): 311 modules = set(IMPORT_MAPPING.values()) 312 modules |= set(REVERSE_IMPORT_MAPPING) 313 modules |= {module for module, name in REVERSE_NAME_MAPPING} 314 modules |= {module for module, name in NAME_MAPPING.values()} 315 for module in modules: 316 try: 317 getmodule(module) 318 except ImportError: 319 pass 320 321 def test_import_mapping(self): 322 for module3, module2 in REVERSE_IMPORT_MAPPING.items(): 323 with self.subTest((module3, module2)): 324 try: 325 getmodule(module3) 326 except ImportError: 327 pass 328 if module3[:1] != '_': 329 self.assertIn(module2, IMPORT_MAPPING) 330 self.assertEqual(IMPORT_MAPPING[module2], module3) 331 332 def test_name_mapping(self): 333 for (module3, name3), (module2, name2) in REVERSE_NAME_MAPPING.items(): 334 with self.subTest(((module3, name3), (module2, name2))): 335 if (module2, name2) == ('exceptions', 'OSError'): 336 attr = getattribute(module3, name3) 337 self.assertTrue(issubclass(attr, OSError)) 338 elif (module2, name2) == ('exceptions', 'ImportError'): 339 attr = getattribute(module3, name3) 340 self.assertTrue(issubclass(attr, ImportError)) 341 else: 342 module, name = mapping(module2, name2) 343 if module3[:1] != '_': 344 self.assertEqual((module, name), (module3, name3)) 345 try: 346 attr = getattribute(module3, name3) 347 except ImportError: 348 pass 349 else: 350 self.assertEqual(getattribute(module, name), attr) 351 352 def test_reverse_import_mapping(self): 353 for module2, module3 in IMPORT_MAPPING.items(): 354 with self.subTest((module2, module3)): 355 try: 356 getmodule(module3) 357 except ImportError as exc: 358 if support.verbose: 359 print(exc) 360 if ((module2, module3) not in ALT_IMPORT_MAPPING and 361 REVERSE_IMPORT_MAPPING.get(module3, None) != module2): 362 for (m3, n3), (m2, n2) in REVERSE_NAME_MAPPING.items(): 363 if (module3, module2) == (m3, m2): 364 break 365 else: 366 self.fail('No reverse mapping from %r to %r' % 367 (module3, module2)) 368 module = REVERSE_IMPORT_MAPPING.get(module3, module3) 369 module = IMPORT_MAPPING.get(module, module) 370 self.assertEqual(module, module3) 371 372 def test_reverse_name_mapping(self): 373 for (module2, name2), (module3, name3) in NAME_MAPPING.items(): 374 with self.subTest(((module2, name2), (module3, name3))): 375 try: 376 attr = getattribute(module3, name3) 377 except ImportError: 378 pass 379 module, name = reverse_mapping(module3, name3) 380 if (module2, name2, module3, name3) not in ALT_NAME_MAPPING: 381 self.assertEqual((module, name), (module2, name2)) 382 module, name = mapping(module, name) 383 self.assertEqual((module, name), (module3, name3)) 384 385 def test_exceptions(self): 386 self.assertEqual(mapping('exceptions', 'StandardError'), 387 ('builtins', 'Exception')) 388 self.assertEqual(mapping('exceptions', 'Exception'), 389 ('builtins', 'Exception')) 390 self.assertEqual(reverse_mapping('builtins', 'Exception'), 391 ('exceptions', 'Exception')) 392 self.assertEqual(mapping('exceptions', 'OSError'), 393 ('builtins', 'OSError')) 394 self.assertEqual(reverse_mapping('builtins', 'OSError'), 395 ('exceptions', 'OSError')) 396 397 for name, exc in get_exceptions(builtins): 398 with self.subTest(name): 399 if exc in (BlockingIOError, 400 ResourceWarning, 401 StopAsyncIteration, 402 RecursionError): 403 continue 404 if exc is not OSError and issubclass(exc, OSError): 405 self.assertEqual(reverse_mapping('builtins', name), 406 ('exceptions', 'OSError')) 407 elif exc is not ImportError and issubclass(exc, ImportError): 408 self.assertEqual(reverse_mapping('builtins', name), 409 ('exceptions', 'ImportError')) 410 self.assertEqual(mapping('exceptions', name), 411 ('exceptions', name)) 412 else: 413 self.assertEqual(reverse_mapping('builtins', name), 414 ('exceptions', name)) 415 self.assertEqual(mapping('exceptions', name), 416 ('builtins', name)) 417 418 def test_multiprocessing_exceptions(self): 419 module = support.import_module('multiprocessing.context') 420 for name, exc in get_exceptions(module): 421 with self.subTest(name): 422 self.assertEqual(reverse_mapping('multiprocessing.context', name), 423 ('multiprocessing', name)) 424 self.assertEqual(mapping('multiprocessing', name), 425 ('multiprocessing.context', name)) 426 427 428def test_main(): 429 tests = [PickleTests, PyUnpicklerTests, PyPicklerTests, 430 PyPersPicklerTests, PyIdPersPicklerTests, 431 PyDispatchTableTests, PyChainDispatchTableTests, 432 CompatPickleTests] 433 if has_c_implementation: 434 tests.extend([CUnpicklerTests, CPicklerTests, 435 CPersPicklerTests, CIdPersPicklerTests, 436 CDumpPickle_LoadPickle, DumpPickle_CLoadPickle, 437 PyPicklerUnpicklerObjectTests, 438 CPicklerUnpicklerObjectTests, 439 CDispatchTableTests, CChainDispatchTableTests, 440 InMemoryPickleTests, SizeofTests]) 441 support.run_unittest(*tests) 442 support.run_doctest(pickle) 443 444if __name__ == "__main__": 445 test_main() 446