1import unittest
2from test import test_support
3from weakref import proxy, ref, WeakSet
4import operator
5import copy
6import string
7import os
8from random import randrange, shuffle
9import sys
10import warnings
11import collections
12import gc
13import contextlib
14
15
16class Foo:
17    pass
18
19class SomeClass(object):
20    def __init__(self, value):
21        self.value = value
22    def __eq__(self, other):
23        if type(other) != type(self):
24            return False
25        return other.value == self.value
26
27    def __ne__(self, other):
28        return not self.__eq__(other)
29
30    def __hash__(self):
31        return hash((SomeClass, self.value))
32
33class RefCycle(object):
34    def __init__(self):
35        self.cycle = self
36
37class TestWeakSet(unittest.TestCase):
38
39    def setUp(self):
40        # need to keep references to them
41        self.items = [SomeClass(c) for c in ('a', 'b', 'c')]
42        self.items2 = [SomeClass(c) for c in ('x', 'y', 'z')]
43        self.letters = [SomeClass(c) for c in string.ascii_letters]
44        self.ab_items = [SomeClass(c) for c in 'ab']
45        self.abcde_items = [SomeClass(c) for c in 'abcde']
46        self.def_items = [SomeClass(c) for c in 'def']
47        self.ab_weakset = WeakSet(self.ab_items)
48        self.abcde_weakset = WeakSet(self.abcde_items)
49        self.def_weakset = WeakSet(self.def_items)
50        self.s = WeakSet(self.items)
51        self.d = dict.fromkeys(self.items)
52        self.obj = SomeClass('F')
53        self.fs = WeakSet([self.obj])
54
55    def test_methods(self):
56        weaksetmethods = dir(WeakSet)
57        for method in dir(set):
58            if method == 'test_c_api' or method.startswith('_'):
59                continue
60            self.assertIn(method, weaksetmethods,
61                         "WeakSet missing method " + method)
62
63    def test_new_or_init(self):
64        self.assertRaises(TypeError, WeakSet, [], 2)
65
66    def test_len(self):
67        self.assertEqual(len(self.s), len(self.d))
68        self.assertEqual(len(self.fs), 1)
69        del self.obj
70        self.assertEqual(len(self.fs), 0)
71
72    def test_contains(self):
73        for c in self.letters:
74            self.assertEqual(c in self.s, c in self.d)
75        # 1 is not weakref'able, but that TypeError is caught by __contains__
76        self.assertNotIn(1, self.s)
77        self.assertIn(self.obj, self.fs)
78        del self.obj
79        self.assertNotIn(SomeClass('F'), self.fs)
80
81    def test_union(self):
82        u = self.s.union(self.items2)
83        for c in self.letters:
84            self.assertEqual(c in u, c in self.d or c in self.items2)
85        self.assertEqual(self.s, WeakSet(self.items))
86        self.assertEqual(type(u), WeakSet)
87        self.assertRaises(TypeError, self.s.union, [[]])
88        for C in set, frozenset, dict.fromkeys, list, tuple:
89            x = WeakSet(self.items + self.items2)
90            c = C(self.items2)
91            self.assertEqual(self.s.union(c), x)
92            del c
93        self.assertEqual(len(u), len(self.items) + len(self.items2))
94        self.items2.pop()
95        gc.collect()
96        self.assertEqual(len(u), len(self.items) + len(self.items2))
97
98    def test_or(self):
99        i = self.s.union(self.items2)
100        self.assertEqual(self.s | set(self.items2), i)
101        self.assertEqual(self.s | frozenset(self.items2), i)
102
103    def test_intersection(self):
104        s = WeakSet(self.letters)
105        i = s.intersection(self.items2)
106        for c in self.letters:
107            self.assertEqual(c in i, c in self.items2 and c in self.letters)
108        self.assertEqual(s, WeakSet(self.letters))
109        self.assertEqual(type(i), WeakSet)
110        for C in set, frozenset, dict.fromkeys, list, tuple:
111            x = WeakSet([])
112            self.assertEqual(i.intersection(C(self.items)), x)
113        self.assertEqual(len(i), len(self.items2))
114        self.items2.pop()
115        gc.collect()
116        self.assertEqual(len(i), len(self.items2))
117
118    def test_isdisjoint(self):
119        self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
120        self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters)))
121
122    def test_and(self):
123        i = self.s.intersection(self.items2)
124        self.assertEqual(self.s & set(self.items2), i)
125        self.assertEqual(self.s & frozenset(self.items2), i)
126
127    def test_difference(self):
128        i = self.s.difference(self.items2)
129        for c in self.letters:
130            self.assertEqual(c in i, c in self.d and c not in self.items2)
131        self.assertEqual(self.s, WeakSet(self.items))
132        self.assertEqual(type(i), WeakSet)
133        self.assertRaises(TypeError, self.s.difference, [[]])
134
135    def test_sub(self):
136        i = self.s.difference(self.items2)
137        self.assertEqual(self.s - set(self.items2), i)
138        self.assertEqual(self.s - frozenset(self.items2), i)
139
140    def test_symmetric_difference(self):
141        i = self.s.symmetric_difference(self.items2)
142        for c in self.letters:
143            self.assertEqual(c in i, (c in self.d) ^ (c in self.items2))
144        self.assertEqual(self.s, WeakSet(self.items))
145        self.assertEqual(type(i), WeakSet)
146        self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
147        self.assertEqual(len(i), len(self.items) + len(self.items2))
148        self.items2.pop()
149        gc.collect()
150        self.assertEqual(len(i), len(self.items) + len(self.items2))
151
152    def test_xor(self):
153        i = self.s.symmetric_difference(self.items2)
154        self.assertEqual(self.s ^ set(self.items2), i)
155        self.assertEqual(self.s ^ frozenset(self.items2), i)
156
157    def test_sub_and_super(self):
158        self.assertTrue(self.ab_weakset <= self.abcde_weakset)
159        self.assertTrue(self.abcde_weakset <= self.abcde_weakset)
160        self.assertTrue(self.abcde_weakset >= self.ab_weakset)
161        self.assertFalse(self.abcde_weakset <= self.def_weakset)
162        self.assertFalse(self.abcde_weakset >= self.def_weakset)
163        self.assertTrue(set('a').issubset('abc'))
164        self.assertTrue(set('abc').issuperset('a'))
165        self.assertFalse(set('a').issubset('cbs'))
166        self.assertFalse(set('cbs').issuperset('a'))
167
168    def test_lt(self):
169        self.assertTrue(self.ab_weakset < self.abcde_weakset)
170        self.assertFalse(self.abcde_weakset < self.def_weakset)
171        self.assertFalse(self.ab_weakset < self.ab_weakset)
172        self.assertFalse(WeakSet() < WeakSet())
173
174    def test_gt(self):
175        self.assertTrue(self.abcde_weakset > self.ab_weakset)
176        self.assertFalse(self.abcde_weakset > self.def_weakset)
177        self.assertFalse(self.ab_weakset > self.ab_weakset)
178        self.assertFalse(WeakSet() > WeakSet())
179
180    def test_gc(self):
181        # Create a nest of cycles to exercise overall ref count check
182        s = WeakSet(Foo() for i in range(1000))
183        for elem in s:
184            elem.cycle = s
185            elem.sub = elem
186            elem.set = WeakSet([elem])
187
188    def test_subclass_with_custom_hash(self):
189        # Bug #1257731
190        class H(WeakSet):
191            def __hash__(self):
192                return int(id(self) & 0x7fffffff)
193        s=H()
194        f=set()
195        f.add(s)
196        self.assertIn(s, f)
197        f.remove(s)
198        f.add(s)
199        f.discard(s)
200
201    def test_init(self):
202        s = WeakSet()
203        s.__init__(self.items)
204        self.assertEqual(s, self.s)
205        s.__init__(self.items2)
206        self.assertEqual(s, WeakSet(self.items2))
207        self.assertRaises(TypeError, s.__init__, s, 2);
208        self.assertRaises(TypeError, s.__init__, 1);
209
210    def test_constructor_identity(self):
211        s = WeakSet(self.items)
212        t = WeakSet(s)
213        self.assertNotEqual(id(s), id(t))
214
215    def test_hash(self):
216        self.assertRaises(TypeError, hash, self.s)
217
218    def test_clear(self):
219        self.s.clear()
220        self.assertEqual(self.s, WeakSet([]))
221        self.assertEqual(len(self.s), 0)
222
223    def test_copy(self):
224        dup = self.s.copy()
225        self.assertEqual(self.s, dup)
226        self.assertNotEqual(id(self.s), id(dup))
227
228    def test_add(self):
229        x = SomeClass('Q')
230        self.s.add(x)
231        self.assertIn(x, self.s)
232        dup = self.s.copy()
233        self.s.add(x)
234        self.assertEqual(self.s, dup)
235        self.assertRaises(TypeError, self.s.add, [])
236        self.fs.add(Foo())
237        self.assertTrue(len(self.fs) == 1)
238        self.fs.add(self.obj)
239        self.assertTrue(len(self.fs) == 1)
240
241    def test_remove(self):
242        x = SomeClass('a')
243        self.s.remove(x)
244        self.assertNotIn(x, self.s)
245        self.assertRaises(KeyError, self.s.remove, x)
246        self.assertRaises(TypeError, self.s.remove, [])
247
248    def test_discard(self):
249        a, q = SomeClass('a'), SomeClass('Q')
250        self.s.discard(a)
251        self.assertNotIn(a, self.s)
252        self.s.discard(q)
253        self.assertRaises(TypeError, self.s.discard, [])
254
255    def test_pop(self):
256        for i in range(len(self.s)):
257            elem = self.s.pop()
258            self.assertNotIn(elem, self.s)
259        self.assertRaises(KeyError, self.s.pop)
260
261    def test_update(self):
262        retval = self.s.update(self.items2)
263        self.assertEqual(retval, None)
264        for c in (self.items + self.items2):
265            self.assertIn(c, self.s)
266        self.assertRaises(TypeError, self.s.update, [[]])
267
268    def test_update_set(self):
269        self.s.update(set(self.items2))
270        for c in (self.items + self.items2):
271            self.assertIn(c, self.s)
272
273    def test_ior(self):
274        self.s |= set(self.items2)
275        for c in (self.items + self.items2):
276            self.assertIn(c, self.s)
277
278    def test_intersection_update(self):
279        retval = self.s.intersection_update(self.items2)
280        self.assertEqual(retval, None)
281        for c in (self.items + self.items2):
282            if c in self.items2 and c in self.items:
283                self.assertIn(c, self.s)
284            else:
285                self.assertNotIn(c, self.s)
286        self.assertRaises(TypeError, self.s.intersection_update, [[]])
287
288    def test_iand(self):
289        self.s &= set(self.items2)
290        for c in (self.items + self.items2):
291            if c in self.items2 and c in self.items:
292                self.assertIn(c, self.s)
293            else:
294                self.assertNotIn(c, self.s)
295
296    def test_difference_update(self):
297        retval = self.s.difference_update(self.items2)
298        self.assertEqual(retval, None)
299        for c in (self.items + self.items2):
300            if c in self.items and c not in self.items2:
301                self.assertIn(c, self.s)
302            else:
303                self.assertNotIn(c, self.s)
304        self.assertRaises(TypeError, self.s.difference_update, [[]])
305        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
306
307    def test_isub(self):
308        self.s -= set(self.items2)
309        for c in (self.items + self.items2):
310            if c in self.items and c not in self.items2:
311                self.assertIn(c, self.s)
312            else:
313                self.assertNotIn(c, self.s)
314
315    def test_symmetric_difference_update(self):
316        retval = self.s.symmetric_difference_update(self.items2)
317        self.assertEqual(retval, None)
318        for c in (self.items + self.items2):
319            if (c in self.items) ^ (c in self.items2):
320                self.assertIn(c, self.s)
321            else:
322                self.assertNotIn(c, self.s)
323        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
324
325    def test_ixor(self):
326        self.s ^= set(self.items2)
327        for c in (self.items + self.items2):
328            if (c in self.items) ^ (c in self.items2):
329                self.assertIn(c, self.s)
330            else:
331                self.assertNotIn(c, self.s)
332
333    def test_inplace_on_self(self):
334        t = self.s.copy()
335        t |= t
336        self.assertEqual(t, self.s)
337        t &= t
338        self.assertEqual(t, self.s)
339        t -= t
340        self.assertEqual(t, WeakSet())
341        t = self.s.copy()
342        t ^= t
343        self.assertEqual(t, WeakSet())
344
345    def test_eq(self):
346        # issue 5964
347        self.assertTrue(self.s == self.s)
348        self.assertTrue(self.s == WeakSet(self.items))
349        self.assertFalse(self.s == set(self.items))
350        self.assertFalse(self.s == list(self.items))
351        self.assertFalse(self.s == tuple(self.items))
352        self.assertFalse(self.s == 1)
353
354    def test_weak_destroy_while_iterating(self):
355        # Issue #7105: iterators shouldn't crash when a key is implicitly removed
356        # Create new items to be sure no-one else holds a reference
357        items = [SomeClass(c) for c in ('a', 'b', 'c')]
358        s = WeakSet(items)
359        it = iter(s)
360        next(it)             # Trigger internal iteration
361        # Destroy an item
362        del items[-1]
363        gc.collect()    # just in case
364        # We have removed either the first consumed items, or another one
365        self.assertIn(len(list(it)), [len(items), len(items) - 1])
366        del it
367        # The removal has been committed
368        self.assertEqual(len(s), len(items))
369
370    def test_weak_destroy_and_mutate_while_iterating(self):
371        # Issue #7105: iterators shouldn't crash when a key is implicitly removed
372        items = [SomeClass(c) for c in string.ascii_letters]
373        s = WeakSet(items)
374        @contextlib.contextmanager
375        def testcontext():
376            try:
377                it = iter(s)
378                next(it)
379                # Schedule an item for removal and recreate it
380                u = SomeClass(str(items.pop()))
381                gc.collect()      # just in case
382                yield u
383            finally:
384                it = None           # should commit all removals
385
386        with testcontext() as u:
387            self.assertNotIn(u, s)
388        with testcontext() as u:
389            self.assertRaises(KeyError, s.remove, u)
390        self.assertNotIn(u, s)
391        with testcontext() as u:
392            s.add(u)
393        self.assertIn(u, s)
394        t = s.copy()
395        with testcontext() as u:
396            s.update(t)
397        self.assertEqual(len(s), len(t))
398        with testcontext() as u:
399            s.clear()
400        self.assertEqual(len(s), 0)
401
402    def test_len_cycles(self):
403        N = 20
404        items = [RefCycle() for i in range(N)]
405        s = WeakSet(items)
406        del items
407        it = iter(s)
408        try:
409            next(it)
410        except StopIteration:
411            pass
412        gc.collect()
413        n1 = len(s)
414        del it
415        gc.collect()
416        n2 = len(s)
417        # one item may be kept alive inside the iterator
418        self.assertIn(n1, (0, 1))
419        self.assertEqual(n2, 0)
420
421    def test_len_race(self):
422        # Extended sanity checks for len() in the face of cyclic collection
423        self.addCleanup(gc.set_threshold, *gc.get_threshold())
424        for th in range(1, 100):
425            N = 20
426            gc.collect(0)
427            gc.set_threshold(th, th, th)
428            items = [RefCycle() for i in range(N)]
429            s = WeakSet(items)
430            del items
431            # All items will be collected at next garbage collection pass
432            it = iter(s)
433            try:
434                next(it)
435            except StopIteration:
436                pass
437            n1 = len(s)
438            del it
439            n2 = len(s)
440            self.assertGreaterEqual(n1, 0)
441            self.assertLessEqual(n1, N)
442            self.assertGreaterEqual(n2, 0)
443            self.assertLessEqual(n2, n1)
444
445
446def test_main(verbose=None):
447    test_support.run_unittest(TestWeakSet)
448
449if __name__ == "__main__":
450    test_main(verbose=True)
451