1import unittest
2from test import support
3import gc
4import weakref
5import operator
6import copy
7import pickle
8from random import randrange, shuffle
9import warnings
10import collections
11import collections.abc
12import itertools
13import string
14
15class PassThru(Exception):
16    pass
17
18def check_pass_thru():
19    raise PassThru
20    yield 1
21
22class BadCmp:
23    def __hash__(self):
24        return 1
25    def __eq__(self, other):
26        raise RuntimeError
27
28class ReprWrapper:
29    'Used to test self-referential repr() calls'
30    def __repr__(self):
31        return repr(self.value)
32
33class HashCountingInt(int):
34    'int-like object that counts the number of times __hash__ is called'
35    def __init__(self, *args):
36        self.hash_count = 0
37    def __hash__(self):
38        self.hash_count += 1
39        return int.__hash__(self)
40
41class TestJointOps:
42    # Tests common to both set and frozenset
43
44    def setUp(self):
45        self.word = word = 'simsalabim'
46        self.otherword = 'madagascar'
47        self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
48        self.s = self.thetype(word)
49        self.d = dict.fromkeys(word)
50
51    def test_new_or_init(self):
52        self.assertRaises(TypeError, self.thetype, [], 2)
53        self.assertRaises(TypeError, set().__init__, a=1)
54
55    def test_uniquification(self):
56        actual = sorted(self.s)
57        expected = sorted(self.d)
58        self.assertEqual(actual, expected)
59        self.assertRaises(PassThru, self.thetype, check_pass_thru())
60        self.assertRaises(TypeError, self.thetype, [[]])
61
62    def test_len(self):
63        self.assertEqual(len(self.s), len(self.d))
64
65    def test_contains(self):
66        for c in self.letters:
67            self.assertEqual(c in self.s, c in self.d)
68        self.assertRaises(TypeError, self.s.__contains__, [[]])
69        s = self.thetype([frozenset(self.letters)])
70        self.assertIn(self.thetype(self.letters), s)
71
72    def test_union(self):
73        u = self.s.union(self.otherword)
74        for c in self.letters:
75            self.assertEqual(c in u, c in self.d or c in self.otherword)
76        self.assertEqual(self.s, self.thetype(self.word))
77        self.assertEqual(type(u), self.basetype)
78        self.assertRaises(PassThru, self.s.union, check_pass_thru())
79        self.assertRaises(TypeError, self.s.union, [[]])
80        for C in set, frozenset, dict.fromkeys, str, list, tuple:
81            self.assertEqual(self.thetype('abcba').union(C('cdc')), set('abcd'))
82            self.assertEqual(self.thetype('abcba').union(C('efgfe')), set('abcefg'))
83            self.assertEqual(self.thetype('abcba').union(C('ccb')), set('abc'))
84            self.assertEqual(self.thetype('abcba').union(C('ef')), set('abcef'))
85            self.assertEqual(self.thetype('abcba').union(C('ef'), C('fg')), set('abcefg'))
86
87        # Issue #6573
88        x = self.thetype()
89        self.assertEqual(x.union(set([1]), x, set([2])), self.thetype([1, 2]))
90
91    def test_or(self):
92        i = self.s.union(self.otherword)
93        self.assertEqual(self.s | set(self.otherword), i)
94        self.assertEqual(self.s | frozenset(self.otherword), i)
95        try:
96            self.s | self.otherword
97        except TypeError:
98            pass
99        else:
100            self.fail("s|t did not screen-out general iterables")
101
102    def test_intersection(self):
103        i = self.s.intersection(self.otherword)
104        for c in self.letters:
105            self.assertEqual(c in i, c in self.d and c in self.otherword)
106        self.assertEqual(self.s, self.thetype(self.word))
107        self.assertEqual(type(i), self.basetype)
108        self.assertRaises(PassThru, self.s.intersection, check_pass_thru())
109        for C in set, frozenset, dict.fromkeys, str, list, tuple:
110            self.assertEqual(self.thetype('abcba').intersection(C('cdc')), set('cc'))
111            self.assertEqual(self.thetype('abcba').intersection(C('efgfe')), set(''))
112            self.assertEqual(self.thetype('abcba').intersection(C('ccb')), set('bc'))
113            self.assertEqual(self.thetype('abcba').intersection(C('ef')), set(''))
114            self.assertEqual(self.thetype('abcba').intersection(C('cbcf'), C('bag')), set('b'))
115        s = self.thetype('abcba')
116        z = s.intersection()
117        if self.thetype == frozenset():
118            self.assertEqual(id(s), id(z))
119        else:
120            self.assertNotEqual(id(s), id(z))
121
122    def test_isdisjoint(self):
123        def f(s1, s2):
124            'Pure python equivalent of isdisjoint()'
125            return not set(s1).intersection(s2)
126        for larg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
127            s1 = self.thetype(larg)
128            for rarg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
129                for C in set, frozenset, dict.fromkeys, str, list, tuple:
130                    s2 = C(rarg)
131                    actual = s1.isdisjoint(s2)
132                    expected = f(s1, s2)
133                    self.assertEqual(actual, expected)
134                    self.assertTrue(actual is True or actual is False)
135
136    def test_and(self):
137        i = self.s.intersection(self.otherword)
138        self.assertEqual(self.s & set(self.otherword), i)
139        self.assertEqual(self.s & frozenset(self.otherword), i)
140        try:
141            self.s & self.otherword
142        except TypeError:
143            pass
144        else:
145            self.fail("s&t did not screen-out general iterables")
146
147    def test_difference(self):
148        i = self.s.difference(self.otherword)
149        for c in self.letters:
150            self.assertEqual(c in i, c in self.d and c not in self.otherword)
151        self.assertEqual(self.s, self.thetype(self.word))
152        self.assertEqual(type(i), self.basetype)
153        self.assertRaises(PassThru, self.s.difference, check_pass_thru())
154        self.assertRaises(TypeError, self.s.difference, [[]])
155        for C in set, frozenset, dict.fromkeys, str, list, tuple:
156            self.assertEqual(self.thetype('abcba').difference(C('cdc')), set('ab'))
157            self.assertEqual(self.thetype('abcba').difference(C('efgfe')), set('abc'))
158            self.assertEqual(self.thetype('abcba').difference(C('ccb')), set('a'))
159            self.assertEqual(self.thetype('abcba').difference(C('ef')), set('abc'))
160            self.assertEqual(self.thetype('abcba').difference(), set('abc'))
161            self.assertEqual(self.thetype('abcba').difference(C('a'), C('b')), set('c'))
162
163    def test_sub(self):
164        i = self.s.difference(self.otherword)
165        self.assertEqual(self.s - set(self.otherword), i)
166        self.assertEqual(self.s - frozenset(self.otherword), i)
167        try:
168            self.s - self.otherword
169        except TypeError:
170            pass
171        else:
172            self.fail("s-t did not screen-out general iterables")
173
174    def test_symmetric_difference(self):
175        i = self.s.symmetric_difference(self.otherword)
176        for c in self.letters:
177            self.assertEqual(c in i, (c in self.d) ^ (c in self.otherword))
178        self.assertEqual(self.s, self.thetype(self.word))
179        self.assertEqual(type(i), self.basetype)
180        self.assertRaises(PassThru, self.s.symmetric_difference, check_pass_thru())
181        self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
182        for C in set, frozenset, dict.fromkeys, str, list, tuple:
183            self.assertEqual(self.thetype('abcba').symmetric_difference(C('cdc')), set('abd'))
184            self.assertEqual(self.thetype('abcba').symmetric_difference(C('efgfe')), set('abcefg'))
185            self.assertEqual(self.thetype('abcba').symmetric_difference(C('ccb')), set('a'))
186            self.assertEqual(self.thetype('abcba').symmetric_difference(C('ef')), set('abcef'))
187
188    def test_xor(self):
189        i = self.s.symmetric_difference(self.otherword)
190        self.assertEqual(self.s ^ set(self.otherword), i)
191        self.assertEqual(self.s ^ frozenset(self.otherword), i)
192        try:
193            self.s ^ self.otherword
194        except TypeError:
195            pass
196        else:
197            self.fail("s^t did not screen-out general iterables")
198
199    def test_equality(self):
200        self.assertEqual(self.s, set(self.word))
201        self.assertEqual(self.s, frozenset(self.word))
202        self.assertEqual(self.s == self.word, False)
203        self.assertNotEqual(self.s, set(self.otherword))
204        self.assertNotEqual(self.s, frozenset(self.otherword))
205        self.assertEqual(self.s != self.word, True)
206
207    def test_setOfFrozensets(self):
208        t = map(frozenset, ['abcdef', 'bcd', 'bdcb', 'fed', 'fedccba'])
209        s = self.thetype(t)
210        self.assertEqual(len(s), 3)
211
212    def test_sub_and_super(self):
213        p, q, r = map(self.thetype, ['ab', 'abcde', 'def'])
214        self.assertTrue(p < q)
215        self.assertTrue(p <= q)
216        self.assertTrue(q <= q)
217        self.assertTrue(q > p)
218        self.assertTrue(q >= p)
219        self.assertFalse(q < r)
220        self.assertFalse(q <= r)
221        self.assertFalse(q > r)
222        self.assertFalse(q >= r)
223        self.assertTrue(set('a').issubset('abc'))
224        self.assertTrue(set('abc').issuperset('a'))
225        self.assertFalse(set('a').issubset('cbs'))
226        self.assertFalse(set('cbs').issuperset('a'))
227
228    def test_pickling(self):
229        for i in range(pickle.HIGHEST_PROTOCOL + 1):
230            p = pickle.dumps(self.s, i)
231            dup = pickle.loads(p)
232            self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup))
233            if type(self.s) not in (set, frozenset):
234                self.s.x = 10
235                p = pickle.dumps(self.s, i)
236                dup = pickle.loads(p)
237                self.assertEqual(self.s.x, dup.x)
238
239    def test_iterator_pickling(self):
240        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
241            itorg = iter(self.s)
242            data = self.thetype(self.s)
243            d = pickle.dumps(itorg, proto)
244            it = pickle.loads(d)
245            # Set iterators unpickle as list iterators due to the
246            # undefined order of set items.
247            # self.assertEqual(type(itorg), type(it))
248            self.assertIsInstance(it, collections.abc.Iterator)
249            self.assertEqual(self.thetype(it), data)
250
251            it = pickle.loads(d)
252            try:
253                drop = next(it)
254            except StopIteration:
255                continue
256            d = pickle.dumps(it, proto)
257            it = pickle.loads(d)
258            self.assertEqual(self.thetype(it), data - self.thetype((drop,)))
259
260    def test_deepcopy(self):
261        class Tracer:
262            def __init__(self, value):
263                self.value = value
264            def __hash__(self):
265                return self.value
266            def __deepcopy__(self, memo=None):
267                return Tracer(self.value + 1)
268        t = Tracer(10)
269        s = self.thetype([t])
270        dup = copy.deepcopy(s)
271        self.assertNotEqual(id(s), id(dup))
272        for elem in dup:
273            newt = elem
274        self.assertNotEqual(id(t), id(newt))
275        self.assertEqual(t.value + 1, newt.value)
276
277    def test_gc(self):
278        # Create a nest of cycles to exercise overall ref count check
279        class A:
280            pass
281        s = set(A() for i in range(1000))
282        for elem in s:
283            elem.cycle = s
284            elem.sub = elem
285            elem.set = set([elem])
286
287    def test_subclass_with_custom_hash(self):
288        # Bug #1257731
289        class H(self.thetype):
290            def __hash__(self):
291                return int(id(self) & 0x7fffffff)
292        s=H()
293        f=set()
294        f.add(s)
295        self.assertIn(s, f)
296        f.remove(s)
297        f.add(s)
298        f.discard(s)
299
300    def test_badcmp(self):
301        s = self.thetype([BadCmp()])
302        # Detect comparison errors during insertion and lookup
303        self.assertRaises(RuntimeError, self.thetype, [BadCmp(), BadCmp()])
304        self.assertRaises(RuntimeError, s.__contains__, BadCmp())
305        # Detect errors during mutating operations
306        if hasattr(s, 'add'):
307            self.assertRaises(RuntimeError, s.add, BadCmp())
308            self.assertRaises(RuntimeError, s.discard, BadCmp())
309            self.assertRaises(RuntimeError, s.remove, BadCmp())
310
311    def test_cyclical_repr(self):
312        w = ReprWrapper()
313        s = self.thetype([w])
314        w.value = s
315        if self.thetype == set:
316            self.assertEqual(repr(s), '{set(...)}')
317        else:
318            name = repr(s).partition('(')[0]    # strip class name
319            self.assertEqual(repr(s), '%s({%s(...)})' % (name, name))
320
321    def test_cyclical_print(self):
322        w = ReprWrapper()
323        s = self.thetype([w])
324        w.value = s
325        fo = open(support.TESTFN, "w")
326        try:
327            fo.write(str(s))
328            fo.close()
329            fo = open(support.TESTFN, "r")
330            self.assertEqual(fo.read(), repr(s))
331        finally:
332            fo.close()
333            support.unlink(support.TESTFN)
334
335    def test_do_not_rehash_dict_keys(self):
336        n = 10
337        d = dict.fromkeys(map(HashCountingInt, range(n)))
338        self.assertEqual(sum(elem.hash_count for elem in d), n)
339        s = self.thetype(d)
340        self.assertEqual(sum(elem.hash_count for elem in d), n)
341        s.difference(d)
342        self.assertEqual(sum(elem.hash_count for elem in d), n)
343        if hasattr(s, 'symmetric_difference_update'):
344            s.symmetric_difference_update(d)
345        self.assertEqual(sum(elem.hash_count for elem in d), n)
346        d2 = dict.fromkeys(set(d))
347        self.assertEqual(sum(elem.hash_count for elem in d), n)
348        d3 = dict.fromkeys(frozenset(d))
349        self.assertEqual(sum(elem.hash_count for elem in d), n)
350        d3 = dict.fromkeys(frozenset(d), 123)
351        self.assertEqual(sum(elem.hash_count for elem in d), n)
352        self.assertEqual(d3, dict.fromkeys(d, 123))
353
354    def test_container_iterator(self):
355        # Bug #3680: tp_traverse was not implemented for set iterator object
356        class C(object):
357            pass
358        obj = C()
359        ref = weakref.ref(obj)
360        container = set([obj, 1])
361        obj.x = iter(container)
362        del obj, container
363        gc.collect()
364        self.assertTrue(ref() is None, "Cycle was not collected")
365
366    def test_free_after_iterating(self):
367        support.check_free_after_iterating(self, iter, self.thetype)
368
369class TestSet(TestJointOps, unittest.TestCase):
370    thetype = set
371    basetype = set
372
373    def test_init(self):
374        s = self.thetype()
375        s.__init__(self.word)
376        self.assertEqual(s, set(self.word))
377        s.__init__(self.otherword)
378        self.assertEqual(s, set(self.otherword))
379        self.assertRaises(TypeError, s.__init__, s, 2);
380        self.assertRaises(TypeError, s.__init__, 1);
381
382    def test_constructor_identity(self):
383        s = self.thetype(range(3))
384        t = self.thetype(s)
385        self.assertNotEqual(id(s), id(t))
386
387    def test_set_literal(self):
388        s = set([1,2,3])
389        t = {1,2,3}
390        self.assertEqual(s, t)
391
392    def test_set_literal_insertion_order(self):
393        # SF Issue #26020 -- Expect left to right insertion
394        s = {1, 1.0, True}
395        self.assertEqual(len(s), 1)
396        stored_value = s.pop()
397        self.assertEqual(type(stored_value), int)
398
399    def test_set_literal_evaluation_order(self):
400        # Expect left to right expression evaluation
401        events = []
402        def record(obj):
403            events.append(obj)
404        s = {record(1), record(2), record(3)}
405        self.assertEqual(events, [1, 2, 3])
406
407    def test_hash(self):
408        self.assertRaises(TypeError, hash, self.s)
409
410    def test_clear(self):
411        self.s.clear()
412        self.assertEqual(self.s, set())
413        self.assertEqual(len(self.s), 0)
414
415    def test_copy(self):
416        dup = self.s.copy()
417        self.assertEqual(self.s, dup)
418        self.assertNotEqual(id(self.s), id(dup))
419        self.assertEqual(type(dup), self.basetype)
420
421    def test_add(self):
422        self.s.add('Q')
423        self.assertIn('Q', self.s)
424        dup = self.s.copy()
425        self.s.add('Q')
426        self.assertEqual(self.s, dup)
427        self.assertRaises(TypeError, self.s.add, [])
428
429    def test_remove(self):
430        self.s.remove('a')
431        self.assertNotIn('a', self.s)
432        self.assertRaises(KeyError, self.s.remove, 'Q')
433        self.assertRaises(TypeError, self.s.remove, [])
434        s = self.thetype([frozenset(self.word)])
435        self.assertIn(self.thetype(self.word), s)
436        s.remove(self.thetype(self.word))
437        self.assertNotIn(self.thetype(self.word), s)
438        self.assertRaises(KeyError, self.s.remove, self.thetype(self.word))
439
440    def test_remove_keyerror_unpacking(self):
441        # bug:  www.python.org/sf/1576657
442        for v1 in ['Q', (1,)]:
443            try:
444                self.s.remove(v1)
445            except KeyError as e:
446                v2 = e.args[0]
447                self.assertEqual(v1, v2)
448            else:
449                self.fail()
450
451    def test_remove_keyerror_set(self):
452        key = self.thetype([3, 4])
453        try:
454            self.s.remove(key)
455        except KeyError as e:
456            self.assertTrue(e.args[0] is key,
457                         "KeyError should be {0}, not {1}".format(key,
458                                                                  e.args[0]))
459        else:
460            self.fail()
461
462    def test_discard(self):
463        self.s.discard('a')
464        self.assertNotIn('a', self.s)
465        self.s.discard('Q')
466        self.assertRaises(TypeError, self.s.discard, [])
467        s = self.thetype([frozenset(self.word)])
468        self.assertIn(self.thetype(self.word), s)
469        s.discard(self.thetype(self.word))
470        self.assertNotIn(self.thetype(self.word), s)
471        s.discard(self.thetype(self.word))
472
473    def test_pop(self):
474        for i in range(len(self.s)):
475            elem = self.s.pop()
476            self.assertNotIn(elem, self.s)
477        self.assertRaises(KeyError, self.s.pop)
478
479    def test_update(self):
480        retval = self.s.update(self.otherword)
481        self.assertEqual(retval, None)
482        for c in (self.word + self.otherword):
483            self.assertIn(c, self.s)
484        self.assertRaises(PassThru, self.s.update, check_pass_thru())
485        self.assertRaises(TypeError, self.s.update, [[]])
486        for p, q in (('cdc', 'abcd'), ('efgfe', 'abcefg'), ('ccb', 'abc'), ('ef', 'abcef')):
487            for C in set, frozenset, dict.fromkeys, str, list, tuple:
488                s = self.thetype('abcba')
489                self.assertEqual(s.update(C(p)), None)
490                self.assertEqual(s, set(q))
491        for p in ('cdc', 'efgfe', 'ccb', 'ef', 'abcda'):
492            q = 'ahi'
493            for C in set, frozenset, dict.fromkeys, str, list, tuple:
494                s = self.thetype('abcba')
495                self.assertEqual(s.update(C(p), C(q)), None)
496                self.assertEqual(s, set(s) | set(p) | set(q))
497
498    def test_ior(self):
499        self.s |= set(self.otherword)
500        for c in (self.word + self.otherword):
501            self.assertIn(c, self.s)
502
503    def test_intersection_update(self):
504        retval = self.s.intersection_update(self.otherword)
505        self.assertEqual(retval, None)
506        for c in (self.word + self.otherword):
507            if c in self.otherword and c in self.word:
508                self.assertIn(c, self.s)
509            else:
510                self.assertNotIn(c, self.s)
511        self.assertRaises(PassThru, self.s.intersection_update, check_pass_thru())
512        self.assertRaises(TypeError, self.s.intersection_update, [[]])
513        for p, q in (('cdc', 'c'), ('efgfe', ''), ('ccb', 'bc'), ('ef', '')):
514            for C in set, frozenset, dict.fromkeys, str, list, tuple:
515                s = self.thetype('abcba')
516                self.assertEqual(s.intersection_update(C(p)), None)
517                self.assertEqual(s, set(q))
518                ss = 'abcba'
519                s = self.thetype(ss)
520                t = 'cbc'
521                self.assertEqual(s.intersection_update(C(p), C(t)), None)
522                self.assertEqual(s, set('abcba')&set(p)&set(t))
523
524    def test_iand(self):
525        self.s &= set(self.otherword)
526        for c in (self.word + self.otherword):
527            if c in self.otherword and c in self.word:
528                self.assertIn(c, self.s)
529            else:
530                self.assertNotIn(c, self.s)
531
532    def test_difference_update(self):
533        retval = self.s.difference_update(self.otherword)
534        self.assertEqual(retval, None)
535        for c in (self.word + self.otherword):
536            if c in self.word and c not in self.otherword:
537                self.assertIn(c, self.s)
538            else:
539                self.assertNotIn(c, self.s)
540        self.assertRaises(PassThru, self.s.difference_update, check_pass_thru())
541        self.assertRaises(TypeError, self.s.difference_update, [[]])
542        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
543        for p, q in (('cdc', 'ab'), ('efgfe', 'abc'), ('ccb', 'a'), ('ef', 'abc')):
544            for C in set, frozenset, dict.fromkeys, str, list, tuple:
545                s = self.thetype('abcba')
546                self.assertEqual(s.difference_update(C(p)), None)
547                self.assertEqual(s, set(q))
548
549                s = self.thetype('abcdefghih')
550                s.difference_update()
551                self.assertEqual(s, self.thetype('abcdefghih'))
552
553                s = self.thetype('abcdefghih')
554                s.difference_update(C('aba'))
555                self.assertEqual(s, self.thetype('cdefghih'))
556
557                s = self.thetype('abcdefghih')
558                s.difference_update(C('cdc'), C('aba'))
559                self.assertEqual(s, self.thetype('efghih'))
560
561    def test_isub(self):
562        self.s -= set(self.otherword)
563        for c in (self.word + self.otherword):
564            if c in self.word and c not in self.otherword:
565                self.assertIn(c, self.s)
566            else:
567                self.assertNotIn(c, self.s)
568
569    def test_symmetric_difference_update(self):
570        retval = self.s.symmetric_difference_update(self.otherword)
571        self.assertEqual(retval, None)
572        for c in (self.word + self.otherword):
573            if (c in self.word) ^ (c in self.otherword):
574                self.assertIn(c, self.s)
575            else:
576                self.assertNotIn(c, self.s)
577        self.assertRaises(PassThru, self.s.symmetric_difference_update, check_pass_thru())
578        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
579        for p, q in (('cdc', 'abd'), ('efgfe', 'abcefg'), ('ccb', 'a'), ('ef', 'abcef')):
580            for C in set, frozenset, dict.fromkeys, str, list, tuple:
581                s = self.thetype('abcba')
582                self.assertEqual(s.symmetric_difference_update(C(p)), None)
583                self.assertEqual(s, set(q))
584
585    def test_ixor(self):
586        self.s ^= set(self.otherword)
587        for c in (self.word + self.otherword):
588            if (c in self.word) ^ (c in self.otherword):
589                self.assertIn(c, self.s)
590            else:
591                self.assertNotIn(c, self.s)
592
593    def test_inplace_on_self(self):
594        t = self.s.copy()
595        t |= t
596        self.assertEqual(t, self.s)
597        t &= t
598        self.assertEqual(t, self.s)
599        t -= t
600        self.assertEqual(t, self.thetype())
601        t = self.s.copy()
602        t ^= t
603        self.assertEqual(t, self.thetype())
604
605    def test_weakref(self):
606        s = self.thetype('gallahad')
607        p = weakref.proxy(s)
608        self.assertEqual(str(p), str(s))
609        s = None
610        self.assertRaises(ReferenceError, str, p)
611
612    def test_rich_compare(self):
613        class TestRichSetCompare:
614            def __gt__(self, some_set):
615                self.gt_called = True
616                return False
617            def __lt__(self, some_set):
618                self.lt_called = True
619                return False
620            def __ge__(self, some_set):
621                self.ge_called = True
622                return False
623            def __le__(self, some_set):
624                self.le_called = True
625                return False
626
627        # This first tries the builtin rich set comparison, which doesn't know
628        # how to handle the custom object. Upon returning NotImplemented, the
629        # corresponding comparison on the right object is invoked.
630        myset = {1, 2, 3}
631
632        myobj = TestRichSetCompare()
633        myset < myobj
634        self.assertTrue(myobj.gt_called)
635
636        myobj = TestRichSetCompare()
637        myset > myobj
638        self.assertTrue(myobj.lt_called)
639
640        myobj = TestRichSetCompare()
641        myset <= myobj
642        self.assertTrue(myobj.ge_called)
643
644        myobj = TestRichSetCompare()
645        myset >= myobj
646        self.assertTrue(myobj.le_called)
647
648    @unittest.skipUnless(hasattr(set, "test_c_api"),
649                         'C API test only available in a debug build')
650    def test_c_api(self):
651        self.assertEqual(set().test_c_api(), True)
652
653class SetSubclass(set):
654    pass
655
656class TestSetSubclass(TestSet):
657    thetype = SetSubclass
658    basetype = set
659
660class SetSubclassWithKeywordArgs(set):
661    def __init__(self, iterable=[], newarg=None):
662        set.__init__(self, iterable)
663
664class TestSetSubclassWithKeywordArgs(TestSet):
665
666    def test_keywords_in_subclass(self):
667        'SF bug #1486663 -- this used to erroneously raise a TypeError'
668        SetSubclassWithKeywordArgs(newarg=1)
669
670class TestFrozenSet(TestJointOps, unittest.TestCase):
671    thetype = frozenset
672    basetype = frozenset
673
674    def test_init(self):
675        s = self.thetype(self.word)
676        s.__init__(self.otherword)
677        self.assertEqual(s, set(self.word))
678
679    def test_singleton_empty_frozenset(self):
680        f = frozenset()
681        efs = [frozenset(), frozenset([]), frozenset(()), frozenset(''),
682               frozenset(), frozenset([]), frozenset(()), frozenset(''),
683               frozenset(range(0)), frozenset(frozenset()),
684               frozenset(f), f]
685        # All of the empty frozensets should have just one id()
686        self.assertEqual(len(set(map(id, efs))), 1)
687
688    def test_constructor_identity(self):
689        s = self.thetype(range(3))
690        t = self.thetype(s)
691        self.assertEqual(id(s), id(t))
692
693    def test_hash(self):
694        self.assertEqual(hash(self.thetype('abcdeb')),
695                         hash(self.thetype('ebecda')))
696
697        # make sure that all permutations give the same hash value
698        n = 100
699        seq = [randrange(n) for i in range(n)]
700        results = set()
701        for i in range(200):
702            shuffle(seq)
703            results.add(hash(self.thetype(seq)))
704        self.assertEqual(len(results), 1)
705
706    def test_copy(self):
707        dup = self.s.copy()
708        self.assertEqual(id(self.s), id(dup))
709
710    def test_frozen_as_dictkey(self):
711        seq = list(range(10)) + list('abcdefg') + ['apple']
712        key1 = self.thetype(seq)
713        key2 = self.thetype(reversed(seq))
714        self.assertEqual(key1, key2)
715        self.assertNotEqual(id(key1), id(key2))
716        d = {}
717        d[key1] = 42
718        self.assertEqual(d[key2], 42)
719
720    def test_hash_caching(self):
721        f = self.thetype('abcdcda')
722        self.assertEqual(hash(f), hash(f))
723
724    def test_hash_effectiveness(self):
725        n = 13
726        hashvalues = set()
727        addhashvalue = hashvalues.add
728        elemmasks = [(i+1, 1<<i) for i in range(n)]
729        for i in range(2**n):
730            addhashvalue(hash(frozenset([e for e, m in elemmasks if m&i])))
731        self.assertEqual(len(hashvalues), 2**n)
732
733        def zf_range(n):
734            # https://en.wikipedia.org/wiki/Set-theoretic_definition_of_natural_numbers
735            nums = [frozenset()]
736            for i in range(n-1):
737                num = frozenset(nums)
738                nums.append(num)
739            return nums[:n]
740
741        def powerset(s):
742            for i in range(len(s)+1):
743                yield from map(frozenset, itertools.combinations(s, i))
744
745        for n in range(18):
746            t = 2 ** n
747            mask = t - 1
748            for nums in (range, zf_range):
749                u = len({h & mask for h in map(hash, powerset(nums(n)))})
750                self.assertGreater(4*u, t)
751
752class FrozenSetSubclass(frozenset):
753    pass
754
755class TestFrozenSetSubclass(TestFrozenSet):
756    thetype = FrozenSetSubclass
757    basetype = frozenset
758
759    def test_constructor_identity(self):
760        s = self.thetype(range(3))
761        t = self.thetype(s)
762        self.assertNotEqual(id(s), id(t))
763
764    def test_copy(self):
765        dup = self.s.copy()
766        self.assertNotEqual(id(self.s), id(dup))
767
768    def test_nested_empty_constructor(self):
769        s = self.thetype()
770        t = self.thetype(s)
771        self.assertEqual(s, t)
772
773    def test_singleton_empty_frozenset(self):
774        Frozenset = self.thetype
775        f = frozenset()
776        F = Frozenset()
777        efs = [Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''),
778               Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''),
779               Frozenset(range(0)), Frozenset(Frozenset()),
780               Frozenset(frozenset()), f, F, Frozenset(f), Frozenset(F)]
781        # All empty frozenset subclass instances should have different ids
782        self.assertEqual(len(set(map(id, efs))), len(efs))
783
784# Tests taken from test_sets.py =============================================
785
786empty_set = set()
787
788#==============================================================================
789
790class TestBasicOps:
791
792    def test_repr(self):
793        if self.repr is not None:
794            self.assertEqual(repr(self.set), self.repr)
795
796    def check_repr_against_values(self):
797        text = repr(self.set)
798        self.assertTrue(text.startswith('{'))
799        self.assertTrue(text.endswith('}'))
800
801        result = text[1:-1].split(', ')
802        result.sort()
803        sorted_repr_values = [repr(value) for value in self.values]
804        sorted_repr_values.sort()
805        self.assertEqual(result, sorted_repr_values)
806
807    def test_print(self):
808        try:
809            fo = open(support.TESTFN, "w")
810            fo.write(str(self.set))
811            fo.close()
812            fo = open(support.TESTFN, "r")
813            self.assertEqual(fo.read(), repr(self.set))
814        finally:
815            fo.close()
816            support.unlink(support.TESTFN)
817
818    def test_length(self):
819        self.assertEqual(len(self.set), self.length)
820
821    def test_self_equality(self):
822        self.assertEqual(self.set, self.set)
823
824    def test_equivalent_equality(self):
825        self.assertEqual(self.set, self.dup)
826
827    def test_copy(self):
828        self.assertEqual(self.set.copy(), self.dup)
829
830    def test_self_union(self):
831        result = self.set | self.set
832        self.assertEqual(result, self.dup)
833
834    def test_empty_union(self):
835        result = self.set | empty_set
836        self.assertEqual(result, self.dup)
837
838    def test_union_empty(self):
839        result = empty_set | self.set
840        self.assertEqual(result, self.dup)
841
842    def test_self_intersection(self):
843        result = self.set & self.set
844        self.assertEqual(result, self.dup)
845
846    def test_empty_intersection(self):
847        result = self.set & empty_set
848        self.assertEqual(result, empty_set)
849
850    def test_intersection_empty(self):
851        result = empty_set & self.set
852        self.assertEqual(result, empty_set)
853
854    def test_self_isdisjoint(self):
855        result = self.set.isdisjoint(self.set)
856        self.assertEqual(result, not self.set)
857
858    def test_empty_isdisjoint(self):
859        result = self.set.isdisjoint(empty_set)
860        self.assertEqual(result, True)
861
862    def test_isdisjoint_empty(self):
863        result = empty_set.isdisjoint(self.set)
864        self.assertEqual(result, True)
865
866    def test_self_symmetric_difference(self):
867        result = self.set ^ self.set
868        self.assertEqual(result, empty_set)
869
870    def test_empty_symmetric_difference(self):
871        result = self.set ^ empty_set
872        self.assertEqual(result, self.set)
873
874    def test_self_difference(self):
875        result = self.set - self.set
876        self.assertEqual(result, empty_set)
877
878    def test_empty_difference(self):
879        result = self.set - empty_set
880        self.assertEqual(result, self.dup)
881
882    def test_empty_difference_rev(self):
883        result = empty_set - self.set
884        self.assertEqual(result, empty_set)
885
886    def test_iteration(self):
887        for v in self.set:
888            self.assertIn(v, self.values)
889        setiter = iter(self.set)
890        self.assertEqual(setiter.__length_hint__(), len(self.set))
891
892    def test_pickling(self):
893        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
894            p = pickle.dumps(self.set, proto)
895            copy = pickle.loads(p)
896            self.assertEqual(self.set, copy,
897                             "%s != %s" % (self.set, copy))
898
899#------------------------------------------------------------------------------
900
901class TestBasicOpsEmpty(TestBasicOps, unittest.TestCase):
902    def setUp(self):
903        self.case   = "empty set"
904        self.values = []
905        self.set    = set(self.values)
906        self.dup    = set(self.values)
907        self.length = 0
908        self.repr   = "set()"
909
910#------------------------------------------------------------------------------
911
912class TestBasicOpsSingleton(TestBasicOps, unittest.TestCase):
913    def setUp(self):
914        self.case   = "unit set (number)"
915        self.values = [3]
916        self.set    = set(self.values)
917        self.dup    = set(self.values)
918        self.length = 1
919        self.repr   = "{3}"
920
921    def test_in(self):
922        self.assertIn(3, self.set)
923
924    def test_not_in(self):
925        self.assertNotIn(2, self.set)
926
927#------------------------------------------------------------------------------
928
929class TestBasicOpsTuple(TestBasicOps, unittest.TestCase):
930    def setUp(self):
931        self.case   = "unit set (tuple)"
932        self.values = [(0, "zero")]
933        self.set    = set(self.values)
934        self.dup    = set(self.values)
935        self.length = 1
936        self.repr   = "{(0, 'zero')}"
937
938    def test_in(self):
939        self.assertIn((0, "zero"), self.set)
940
941    def test_not_in(self):
942        self.assertNotIn(9, self.set)
943
944#------------------------------------------------------------------------------
945
946class TestBasicOpsTriple(TestBasicOps, unittest.TestCase):
947    def setUp(self):
948        self.case   = "triple set"
949        self.values = [0, "zero", operator.add]
950        self.set    = set(self.values)
951        self.dup    = set(self.values)
952        self.length = 3
953        self.repr   = None
954
955#------------------------------------------------------------------------------
956
957class TestBasicOpsString(TestBasicOps, unittest.TestCase):
958    def setUp(self):
959        self.case   = "string set"
960        self.values = ["a", "b", "c"]
961        self.set    = set(self.values)
962        self.dup    = set(self.values)
963        self.length = 3
964
965    def test_repr(self):
966        self.check_repr_against_values()
967
968#------------------------------------------------------------------------------
969
970class TestBasicOpsBytes(TestBasicOps, unittest.TestCase):
971    def setUp(self):
972        self.case   = "bytes set"
973        self.values = [b"a", b"b", b"c"]
974        self.set    = set(self.values)
975        self.dup    = set(self.values)
976        self.length = 3
977
978    def test_repr(self):
979        self.check_repr_against_values()
980
981#------------------------------------------------------------------------------
982
983class TestBasicOpsMixedStringBytes(TestBasicOps, unittest.TestCase):
984    def setUp(self):
985        self._warning_filters = support.check_warnings()
986        self._warning_filters.__enter__()
987        warnings.simplefilter('ignore', BytesWarning)
988        self.case   = "string and bytes set"
989        self.values = ["a", "b", b"a", b"b"]
990        self.set    = set(self.values)
991        self.dup    = set(self.values)
992        self.length = 4
993
994    def tearDown(self):
995        self._warning_filters.__exit__(None, None, None)
996
997    def test_repr(self):
998        self.check_repr_against_values()
999
1000#==============================================================================
1001
1002def baditer():
1003    raise TypeError
1004    yield True
1005
1006def gooditer():
1007    yield True
1008
1009class TestExceptionPropagation(unittest.TestCase):
1010    """SF 628246:  Set constructor should not trap iterator TypeErrors"""
1011
1012    def test_instanceWithException(self):
1013        self.assertRaises(TypeError, set, baditer())
1014
1015    def test_instancesWithoutException(self):
1016        # All of these iterables should load without exception.
1017        set([1,2,3])
1018        set((1,2,3))
1019        set({'one':1, 'two':2, 'three':3})
1020        set(range(3))
1021        set('abc')
1022        set(gooditer())
1023
1024    def test_changingSizeWhileIterating(self):
1025        s = set([1,2,3])
1026        try:
1027            for i in s:
1028                s.update([4])
1029        except RuntimeError:
1030            pass
1031        else:
1032            self.fail("no exception when changing size during iteration")
1033
1034#==============================================================================
1035
1036class TestSetOfSets(unittest.TestCase):
1037    def test_constructor(self):
1038        inner = frozenset([1])
1039        outer = set([inner])
1040        element = outer.pop()
1041        self.assertEqual(type(element), frozenset)
1042        outer.add(inner)        # Rebuild set of sets with .add method
1043        outer.remove(inner)
1044        self.assertEqual(outer, set())   # Verify that remove worked
1045        outer.discard(inner)    # Absence of KeyError indicates working fine
1046
1047#==============================================================================
1048
1049class TestBinaryOps(unittest.TestCase):
1050    def setUp(self):
1051        self.set = set((2, 4, 6))
1052
1053    def test_eq(self):              # SF bug 643115
1054        self.assertEqual(self.set, set({2:1,4:3,6:5}))
1055
1056    def test_union_subset(self):
1057        result = self.set | set([2])
1058        self.assertEqual(result, set((2, 4, 6)))
1059
1060    def test_union_superset(self):
1061        result = self.set | set([2, 4, 6, 8])
1062        self.assertEqual(result, set([2, 4, 6, 8]))
1063
1064    def test_union_overlap(self):
1065        result = self.set | set([3, 4, 5])
1066        self.assertEqual(result, set([2, 3, 4, 5, 6]))
1067
1068    def test_union_non_overlap(self):
1069        result = self.set | set([8])
1070        self.assertEqual(result, set([2, 4, 6, 8]))
1071
1072    def test_intersection_subset(self):
1073        result = self.set & set((2, 4))
1074        self.assertEqual(result, set((2, 4)))
1075
1076    def test_intersection_superset(self):
1077        result = self.set & set([2, 4, 6, 8])
1078        self.assertEqual(result, set([2, 4, 6]))
1079
1080    def test_intersection_overlap(self):
1081        result = self.set & set([3, 4, 5])
1082        self.assertEqual(result, set([4]))
1083
1084    def test_intersection_non_overlap(self):
1085        result = self.set & set([8])
1086        self.assertEqual(result, empty_set)
1087
1088    def test_isdisjoint_subset(self):
1089        result = self.set.isdisjoint(set((2, 4)))
1090        self.assertEqual(result, False)
1091
1092    def test_isdisjoint_superset(self):
1093        result = self.set.isdisjoint(set([2, 4, 6, 8]))
1094        self.assertEqual(result, False)
1095
1096    def test_isdisjoint_overlap(self):
1097        result = self.set.isdisjoint(set([3, 4, 5]))
1098        self.assertEqual(result, False)
1099
1100    def test_isdisjoint_non_overlap(self):
1101        result = self.set.isdisjoint(set([8]))
1102        self.assertEqual(result, True)
1103
1104    def test_sym_difference_subset(self):
1105        result = self.set ^ set((2, 4))
1106        self.assertEqual(result, set([6]))
1107
1108    def test_sym_difference_superset(self):
1109        result = self.set ^ set((2, 4, 6, 8))
1110        self.assertEqual(result, set([8]))
1111
1112    def test_sym_difference_overlap(self):
1113        result = self.set ^ set((3, 4, 5))
1114        self.assertEqual(result, set([2, 3, 5, 6]))
1115
1116    def test_sym_difference_non_overlap(self):
1117        result = self.set ^ set([8])
1118        self.assertEqual(result, set([2, 4, 6, 8]))
1119
1120#==============================================================================
1121
1122class TestUpdateOps(unittest.TestCase):
1123    def setUp(self):
1124        self.set = set((2, 4, 6))
1125
1126    def test_union_subset(self):
1127        self.set |= set([2])
1128        self.assertEqual(self.set, set((2, 4, 6)))
1129
1130    def test_union_superset(self):
1131        self.set |= set([2, 4, 6, 8])
1132        self.assertEqual(self.set, set([2, 4, 6, 8]))
1133
1134    def test_union_overlap(self):
1135        self.set |= set([3, 4, 5])
1136        self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
1137
1138    def test_union_non_overlap(self):
1139        self.set |= set([8])
1140        self.assertEqual(self.set, set([2, 4, 6, 8]))
1141
1142    def test_union_method_call(self):
1143        self.set.update(set([3, 4, 5]))
1144        self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
1145
1146    def test_intersection_subset(self):
1147        self.set &= set((2, 4))
1148        self.assertEqual(self.set, set((2, 4)))
1149
1150    def test_intersection_superset(self):
1151        self.set &= set([2, 4, 6, 8])
1152        self.assertEqual(self.set, set([2, 4, 6]))
1153
1154    def test_intersection_overlap(self):
1155        self.set &= set([3, 4, 5])
1156        self.assertEqual(self.set, set([4]))
1157
1158    def test_intersection_non_overlap(self):
1159        self.set &= set([8])
1160        self.assertEqual(self.set, empty_set)
1161
1162    def test_intersection_method_call(self):
1163        self.set.intersection_update(set([3, 4, 5]))
1164        self.assertEqual(self.set, set([4]))
1165
1166    def test_sym_difference_subset(self):
1167        self.set ^= set((2, 4))
1168        self.assertEqual(self.set, set([6]))
1169
1170    def test_sym_difference_superset(self):
1171        self.set ^= set((2, 4, 6, 8))
1172        self.assertEqual(self.set, set([8]))
1173
1174    def test_sym_difference_overlap(self):
1175        self.set ^= set((3, 4, 5))
1176        self.assertEqual(self.set, set([2, 3, 5, 6]))
1177
1178    def test_sym_difference_non_overlap(self):
1179        self.set ^= set([8])
1180        self.assertEqual(self.set, set([2, 4, 6, 8]))
1181
1182    def test_sym_difference_method_call(self):
1183        self.set.symmetric_difference_update(set([3, 4, 5]))
1184        self.assertEqual(self.set, set([2, 3, 5, 6]))
1185
1186    def test_difference_subset(self):
1187        self.set -= set((2, 4))
1188        self.assertEqual(self.set, set([6]))
1189
1190    def test_difference_superset(self):
1191        self.set -= set((2, 4, 6, 8))
1192        self.assertEqual(self.set, set([]))
1193
1194    def test_difference_overlap(self):
1195        self.set -= set((3, 4, 5))
1196        self.assertEqual(self.set, set([2, 6]))
1197
1198    def test_difference_non_overlap(self):
1199        self.set -= set([8])
1200        self.assertEqual(self.set, set([2, 4, 6]))
1201
1202    def test_difference_method_call(self):
1203        self.set.difference_update(set([3, 4, 5]))
1204        self.assertEqual(self.set, set([2, 6]))
1205
1206#==============================================================================
1207
1208class TestMutate(unittest.TestCase):
1209    def setUp(self):
1210        self.values = ["a", "b", "c"]
1211        self.set = set(self.values)
1212
1213    def test_add_present(self):
1214        self.set.add("c")
1215        self.assertEqual(self.set, set("abc"))
1216
1217    def test_add_absent(self):
1218        self.set.add("d")
1219        self.assertEqual(self.set, set("abcd"))
1220
1221    def test_add_until_full(self):
1222        tmp = set()
1223        expected_len = 0
1224        for v in self.values:
1225            tmp.add(v)
1226            expected_len += 1
1227            self.assertEqual(len(tmp), expected_len)
1228        self.assertEqual(tmp, self.set)
1229
1230    def test_remove_present(self):
1231        self.set.remove("b")
1232        self.assertEqual(self.set, set("ac"))
1233
1234    def test_remove_absent(self):
1235        try:
1236            self.set.remove("d")
1237            self.fail("Removing missing element should have raised LookupError")
1238        except LookupError:
1239            pass
1240
1241    def test_remove_until_empty(self):
1242        expected_len = len(self.set)
1243        for v in self.values:
1244            self.set.remove(v)
1245            expected_len -= 1
1246            self.assertEqual(len(self.set), expected_len)
1247
1248    def test_discard_present(self):
1249        self.set.discard("c")
1250        self.assertEqual(self.set, set("ab"))
1251
1252    def test_discard_absent(self):
1253        self.set.discard("d")
1254        self.assertEqual(self.set, set("abc"))
1255
1256    def test_clear(self):
1257        self.set.clear()
1258        self.assertEqual(len(self.set), 0)
1259
1260    def test_pop(self):
1261        popped = {}
1262        while self.set:
1263            popped[self.set.pop()] = None
1264        self.assertEqual(len(popped), len(self.values))
1265        for v in self.values:
1266            self.assertIn(v, popped)
1267
1268    def test_update_empty_tuple(self):
1269        self.set.update(())
1270        self.assertEqual(self.set, set(self.values))
1271
1272    def test_update_unit_tuple_overlap(self):
1273        self.set.update(("a",))
1274        self.assertEqual(self.set, set(self.values))
1275
1276    def test_update_unit_tuple_non_overlap(self):
1277        self.set.update(("a", "z"))
1278        self.assertEqual(self.set, set(self.values + ["z"]))
1279
1280#==============================================================================
1281
1282class TestSubsets:
1283
1284    case2method = {"<=": "issubset",
1285                   ">=": "issuperset",
1286                  }
1287
1288    reverse = {"==": "==",
1289               "!=": "!=",
1290               "<":  ">",
1291               ">":  "<",
1292               "<=": ">=",
1293               ">=": "<=",
1294              }
1295
1296    def test_issubset(self):
1297        x = self.left
1298        y = self.right
1299        for case in "!=", "==", "<", "<=", ">", ">=":
1300            expected = case in self.cases
1301            # Test the binary infix spelling.
1302            result = eval("x" + case + "y", locals())
1303            self.assertEqual(result, expected)
1304            # Test the "friendly" method-name spelling, if one exists.
1305            if case in TestSubsets.case2method:
1306                method = getattr(x, TestSubsets.case2method[case])
1307                result = method(y)
1308                self.assertEqual(result, expected)
1309
1310            # Now do the same for the operands reversed.
1311            rcase = TestSubsets.reverse[case]
1312            result = eval("y" + rcase + "x", locals())
1313            self.assertEqual(result, expected)
1314            if rcase in TestSubsets.case2method:
1315                method = getattr(y, TestSubsets.case2method[rcase])
1316                result = method(x)
1317                self.assertEqual(result, expected)
1318#------------------------------------------------------------------------------
1319
1320class TestSubsetEqualEmpty(TestSubsets, unittest.TestCase):
1321    left  = set()
1322    right = set()
1323    name  = "both empty"
1324    cases = "==", "<=", ">="
1325
1326#------------------------------------------------------------------------------
1327
1328class TestSubsetEqualNonEmpty(TestSubsets, unittest.TestCase):
1329    left  = set([1, 2])
1330    right = set([1, 2])
1331    name  = "equal pair"
1332    cases = "==", "<=", ">="
1333
1334#------------------------------------------------------------------------------
1335
1336class TestSubsetEmptyNonEmpty(TestSubsets, unittest.TestCase):
1337    left  = set()
1338    right = set([1, 2])
1339    name  = "one empty, one non-empty"
1340    cases = "!=", "<", "<="
1341
1342#------------------------------------------------------------------------------
1343
1344class TestSubsetPartial(TestSubsets, unittest.TestCase):
1345    left  = set([1])
1346    right = set([1, 2])
1347    name  = "one a non-empty proper subset of other"
1348    cases = "!=", "<", "<="
1349
1350#------------------------------------------------------------------------------
1351
1352class TestSubsetNonOverlap(TestSubsets, unittest.TestCase):
1353    left  = set([1])
1354    right = set([2])
1355    name  = "neither empty, neither contains"
1356    cases = "!="
1357
1358#==============================================================================
1359
1360class TestOnlySetsInBinaryOps:
1361
1362    def test_eq_ne(self):
1363        # Unlike the others, this is testing that == and != *are* allowed.
1364        self.assertEqual(self.other == self.set, False)
1365        self.assertEqual(self.set == self.other, False)
1366        self.assertEqual(self.other != self.set, True)
1367        self.assertEqual(self.set != self.other, True)
1368
1369    def test_ge_gt_le_lt(self):
1370        self.assertRaises(TypeError, lambda: self.set < self.other)
1371        self.assertRaises(TypeError, lambda: self.set <= self.other)
1372        self.assertRaises(TypeError, lambda: self.set > self.other)
1373        self.assertRaises(TypeError, lambda: self.set >= self.other)
1374
1375        self.assertRaises(TypeError, lambda: self.other < self.set)
1376        self.assertRaises(TypeError, lambda: self.other <= self.set)
1377        self.assertRaises(TypeError, lambda: self.other > self.set)
1378        self.assertRaises(TypeError, lambda: self.other >= self.set)
1379
1380    def test_update_operator(self):
1381        try:
1382            self.set |= self.other
1383        except TypeError:
1384            pass
1385        else:
1386            self.fail("expected TypeError")
1387
1388    def test_update(self):
1389        if self.otherIsIterable:
1390            self.set.update(self.other)
1391        else:
1392            self.assertRaises(TypeError, self.set.update, self.other)
1393
1394    def test_union(self):
1395        self.assertRaises(TypeError, lambda: self.set | self.other)
1396        self.assertRaises(TypeError, lambda: self.other | self.set)
1397        if self.otherIsIterable:
1398            self.set.union(self.other)
1399        else:
1400            self.assertRaises(TypeError, self.set.union, self.other)
1401
1402    def test_intersection_update_operator(self):
1403        try:
1404            self.set &= self.other
1405        except TypeError:
1406            pass
1407        else:
1408            self.fail("expected TypeError")
1409
1410    def test_intersection_update(self):
1411        if self.otherIsIterable:
1412            self.set.intersection_update(self.other)
1413        else:
1414            self.assertRaises(TypeError,
1415                              self.set.intersection_update,
1416                              self.other)
1417
1418    def test_intersection(self):
1419        self.assertRaises(TypeError, lambda: self.set & self.other)
1420        self.assertRaises(TypeError, lambda: self.other & self.set)
1421        if self.otherIsIterable:
1422            self.set.intersection(self.other)
1423        else:
1424            self.assertRaises(TypeError, self.set.intersection, self.other)
1425
1426    def test_sym_difference_update_operator(self):
1427        try:
1428            self.set ^= self.other
1429        except TypeError:
1430            pass
1431        else:
1432            self.fail("expected TypeError")
1433
1434    def test_sym_difference_update(self):
1435        if self.otherIsIterable:
1436            self.set.symmetric_difference_update(self.other)
1437        else:
1438            self.assertRaises(TypeError,
1439                              self.set.symmetric_difference_update,
1440                              self.other)
1441
1442    def test_sym_difference(self):
1443        self.assertRaises(TypeError, lambda: self.set ^ self.other)
1444        self.assertRaises(TypeError, lambda: self.other ^ self.set)
1445        if self.otherIsIterable:
1446            self.set.symmetric_difference(self.other)
1447        else:
1448            self.assertRaises(TypeError, self.set.symmetric_difference, self.other)
1449
1450    def test_difference_update_operator(self):
1451        try:
1452            self.set -= self.other
1453        except TypeError:
1454            pass
1455        else:
1456            self.fail("expected TypeError")
1457
1458    def test_difference_update(self):
1459        if self.otherIsIterable:
1460            self.set.difference_update(self.other)
1461        else:
1462            self.assertRaises(TypeError,
1463                              self.set.difference_update,
1464                              self.other)
1465
1466    def test_difference(self):
1467        self.assertRaises(TypeError, lambda: self.set - self.other)
1468        self.assertRaises(TypeError, lambda: self.other - self.set)
1469        if self.otherIsIterable:
1470            self.set.difference(self.other)
1471        else:
1472            self.assertRaises(TypeError, self.set.difference, self.other)
1473
1474#------------------------------------------------------------------------------
1475
1476class TestOnlySetsNumeric(TestOnlySetsInBinaryOps, unittest.TestCase):
1477    def setUp(self):
1478        self.set   = set((1, 2, 3))
1479        self.other = 19
1480        self.otherIsIterable = False
1481
1482#------------------------------------------------------------------------------
1483
1484class TestOnlySetsDict(TestOnlySetsInBinaryOps, unittest.TestCase):
1485    def setUp(self):
1486        self.set   = set((1, 2, 3))
1487        self.other = {1:2, 3:4}
1488        self.otherIsIterable = True
1489
1490#------------------------------------------------------------------------------
1491
1492class TestOnlySetsOperator(TestOnlySetsInBinaryOps, unittest.TestCase):
1493    def setUp(self):
1494        self.set   = set((1, 2, 3))
1495        self.other = operator.add
1496        self.otherIsIterable = False
1497
1498#------------------------------------------------------------------------------
1499
1500class TestOnlySetsTuple(TestOnlySetsInBinaryOps, unittest.TestCase):
1501    def setUp(self):
1502        self.set   = set((1, 2, 3))
1503        self.other = (2, 4, 6)
1504        self.otherIsIterable = True
1505
1506#------------------------------------------------------------------------------
1507
1508class TestOnlySetsString(TestOnlySetsInBinaryOps, unittest.TestCase):
1509    def setUp(self):
1510        self.set   = set((1, 2, 3))
1511        self.other = 'abc'
1512        self.otherIsIterable = True
1513
1514#------------------------------------------------------------------------------
1515
1516class TestOnlySetsGenerator(TestOnlySetsInBinaryOps, unittest.TestCase):
1517    def setUp(self):
1518        def gen():
1519            for i in range(0, 10, 2):
1520                yield i
1521        self.set   = set((1, 2, 3))
1522        self.other = gen()
1523        self.otherIsIterable = True
1524
1525#==============================================================================
1526
1527class TestCopying:
1528
1529    def test_copy(self):
1530        dup = self.set.copy()
1531        dup_list = sorted(dup, key=repr)
1532        set_list = sorted(self.set, key=repr)
1533        self.assertEqual(len(dup_list), len(set_list))
1534        for i in range(len(dup_list)):
1535            self.assertTrue(dup_list[i] is set_list[i])
1536
1537    def test_deep_copy(self):
1538        dup = copy.deepcopy(self.set)
1539        ##print type(dup), repr(dup)
1540        dup_list = sorted(dup, key=repr)
1541        set_list = sorted(self.set, key=repr)
1542        self.assertEqual(len(dup_list), len(set_list))
1543        for i in range(len(dup_list)):
1544            self.assertEqual(dup_list[i], set_list[i])
1545
1546#------------------------------------------------------------------------------
1547
1548class TestCopyingEmpty(TestCopying, unittest.TestCase):
1549    def setUp(self):
1550        self.set = set()
1551
1552#------------------------------------------------------------------------------
1553
1554class TestCopyingSingleton(TestCopying, unittest.TestCase):
1555    def setUp(self):
1556        self.set = set(["hello"])
1557
1558#------------------------------------------------------------------------------
1559
1560class TestCopyingTriple(TestCopying, unittest.TestCase):
1561    def setUp(self):
1562        self.set = set(["zero", 0, None])
1563
1564#------------------------------------------------------------------------------
1565
1566class TestCopyingTuple(TestCopying, unittest.TestCase):
1567    def setUp(self):
1568        self.set = set([(1, 2)])
1569
1570#------------------------------------------------------------------------------
1571
1572class TestCopyingNested(TestCopying, unittest.TestCase):
1573    def setUp(self):
1574        self.set = set([((1, 2), (3, 4))])
1575
1576#==============================================================================
1577
1578class TestIdentities(unittest.TestCase):
1579    def setUp(self):
1580        self.a = set('abracadabra')
1581        self.b = set('alacazam')
1582
1583    def test_binopsVsSubsets(self):
1584        a, b = self.a, self.b
1585        self.assertTrue(a - b < a)
1586        self.assertTrue(b - a < b)
1587        self.assertTrue(a & b < a)
1588        self.assertTrue(a & b < b)
1589        self.assertTrue(a | b > a)
1590        self.assertTrue(a | b > b)
1591        self.assertTrue(a ^ b < a | b)
1592
1593    def test_commutativity(self):
1594        a, b = self.a, self.b
1595        self.assertEqual(a&b, b&a)
1596        self.assertEqual(a|b, b|a)
1597        self.assertEqual(a^b, b^a)
1598        if a != b:
1599            self.assertNotEqual(a-b, b-a)
1600
1601    def test_summations(self):
1602        # check that sums of parts equal the whole
1603        a, b = self.a, self.b
1604        self.assertEqual((a-b)|(a&b)|(b-a), a|b)
1605        self.assertEqual((a&b)|(a^b), a|b)
1606        self.assertEqual(a|(b-a), a|b)
1607        self.assertEqual((a-b)|b, a|b)
1608        self.assertEqual((a-b)|(a&b), a)
1609        self.assertEqual((b-a)|(a&b), b)
1610        self.assertEqual((a-b)|(b-a), a^b)
1611
1612    def test_exclusion(self):
1613        # check that inverse operations show non-overlap
1614        a, b, zero = self.a, self.b, set()
1615        self.assertEqual((a-b)&b, zero)
1616        self.assertEqual((b-a)&a, zero)
1617        self.assertEqual((a&b)&(a^b), zero)
1618
1619# Tests derived from test_itertools.py =======================================
1620
1621def R(seqn):
1622    'Regular generator'
1623    for i in seqn:
1624        yield i
1625
1626class G:
1627    'Sequence using __getitem__'
1628    def __init__(self, seqn):
1629        self.seqn = seqn
1630    def __getitem__(self, i):
1631        return self.seqn[i]
1632
1633class I:
1634    'Sequence using iterator protocol'
1635    def __init__(self, seqn):
1636        self.seqn = seqn
1637        self.i = 0
1638    def __iter__(self):
1639        return self
1640    def __next__(self):
1641        if self.i >= len(self.seqn): raise StopIteration
1642        v = self.seqn[self.i]
1643        self.i += 1
1644        return v
1645
1646class Ig:
1647    'Sequence using iterator protocol defined with a generator'
1648    def __init__(self, seqn):
1649        self.seqn = seqn
1650        self.i = 0
1651    def __iter__(self):
1652        for val in self.seqn:
1653            yield val
1654
1655class X:
1656    'Missing __getitem__ and __iter__'
1657    def __init__(self, seqn):
1658        self.seqn = seqn
1659        self.i = 0
1660    def __next__(self):
1661        if self.i >= len(self.seqn): raise StopIteration
1662        v = self.seqn[self.i]
1663        self.i += 1
1664        return v
1665
1666class N:
1667    'Iterator missing __next__()'
1668    def __init__(self, seqn):
1669        self.seqn = seqn
1670        self.i = 0
1671    def __iter__(self):
1672        return self
1673
1674class E:
1675    'Test propagation of exceptions'
1676    def __init__(self, seqn):
1677        self.seqn = seqn
1678        self.i = 0
1679    def __iter__(self):
1680        return self
1681    def __next__(self):
1682        3 // 0
1683
1684class S:
1685    'Test immediate stop'
1686    def __init__(self, seqn):
1687        pass
1688    def __iter__(self):
1689        return self
1690    def __next__(self):
1691        raise StopIteration
1692
1693from itertools import chain
1694def L(seqn):
1695    'Test multiple tiers of iterators'
1696    return chain(map(lambda x:x, R(Ig(G(seqn)))))
1697
1698class TestVariousIteratorArgs(unittest.TestCase):
1699
1700    def test_constructor(self):
1701        for cons in (set, frozenset):
1702            for s in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5)):
1703                for g in (G, I, Ig, S, L, R):
1704                    self.assertEqual(sorted(cons(g(s)), key=repr), sorted(g(s), key=repr))
1705                self.assertRaises(TypeError, cons , X(s))
1706                self.assertRaises(TypeError, cons , N(s))
1707                self.assertRaises(ZeroDivisionError, cons , E(s))
1708
1709    def test_inline_methods(self):
1710        s = set('november')
1711        for data in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5), 'december'):
1712            for meth in (s.union, s.intersection, s.difference, s.symmetric_difference, s.isdisjoint):
1713                for g in (G, I, Ig, L, R):
1714                    expected = meth(data)
1715                    actual = meth(g(data))
1716                    if isinstance(expected, bool):
1717                        self.assertEqual(actual, expected)
1718                    else:
1719                        self.assertEqual(sorted(actual, key=repr), sorted(expected, key=repr))
1720                self.assertRaises(TypeError, meth, X(s))
1721                self.assertRaises(TypeError, meth, N(s))
1722                self.assertRaises(ZeroDivisionError, meth, E(s))
1723
1724    def test_inplace_methods(self):
1725        for data in ("123", "", range(1000), ('do', 1.2), range(2000,2200,5), 'december'):
1726            for methname in ('update', 'intersection_update',
1727                             'difference_update', 'symmetric_difference_update'):
1728                for g in (G, I, Ig, S, L, R):
1729                    s = set('january')
1730                    t = s.copy()
1731                    getattr(s, methname)(list(g(data)))
1732                    getattr(t, methname)(g(data))
1733                    self.assertEqual(sorted(s, key=repr), sorted(t, key=repr))
1734
1735                self.assertRaises(TypeError, getattr(set('january'), methname), X(data))
1736                self.assertRaises(TypeError, getattr(set('january'), methname), N(data))
1737                self.assertRaises(ZeroDivisionError, getattr(set('january'), methname), E(data))
1738
1739class bad_eq:
1740    def __eq__(self, other):
1741        if be_bad:
1742            set2.clear()
1743            raise ZeroDivisionError
1744        return self is other
1745    def __hash__(self):
1746        return 0
1747
1748class bad_dict_clear:
1749    def __eq__(self, other):
1750        if be_bad:
1751            dict2.clear()
1752        return self is other
1753    def __hash__(self):
1754        return 0
1755
1756class TestWeirdBugs(unittest.TestCase):
1757    def test_8420_set_merge(self):
1758        # This used to segfault
1759        global be_bad, set2, dict2
1760        be_bad = False
1761        set1 = {bad_eq()}
1762        set2 = {bad_eq() for i in range(75)}
1763        be_bad = True
1764        self.assertRaises(ZeroDivisionError, set1.update, set2)
1765
1766        be_bad = False
1767        set1 = {bad_dict_clear()}
1768        dict2 = {bad_dict_clear(): None}
1769        be_bad = True
1770        set1.symmetric_difference_update(dict2)
1771
1772    def test_iter_and_mutate(self):
1773        # Issue #24581
1774        s = set(range(100))
1775        s.clear()
1776        s.update(range(100))
1777        si = iter(s)
1778        s.clear()
1779        a = list(range(100))
1780        s.update(range(100))
1781        list(si)
1782
1783    def test_merge_and_mutate(self):
1784        class X:
1785            def __hash__(self):
1786                return hash(0)
1787            def __eq__(self, o):
1788                other.clear()
1789                return False
1790
1791        other = set()
1792        other = {X() for i in range(10)}
1793        s = {0}
1794        s.update(other)
1795
1796# Application tests (based on David Eppstein's graph recipes ====================================
1797
1798def powerset(U):
1799    """Generates all subsets of a set or sequence U."""
1800    U = iter(U)
1801    try:
1802        x = frozenset([next(U)])
1803        for S in powerset(U):
1804            yield S
1805            yield S | x
1806    except StopIteration:
1807        yield frozenset()
1808
1809def cube(n):
1810    """Graph of n-dimensional hypercube."""
1811    singletons = [frozenset([x]) for x in range(n)]
1812    return dict([(x, frozenset([x^s for s in singletons]))
1813                 for x in powerset(range(n))])
1814
1815def linegraph(G):
1816    """Graph, the vertices of which are edges of G,
1817    with two vertices being adjacent iff the corresponding
1818    edges share a vertex."""
1819    L = {}
1820    for x in G:
1821        for y in G[x]:
1822            nx = [frozenset([x,z]) for z in G[x] if z != y]
1823            ny = [frozenset([y,z]) for z in G[y] if z != x]
1824            L[frozenset([x,y])] = frozenset(nx+ny)
1825    return L
1826
1827def faces(G):
1828    'Return a set of faces in G.  Where a face is a set of vertices on that face'
1829    # currently limited to triangles,squares, and pentagons
1830    f = set()
1831    for v1, edges in G.items():
1832        for v2 in edges:
1833            for v3 in G[v2]:
1834                if v1 == v3:
1835                    continue
1836                if v1 in G[v3]:
1837                    f.add(frozenset([v1, v2, v3]))
1838                else:
1839                    for v4 in G[v3]:
1840                        if v4 == v2:
1841                            continue
1842                        if v1 in G[v4]:
1843                            f.add(frozenset([v1, v2, v3, v4]))
1844                        else:
1845                            for v5 in G[v4]:
1846                                if v5 == v3 or v5 == v2:
1847                                    continue
1848                                if v1 in G[v5]:
1849                                    f.add(frozenset([v1, v2, v3, v4, v5]))
1850    return f
1851
1852
1853class TestGraphs(unittest.TestCase):
1854
1855    def test_cube(self):
1856
1857        g = cube(3)                             # vert --> {v1, v2, v3}
1858        vertices1 = set(g)
1859        self.assertEqual(len(vertices1), 8)     # eight vertices
1860        for edge in g.values():
1861            self.assertEqual(len(edge), 3)      # each vertex connects to three edges
1862        vertices2 = set(v for edges in g.values() for v in edges)
1863        self.assertEqual(vertices1, vertices2)  # edge vertices in original set
1864
1865        cubefaces = faces(g)
1866        self.assertEqual(len(cubefaces), 6)     # six faces
1867        for face in cubefaces:
1868            self.assertEqual(len(face), 4)      # each face is a square
1869
1870    def test_cuboctahedron(self):
1871
1872        # http://en.wikipedia.org/wiki/Cuboctahedron
1873        # 8 triangular faces and 6 square faces
1874        # 12 identical vertices each connecting a triangle and square
1875
1876        g = cube(3)
1877        cuboctahedron = linegraph(g)            # V( --> {V1, V2, V3, V4}
1878        self.assertEqual(len(cuboctahedron), 12)# twelve vertices
1879
1880        vertices = set(cuboctahedron)
1881        for edges in cuboctahedron.values():
1882            self.assertEqual(len(edges), 4)     # each vertex connects to four other vertices
1883        othervertices = set(edge for edges in cuboctahedron.values() for edge in edges)
1884        self.assertEqual(vertices, othervertices)   # edge vertices in original set
1885
1886        cubofaces = faces(cuboctahedron)
1887        facesizes = collections.defaultdict(int)
1888        for face in cubofaces:
1889            facesizes[len(face)] += 1
1890        self.assertEqual(facesizes[3], 8)       # eight triangular faces
1891        self.assertEqual(facesizes[4], 6)       # six square faces
1892
1893        for vertex in cuboctahedron:
1894            edge = vertex                       # Cuboctahedron vertices are edges in Cube
1895            self.assertEqual(len(edge), 2)      # Two cube vertices define an edge
1896            for cubevert in edge:
1897                self.assertIn(cubevert, g)
1898
1899
1900#==============================================================================
1901
1902if __name__ == "__main__":
1903    unittest.main()
1904