1import unittest 2from test import test_support 3from weakref import proxy, ref, WeakSet 4import operator 5import copy 6import string 7import os 8from random import randrange, shuffle 9import sys 10import warnings 11import collections 12import gc 13import contextlib 14 15 16class Foo: 17 pass 18 19class SomeClass(object): 20 def __init__(self, value): 21 self.value = value 22 def __eq__(self, other): 23 if type(other) != type(self): 24 return False 25 return other.value == self.value 26 27 def __ne__(self, other): 28 return not self.__eq__(other) 29 30 def __hash__(self): 31 return hash((SomeClass, self.value)) 32 33class RefCycle(object): 34 def __init__(self): 35 self.cycle = self 36 37class TestWeakSet(unittest.TestCase): 38 39 def setUp(self): 40 # need to keep references to them 41 self.items = [SomeClass(c) for c in ('a', 'b', 'c')] 42 self.items2 = [SomeClass(c) for c in ('x', 'y', 'z')] 43 self.letters = [SomeClass(c) for c in string.ascii_letters] 44 self.ab_items = [SomeClass(c) for c in 'ab'] 45 self.abcde_items = [SomeClass(c) for c in 'abcde'] 46 self.def_items = [SomeClass(c) for c in 'def'] 47 self.ab_weakset = WeakSet(self.ab_items) 48 self.abcde_weakset = WeakSet(self.abcde_items) 49 self.def_weakset = WeakSet(self.def_items) 50 self.s = WeakSet(self.items) 51 self.d = dict.fromkeys(self.items) 52 self.obj = SomeClass('F') 53 self.fs = WeakSet([self.obj]) 54 55 def test_methods(self): 56 weaksetmethods = dir(WeakSet) 57 for method in dir(set): 58 if method == 'test_c_api' or method.startswith('_'): 59 continue 60 self.assertIn(method, weaksetmethods, 61 "WeakSet missing method " + method) 62 63 def test_new_or_init(self): 64 self.assertRaises(TypeError, WeakSet, [], 2) 65 66 def test_len(self): 67 self.assertEqual(len(self.s), len(self.d)) 68 self.assertEqual(len(self.fs), 1) 69 del self.obj 70 self.assertEqual(len(self.fs), 0) 71 72 def test_contains(self): 73 for c in self.letters: 74 self.assertEqual(c in self.s, c in self.d) 75 # 1 is not weakref'able, but that TypeError is caught by __contains__ 76 self.assertNotIn(1, self.s) 77 self.assertIn(self.obj, self.fs) 78 del self.obj 79 self.assertNotIn(SomeClass('F'), self.fs) 80 81 def test_union(self): 82 u = self.s.union(self.items2) 83 for c in self.letters: 84 self.assertEqual(c in u, c in self.d or c in self.items2) 85 self.assertEqual(self.s, WeakSet(self.items)) 86 self.assertEqual(type(u), WeakSet) 87 self.assertRaises(TypeError, self.s.union, [[]]) 88 for C in set, frozenset, dict.fromkeys, list, tuple: 89 x = WeakSet(self.items + self.items2) 90 c = C(self.items2) 91 self.assertEqual(self.s.union(c), x) 92 del c 93 self.assertEqual(len(u), len(self.items) + len(self.items2)) 94 self.items2.pop() 95 gc.collect() 96 self.assertEqual(len(u), len(self.items) + len(self.items2)) 97 98 def test_or(self): 99 i = self.s.union(self.items2) 100 self.assertEqual(self.s | set(self.items2), i) 101 self.assertEqual(self.s | frozenset(self.items2), i) 102 103 def test_intersection(self): 104 s = WeakSet(self.letters) 105 i = s.intersection(self.items2) 106 for c in self.letters: 107 self.assertEqual(c in i, c in self.items2 and c in self.letters) 108 self.assertEqual(s, WeakSet(self.letters)) 109 self.assertEqual(type(i), WeakSet) 110 for C in set, frozenset, dict.fromkeys, list, tuple: 111 x = WeakSet([]) 112 self.assertEqual(i.intersection(C(self.items)), x) 113 self.assertEqual(len(i), len(self.items2)) 114 self.items2.pop() 115 gc.collect() 116 self.assertEqual(len(i), len(self.items2)) 117 118 def test_isdisjoint(self): 119 self.assertTrue(self.s.isdisjoint(WeakSet(self.items2))) 120 self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters))) 121 122 def test_and(self): 123 i = self.s.intersection(self.items2) 124 self.assertEqual(self.s & set(self.items2), i) 125 self.assertEqual(self.s & frozenset(self.items2), i) 126 127 def test_difference(self): 128 i = self.s.difference(self.items2) 129 for c in self.letters: 130 self.assertEqual(c in i, c in self.d and c not in self.items2) 131 self.assertEqual(self.s, WeakSet(self.items)) 132 self.assertEqual(type(i), WeakSet) 133 self.assertRaises(TypeError, self.s.difference, [[]]) 134 135 def test_sub(self): 136 i = self.s.difference(self.items2) 137 self.assertEqual(self.s - set(self.items2), i) 138 self.assertEqual(self.s - frozenset(self.items2), i) 139 140 def test_symmetric_difference(self): 141 i = self.s.symmetric_difference(self.items2) 142 for c in self.letters: 143 self.assertEqual(c in i, (c in self.d) ^ (c in self.items2)) 144 self.assertEqual(self.s, WeakSet(self.items)) 145 self.assertEqual(type(i), WeakSet) 146 self.assertRaises(TypeError, self.s.symmetric_difference, [[]]) 147 self.assertEqual(len(i), len(self.items) + len(self.items2)) 148 self.items2.pop() 149 gc.collect() 150 self.assertEqual(len(i), len(self.items) + len(self.items2)) 151 152 def test_xor(self): 153 i = self.s.symmetric_difference(self.items2) 154 self.assertEqual(self.s ^ set(self.items2), i) 155 self.assertEqual(self.s ^ frozenset(self.items2), i) 156 157 def test_sub_and_super(self): 158 self.assertTrue(self.ab_weakset <= self.abcde_weakset) 159 self.assertTrue(self.abcde_weakset <= self.abcde_weakset) 160 self.assertTrue(self.abcde_weakset >= self.ab_weakset) 161 self.assertFalse(self.abcde_weakset <= self.def_weakset) 162 self.assertFalse(self.abcde_weakset >= self.def_weakset) 163 self.assertTrue(set('a').issubset('abc')) 164 self.assertTrue(set('abc').issuperset('a')) 165 self.assertFalse(set('a').issubset('cbs')) 166 self.assertFalse(set('cbs').issuperset('a')) 167 168 def test_lt(self): 169 self.assertTrue(self.ab_weakset < self.abcde_weakset) 170 self.assertFalse(self.abcde_weakset < self.def_weakset) 171 self.assertFalse(self.ab_weakset < self.ab_weakset) 172 self.assertFalse(WeakSet() < WeakSet()) 173 174 def test_gt(self): 175 self.assertTrue(self.abcde_weakset > self.ab_weakset) 176 self.assertFalse(self.abcde_weakset > self.def_weakset) 177 self.assertFalse(self.ab_weakset > self.ab_weakset) 178 self.assertFalse(WeakSet() > WeakSet()) 179 180 def test_gc(self): 181 # Create a nest of cycles to exercise overall ref count check 182 s = WeakSet(Foo() for i in range(1000)) 183 for elem in s: 184 elem.cycle = s 185 elem.sub = elem 186 elem.set = WeakSet([elem]) 187 188 def test_subclass_with_custom_hash(self): 189 # Bug #1257731 190 class H(WeakSet): 191 def __hash__(self): 192 return int(id(self) & 0x7fffffff) 193 s=H() 194 f=set() 195 f.add(s) 196 self.assertIn(s, f) 197 f.remove(s) 198 f.add(s) 199 f.discard(s) 200 201 def test_init(self): 202 s = WeakSet() 203 s.__init__(self.items) 204 self.assertEqual(s, self.s) 205 s.__init__(self.items2) 206 self.assertEqual(s, WeakSet(self.items2)) 207 self.assertRaises(TypeError, s.__init__, s, 2); 208 self.assertRaises(TypeError, s.__init__, 1); 209 210 def test_constructor_identity(self): 211 s = WeakSet(self.items) 212 t = WeakSet(s) 213 self.assertNotEqual(id(s), id(t)) 214 215 def test_hash(self): 216 self.assertRaises(TypeError, hash, self.s) 217 218 def test_clear(self): 219 self.s.clear() 220 self.assertEqual(self.s, WeakSet([])) 221 self.assertEqual(len(self.s), 0) 222 223 def test_copy(self): 224 dup = self.s.copy() 225 self.assertEqual(self.s, dup) 226 self.assertNotEqual(id(self.s), id(dup)) 227 228 def test_add(self): 229 x = SomeClass('Q') 230 self.s.add(x) 231 self.assertIn(x, self.s) 232 dup = self.s.copy() 233 self.s.add(x) 234 self.assertEqual(self.s, dup) 235 self.assertRaises(TypeError, self.s.add, []) 236 self.fs.add(Foo()) 237 self.assertTrue(len(self.fs) == 1) 238 self.fs.add(self.obj) 239 self.assertTrue(len(self.fs) == 1) 240 241 def test_remove(self): 242 x = SomeClass('a') 243 self.s.remove(x) 244 self.assertNotIn(x, self.s) 245 self.assertRaises(KeyError, self.s.remove, x) 246 self.assertRaises(TypeError, self.s.remove, []) 247 248 def test_discard(self): 249 a, q = SomeClass('a'), SomeClass('Q') 250 self.s.discard(a) 251 self.assertNotIn(a, self.s) 252 self.s.discard(q) 253 self.assertRaises(TypeError, self.s.discard, []) 254 255 def test_pop(self): 256 for i in range(len(self.s)): 257 elem = self.s.pop() 258 self.assertNotIn(elem, self.s) 259 self.assertRaises(KeyError, self.s.pop) 260 261 def test_update(self): 262 retval = self.s.update(self.items2) 263 self.assertEqual(retval, None) 264 for c in (self.items + self.items2): 265 self.assertIn(c, self.s) 266 self.assertRaises(TypeError, self.s.update, [[]]) 267 268 def test_update_set(self): 269 self.s.update(set(self.items2)) 270 for c in (self.items + self.items2): 271 self.assertIn(c, self.s) 272 273 def test_ior(self): 274 self.s |= set(self.items2) 275 for c in (self.items + self.items2): 276 self.assertIn(c, self.s) 277 278 def test_intersection_update(self): 279 retval = self.s.intersection_update(self.items2) 280 self.assertEqual(retval, None) 281 for c in (self.items + self.items2): 282 if c in self.items2 and c in self.items: 283 self.assertIn(c, self.s) 284 else: 285 self.assertNotIn(c, self.s) 286 self.assertRaises(TypeError, self.s.intersection_update, [[]]) 287 288 def test_iand(self): 289 self.s &= set(self.items2) 290 for c in (self.items + self.items2): 291 if c in self.items2 and c in self.items: 292 self.assertIn(c, self.s) 293 else: 294 self.assertNotIn(c, self.s) 295 296 def test_difference_update(self): 297 retval = self.s.difference_update(self.items2) 298 self.assertEqual(retval, None) 299 for c in (self.items + self.items2): 300 if c in self.items and c not in self.items2: 301 self.assertIn(c, self.s) 302 else: 303 self.assertNotIn(c, self.s) 304 self.assertRaises(TypeError, self.s.difference_update, [[]]) 305 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) 306 307 def test_isub(self): 308 self.s -= set(self.items2) 309 for c in (self.items + self.items2): 310 if c in self.items and c not in self.items2: 311 self.assertIn(c, self.s) 312 else: 313 self.assertNotIn(c, self.s) 314 315 def test_symmetric_difference_update(self): 316 retval = self.s.symmetric_difference_update(self.items2) 317 self.assertEqual(retval, None) 318 for c in (self.items + self.items2): 319 if (c in self.items) ^ (c in self.items2): 320 self.assertIn(c, self.s) 321 else: 322 self.assertNotIn(c, self.s) 323 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) 324 325 def test_ixor(self): 326 self.s ^= set(self.items2) 327 for c in (self.items + self.items2): 328 if (c in self.items) ^ (c in self.items2): 329 self.assertIn(c, self.s) 330 else: 331 self.assertNotIn(c, self.s) 332 333 def test_inplace_on_self(self): 334 t = self.s.copy() 335 t |= t 336 self.assertEqual(t, self.s) 337 t &= t 338 self.assertEqual(t, self.s) 339 t -= t 340 self.assertEqual(t, WeakSet()) 341 t = self.s.copy() 342 t ^= t 343 self.assertEqual(t, WeakSet()) 344 345 def test_eq(self): 346 # issue 5964 347 self.assertTrue(self.s == self.s) 348 self.assertTrue(self.s == WeakSet(self.items)) 349 self.assertFalse(self.s == set(self.items)) 350 self.assertFalse(self.s == list(self.items)) 351 self.assertFalse(self.s == tuple(self.items)) 352 self.assertFalse(self.s == 1) 353 354 def test_weak_destroy_while_iterating(self): 355 # Issue #7105: iterators shouldn't crash when a key is implicitly removed 356 # Create new items to be sure no-one else holds a reference 357 items = [SomeClass(c) for c in ('a', 'b', 'c')] 358 s = WeakSet(items) 359 it = iter(s) 360 next(it) # Trigger internal iteration 361 # Destroy an item 362 del items[-1] 363 gc.collect() # just in case 364 # We have removed either the first consumed items, or another one 365 self.assertIn(len(list(it)), [len(items), len(items) - 1]) 366 del it 367 # The removal has been committed 368 self.assertEqual(len(s), len(items)) 369 370 def test_weak_destroy_and_mutate_while_iterating(self): 371 # Issue #7105: iterators shouldn't crash when a key is implicitly removed 372 items = [SomeClass(c) for c in string.ascii_letters] 373 s = WeakSet(items) 374 @contextlib.contextmanager 375 def testcontext(): 376 try: 377 it = iter(s) 378 next(it) 379 # Schedule an item for removal and recreate it 380 u = SomeClass(str(items.pop())) 381 gc.collect() # just in case 382 yield u 383 finally: 384 it = None # should commit all removals 385 386 with testcontext() as u: 387 self.assertNotIn(u, s) 388 with testcontext() as u: 389 self.assertRaises(KeyError, s.remove, u) 390 self.assertNotIn(u, s) 391 with testcontext() as u: 392 s.add(u) 393 self.assertIn(u, s) 394 t = s.copy() 395 with testcontext() as u: 396 s.update(t) 397 self.assertEqual(len(s), len(t)) 398 with testcontext() as u: 399 s.clear() 400 self.assertEqual(len(s), 0) 401 402 def test_len_cycles(self): 403 N = 20 404 items = [RefCycle() for i in range(N)] 405 s = WeakSet(items) 406 del items 407 it = iter(s) 408 try: 409 next(it) 410 except StopIteration: 411 pass 412 gc.collect() 413 n1 = len(s) 414 del it 415 gc.collect() 416 n2 = len(s) 417 # one item may be kept alive inside the iterator 418 self.assertIn(n1, (0, 1)) 419 self.assertEqual(n2, 0) 420 421 def test_len_race(self): 422 # Extended sanity checks for len() in the face of cyclic collection 423 self.addCleanup(gc.set_threshold, *gc.get_threshold()) 424 for th in range(1, 100): 425 N = 20 426 gc.collect(0) 427 gc.set_threshold(th, th, th) 428 items = [RefCycle() for i in range(N)] 429 s = WeakSet(items) 430 del items 431 # All items will be collected at next garbage collection pass 432 it = iter(s) 433 try: 434 next(it) 435 except StopIteration: 436 pass 437 n1 = len(s) 438 del it 439 n2 = len(s) 440 self.assertGreaterEqual(n1, 0) 441 self.assertLessEqual(n1, N) 442 self.assertGreaterEqual(n2, 0) 443 self.assertLessEqual(n2, n1) 444 445 446def test_main(verbose=None): 447 test_support.run_unittest(TestWeakSet) 448 449if __name__ == "__main__": 450 test_main(verbose=True) 451