1"""Unittests for heapq."""
2
3import random
4import unittest
5
6from test import support
7from unittest import TestCase, skipUnless
8from operator import itemgetter
9
10py_heapq = support.import_fresh_module('heapq', blocked=['_heapq'])
11c_heapq = support.import_fresh_module('heapq', fresh=['_heapq'])
12
13# _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
14# _heapq is imported, so check them there
15func_names = ['heapify', 'heappop', 'heappush', 'heappushpop', 'heapreplace',
16              '_heappop_max', '_heapreplace_max', '_heapify_max']
17
18class TestModules(TestCase):
19    def test_py_functions(self):
20        for fname in func_names:
21            self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')
22
23    @skipUnless(c_heapq, 'requires _heapq')
24    def test_c_functions(self):
25        for fname in func_names:
26            self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
27
28
29class TestHeap:
30
31    def test_push_pop(self):
32        # 1) Push 256 random numbers and pop them off, verifying all's OK.
33        heap = []
34        data = []
35        self.check_invariant(heap)
36        for i in range(256):
37            item = random.random()
38            data.append(item)
39            self.module.heappush(heap, item)
40            self.check_invariant(heap)
41        results = []
42        while heap:
43            item = self.module.heappop(heap)
44            self.check_invariant(heap)
45            results.append(item)
46        data_sorted = data[:]
47        data_sorted.sort()
48        self.assertEqual(data_sorted, results)
49        # 2) Check that the invariant holds for a sorted array
50        self.check_invariant(results)
51
52        self.assertRaises(TypeError, self.module.heappush, [])
53        try:
54            self.assertRaises(TypeError, self.module.heappush, None, None)
55            self.assertRaises(TypeError, self.module.heappop, None)
56        except AttributeError:
57            pass
58
59    def check_invariant(self, heap):
60        # Check the heap invariant.
61        for pos, item in enumerate(heap):
62            if pos: # pos 0 has no parent
63                parentpos = (pos-1) >> 1
64                self.assertTrue(heap[parentpos] <= item)
65
66    def test_heapify(self):
67        for size in list(range(30)) + [20000]:
68            heap = [random.random() for dummy in range(size)]
69            self.module.heapify(heap)
70            self.check_invariant(heap)
71
72        self.assertRaises(TypeError, self.module.heapify, None)
73
74    def test_naive_nbest(self):
75        data = [random.randrange(2000) for i in range(1000)]
76        heap = []
77        for item in data:
78            self.module.heappush(heap, item)
79            if len(heap) > 10:
80                self.module.heappop(heap)
81        heap.sort()
82        self.assertEqual(heap, sorted(data)[-10:])
83
84    def heapiter(self, heap):
85        # An iterator returning a heap's elements, smallest-first.
86        try:
87            while 1:
88                yield self.module.heappop(heap)
89        except IndexError:
90            pass
91
92    def test_nbest(self):
93        # Less-naive "N-best" algorithm, much faster (if len(data) is big
94        # enough <wink>) than sorting all of data.  However, if we had a max
95        # heap instead of a min heap, it could go faster still via
96        # heapify'ing all of data (linear time), then doing 10 heappops
97        # (10 log-time steps).
98        data = [random.randrange(2000) for i in range(1000)]
99        heap = data[:10]
100        self.module.heapify(heap)
101        for item in data[10:]:
102            if item > heap[0]:  # this gets rarer the longer we run
103                self.module.heapreplace(heap, item)
104        self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
105
106        self.assertRaises(TypeError, self.module.heapreplace, None)
107        self.assertRaises(TypeError, self.module.heapreplace, None, None)
108        self.assertRaises(IndexError, self.module.heapreplace, [], None)
109
110    def test_nbest_with_pushpop(self):
111        data = [random.randrange(2000) for i in range(1000)]
112        heap = data[:10]
113        self.module.heapify(heap)
114        for item in data[10:]:
115            self.module.heappushpop(heap, item)
116        self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
117        self.assertEqual(self.module.heappushpop([], 'x'), 'x')
118
119    def test_heappushpop(self):
120        h = []
121        x = self.module.heappushpop(h, 10)
122        self.assertEqual((h, x), ([], 10))
123
124        h = [10]
125        x = self.module.heappushpop(h, 10.0)
126        self.assertEqual((h, x), ([10], 10.0))
127        self.assertEqual(type(h[0]), int)
128        self.assertEqual(type(x), float)
129
130        h = [10];
131        x = self.module.heappushpop(h, 9)
132        self.assertEqual((h, x), ([10], 9))
133
134        h = [10];
135        x = self.module.heappushpop(h, 11)
136        self.assertEqual((h, x), ([11], 10))
137
138    def test_heapsort(self):
139        # Exercise everything with repeated heapsort checks
140        for trial in range(100):
141            size = random.randrange(50)
142            data = [random.randrange(25) for i in range(size)]
143            if trial & 1:     # Half of the time, use heapify
144                heap = data[:]
145                self.module.heapify(heap)
146            else:             # The rest of the time, use heappush
147                heap = []
148                for item in data:
149                    self.module.heappush(heap, item)
150            heap_sorted = [self.module.heappop(heap) for i in range(size)]
151            self.assertEqual(heap_sorted, sorted(data))
152
153    def test_merge(self):
154        inputs = []
155        for i in range(random.randrange(25)):
156            row = []
157            for j in range(random.randrange(100)):
158                tup = random.choice('ABC'), random.randrange(-500, 500)
159                row.append(tup)
160            inputs.append(row)
161
162        for key in [None, itemgetter(0), itemgetter(1), itemgetter(1, 0)]:
163            for reverse in [False, True]:
164                seqs = []
165                for seq in inputs:
166                    seqs.append(sorted(seq, key=key, reverse=reverse))
167                self.assertEqual(sorted(chain(*inputs), key=key, reverse=reverse),
168                                 list(self.module.merge(*seqs, key=key, reverse=reverse)))
169                self.assertEqual(list(self.module.merge()), [])
170
171    def test_merge_does_not_suppress_index_error(self):
172        # Issue 19018: Heapq.merge suppresses IndexError from user generator
173        def iterable():
174            s = list(range(10))
175            for i in range(20):
176                yield s[i]       # IndexError when i > 10
177        with self.assertRaises(IndexError):
178            list(self.module.merge(iterable(), iterable()))
179
180    def test_merge_stability(self):
181        class Int(int):
182            pass
183        inputs = [[], [], [], []]
184        for i in range(20000):
185            stream = random.randrange(4)
186            x = random.randrange(500)
187            obj = Int(x)
188            obj.pair = (x, stream)
189            inputs[stream].append(obj)
190        for stream in inputs:
191            stream.sort()
192        result = [i.pair for i in self.module.merge(*inputs)]
193        self.assertEqual(result, sorted(result))
194
195    def test_nsmallest(self):
196        data = [(random.randrange(2000), i) for i in range(1000)]
197        for f in (None, lambda x:  x[0] * 547 % 2000):
198            for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
199                self.assertEqual(list(self.module.nsmallest(n, data)),
200                                 sorted(data)[:n])
201                self.assertEqual(list(self.module.nsmallest(n, data, key=f)),
202                                 sorted(data, key=f)[:n])
203
204    def test_nlargest(self):
205        data = [(random.randrange(2000), i) for i in range(1000)]
206        for f in (None, lambda x:  x[0] * 547 % 2000):
207            for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
208                self.assertEqual(list(self.module.nlargest(n, data)),
209                                 sorted(data, reverse=True)[:n])
210                self.assertEqual(list(self.module.nlargest(n, data, key=f)),
211                                 sorted(data, key=f, reverse=True)[:n])
212
213    def test_comparison_operator(self):
214        # Issue 3051: Make sure heapq works with both __lt__
215        # For python 3.0, __le__ alone is not enough
216        def hsort(data, comp):
217            data = [comp(x) for x in data]
218            self.module.heapify(data)
219            return [self.module.heappop(data).x for i in range(len(data))]
220        class LT:
221            def __init__(self, x):
222                self.x = x
223            def __lt__(self, other):
224                return self.x > other.x
225        class LE:
226            def __init__(self, x):
227                self.x = x
228            def __le__(self, other):
229                return self.x >= other.x
230        data = [random.random() for i in range(100)]
231        target = sorted(data, reverse=True)
232        self.assertEqual(hsort(data, LT), target)
233        self.assertRaises(TypeError, data, LE)
234
235
236class TestHeapPython(TestHeap, TestCase):
237    module = py_heapq
238
239
240@skipUnless(c_heapq, 'requires _heapq')
241class TestHeapC(TestHeap, TestCase):
242    module = c_heapq
243
244
245#==============================================================================
246
247class LenOnly:
248    "Dummy sequence class defining __len__ but not __getitem__."
249    def __len__(self):
250        return 10
251
252class GetOnly:
253    "Dummy sequence class defining __getitem__ but not __len__."
254    def __getitem__(self, ndx):
255        return 10
256
257class CmpErr:
258    "Dummy element that always raises an error during comparison"
259    def __eq__(self, other):
260        raise ZeroDivisionError
261    __ne__ = __lt__ = __le__ = __gt__ = __ge__ = __eq__
262
263def R(seqn):
264    'Regular generator'
265    for i in seqn:
266        yield i
267
268class G:
269    'Sequence using __getitem__'
270    def __init__(self, seqn):
271        self.seqn = seqn
272    def __getitem__(self, i):
273        return self.seqn[i]
274
275class I:
276    'Sequence using iterator protocol'
277    def __init__(self, seqn):
278        self.seqn = seqn
279        self.i = 0
280    def __iter__(self):
281        return self
282    def __next__(self):
283        if self.i >= len(self.seqn): raise StopIteration
284        v = self.seqn[self.i]
285        self.i += 1
286        return v
287
288class Ig:
289    'Sequence using iterator protocol defined with a generator'
290    def __init__(self, seqn):
291        self.seqn = seqn
292        self.i = 0
293    def __iter__(self):
294        for val in self.seqn:
295            yield val
296
297class X:
298    'Missing __getitem__ and __iter__'
299    def __init__(self, seqn):
300        self.seqn = seqn
301        self.i = 0
302    def __next__(self):
303        if self.i >= len(self.seqn): raise StopIteration
304        v = self.seqn[self.i]
305        self.i += 1
306        return v
307
308class N:
309    'Iterator missing __next__()'
310    def __init__(self, seqn):
311        self.seqn = seqn
312        self.i = 0
313    def __iter__(self):
314        return self
315
316class E:
317    'Test propagation of exceptions'
318    def __init__(self, seqn):
319        self.seqn = seqn
320        self.i = 0
321    def __iter__(self):
322        return self
323    def __next__(self):
324        3 // 0
325
326class S:
327    'Test immediate stop'
328    def __init__(self, seqn):
329        pass
330    def __iter__(self):
331        return self
332    def __next__(self):
333        raise StopIteration
334
335from itertools import chain
336def L(seqn):
337    'Test multiple tiers of iterators'
338    return chain(map(lambda x:x, R(Ig(G(seqn)))))
339
340
341class SideEffectLT:
342    def __init__(self, value, heap):
343        self.value = value
344        self.heap = heap
345
346    def __lt__(self, other):
347        self.heap[:] = []
348        return self.value < other.value
349
350
351class TestErrorHandling:
352
353    def test_non_sequence(self):
354        for f in (self.module.heapify, self.module.heappop):
355            self.assertRaises((TypeError, AttributeError), f, 10)
356        for f in (self.module.heappush, self.module.heapreplace,
357                  self.module.nlargest, self.module.nsmallest):
358            self.assertRaises((TypeError, AttributeError), f, 10, 10)
359
360    def test_len_only(self):
361        for f in (self.module.heapify, self.module.heappop):
362            self.assertRaises((TypeError, AttributeError), f, LenOnly())
363        for f in (self.module.heappush, self.module.heapreplace):
364            self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10)
365        for f in (self.module.nlargest, self.module.nsmallest):
366            self.assertRaises(TypeError, f, 2, LenOnly())
367
368    def test_get_only(self):
369        for f in (self.module.heapify, self.module.heappop):
370            self.assertRaises(TypeError, f, GetOnly())
371        for f in (self.module.heappush, self.module.heapreplace):
372            self.assertRaises(TypeError, f, GetOnly(), 10)
373        for f in (self.module.nlargest, self.module.nsmallest):
374            self.assertRaises(TypeError, f, 2, GetOnly())
375
376    def test_get_only(self):
377        seq = [CmpErr(), CmpErr(), CmpErr()]
378        for f in (self.module.heapify, self.module.heappop):
379            self.assertRaises(ZeroDivisionError, f, seq)
380        for f in (self.module.heappush, self.module.heapreplace):
381            self.assertRaises(ZeroDivisionError, f, seq, 10)
382        for f in (self.module.nlargest, self.module.nsmallest):
383            self.assertRaises(ZeroDivisionError, f, 2, seq)
384
385    def test_arg_parsing(self):
386        for f in (self.module.heapify, self.module.heappop,
387                  self.module.heappush, self.module.heapreplace,
388                  self.module.nlargest, self.module.nsmallest):
389            self.assertRaises((TypeError, AttributeError), f, 10)
390
391    def test_iterable_args(self):
392        for f in (self.module.nlargest, self.module.nsmallest):
393            for s in ("123", "", range(1000), (1, 1.2), range(2000,2200,5)):
394                for g in (G, I, Ig, L, R):
395                    self.assertEqual(list(f(2, g(s))), list(f(2,s)))
396                self.assertEqual(list(f(2, S(s))), [])
397                self.assertRaises(TypeError, f, 2, X(s))
398                self.assertRaises(TypeError, f, 2, N(s))
399                self.assertRaises(ZeroDivisionError, f, 2, E(s))
400
401    # Issue #17278: the heap may change size while it's being walked.
402
403    def test_heappush_mutating_heap(self):
404        heap = []
405        heap.extend(SideEffectLT(i, heap) for i in range(200))
406        # Python version raises IndexError, C version RuntimeError
407        with self.assertRaises((IndexError, RuntimeError)):
408            self.module.heappush(heap, SideEffectLT(5, heap))
409
410    def test_heappop_mutating_heap(self):
411        heap = []
412        heap.extend(SideEffectLT(i, heap) for i in range(200))
413        # Python version raises IndexError, C version RuntimeError
414        with self.assertRaises((IndexError, RuntimeError)):
415            self.module.heappop(heap)
416
417
418class TestErrorHandlingPython(TestErrorHandling, TestCase):
419    module = py_heapq
420
421@skipUnless(c_heapq, 'requires _heapq')
422class TestErrorHandlingC(TestErrorHandling, TestCase):
423    module = c_heapq
424
425
426if __name__ == "__main__":
427    unittest.main()
428