1"""Unittests for heapq."""
2
3import sys
4import random
5
6from test import test_support
7from unittest import TestCase, skipUnless
8
9py_heapq = test_support.import_fresh_module('heapq', blocked=['_heapq'])
10c_heapq = test_support.import_fresh_module('heapq', fresh=['_heapq'])
11
12# _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
13# _heapq is imported, so check them there
14func_names = ['heapify', 'heappop', 'heappush', 'heappushpop',
15              'heapreplace', '_nlargest', '_nsmallest']
16
17class TestModules(TestCase):
18    def test_py_functions(self):
19        for fname in func_names:
20            self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')
21
22    @skipUnless(c_heapq, 'requires _heapq')
23    def test_c_functions(self):
24        for fname in func_names:
25            self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
26
27
28class TestHeap(TestCase):
29    module = None
30
31    def test_push_pop(self):
32        # 1) Push 256 random numbers and pop them off, verifying all's OK.
33        heap = []
34        data = []
35        self.check_invariant(heap)
36        for i in range(256):
37            item = random.random()
38            data.append(item)
39            self.module.heappush(heap, item)
40            self.check_invariant(heap)
41        results = []
42        while heap:
43            item = self.module.heappop(heap)
44            self.check_invariant(heap)
45            results.append(item)
46        data_sorted = data[:]
47        data_sorted.sort()
48        self.assertEqual(data_sorted, results)
49        # 2) Check that the invariant holds for a sorted array
50        self.check_invariant(results)
51
52        self.assertRaises(TypeError, self.module.heappush, [])
53        try:
54            self.assertRaises(TypeError, self.module.heappush, None, None)
55            self.assertRaises(TypeError, self.module.heappop, None)
56        except AttributeError:
57            pass
58
59    def check_invariant(self, heap):
60        # Check the heap invariant.
61        for pos, item in enumerate(heap):
62            if pos: # pos 0 has no parent
63                parentpos = (pos-1) >> 1
64                self.assertTrue(heap[parentpos] <= item)
65
66    def test_heapify(self):
67        for size in range(30):
68            heap = [random.random() for dummy in range(size)]
69            self.module.heapify(heap)
70            self.check_invariant(heap)
71
72        self.assertRaises(TypeError, self.module.heapify, None)
73
74    def test_naive_nbest(self):
75        data = [random.randrange(2000) for i in range(1000)]
76        heap = []
77        for item in data:
78            self.module.heappush(heap, item)
79            if len(heap) > 10:
80                self.module.heappop(heap)
81        heap.sort()
82        self.assertEqual(heap, sorted(data)[-10:])
83
84    def heapiter(self, heap):
85        # An iterator returning a heap's elements, smallest-first.
86        try:
87            while 1:
88                yield self.module.heappop(heap)
89        except IndexError:
90            pass
91
92    def test_nbest(self):
93        # Less-naive "N-best" algorithm, much faster (if len(data) is big
94        # enough <wink>) than sorting all of data.  However, if we had a max
95        # heap instead of a min heap, it could go faster still via
96        # heapify'ing all of data (linear time), then doing 10 heappops
97        # (10 log-time steps).
98        data = [random.randrange(2000) for i in range(1000)]
99        heap = data[:10]
100        self.module.heapify(heap)
101        for item in data[10:]:
102            if item > heap[0]:  # this gets rarer the longer we run
103                self.module.heapreplace(heap, item)
104        self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
105
106        self.assertRaises(TypeError, self.module.heapreplace, None)
107        self.assertRaises(TypeError, self.module.heapreplace, None, None)
108        self.assertRaises(IndexError, self.module.heapreplace, [], None)
109
110    def test_nbest_with_pushpop(self):
111        data = [random.randrange(2000) for i in range(1000)]
112        heap = data[:10]
113        self.module.heapify(heap)
114        for item in data[10:]:
115            self.module.heappushpop(heap, item)
116        self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
117        self.assertEqual(self.module.heappushpop([], 'x'), 'x')
118
119    def test_heappushpop(self):
120        h = []
121        x = self.module.heappushpop(h, 10)
122        self.assertEqual((h, x), ([], 10))
123
124        h = [10]
125        x = self.module.heappushpop(h, 10.0)
126        self.assertEqual((h, x), ([10], 10.0))
127        self.assertEqual(type(h[0]), int)
128        self.assertEqual(type(x), float)
129
130        h = [10];
131        x = self.module.heappushpop(h, 9)
132        self.assertEqual((h, x), ([10], 9))
133
134        h = [10];
135        x = self.module.heappushpop(h, 11)
136        self.assertEqual((h, x), ([11], 10))
137
138    def test_heapsort(self):
139        # Exercise everything with repeated heapsort checks
140        for trial in xrange(100):
141            size = random.randrange(50)
142            data = [random.randrange(25) for i in range(size)]
143            if trial & 1:     # Half of the time, use heapify
144                heap = data[:]
145                self.module.heapify(heap)
146            else:             # The rest of the time, use heappush
147                heap = []
148                for item in data:
149                    self.module.heappush(heap, item)
150            heap_sorted = [self.module.heappop(heap) for i in range(size)]
151            self.assertEqual(heap_sorted, sorted(data))
152
153    def test_merge(self):
154        inputs = []
155        for i in xrange(random.randrange(5)):
156            row = sorted(random.randrange(1000) for j in range(random.randrange(10)))
157            inputs.append(row)
158        self.assertEqual(sorted(chain(*inputs)), list(self.module.merge(*inputs)))
159        self.assertEqual(list(self.module.merge()), [])
160
161    def test_merge_stability(self):
162        class Int(int):
163            pass
164        inputs = [[], [], [], []]
165        for i in range(20000):
166            stream = random.randrange(4)
167            x = random.randrange(500)
168            obj = Int(x)
169            obj.pair = (x, stream)
170            inputs[stream].append(obj)
171        for stream in inputs:
172            stream.sort()
173        result = [i.pair for i in self.module.merge(*inputs)]
174        self.assertEqual(result, sorted(result))
175
176    def test_nsmallest(self):
177        data = [(random.randrange(2000), i) for i in range(1000)]
178        for f in (None, lambda x:  x[0] * 547 % 2000):
179            for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
180                self.assertEqual(self.module.nsmallest(n, data), sorted(data)[:n])
181                self.assertEqual(self.module.nsmallest(n, data, key=f),
182                                 sorted(data, key=f)[:n])
183
184    def test_nlargest(self):
185        data = [(random.randrange(2000), i) for i in range(1000)]
186        for f in (None, lambda x:  x[0] * 547 % 2000):
187            for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
188                self.assertEqual(self.module.nlargest(n, data),
189                                 sorted(data, reverse=True)[:n])
190                self.assertEqual(self.module.nlargest(n, data, key=f),
191                                 sorted(data, key=f, reverse=True)[:n])
192
193    def test_comparison_operator(self):
194        # Issue 3051: Make sure heapq works with both __lt__ and __le__
195        def hsort(data, comp):
196            data = map(comp, data)
197            self.module.heapify(data)
198            return [self.module.heappop(data).x for i in range(len(data))]
199        class LT:
200            def __init__(self, x):
201                self.x = x
202            def __lt__(self, other):
203                return self.x > other.x
204        class LE:
205            def __init__(self, x):
206                self.x = x
207            def __le__(self, other):
208                return self.x >= other.x
209        data = [random.random() for i in range(100)]
210        target = sorted(data, reverse=True)
211        self.assertEqual(hsort(data, LT), target)
212        self.assertEqual(hsort(data, LE), target)
213
214
215class TestHeapPython(TestHeap):
216    module = py_heapq
217
218
219@skipUnless(c_heapq, 'requires _heapq')
220class TestHeapC(TestHeap):
221    module = c_heapq
222
223
224#==============================================================================
225
226class LenOnly:
227    "Dummy sequence class defining __len__ but not __getitem__."
228    def __len__(self):
229        return 10
230
231class GetOnly:
232    "Dummy sequence class defining __getitem__ but not __len__."
233    def __getitem__(self, ndx):
234        return 10
235
236class CmpErr:
237    "Dummy element that always raises an error during comparison"
238    def __cmp__(self, other):
239        raise ZeroDivisionError
240
241def R(seqn):
242    'Regular generator'
243    for i in seqn:
244        yield i
245
246class G:
247    'Sequence using __getitem__'
248    def __init__(self, seqn):
249        self.seqn = seqn
250    def __getitem__(self, i):
251        return self.seqn[i]
252
253class I:
254    'Sequence using iterator protocol'
255    def __init__(self, seqn):
256        self.seqn = seqn
257        self.i = 0
258    def __iter__(self):
259        return self
260    def next(self):
261        if self.i >= len(self.seqn): raise StopIteration
262        v = self.seqn[self.i]
263        self.i += 1
264        return v
265
266class Ig:
267    'Sequence using iterator protocol defined with a generator'
268    def __init__(self, seqn):
269        self.seqn = seqn
270        self.i = 0
271    def __iter__(self):
272        for val in self.seqn:
273            yield val
274
275class X:
276    'Missing __getitem__ and __iter__'
277    def __init__(self, seqn):
278        self.seqn = seqn
279        self.i = 0
280    def next(self):
281        if self.i >= len(self.seqn): raise StopIteration
282        v = self.seqn[self.i]
283        self.i += 1
284        return v
285
286class N:
287    'Iterator missing next()'
288    def __init__(self, seqn):
289        self.seqn = seqn
290        self.i = 0
291    def __iter__(self):
292        return self
293
294class E:
295    'Test propagation of exceptions'
296    def __init__(self, seqn):
297        self.seqn = seqn
298        self.i = 0
299    def __iter__(self):
300        return self
301    def next(self):
302        3 // 0
303
304class S:
305    'Test immediate stop'
306    def __init__(self, seqn):
307        pass
308    def __iter__(self):
309        return self
310    def next(self):
311        raise StopIteration
312
313from itertools import chain, imap
314def L(seqn):
315    'Test multiple tiers of iterators'
316    return chain(imap(lambda x:x, R(Ig(G(seqn)))))
317
318class SideEffectLT:
319    def __init__(self, value, heap):
320        self.value = value
321        self.heap = heap
322
323    def __lt__(self, other):
324        self.heap[:] = []
325        return self.value < other.value
326
327
328class TestErrorHandling(TestCase):
329    module = None
330
331    def test_non_sequence(self):
332        for f in (self.module.heapify, self.module.heappop):
333            self.assertRaises((TypeError, AttributeError), f, 10)
334        for f in (self.module.heappush, self.module.heapreplace,
335                  self.module.nlargest, self.module.nsmallest):
336            self.assertRaises((TypeError, AttributeError), f, 10, 10)
337
338    def test_len_only(self):
339        for f in (self.module.heapify, self.module.heappop):
340            self.assertRaises((TypeError, AttributeError), f, LenOnly())
341        for f in (self.module.heappush, self.module.heapreplace):
342            self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10)
343        for f in (self.module.nlargest, self.module.nsmallest):
344            self.assertRaises(TypeError, f, 2, LenOnly())
345
346    def test_get_only(self):
347        seq = [CmpErr(), CmpErr(), CmpErr()]
348        for f in (self.module.heapify, self.module.heappop):
349            self.assertRaises(ZeroDivisionError, f, seq)
350        for f in (self.module.heappush, self.module.heapreplace):
351            self.assertRaises(ZeroDivisionError, f, seq, 10)
352        for f in (self.module.nlargest, self.module.nsmallest):
353            self.assertRaises(ZeroDivisionError, f, 2, seq)
354
355    def test_arg_parsing(self):
356        for f in (self.module.heapify, self.module.heappop,
357                  self.module.heappush, self.module.heapreplace,
358                  self.module.nlargest, self.module.nsmallest):
359            self.assertRaises((TypeError, AttributeError), f, 10)
360
361    def test_iterable_args(self):
362        for f in (self.module.nlargest, self.module.nsmallest):
363            for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
364                for g in (G, I, Ig, L, R):
365                    with test_support.check_py3k_warnings(
366                            ("comparing unequal types not supported",
367                             DeprecationWarning), quiet=True):
368                        self.assertEqual(f(2, g(s)), f(2,s))
369                self.assertEqual(f(2, S(s)), [])
370                self.assertRaises(TypeError, f, 2, X(s))
371                self.assertRaises(TypeError, f, 2, N(s))
372                self.assertRaises(ZeroDivisionError, f, 2, E(s))
373
374    # Issue #17278: the heap may change size while it's being walked.
375
376    def test_heappush_mutating_heap(self):
377        heap = []
378        heap.extend(SideEffectLT(i, heap) for i in range(200))
379        # Python version raises IndexError, C version RuntimeError
380        with self.assertRaises((IndexError, RuntimeError)):
381            self.module.heappush(heap, SideEffectLT(5, heap))
382
383    def test_heappop_mutating_heap(self):
384        heap = []
385        heap.extend(SideEffectLT(i, heap) for i in range(200))
386        # Python version raises IndexError, C version RuntimeError
387        with self.assertRaises((IndexError, RuntimeError)):
388            self.module.heappop(heap)
389
390
391class TestErrorHandlingPython(TestErrorHandling):
392    module = py_heapq
393
394
395@skipUnless(c_heapq, 'requires _heapq')
396class TestErrorHandlingC(TestErrorHandling):
397    module = c_heapq
398
399
400#==============================================================================
401
402
403def test_main(verbose=None):
404    test_classes = [TestModules, TestHeapPython, TestHeapC,
405                    TestErrorHandlingPython, TestErrorHandlingC]
406    test_support.run_unittest(*test_classes)
407
408    # verify reference counting
409    if verbose and hasattr(sys, "gettotalrefcount"):
410        import gc
411        counts = [None] * 5
412        for i in xrange(len(counts)):
413            test_support.run_unittest(*test_classes)
414            gc.collect()
415            counts[i] = sys.gettotalrefcount()
416        print counts
417
418if __name__ == "__main__":
419    test_main(verbose=True)
420