test_set.py revision 6e70accaffc61a5af7d78be1b365d1cab804751b
1import unittest
2from test import test_support
3import operator
4import copy
5import pickle
6
7class PassThru(Exception):
8    pass
9
10def check_pass_thru():
11    raise PassThru
12    yield 1
13
14class TestJointOps(unittest.TestCase):
15    # Tests common to both set and frozenset
16
17    def setUp(self):
18        self.word = word = 'simsalabim'
19        self.otherword = 'madagascar'
20        self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
21        self.s = self.thetype(word)
22        self.d = dict.fromkeys(word)
23
24    def test_uniquification(self):
25        actual = sorted(self.s)
26        expected = sorted(self.d)
27        self.assertEqual(actual, expected)
28        self.assertRaises(PassThru, self.thetype, check_pass_thru())
29        self.assertRaises(TypeError, self.thetype, [[]])
30
31    def test_len(self):
32        self.assertEqual(len(self.s), len(self.d))
33
34    def test_contains(self):
35        for c in self.letters:
36            self.assertEqual(c in self.s, c in self.d)
37        self.assertRaises(TypeError, self.s.__contains__, [[]])
38        s = self.thetype([frozenset(self.letters)])
39        self.assert_(self.thetype(self.letters) in s)
40
41    def test_union(self):
42        u = self.s.union(self.otherword)
43        for c in self.letters:
44            self.assertEqual(c in u, c in self.d or c in self.otherword)
45        self.assertEqual(self.s, self.thetype(self.word))
46        self.assertEqual(type(u), self.thetype)
47        self.assertRaises(PassThru, self.s.union, check_pass_thru())
48        self.assertRaises(TypeError, self.s.union, [[]])
49        for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
50            self.assertEqual(self.thetype('abcba').union(C('cdc')), set('abcd'))
51            self.assertEqual(self.thetype('abcba').union(C('efgfe')), set('abcefg'))
52            self.assertEqual(self.thetype('abcba').union(C('ccb')), set('abc'))
53            self.assertEqual(self.thetype('abcba').union(C('ef')), set('abcef'))
54
55    def test_or(self):
56        i = self.s.union(self.otherword)
57        self.assertEqual(self.s | set(self.otherword), i)
58        self.assertEqual(self.s | frozenset(self.otherword), i)
59        try:
60            self.s | self.otherword
61        except TypeError:
62            pass
63        else:
64            self.fail("s|t did not screen-out general iterables")
65
66    def test_intersection(self):
67        i = self.s.intersection(self.otherword)
68        for c in self.letters:
69            self.assertEqual(c in i, c in self.d and c in self.otherword)
70        self.assertEqual(self.s, self.thetype(self.word))
71        self.assertEqual(type(i), self.thetype)
72        self.assertRaises(PassThru, self.s.intersection, check_pass_thru())
73        for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
74            self.assertEqual(self.thetype('abcba').intersection(C('cdc')), set('cc'))
75            self.assertEqual(self.thetype('abcba').intersection(C('efgfe')), set(''))
76            self.assertEqual(self.thetype('abcba').intersection(C('ccb')), set('bc'))
77            self.assertEqual(self.thetype('abcba').intersection(C('ef')), set(''))
78
79    def test_and(self):
80        i = self.s.intersection(self.otherword)
81        self.assertEqual(self.s & set(self.otherword), i)
82        self.assertEqual(self.s & frozenset(self.otherword), i)
83        try:
84            self.s & self.otherword
85        except TypeError:
86            pass
87        else:
88            self.fail("s&t did not screen-out general iterables")
89
90    def test_difference(self):
91        i = self.s.difference(self.otherword)
92        for c in self.letters:
93            self.assertEqual(c in i, c in self.d and c not in self.otherword)
94        self.assertEqual(self.s, self.thetype(self.word))
95        self.assertEqual(type(i), self.thetype)
96        self.assertRaises(PassThru, self.s.difference, check_pass_thru())
97        self.assertRaises(TypeError, self.s.difference, [[]])
98        for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
99            self.assertEqual(self.thetype('abcba').difference(C('cdc')), set('ab'))
100            self.assertEqual(self.thetype('abcba').difference(C('efgfe')), set('abc'))
101            self.assertEqual(self.thetype('abcba').difference(C('ccb')), set('a'))
102            self.assertEqual(self.thetype('abcba').difference(C('ef')), set('abc'))
103
104    def test_sub(self):
105        i = self.s.difference(self.otherword)
106        self.assertEqual(self.s - set(self.otherword), i)
107        self.assertEqual(self.s - frozenset(self.otherword), i)
108        try:
109            self.s - self.otherword
110        except TypeError:
111            pass
112        else:
113            self.fail("s-t did not screen-out general iterables")
114
115    def test_symmetric_difference(self):
116        i = self.s.symmetric_difference(self.otherword)
117        for c in self.letters:
118            self.assertEqual(c in i, (c in self.d) ^ (c in self.otherword))
119        self.assertEqual(self.s, self.thetype(self.word))
120        self.assertEqual(type(i), self.thetype)
121        self.assertRaises(PassThru, self.s.symmetric_difference, check_pass_thru())
122        self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
123        for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
124            self.assertEqual(self.thetype('abcba').symmetric_difference(C('cdc')), set('abd'))
125            self.assertEqual(self.thetype('abcba').symmetric_difference(C('efgfe')), set('abcefg'))
126            self.assertEqual(self.thetype('abcba').symmetric_difference(C('ccb')), set('a'))
127            self.assertEqual(self.thetype('abcba').symmetric_difference(C('ef')), set('abcef'))
128
129    def test_xor(self):
130        i = self.s.symmetric_difference(self.otherword)
131        self.assertEqual(self.s ^ set(self.otherword), i)
132        self.assertEqual(self.s ^ frozenset(self.otherword), i)
133        try:
134            self.s ^ self.otherword
135        except TypeError:
136            pass
137        else:
138            self.fail("s^t did not screen-out general iterables")
139
140    def test_equality(self):
141        self.assertEqual(self.s, set(self.word))
142        self.assertEqual(self.s, frozenset(self.word))
143        self.assertEqual(self.s == self.word, False)
144        self.assertNotEqual(self.s, set(self.otherword))
145        self.assertNotEqual(self.s, frozenset(self.otherword))
146        self.assertEqual(self.s != self.word, True)
147
148    def test_setOfFrozensets(self):
149        t = map(frozenset, ['abcdef', 'bcd', 'bdcb', 'fed', 'fedccba'])
150        s = self.thetype(t)
151        self.assertEqual(len(s), 3)
152
153    def test_compare(self):
154        self.assertRaises(TypeError, self.s.__cmp__, self.s)
155
156    def test_sub_and_super(self):
157        p, q, r = map(self.thetype, ['ab', 'abcde', 'def'])
158        self.assert_(p < q)
159        self.assert_(p <= q)
160        self.assert_(q <= q)
161        self.assert_(q > p)
162        self.assert_(q >= p)
163        self.failIf(q < r)
164        self.failIf(q <= r)
165        self.failIf(q > r)
166        self.failIf(q >= r)
167        self.assert_(set('a').issubset('abc'))
168        self.assert_(set('abc').issuperset('a'))
169        self.failIf(set('a').issubset('cbs'))
170        self.failIf(set('cbs').issuperset('a'))
171
172    def test_pickling(self):
173        p = pickle.dumps(self.s)
174        dup = pickle.loads(p)
175        self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup))
176
177    def test_deepcopy(self):
178        class Tracer:
179            def __init__(self, value):
180                self.value = value
181            def __hash__(self):
182                    return self.value
183            def __deepcopy__(self, memo=None):
184                return Tracer(self.value + 1)
185        t = Tracer(10)
186        s = self.thetype([t])
187        dup = copy.deepcopy(s)
188        self.assertNotEqual(id(s), id(dup))
189        for elem in dup:
190            newt = elem
191        self.assertNotEqual(id(t), id(newt))
192        self.assertEqual(t.value + 1, newt.value)
193
194class TestSet(TestJointOps):
195    thetype = set
196
197    def test_init(self):
198        s = self.thetype()
199        s.__init__(self.word)
200        self.assertEqual(s, set(self.word))
201        s.__init__(self.otherword)
202        self.assertEqual(s, set(self.otherword))
203
204    def test_constructor_identity(self):
205        s = self.thetype(range(3))
206        t = self.thetype(s)
207        self.assertNotEqual(id(s), id(t))
208
209    def test_hash(self):
210        self.assertRaises(TypeError, hash, self.s)
211
212    def test_clear(self):
213        self.s.clear()
214        self.assertEqual(self.s, set())
215        self.assertEqual(len(self.s), 0)
216
217    def test_copy(self):
218        dup = self.s.copy()
219        self.assertEqual(self.s, dup)
220        self.assertNotEqual(id(self.s), id(dup))
221
222    def test_add(self):
223        self.s.add('Q')
224        self.assert_('Q' in self.s)
225        dup = self.s.copy()
226        self.s.add('Q')
227        self.assertEqual(self.s, dup)
228        self.assertRaises(TypeError, self.s.add, [])
229
230    def test_remove(self):
231        self.s.remove('a')
232        self.assert_('a' not in self.s)
233        self.assertRaises(KeyError, self.s.remove, 'Q')
234        self.assertRaises(TypeError, self.s.remove, [])
235        s = self.thetype([frozenset(self.word)])
236        self.assert_(self.thetype(self.word) in s)
237        s.remove(self.thetype(self.word))
238        self.assert_(self.thetype(self.word) not in s)
239        self.assertRaises(KeyError, self.s.remove, self.thetype(self.word))
240
241    def test_discard(self):
242        self.s.discard('a')
243        self.assert_('a' not in self.s)
244        self.s.discard('Q')
245        self.assertRaises(TypeError, self.s.discard, [])
246        s = self.thetype([frozenset(self.word)])
247        self.assert_(self.thetype(self.word) in s)
248        s.discard(self.thetype(self.word))
249        self.assert_(self.thetype(self.word) not in s)
250        s.discard(self.thetype(self.word))
251
252    def test_pop(self):
253        for i in xrange(len(self.s)):
254            elem = self.s.pop()
255            self.assert_(elem not in self.s)
256        self.assertRaises(KeyError, self.s.pop)
257
258    def test_update(self):
259        retval = self.s.update(self.otherword)
260        self.assertEqual(retval, None)
261        for c in (self.word + self.otherword):
262            self.assert_(c in self.s)
263        self.assertRaises(PassThru, self.s.update, check_pass_thru())
264        self.assertRaises(TypeError, self.s.update, [[]])
265        for p, q in (('cdc', 'abcd'), ('efgfe', 'abcefg'), ('ccb', 'abc'), ('ef', 'abcef')):
266            for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
267                s = self.thetype('abcba')
268                self.assertEqual(s.update(C(p)), None)
269                self.assertEqual(s, set(q))
270
271    def test_ior(self):
272        self.s |= set(self.otherword)
273        for c in (self.word + self.otherword):
274            self.assert_(c in self.s)
275
276    def test_intersection_update(self):
277        retval = self.s.intersection_update(self.otherword)
278        self.assertEqual(retval, None)
279        for c in (self.word + self.otherword):
280            if c in self.otherword and c in self.word:
281                self.assert_(c in self.s)
282            else:
283                self.assert_(c not in self.s)
284        self.assertRaises(PassThru, self.s.intersection_update, check_pass_thru())
285        self.assertRaises(TypeError, self.s.intersection_update, [[]])
286        for p, q in (('cdc', 'c'), ('efgfe', ''), ('ccb', 'bc'), ('ef', '')):
287            for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
288                s = self.thetype('abcba')
289                self.assertEqual(s.intersection_update(C(p)), None)
290                self.assertEqual(s, set(q))
291
292    def test_iand(self):
293        self.s &= set(self.otherword)
294        for c in (self.word + self.otherword):
295            if c in self.otherword and c in self.word:
296                self.assert_(c in self.s)
297            else:
298                self.assert_(c not in self.s)
299
300    def test_difference_update(self):
301        retval = self.s.difference_update(self.otherword)
302        self.assertEqual(retval, None)
303        for c in (self.word + self.otherword):
304            if c in self.word and c not in self.otherword:
305                self.assert_(c in self.s)
306            else:
307                self.assert_(c not in self.s)
308        self.assertRaises(PassThru, self.s.difference_update, check_pass_thru())
309        self.assertRaises(TypeError, self.s.difference_update, [[]])
310        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
311        for p, q in (('cdc', 'ab'), ('efgfe', 'abc'), ('ccb', 'a'), ('ef', 'abc')):
312            for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
313                s = self.thetype('abcba')
314                self.assertEqual(s.difference_update(C(p)), None)
315                self.assertEqual(s, set(q))
316
317    def test_isub(self):
318        self.s -= set(self.otherword)
319        for c in (self.word + self.otherword):
320            if c in self.word and c not in self.otherword:
321                self.assert_(c in self.s)
322            else:
323                self.assert_(c not in self.s)
324
325    def test_symmetric_difference_update(self):
326        retval = self.s.symmetric_difference_update(self.otherword)
327        self.assertEqual(retval, None)
328        for c in (self.word + self.otherword):
329            if (c in self.word) ^ (c in self.otherword):
330                self.assert_(c in self.s)
331            else:
332                self.assert_(c not in self.s)
333        self.assertRaises(PassThru, self.s.symmetric_difference_update, check_pass_thru())
334        self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
335        for p, q in (('cdc', 'abd'), ('efgfe', 'abcefg'), ('ccb', 'a'), ('ef', 'abcef')):
336            for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
337                s = self.thetype('abcba')
338                self.assertEqual(s.symmetric_difference_update(C(p)), None)
339                self.assertEqual(s, set(q))
340
341    def test_ixor(self):
342        self.s ^= set(self.otherword)
343        for c in (self.word + self.otherword):
344            if (c in self.word) ^ (c in self.otherword):
345                self.assert_(c in self.s)
346            else:
347                self.assert_(c not in self.s)
348
349class SetSubclass(set):
350    pass
351
352class TestSetSubclass(TestSet):
353    thetype = SetSubclass
354
355class TestFrozenSet(TestJointOps):
356    thetype = frozenset
357
358    def test_init(self):
359        s = self.thetype(self.word)
360        s.__init__(self.otherword)
361        self.assertEqual(s, set(self.word))
362
363    def test_constructor_identity(self):
364        s = self.thetype(range(3))
365        t = self.thetype(s)
366        self.assertEqual(id(s), id(t))
367
368    def test_hash(self):
369        self.assertEqual(hash(self.thetype('abcdeb')),
370                         hash(self.thetype('ebecda')))
371
372    def test_copy(self):
373        dup = self.s.copy()
374        self.assertEqual(id(self.s), id(dup))
375
376    def test_frozen_as_dictkey(self):
377        seq = range(10) + list('abcdefg') + ['apple']
378        key1 = self.thetype(seq)
379        key2 = self.thetype(reversed(seq))
380        self.assertEqual(key1, key2)
381        self.assertNotEqual(id(key1), id(key2))
382        d = {}
383        d[key1] = 42
384        self.assertEqual(d[key2], 42)
385
386    def test_hash_caching(self):
387        f = self.thetype('abcdcda')
388        self.assertEqual(hash(f), hash(f))
389
390    def test_hash_effectiveness(self):
391        n = 13
392        hashvalues = set()
393        addhashvalue = hashvalues.add
394        elemmasks = [(i+1, 1<<i) for i in range(n)]
395        for i in xrange(2**n):
396            addhashvalue(hash(frozenset([e for e, m in elemmasks if m&i])))
397        self.assertEqual(len(hashvalues), 2**n)
398
399class FrozenSetSubclass(frozenset):
400    pass
401
402class TestFrozenSetSubclass(TestFrozenSet):
403    thetype = FrozenSetSubclass
404
405    def test_constructor_identity(self):
406        s = self.thetype(range(3))
407        t = self.thetype(s)
408        self.assertNotEqual(id(s), id(t))
409
410    def test_copy(self):
411        dup = self.s.copy()
412        self.assertNotEqual(id(self.s), id(dup))
413
414    def test_nested_empty_constructor(self):
415        s = self.thetype()
416        t = self.thetype(s)
417        self.assertEqual(s, t)
418
419# Tests taken from test_sets.py =============================================
420
421empty_set = set()
422
423#==============================================================================
424
425class TestBasicOps(unittest.TestCase):
426
427    def test_repr(self):
428        if self.repr is not None:
429            self.assertEqual(`self.set`, self.repr)
430
431    def test_length(self):
432        self.assertEqual(len(self.set), self.length)
433
434    def test_self_equality(self):
435        self.assertEqual(self.set, self.set)
436
437    def test_equivalent_equality(self):
438        self.assertEqual(self.set, self.dup)
439
440    def test_copy(self):
441        self.assertEqual(self.set.copy(), self.dup)
442
443    def test_self_union(self):
444        result = self.set | self.set
445        self.assertEqual(result, self.dup)
446
447    def test_empty_union(self):
448        result = self.set | empty_set
449        self.assertEqual(result, self.dup)
450
451    def test_union_empty(self):
452        result = empty_set | self.set
453        self.assertEqual(result, self.dup)
454
455    def test_self_intersection(self):
456        result = self.set & self.set
457        self.assertEqual(result, self.dup)
458
459    def test_empty_intersection(self):
460        result = self.set & empty_set
461        self.assertEqual(result, empty_set)
462
463    def test_intersection_empty(self):
464        result = empty_set & self.set
465        self.assertEqual(result, empty_set)
466
467    def test_self_symmetric_difference(self):
468        result = self.set ^ self.set
469        self.assertEqual(result, empty_set)
470
471    def checkempty_symmetric_difference(self):
472        result = self.set ^ empty_set
473        self.assertEqual(result, self.set)
474
475    def test_self_difference(self):
476        result = self.set - self.set
477        self.assertEqual(result, empty_set)
478
479    def test_empty_difference(self):
480        result = self.set - empty_set
481        self.assertEqual(result, self.dup)
482
483    def test_empty_difference_rev(self):
484        result = empty_set - self.set
485        self.assertEqual(result, empty_set)
486
487    def test_iteration(self):
488        for v in self.set:
489            self.assert_(v in self.values)
490
491    def test_pickling(self):
492        p = pickle.dumps(self.set)
493        copy = pickle.loads(p)
494        self.assertEqual(self.set, copy,
495                         "%s != %s" % (self.set, copy))
496
497#------------------------------------------------------------------------------
498
499class TestBasicOpsEmpty(TestBasicOps):
500    def setUp(self):
501        self.case   = "empty set"
502        self.values = []
503        self.set    = set(self.values)
504        self.dup    = set(self.values)
505        self.length = 0
506        self.repr   = "set([])"
507
508#------------------------------------------------------------------------------
509
510class TestBasicOpsSingleton(TestBasicOps):
511    def setUp(self):
512        self.case   = "unit set (number)"
513        self.values = [3]
514        self.set    = set(self.values)
515        self.dup    = set(self.values)
516        self.length = 1
517        self.repr   = "set([3])"
518
519    def test_in(self):
520        self.failUnless(3 in self.set)
521
522    def test_not_in(self):
523        self.failUnless(2 not in self.set)
524
525#------------------------------------------------------------------------------
526
527class TestBasicOpsTuple(TestBasicOps):
528    def setUp(self):
529        self.case   = "unit set (tuple)"
530        self.values = [(0, "zero")]
531        self.set    = set(self.values)
532        self.dup    = set(self.values)
533        self.length = 1
534        self.repr   = "set([(0, 'zero')])"
535
536    def test_in(self):
537        self.failUnless((0, "zero") in self.set)
538
539    def test_not_in(self):
540        self.failUnless(9 not in self.set)
541
542#------------------------------------------------------------------------------
543
544class TestBasicOpsTriple(TestBasicOps):
545    def setUp(self):
546        self.case   = "triple set"
547        self.values = [0, "zero", operator.add]
548        self.set    = set(self.values)
549        self.dup    = set(self.values)
550        self.length = 3
551        self.repr   = None
552
553#==============================================================================
554
555def baditer():
556    raise TypeError
557    yield True
558
559def gooditer():
560    yield True
561
562class TestExceptionPropagation(unittest.TestCase):
563    """SF 628246:  Set constructor should not trap iterator TypeErrors"""
564
565    def test_instanceWithException(self):
566        self.assertRaises(TypeError, set, baditer())
567
568    def test_instancesWithoutException(self):
569        # All of these iterables should load without exception.
570        set([1,2,3])
571        set((1,2,3))
572        set({'one':1, 'two':2, 'three':3})
573        set(xrange(3))
574        set('abc')
575        set(gooditer())
576
577#==============================================================================
578
579class TestSetOfSets(unittest.TestCase):
580    def test_constructor(self):
581        inner = frozenset([1])
582        outer = set([inner])
583        element = outer.pop()
584        self.assertEqual(type(element), frozenset)
585        outer.add(inner)        # Rebuild set of sets with .add method
586        outer.remove(inner)
587        self.assertEqual(outer, set())   # Verify that remove worked
588        outer.discard(inner)    # Absence of KeyError indicates working fine
589
590#==============================================================================
591
592class TestBinaryOps(unittest.TestCase):
593    def setUp(self):
594        self.set = set((2, 4, 6))
595
596    def test_eq(self):              # SF bug 643115
597        self.assertEqual(self.set, set({2:1,4:3,6:5}))
598
599    def test_union_subset(self):
600        result = self.set | set([2])
601        self.assertEqual(result, set((2, 4, 6)))
602
603    def test_union_superset(self):
604        result = self.set | set([2, 4, 6, 8])
605        self.assertEqual(result, set([2, 4, 6, 8]))
606
607    def test_union_overlap(self):
608        result = self.set | set([3, 4, 5])
609        self.assertEqual(result, set([2, 3, 4, 5, 6]))
610
611    def test_union_non_overlap(self):
612        result = self.set | set([8])
613        self.assertEqual(result, set([2, 4, 6, 8]))
614
615    def test_intersection_subset(self):
616        result = self.set & set((2, 4))
617        self.assertEqual(result, set((2, 4)))
618
619    def test_intersection_superset(self):
620        result = self.set & set([2, 4, 6, 8])
621        self.assertEqual(result, set([2, 4, 6]))
622
623    def test_intersection_overlap(self):
624        result = self.set & set([3, 4, 5])
625        self.assertEqual(result, set([4]))
626
627    def test_intersection_non_overlap(self):
628        result = self.set & set([8])
629        self.assertEqual(result, empty_set)
630
631    def test_sym_difference_subset(self):
632        result = self.set ^ set((2, 4))
633        self.assertEqual(result, set([6]))
634
635    def test_sym_difference_superset(self):
636        result = self.set ^ set((2, 4, 6, 8))
637        self.assertEqual(result, set([8]))
638
639    def test_sym_difference_overlap(self):
640        result = self.set ^ set((3, 4, 5))
641        self.assertEqual(result, set([2, 3, 5, 6]))
642
643    def test_sym_difference_non_overlap(self):
644        result = self.set ^ set([8])
645        self.assertEqual(result, set([2, 4, 6, 8]))
646
647    def test_cmp(self):
648        a, b = set('a'), set('b')
649        self.assertRaises(TypeError, cmp, a, b)
650
651        # You can view this as a buglet:  cmp(a, a) does not raise TypeError,
652        # because __eq__ is tried before __cmp__, and a.__eq__(a) returns True,
653        # which Python thinks is good enough to synthesize a cmp() result
654        # without calling __cmp__.
655        self.assertEqual(cmp(a, a), 0)
656
657        self.assertRaises(TypeError, cmp, a, 12)
658        self.assertRaises(TypeError, cmp, "abc", a)
659
660#==============================================================================
661
662class TestUpdateOps(unittest.TestCase):
663    def setUp(self):
664        self.set = set((2, 4, 6))
665
666    def test_union_subset(self):
667        self.set |= set([2])
668        self.assertEqual(self.set, set((2, 4, 6)))
669
670    def test_union_superset(self):
671        self.set |= set([2, 4, 6, 8])
672        self.assertEqual(self.set, set([2, 4, 6, 8]))
673
674    def test_union_overlap(self):
675        self.set |= set([3, 4, 5])
676        self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
677
678    def test_union_non_overlap(self):
679        self.set |= set([8])
680        self.assertEqual(self.set, set([2, 4, 6, 8]))
681
682    def test_union_method_call(self):
683        self.set.update(set([3, 4, 5]))
684        self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
685
686    def test_intersection_subset(self):
687        self.set &= set((2, 4))
688        self.assertEqual(self.set, set((2, 4)))
689
690    def test_intersection_superset(self):
691        self.set &= set([2, 4, 6, 8])
692        self.assertEqual(self.set, set([2, 4, 6]))
693
694    def test_intersection_overlap(self):
695        self.set &= set([3, 4, 5])
696        self.assertEqual(self.set, set([4]))
697
698    def test_intersection_non_overlap(self):
699        self.set &= set([8])
700        self.assertEqual(self.set, empty_set)
701
702    def test_intersection_method_call(self):
703        self.set.intersection_update(set([3, 4, 5]))
704        self.assertEqual(self.set, set([4]))
705
706    def test_sym_difference_subset(self):
707        self.set ^= set((2, 4))
708        self.assertEqual(self.set, set([6]))
709
710    def test_sym_difference_superset(self):
711        self.set ^= set((2, 4, 6, 8))
712        self.assertEqual(self.set, set([8]))
713
714    def test_sym_difference_overlap(self):
715        self.set ^= set((3, 4, 5))
716        self.assertEqual(self.set, set([2, 3, 5, 6]))
717
718    def test_sym_difference_non_overlap(self):
719        self.set ^= set([8])
720        self.assertEqual(self.set, set([2, 4, 6, 8]))
721
722    def test_sym_difference_method_call(self):
723        self.set.symmetric_difference_update(set([3, 4, 5]))
724        self.assertEqual(self.set, set([2, 3, 5, 6]))
725
726    def test_difference_subset(self):
727        self.set -= set((2, 4))
728        self.assertEqual(self.set, set([6]))
729
730    def test_difference_superset(self):
731        self.set -= set((2, 4, 6, 8))
732        self.assertEqual(self.set, set([]))
733
734    def test_difference_overlap(self):
735        self.set -= set((3, 4, 5))
736        self.assertEqual(self.set, set([2, 6]))
737
738    def test_difference_non_overlap(self):
739        self.set -= set([8])
740        self.assertEqual(self.set, set([2, 4, 6]))
741
742    def test_difference_method_call(self):
743        self.set.difference_update(set([3, 4, 5]))
744        self.assertEqual(self.set, set([2, 6]))
745
746#==============================================================================
747
748class TestMutate(unittest.TestCase):
749    def setUp(self):
750        self.values = ["a", "b", "c"]
751        self.set = set(self.values)
752
753    def test_add_present(self):
754        self.set.add("c")
755        self.assertEqual(self.set, set("abc"))
756
757    def test_add_absent(self):
758        self.set.add("d")
759        self.assertEqual(self.set, set("abcd"))
760
761    def test_add_until_full(self):
762        tmp = set()
763        expected_len = 0
764        for v in self.values:
765            tmp.add(v)
766            expected_len += 1
767            self.assertEqual(len(tmp), expected_len)
768        self.assertEqual(tmp, self.set)
769
770    def test_remove_present(self):
771        self.set.remove("b")
772        self.assertEqual(self.set, set("ac"))
773
774    def test_remove_absent(self):
775        try:
776            self.set.remove("d")
777            self.fail("Removing missing element should have raised LookupError")
778        except LookupError:
779            pass
780
781    def test_remove_until_empty(self):
782        expected_len = len(self.set)
783        for v in self.values:
784            self.set.remove(v)
785            expected_len -= 1
786            self.assertEqual(len(self.set), expected_len)
787
788    def test_discard_present(self):
789        self.set.discard("c")
790        self.assertEqual(self.set, set("ab"))
791
792    def test_discard_absent(self):
793        self.set.discard("d")
794        self.assertEqual(self.set, set("abc"))
795
796    def test_clear(self):
797        self.set.clear()
798        self.assertEqual(len(self.set), 0)
799
800    def test_pop(self):
801        popped = {}
802        while self.set:
803            popped[self.set.pop()] = None
804        self.assertEqual(len(popped), len(self.values))
805        for v in self.values:
806            self.failUnless(v in popped)
807
808    def test_update_empty_tuple(self):
809        self.set.update(())
810        self.assertEqual(self.set, set(self.values))
811
812    def test_update_unit_tuple_overlap(self):
813        self.set.update(("a",))
814        self.assertEqual(self.set, set(self.values))
815
816    def test_update_unit_tuple_non_overlap(self):
817        self.set.update(("a", "z"))
818        self.assertEqual(self.set, set(self.values + ["z"]))
819
820#==============================================================================
821
822class TestSubsets(unittest.TestCase):
823
824    case2method = {"<=": "issubset",
825                   ">=": "issuperset",
826                  }
827
828    reverse = {"==": "==",
829               "!=": "!=",
830               "<":  ">",
831               ">":  "<",
832               "<=": ">=",
833               ">=": "<=",
834              }
835
836    def test_issubset(self):
837        x = self.left
838        y = self.right
839        for case in "!=", "==", "<", "<=", ">", ">=":
840            expected = case in self.cases
841            # Test the binary infix spelling.
842            result = eval("x" + case + "y", locals())
843            self.assertEqual(result, expected)
844            # Test the "friendly" method-name spelling, if one exists.
845            if case in TestSubsets.case2method:
846                method = getattr(x, TestSubsets.case2method[case])
847                result = method(y)
848                self.assertEqual(result, expected)
849
850            # Now do the same for the operands reversed.
851            rcase = TestSubsets.reverse[case]
852            result = eval("y" + rcase + "x", locals())
853            self.assertEqual(result, expected)
854            if rcase in TestSubsets.case2method:
855                method = getattr(y, TestSubsets.case2method[rcase])
856                result = method(x)
857                self.assertEqual(result, expected)
858#------------------------------------------------------------------------------
859
860class TestSubsetEqualEmpty(TestSubsets):
861    left  = set()
862    right = set()
863    name  = "both empty"
864    cases = "==", "<=", ">="
865
866#------------------------------------------------------------------------------
867
868class TestSubsetEqualNonEmpty(TestSubsets):
869    left  = set([1, 2])
870    right = set([1, 2])
871    name  = "equal pair"
872    cases = "==", "<=", ">="
873
874#------------------------------------------------------------------------------
875
876class TestSubsetEmptyNonEmpty(TestSubsets):
877    left  = set()
878    right = set([1, 2])
879    name  = "one empty, one non-empty"
880    cases = "!=", "<", "<="
881
882#------------------------------------------------------------------------------
883
884class TestSubsetPartial(TestSubsets):
885    left  = set([1])
886    right = set([1, 2])
887    name  = "one a non-empty proper subset of other"
888    cases = "!=", "<", "<="
889
890#------------------------------------------------------------------------------
891
892class TestSubsetNonOverlap(TestSubsets):
893    left  = set([1])
894    right = set([2])
895    name  = "neither empty, neither contains"
896    cases = "!="
897
898#==============================================================================
899
900class TestOnlySetsInBinaryOps(unittest.TestCase):
901
902    def test_eq_ne(self):
903        # Unlike the others, this is testing that == and != *are* allowed.
904        self.assertEqual(self.other == self.set, False)
905        self.assertEqual(self.set == self.other, False)
906        self.assertEqual(self.other != self.set, True)
907        self.assertEqual(self.set != self.other, True)
908
909    def test_ge_gt_le_lt(self):
910        self.assertRaises(TypeError, lambda: self.set < self.other)
911        self.assertRaises(TypeError, lambda: self.set <= self.other)
912        self.assertRaises(TypeError, lambda: self.set > self.other)
913        self.assertRaises(TypeError, lambda: self.set >= self.other)
914
915        self.assertRaises(TypeError, lambda: self.other < self.set)
916        self.assertRaises(TypeError, lambda: self.other <= self.set)
917        self.assertRaises(TypeError, lambda: self.other > self.set)
918        self.assertRaises(TypeError, lambda: self.other >= self.set)
919
920    def test_update_operator(self):
921        try:
922            self.set |= self.other
923        except TypeError:
924            pass
925        else:
926            self.fail("expected TypeError")
927
928    def test_update(self):
929        if self.otherIsIterable:
930            self.set.update(self.other)
931        else:
932            self.assertRaises(TypeError, self.set.update, self.other)
933
934    def test_union(self):
935        self.assertRaises(TypeError, lambda: self.set | self.other)
936        self.assertRaises(TypeError, lambda: self.other | self.set)
937        if self.otherIsIterable:
938            self.set.union(self.other)
939        else:
940            self.assertRaises(TypeError, self.set.union, self.other)
941
942    def test_intersection_update_operator(self):
943        try:
944            self.set &= self.other
945        except TypeError:
946            pass
947        else:
948            self.fail("expected TypeError")
949
950    def test_intersection_update(self):
951        if self.otherIsIterable:
952            self.set.intersection_update(self.other)
953        else:
954            self.assertRaises(TypeError,
955                              self.set.intersection_update,
956                              self.other)
957
958    def test_intersection(self):
959        self.assertRaises(TypeError, lambda: self.set & self.other)
960        self.assertRaises(TypeError, lambda: self.other & self.set)
961        if self.otherIsIterable:
962            self.set.intersection(self.other)
963        else:
964            self.assertRaises(TypeError, self.set.intersection, self.other)
965
966    def test_sym_difference_update_operator(self):
967        try:
968            self.set ^= self.other
969        except TypeError:
970            pass
971        else:
972            self.fail("expected TypeError")
973
974    def test_sym_difference_update(self):
975        if self.otherIsIterable:
976            self.set.symmetric_difference_update(self.other)
977        else:
978            self.assertRaises(TypeError,
979                              self.set.symmetric_difference_update,
980                              self.other)
981
982    def test_sym_difference(self):
983        self.assertRaises(TypeError, lambda: self.set ^ self.other)
984        self.assertRaises(TypeError, lambda: self.other ^ self.set)
985        if self.otherIsIterable:
986            self.set.symmetric_difference(self.other)
987        else:
988            self.assertRaises(TypeError, self.set.symmetric_difference, self.other)
989
990    def test_difference_update_operator(self):
991        try:
992            self.set -= self.other
993        except TypeError:
994            pass
995        else:
996            self.fail("expected TypeError")
997
998    def test_difference_update(self):
999        if self.otherIsIterable:
1000            self.set.difference_update(self.other)
1001        else:
1002            self.assertRaises(TypeError,
1003                              self.set.difference_update,
1004                              self.other)
1005
1006    def test_difference(self):
1007        self.assertRaises(TypeError, lambda: self.set - self.other)
1008        self.assertRaises(TypeError, lambda: self.other - self.set)
1009        if self.otherIsIterable:
1010            self.set.difference(self.other)
1011        else:
1012            self.assertRaises(TypeError, self.set.difference, self.other)
1013
1014#------------------------------------------------------------------------------
1015
1016class TestOnlySetsNumeric(TestOnlySetsInBinaryOps):
1017    def setUp(self):
1018        self.set   = set((1, 2, 3))
1019        self.other = 19
1020        self.otherIsIterable = False
1021
1022#------------------------------------------------------------------------------
1023
1024class TestOnlySetsDict(TestOnlySetsInBinaryOps):
1025    def setUp(self):
1026        self.set   = set((1, 2, 3))
1027        self.other = {1:2, 3:4}
1028        self.otherIsIterable = True
1029
1030#------------------------------------------------------------------------------
1031
1032class TestOnlySetsOperator(TestOnlySetsInBinaryOps):
1033    def setUp(self):
1034        self.set   = set((1, 2, 3))
1035        self.other = operator.add
1036        self.otherIsIterable = False
1037
1038#------------------------------------------------------------------------------
1039
1040class TestOnlySetsTuple(TestOnlySetsInBinaryOps):
1041    def setUp(self):
1042        self.set   = set((1, 2, 3))
1043        self.other = (2, 4, 6)
1044        self.otherIsIterable = True
1045
1046#------------------------------------------------------------------------------
1047
1048class TestOnlySetsString(TestOnlySetsInBinaryOps):
1049    def setUp(self):
1050        self.set   = set((1, 2, 3))
1051        self.other = 'abc'
1052        self.otherIsIterable = True
1053
1054#------------------------------------------------------------------------------
1055
1056class TestOnlySetsGenerator(TestOnlySetsInBinaryOps):
1057    def setUp(self):
1058        def gen():
1059            for i in xrange(0, 10, 2):
1060                yield i
1061        self.set   = set((1, 2, 3))
1062        self.other = gen()
1063        self.otherIsIterable = True
1064
1065#==============================================================================
1066
1067class TestCopying(unittest.TestCase):
1068
1069    def test_copy(self):
1070        dup = self.set.copy()
1071        dup_list = list(dup); dup_list.sort()
1072        set_list = list(self.set); set_list.sort()
1073        self.assertEqual(len(dup_list), len(set_list))
1074        for i in range(len(dup_list)):
1075            self.failUnless(dup_list[i] is set_list[i])
1076
1077    def test_deep_copy(self):
1078        dup = copy.deepcopy(self.set)
1079        ##print type(dup), `dup`
1080        dup_list = list(dup); dup_list.sort()
1081        set_list = list(self.set); set_list.sort()
1082        self.assertEqual(len(dup_list), len(set_list))
1083        for i in range(len(dup_list)):
1084            self.assertEqual(dup_list[i], set_list[i])
1085
1086#------------------------------------------------------------------------------
1087
1088class TestCopyingEmpty(TestCopying):
1089    def setUp(self):
1090        self.set = set()
1091
1092#------------------------------------------------------------------------------
1093
1094class TestCopyingSingleton(TestCopying):
1095    def setUp(self):
1096        self.set = set(["hello"])
1097
1098#------------------------------------------------------------------------------
1099
1100class TestCopyingTriple(TestCopying):
1101    def setUp(self):
1102        self.set = set(["zero", 0, None])
1103
1104#------------------------------------------------------------------------------
1105
1106class TestCopyingTuple(TestCopying):
1107    def setUp(self):
1108        self.set = set([(1, 2)])
1109
1110#------------------------------------------------------------------------------
1111
1112class TestCopyingNested(TestCopying):
1113    def setUp(self):
1114        self.set = set([((1, 2), (3, 4))])
1115
1116#==============================================================================
1117
1118class TestIdentities(unittest.TestCase):
1119    def setUp(self):
1120        self.a = set('abracadabra')
1121        self.b = set('alacazam')
1122
1123    def test_binopsVsSubsets(self):
1124        a, b = self.a, self.b
1125        self.assert_(a - b < a)
1126        self.assert_(b - a < b)
1127        self.assert_(a & b < a)
1128        self.assert_(a & b < b)
1129        self.assert_(a | b > a)
1130        self.assert_(a | b > b)
1131        self.assert_(a ^ b < a | b)
1132
1133    def test_commutativity(self):
1134        a, b = self.a, self.b
1135        self.assertEqual(a&b, b&a)
1136        self.assertEqual(a|b, b|a)
1137        self.assertEqual(a^b, b^a)
1138        if a != b:
1139            self.assertNotEqual(a-b, b-a)
1140
1141    def test_summations(self):
1142        # check that sums of parts equal the whole
1143        a, b = self.a, self.b
1144        self.assertEqual((a-b)|(a&b)|(b-a), a|b)
1145        self.assertEqual((a&b)|(a^b), a|b)
1146        self.assertEqual(a|(b-a), a|b)
1147        self.assertEqual((a-b)|b, a|b)
1148        self.assertEqual((a-b)|(a&b), a)
1149        self.assertEqual((b-a)|(a&b), b)
1150        self.assertEqual((a-b)|(b-a), a^b)
1151
1152    def test_exclusion(self):
1153        # check that inverse operations show non-overlap
1154        a, b, zero = self.a, self.b, set()
1155        self.assertEqual((a-b)&b, zero)
1156        self.assertEqual((b-a)&a, zero)
1157        self.assertEqual((a&b)&(a^b), zero)
1158
1159# Tests derived from test_itertools.py =======================================
1160
1161def R(seqn):
1162    'Regular generator'
1163    for i in seqn:
1164        yield i
1165
1166class G:
1167    'Sequence using __getitem__'
1168    def __init__(self, seqn):
1169        self.seqn = seqn
1170    def __getitem__(self, i):
1171        return self.seqn[i]
1172
1173class I:
1174    'Sequence using iterator protocol'
1175    def __init__(self, seqn):
1176        self.seqn = seqn
1177        self.i = 0
1178    def __iter__(self):
1179        return self
1180    def next(self):
1181        if self.i >= len(self.seqn): raise StopIteration
1182        v = self.seqn[self.i]
1183        self.i += 1
1184        return v
1185
1186class Ig:
1187    'Sequence using iterator protocol defined with a generator'
1188    def __init__(self, seqn):
1189        self.seqn = seqn
1190        self.i = 0
1191    def __iter__(self):
1192        for val in self.seqn:
1193            yield val
1194
1195class X:
1196    'Missing __getitem__ and __iter__'
1197    def __init__(self, seqn):
1198        self.seqn = seqn
1199        self.i = 0
1200    def next(self):
1201        if self.i >= len(self.seqn): raise StopIteration
1202        v = self.seqn[self.i]
1203        self.i += 1
1204        return v
1205
1206class N:
1207    'Iterator missing next()'
1208    def __init__(self, seqn):
1209        self.seqn = seqn
1210        self.i = 0
1211    def __iter__(self):
1212        return self
1213
1214class E:
1215    'Test propagation of exceptions'
1216    def __init__(self, seqn):
1217        self.seqn = seqn
1218        self.i = 0
1219    def __iter__(self):
1220        return self
1221    def next(self):
1222        3/0
1223
1224class S:
1225    'Test immediate stop'
1226    def __init__(self, seqn):
1227        pass
1228    def __iter__(self):
1229        return self
1230    def next(self):
1231        raise StopIteration
1232
1233from itertools import chain, imap
1234def L(seqn):
1235    'Test multiple tiers of iterators'
1236    return chain(imap(lambda x:x, R(Ig(G(seqn)))))
1237
1238class TestVariousIteratorArgs(unittest.TestCase):
1239
1240    def test_constructor(self):
1241        for cons in (set, frozenset):
1242            for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
1243                for g in (G, I, Ig, S, L, R):
1244                    self.assertEqual(sorted(cons(g(s))), sorted(g(s)))
1245                self.assertRaises(TypeError, cons , X(s))
1246                self.assertRaises(TypeError, cons , N(s))
1247                self.assertRaises(ZeroDivisionError, cons , E(s))
1248
1249    def test_inline_methods(self):
1250        s = set('november')
1251        for data in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5), 'december'):
1252            for meth in (s.union, s.intersection, s.difference, s.symmetric_difference):
1253                for g in (G, I, Ig, L, R):
1254                    expected = meth(data)
1255                    actual = meth(G(data))
1256                    self.assertEqual(sorted(actual), sorted(expected))
1257                self.assertRaises(TypeError, meth, X(s))
1258                self.assertRaises(TypeError, meth, N(s))
1259                self.assertRaises(ZeroDivisionError, meth, E(s))
1260
1261    def test_inplace_methods(self):
1262        for data in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5), 'december'):
1263            for methname in ('update', 'intersection_update',
1264                             'difference_update', 'symmetric_difference_update'):
1265                for g in (G, I, Ig, S, L, R):
1266                    s = set('january')
1267                    t = s.copy()
1268                    getattr(s, methname)(list(g(data)))
1269                    getattr(t, methname)(g(data))
1270                    self.assertEqual(sorted(s), sorted(t))
1271
1272                self.assertRaises(TypeError, getattr(set('january'), methname), X(data))
1273                self.assertRaises(TypeError, getattr(set('january'), methname), N(data))
1274                self.assertRaises(ZeroDivisionError, getattr(set('january'), methname), E(data))
1275
1276#==============================================================================
1277
1278def test_main(verbose=None):
1279    import sys
1280    from test import test_sets
1281    test_classes = (
1282        TestSet,
1283        TestSetSubclass,
1284        TestFrozenSet,
1285        TestFrozenSetSubclass,
1286        TestSetOfSets,
1287        TestExceptionPropagation,
1288        TestBasicOpsEmpty,
1289        TestBasicOpsSingleton,
1290        TestBasicOpsTuple,
1291        TestBasicOpsTriple,
1292        TestBinaryOps,
1293        TestUpdateOps,
1294        TestMutate,
1295        TestSubsetEqualEmpty,
1296        TestSubsetEqualNonEmpty,
1297        TestSubsetEmptyNonEmpty,
1298        TestSubsetPartial,
1299        TestSubsetNonOverlap,
1300        TestOnlySetsNumeric,
1301        TestOnlySetsDict,
1302        TestOnlySetsOperator,
1303        TestOnlySetsTuple,
1304        TestOnlySetsString,
1305        TestOnlySetsGenerator,
1306        TestCopyingEmpty,
1307        TestCopyingSingleton,
1308        TestCopyingTriple,
1309        TestCopyingTuple,
1310        TestCopyingNested,
1311        TestIdentities,
1312        TestVariousIteratorArgs,
1313        )
1314
1315    test_support.run_unittest(*test_classes)
1316
1317    # verify reference counting
1318    if verbose and hasattr(sys, "gettotalrefcount"):
1319        import gc
1320        counts = [None] * 5
1321        for i in xrange(len(counts)):
1322            test_support.run_unittest(*test_classes)
1323            gc.collect()
1324            counts[i] = sys.gettotalrefcount()
1325        print counts
1326
1327if __name__ == "__main__":
1328    test_main(verbose=True)
1329