1import unittest
2from test import test_support
3
4import UserDict, random, string
5import gc, weakref
6
7
8class DictTest(unittest.TestCase):
9    def test_constructor(self):
10        # calling built-in types without argument must return empty
11        self.assertEqual(dict(), {})
12        self.assertIsNot(dict(), {})
13
14    def test_literal_constructor(self):
15        # check literal constructor for different sized dicts
16        # (to exercise the BUILD_MAP oparg).
17        for n in (0, 1, 6, 256, 400):
18            items = [(''.join(random.sample(string.letters, 8)), i)
19                     for i in range(n)]
20            random.shuffle(items)
21            formatted_items = ('{!r}: {:d}'.format(k, v) for k, v in items)
22            dictliteral = '{' + ', '.join(formatted_items) + '}'
23            self.assertEqual(eval(dictliteral), dict(items))
24
25    def test_bool(self):
26        self.assertIs(not {}, True)
27        self.assertTrue({1: 2})
28        self.assertIs(bool({}), False)
29        self.assertIs(bool({1: 2}), True)
30
31    def test_keys(self):
32        d = {}
33        self.assertEqual(d.keys(), [])
34        d = {'a': 1, 'b': 2}
35        k = d.keys()
36        self.assertTrue(d.has_key('a'))
37        self.assertTrue(d.has_key('b'))
38
39        self.assertRaises(TypeError, d.keys, None)
40
41    def test_values(self):
42        d = {}
43        self.assertEqual(d.values(), [])
44        d = {1:2}
45        self.assertEqual(d.values(), [2])
46
47        self.assertRaises(TypeError, d.values, None)
48
49    def test_items(self):
50        d = {}
51        self.assertEqual(d.items(), [])
52
53        d = {1:2}
54        self.assertEqual(d.items(), [(1, 2)])
55
56        self.assertRaises(TypeError, d.items, None)
57
58    def test_has_key(self):
59        d = {}
60        self.assertFalse(d.has_key('a'))
61        d = {'a': 1, 'b': 2}
62        k = d.keys()
63        k.sort()
64        self.assertEqual(k, ['a', 'b'])
65
66        self.assertRaises(TypeError, d.has_key)
67
68    def test_contains(self):
69        d = {}
70        self.assertNotIn('a', d)
71        self.assertFalse('a' in d)
72        self.assertTrue('a' not in d)
73        d = {'a': 1, 'b': 2}
74        self.assertIn('a', d)
75        self.assertIn('b', d)
76        self.assertNotIn('c', d)
77
78        self.assertRaises(TypeError, d.__contains__)
79
80    def test_len(self):
81        d = {}
82        self.assertEqual(len(d), 0)
83        d = {'a': 1, 'b': 2}
84        self.assertEqual(len(d), 2)
85
86    def test_getitem(self):
87        d = {'a': 1, 'b': 2}
88        self.assertEqual(d['a'], 1)
89        self.assertEqual(d['b'], 2)
90        d['c'] = 3
91        d['a'] = 4
92        self.assertEqual(d['c'], 3)
93        self.assertEqual(d['a'], 4)
94        del d['b']
95        self.assertEqual(d, {'a': 4, 'c': 3})
96
97        self.assertRaises(TypeError, d.__getitem__)
98
99        class BadEq(object):
100            def __eq__(self, other):
101                raise Exc()
102            def __hash__(self):
103                return 24
104
105        d = {}
106        d[BadEq()] = 42
107        self.assertRaises(KeyError, d.__getitem__, 23)
108
109        class Exc(Exception): pass
110
111        class BadHash(object):
112            fail = False
113            def __hash__(self):
114                if self.fail:
115                    raise Exc()
116                else:
117                    return 42
118
119        x = BadHash()
120        d[x] = 42
121        x.fail = True
122        self.assertRaises(Exc, d.__getitem__, x)
123
124    def test_clear(self):
125        d = {1:1, 2:2, 3:3}
126        d.clear()
127        self.assertEqual(d, {})
128
129        self.assertRaises(TypeError, d.clear, None)
130
131    def test_update(self):
132        d = {}
133        d.update({1:100})
134        d.update({2:20})
135        d.update({1:1, 2:2, 3:3})
136        self.assertEqual(d, {1:1, 2:2, 3:3})
137
138        d.update()
139        self.assertEqual(d, {1:1, 2:2, 3:3})
140
141        self.assertRaises((TypeError, AttributeError), d.update, None)
142
143        class SimpleUserDict:
144            def __init__(self):
145                self.d = {1:1, 2:2, 3:3}
146            def keys(self):
147                return self.d.keys()
148            def __getitem__(self, i):
149                return self.d[i]
150        d.clear()
151        d.update(SimpleUserDict())
152        self.assertEqual(d, {1:1, 2:2, 3:3})
153
154        class Exc(Exception): pass
155
156        d.clear()
157        class FailingUserDict:
158            def keys(self):
159                raise Exc
160        self.assertRaises(Exc, d.update, FailingUserDict())
161
162        class FailingUserDict:
163            def keys(self):
164                class BogonIter:
165                    def __init__(self):
166                        self.i = 1
167                    def __iter__(self):
168                        return self
169                    def next(self):
170                        if self.i:
171                            self.i = 0
172                            return 'a'
173                        raise Exc
174                return BogonIter()
175            def __getitem__(self, key):
176                return key
177        self.assertRaises(Exc, d.update, FailingUserDict())
178
179        class FailingUserDict:
180            def keys(self):
181                class BogonIter:
182                    def __init__(self):
183                        self.i = ord('a')
184                    def __iter__(self):
185                        return self
186                    def next(self):
187                        if self.i <= ord('z'):
188                            rtn = chr(self.i)
189                            self.i += 1
190                            return rtn
191                        raise StopIteration
192                return BogonIter()
193            def __getitem__(self, key):
194                raise Exc
195        self.assertRaises(Exc, d.update, FailingUserDict())
196
197        class badseq(object):
198            def __iter__(self):
199                return self
200            def next(self):
201                raise Exc()
202
203        self.assertRaises(Exc, {}.update, badseq())
204
205        self.assertRaises(ValueError, {}.update, [(1, 2, 3)])
206
207    def test_fromkeys(self):
208        self.assertEqual(dict.fromkeys('abc'), {'a':None, 'b':None, 'c':None})
209        d = {}
210        self.assertIsNot(d.fromkeys('abc'), d)
211        self.assertEqual(d.fromkeys('abc'), {'a':None, 'b':None, 'c':None})
212        self.assertEqual(d.fromkeys((4,5),0), {4:0, 5:0})
213        self.assertEqual(d.fromkeys([]), {})
214        def g():
215            yield 1
216        self.assertEqual(d.fromkeys(g()), {1:None})
217        self.assertRaises(TypeError, {}.fromkeys, 3)
218        class dictlike(dict): pass
219        self.assertEqual(dictlike.fromkeys('a'), {'a':None})
220        self.assertEqual(dictlike().fromkeys('a'), {'a':None})
221        self.assertIsInstance(dictlike.fromkeys('a'), dictlike)
222        self.assertIsInstance(dictlike().fromkeys('a'), dictlike)
223        class mydict(dict):
224            def __new__(cls):
225                return UserDict.UserDict()
226        ud = mydict.fromkeys('ab')
227        self.assertEqual(ud, {'a':None, 'b':None})
228        self.assertIsInstance(ud, UserDict.UserDict)
229        self.assertRaises(TypeError, dict.fromkeys)
230
231        class Exc(Exception): pass
232
233        class baddict1(dict):
234            def __init__(self):
235                raise Exc()
236
237        self.assertRaises(Exc, baddict1.fromkeys, [1])
238
239        class BadSeq(object):
240            def __iter__(self):
241                return self
242            def next(self):
243                raise Exc()
244
245        self.assertRaises(Exc, dict.fromkeys, BadSeq())
246
247        class baddict2(dict):
248            def __setitem__(self, key, value):
249                raise Exc()
250
251        self.assertRaises(Exc, baddict2.fromkeys, [1])
252
253        # test fast path for dictionary inputs
254        d = dict(zip(range(6), range(6)))
255        self.assertEqual(dict.fromkeys(d, 0), dict(zip(range(6), [0]*6)))
256
257        class baddict3(dict):
258            def __new__(cls):
259                return d
260        d = {i : i for i in range(10)}
261        res = d.copy()
262        res.update(a=None, b=None, c=None)
263        self.assertEqual(baddict3.fromkeys({"a", "b", "c"}), res)
264
265    def test_copy(self):
266        d = {1:1, 2:2, 3:3}
267        self.assertEqual(d.copy(), {1:1, 2:2, 3:3})
268        self.assertEqual({}.copy(), {})
269        self.assertRaises(TypeError, d.copy, None)
270
271    def test_get(self):
272        d = {}
273        self.assertIs(d.get('c'), None)
274        self.assertEqual(d.get('c', 3), 3)
275        d = {'a': 1, 'b': 2}
276        self.assertIs(d.get('c'), None)
277        self.assertEqual(d.get('c', 3), 3)
278        self.assertEqual(d.get('a'), 1)
279        self.assertEqual(d.get('a', 3), 1)
280        self.assertRaises(TypeError, d.get)
281        self.assertRaises(TypeError, d.get, None, None, None)
282
283    def test_setdefault(self):
284        # dict.setdefault()
285        d = {}
286        self.assertIs(d.setdefault('key0'), None)
287        d.setdefault('key0', [])
288        self.assertIs(d.setdefault('key0'), None)
289        d.setdefault('key', []).append(3)
290        self.assertEqual(d['key'][0], 3)
291        d.setdefault('key', []).append(4)
292        self.assertEqual(len(d['key']), 2)
293        self.assertRaises(TypeError, d.setdefault)
294
295        class Exc(Exception): pass
296
297        class BadHash(object):
298            fail = False
299            def __hash__(self):
300                if self.fail:
301                    raise Exc()
302                else:
303                    return 42
304
305        x = BadHash()
306        d[x] = 42
307        x.fail = True
308        self.assertRaises(Exc, d.setdefault, x, [])
309
310    def test_setdefault_atomic(self):
311        # Issue #13521: setdefault() calls __hash__ and __eq__ only once.
312        class Hashed(object):
313            def __init__(self):
314                self.hash_count = 0
315                self.eq_count = 0
316            def __hash__(self):
317                self.hash_count += 1
318                return 42
319            def __eq__(self, other):
320                self.eq_count += 1
321                return id(self) == id(other)
322        hashed1 = Hashed()
323        y = {hashed1: 5}
324        hashed2 = Hashed()
325        y.setdefault(hashed2, [])
326        self.assertEqual(hashed1.hash_count, 1)
327        self.assertEqual(hashed2.hash_count, 1)
328        self.assertEqual(hashed1.eq_count + hashed2.eq_count, 1)
329
330    def test_popitem(self):
331        # dict.popitem()
332        for copymode in -1, +1:
333            # -1: b has same structure as a
334            # +1: b is a.copy()
335            for log2size in range(12):
336                size = 2**log2size
337                a = {}
338                b = {}
339                for i in range(size):
340                    a[repr(i)] = i
341                    if copymode < 0:
342                        b[repr(i)] = i
343                if copymode > 0:
344                    b = a.copy()
345                for i in range(size):
346                    ka, va = ta = a.popitem()
347                    self.assertEqual(va, int(ka))
348                    kb, vb = tb = b.popitem()
349                    self.assertEqual(vb, int(kb))
350                    self.assertFalse(copymode < 0 and ta != tb)
351                self.assertFalse(a)
352                self.assertFalse(b)
353
354        d = {}
355        self.assertRaises(KeyError, d.popitem)
356
357    def test_pop(self):
358        # Tests for pop with specified key
359        d = {}
360        k, v = 'abc', 'def'
361        d[k] = v
362        self.assertRaises(KeyError, d.pop, 'ghi')
363
364        self.assertEqual(d.pop(k), v)
365        self.assertEqual(len(d), 0)
366
367        self.assertRaises(KeyError, d.pop, k)
368
369        # verify longs/ints get same value when key > 32 bits
370        # (for 64-bit archs).  See SF bug #689659.
371        x = 4503599627370496L
372        y = 4503599627370496
373        h = {x: 'anything', y: 'something else'}
374        self.assertEqual(h[x], h[y])
375
376        self.assertEqual(d.pop(k, v), v)
377        d[k] = v
378        self.assertEqual(d.pop(k, 1), v)
379
380        self.assertRaises(TypeError, d.pop)
381
382        class Exc(Exception): pass
383
384        class BadHash(object):
385            fail = False
386            def __hash__(self):
387                if self.fail:
388                    raise Exc()
389                else:
390                    return 42
391
392        x = BadHash()
393        d[x] = 42
394        x.fail = True
395        self.assertRaises(Exc, d.pop, x)
396
397    def test_mutatingiteration(self):
398        # changing dict size during iteration
399        d = {}
400        d[1] = 1
401        with self.assertRaises(RuntimeError):
402            for i in d:
403                d[i+1] = 1
404
405    def test_repr(self):
406        d = {}
407        self.assertEqual(repr(d), '{}')
408        d[1] = 2
409        self.assertEqual(repr(d), '{1: 2}')
410        d = {}
411        d[1] = d
412        self.assertEqual(repr(d), '{1: {...}}')
413
414        class Exc(Exception): pass
415
416        class BadRepr(object):
417            def __repr__(self):
418                raise Exc()
419
420        d = {1: BadRepr()}
421        self.assertRaises(Exc, repr, d)
422
423    def test_le(self):
424        self.assertFalse({} < {})
425        self.assertFalse({1: 2} < {1L: 2L})
426
427        class Exc(Exception): pass
428
429        class BadCmp(object):
430            def __eq__(self, other):
431                raise Exc()
432            def __hash__(self):
433                return 42
434
435        d1 = {BadCmp(): 1}
436        d2 = {1: 1}
437
438        with self.assertRaises(Exc):
439            d1 < d2
440
441    def test_missing(self):
442        # Make sure dict doesn't have a __missing__ method
443        self.assertFalse(hasattr(dict, "__missing__"))
444        self.assertFalse(hasattr({}, "__missing__"))
445        # Test several cases:
446        # (D) subclass defines __missing__ method returning a value
447        # (E) subclass defines __missing__ method raising RuntimeError
448        # (F) subclass sets __missing__ instance variable (no effect)
449        # (G) subclass doesn't define __missing__ at a all
450        class D(dict):
451            def __missing__(self, key):
452                return 42
453        d = D({1: 2, 3: 4})
454        self.assertEqual(d[1], 2)
455        self.assertEqual(d[3], 4)
456        self.assertNotIn(2, d)
457        self.assertNotIn(2, d.keys())
458        self.assertEqual(d[2], 42)
459
460        class E(dict):
461            def __missing__(self, key):
462                raise RuntimeError(key)
463        e = E()
464        with self.assertRaises(RuntimeError) as c:
465            e[42]
466        self.assertEqual(c.exception.args, (42,))
467
468        class F(dict):
469            def __init__(self):
470                # An instance variable __missing__ should have no effect
471                self.__missing__ = lambda key: None
472        f = F()
473        with self.assertRaises(KeyError) as c:
474            f[42]
475        self.assertEqual(c.exception.args, (42,))
476
477        class G(dict):
478            pass
479        g = G()
480        with self.assertRaises(KeyError) as c:
481            g[42]
482        self.assertEqual(c.exception.args, (42,))
483
484    def test_tuple_keyerror(self):
485        # SF #1576657
486        d = {}
487        with self.assertRaises(KeyError) as c:
488            d[(1,)]
489        self.assertEqual(c.exception.args, ((1,),))
490
491    def test_bad_key(self):
492        # Dictionary lookups should fail if __cmp__() raises an exception.
493        class CustomException(Exception):
494            pass
495
496        class BadDictKey:
497            def __hash__(self):
498                return hash(self.__class__)
499
500            def __cmp__(self, other):
501                if isinstance(other, self.__class__):
502                    raise CustomException
503                return other
504
505        d = {}
506        x1 = BadDictKey()
507        x2 = BadDictKey()
508        d[x1] = 1
509        for stmt in ['d[x2] = 2',
510                     'z = d[x2]',
511                     'x2 in d',
512                     'd.has_key(x2)',
513                     'd.get(x2)',
514                     'd.setdefault(x2, 42)',
515                     'd.pop(x2)',
516                     'd.update({x2: 2})']:
517            with self.assertRaises(CustomException):
518                exec stmt in locals()
519
520    def test_resize1(self):
521        # Dict resizing bug, found by Jack Jansen in 2.2 CVS development.
522        # This version got an assert failure in debug build, infinite loop in
523        # release build.  Unfortunately, provoking this kind of stuff requires
524        # a mix of inserts and deletes hitting exactly the right hash codes in
525        # exactly the right order, and I can't think of a randomized approach
526        # that would be *likely* to hit a failing case in reasonable time.
527
528        d = {}
529        for i in range(5):
530            d[i] = i
531        for i in range(5):
532            del d[i]
533        for i in range(5, 9):  # i==8 was the problem
534            d[i] = i
535
536    def test_resize2(self):
537        # Another dict resizing bug (SF bug #1456209).
538        # This caused Segmentation faults or Illegal instructions.
539
540        class X(object):
541            def __hash__(self):
542                return 5
543            def __eq__(self, other):
544                if resizing:
545                    d.clear()
546                return False
547        d = {}
548        resizing = False
549        d[X()] = 1
550        d[X()] = 2
551        d[X()] = 3
552        d[X()] = 4
553        d[X()] = 5
554        # now trigger a resize
555        resizing = True
556        d[9] = 6
557
558    def test_empty_presized_dict_in_freelist(self):
559        # Bug #3537: if an empty but presized dict with a size larger
560        # than 7 was in the freelist, it triggered an assertion failure
561        with self.assertRaises(ZeroDivisionError):
562            d = {'a': 1 // 0, 'b': None, 'c': None, 'd': None, 'e': None,
563                 'f': None, 'g': None, 'h': None}
564        d = {}
565
566    def test_container_iterator(self):
567        # Bug #3680: tp_traverse was not implemented for dictiter objects
568        class C(object):
569            pass
570        iterators = (dict.iteritems, dict.itervalues, dict.iterkeys)
571        for i in iterators:
572            obj = C()
573            ref = weakref.ref(obj)
574            container = {obj: 1}
575            obj.x = i(container)
576            del obj, container
577            gc.collect()
578            self.assertIs(ref(), None, "Cycle was not collected")
579
580    def _not_tracked(self, t):
581        # Nested containers can take several collections to untrack
582        gc.collect()
583        gc.collect()
584        self.assertFalse(gc.is_tracked(t), t)
585
586    def _tracked(self, t):
587        self.assertTrue(gc.is_tracked(t), t)
588        gc.collect()
589        gc.collect()
590        self.assertTrue(gc.is_tracked(t), t)
591
592    @test_support.cpython_only
593    def test_track_literals(self):
594        # Test GC-optimization of dict literals
595        x, y, z, w = 1.5, "a", (1, None), []
596
597        self._not_tracked({})
598        self._not_tracked({x:(), y:x, z:1})
599        self._not_tracked({1: "a", "b": 2})
600        self._not_tracked({1: 2, (None, True, False, ()): int})
601        self._not_tracked({1: object()})
602
603        # Dicts with mutable elements are always tracked, even if those
604        # elements are not tracked right now.
605        self._tracked({1: []})
606        self._tracked({1: ([],)})
607        self._tracked({1: {}})
608        self._tracked({1: set()})
609
610    @test_support.cpython_only
611    def test_track_dynamic(self):
612        # Test GC-optimization of dynamically-created dicts
613        class MyObject(object):
614            pass
615        x, y, z, w, o = 1.5, "a", (1, object()), [], MyObject()
616
617        d = dict()
618        self._not_tracked(d)
619        d[1] = "a"
620        self._not_tracked(d)
621        d[y] = 2
622        self._not_tracked(d)
623        d[z] = 3
624        self._not_tracked(d)
625        self._not_tracked(d.copy())
626        d[4] = w
627        self._tracked(d)
628        self._tracked(d.copy())
629        d[4] = None
630        self._not_tracked(d)
631        self._not_tracked(d.copy())
632
633        # dd isn't tracked right now, but it may mutate and therefore d
634        # which contains it must be tracked.
635        d = dict()
636        dd = dict()
637        d[1] = dd
638        self._not_tracked(dd)
639        self._tracked(d)
640        dd[1] = d
641        self._tracked(dd)
642
643        d = dict.fromkeys([x, y, z])
644        self._not_tracked(d)
645        dd = dict()
646        dd.update(d)
647        self._not_tracked(dd)
648        d = dict.fromkeys([x, y, z, o])
649        self._tracked(d)
650        dd = dict()
651        dd.update(d)
652        self._tracked(dd)
653
654        d = dict(x=x, y=y, z=z)
655        self._not_tracked(d)
656        d = dict(x=x, y=y, z=z, w=w)
657        self._tracked(d)
658        d = dict()
659        d.update(x=x, y=y, z=z)
660        self._not_tracked(d)
661        d.update(w=w)
662        self._tracked(d)
663
664        d = dict([(x, y), (z, 1)])
665        self._not_tracked(d)
666        d = dict([(x, y), (z, w)])
667        self._tracked(d)
668        d = dict()
669        d.update([(x, y), (z, 1)])
670        self._not_tracked(d)
671        d.update([(x, y), (z, w)])
672        self._tracked(d)
673
674    @test_support.cpython_only
675    def test_track_subtypes(self):
676        # Dict subtypes are always tracked
677        class MyDict(dict):
678            pass
679        self._tracked(MyDict())
680
681
682from test import mapping_tests
683
684class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol):
685    type2test = dict
686
687class Dict(dict):
688    pass
689
690class SubclassMappingTests(mapping_tests.BasicTestMappingProtocol):
691    type2test = Dict
692
693def test_main():
694    with test_support.check_py3k_warnings(
695        ('dict(.has_key..| inequality comparisons) not supported in 3.x',
696         DeprecationWarning)):
697        test_support.run_unittest(
698            DictTest,
699            GeneralMappingTests,
700            SubclassMappingTests,
701        )
702
703if __name__ == "__main__":
704    test_main()
705