1from test import support
2import random
3import unittest
4from functools import cmp_to_key
5
6verbose = support.verbose
7nerrors = 0
8
9
10def check(tag, expected, raw, compare=None):
11    global nerrors
12
13    if verbose:
14        print("    checking", tag)
15
16    orig = raw[:]   # save input in case of error
17    if compare:
18        raw.sort(key=cmp_to_key(compare))
19    else:
20        raw.sort()
21
22    if len(expected) != len(raw):
23        print("error in", tag)
24        print("length mismatch;", len(expected), len(raw))
25        print(expected)
26        print(orig)
27        print(raw)
28        nerrors += 1
29        return
30
31    for i, good in enumerate(expected):
32        maybe = raw[i]
33        if good is not maybe:
34            print("error in", tag)
35            print("out of order at index", i, good, maybe)
36            print(expected)
37            print(orig)
38            print(raw)
39            nerrors += 1
40            return
41
42class TestBase(unittest.TestCase):
43    def testStressfully(self):
44        # Try a variety of sizes at and around powers of 2, and at powers of 10.
45        sizes = [0]
46        for power in range(1, 10):
47            n = 2 ** power
48            sizes.extend(range(n-1, n+2))
49        sizes.extend([10, 100, 1000])
50
51        class Complains(object):
52            maybe_complain = True
53
54            def __init__(self, i):
55                self.i = i
56
57            def __lt__(self, other):
58                if Complains.maybe_complain and random.random() < 0.001:
59                    if verbose:
60                        print("        complaining at", self, other)
61                    raise RuntimeError
62                return self.i < other.i
63
64            def __repr__(self):
65                return "Complains(%d)" % self.i
66
67        class Stable(object):
68            def __init__(self, key, i):
69                self.key = key
70                self.index = i
71
72            def __lt__(self, other):
73                return self.key < other.key
74
75            def __repr__(self):
76                return "Stable(%d, %d)" % (self.key, self.index)
77
78        for n in sizes:
79            x = list(range(n))
80            if verbose:
81                print("Testing size", n)
82
83            s = x[:]
84            check("identity", x, s)
85
86            s = x[:]
87            s.reverse()
88            check("reversed", x, s)
89
90            s = x[:]
91            random.shuffle(s)
92            check("random permutation", x, s)
93
94            y = x[:]
95            y.reverse()
96            s = x[:]
97            check("reversed via function", y, s, lambda a, b: (b>a)-(b<a))
98
99            if verbose:
100                print("    Checking against an insane comparison function.")
101                print("        If the implementation isn't careful, this may segfault.")
102            s = x[:]
103            s.sort(key=cmp_to_key(lambda a, b:  int(random.random() * 3) - 1))
104            check("an insane function left some permutation", x, s)
105
106            if len(x) >= 2:
107                def bad_key(x):
108                    raise RuntimeError
109                s = x[:]
110                self.assertRaises(RuntimeError, s.sort, key=bad_key)
111
112            x = [Complains(i) for i in x]
113            s = x[:]
114            random.shuffle(s)
115            Complains.maybe_complain = True
116            it_complained = False
117            try:
118                s.sort()
119            except RuntimeError:
120                it_complained = True
121            if it_complained:
122                Complains.maybe_complain = False
123                check("exception during sort left some permutation", x, s)
124
125            s = [Stable(random.randrange(10), i) for i in range(n)]
126            augmented = [(e, e.index) for e in s]
127            augmented.sort()    # forced stable because ties broken by index
128            x = [e for e, i in augmented] # a stable sort of s
129            check("stability", x, s)
130
131#==============================================================================
132
133class TestBugs(unittest.TestCase):
134
135    def test_bug453523(self):
136        # bug 453523 -- list.sort() crasher.
137        # If this fails, the most likely outcome is a core dump.
138        # Mutations during a list sort should raise a ValueError.
139
140        class C:
141            def __lt__(self, other):
142                if L and random.random() < 0.75:
143                    L.pop()
144                else:
145                    L.append(3)
146                return random.random() < 0.5
147
148        L = [C() for i in range(50)]
149        self.assertRaises(ValueError, L.sort)
150
151    def test_undetected_mutation(self):
152        # Python 2.4a1 did not always detect mutation
153        memorywaster = []
154        for i in range(20):
155            def mutating_cmp(x, y):
156                L.append(3)
157                L.pop()
158                return (x > y) - (x < y)
159            L = [1,2]
160            self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp))
161            def mutating_cmp(x, y):
162                L.append(3)
163                del L[:]
164                return (x > y) - (x < y)
165            self.assertRaises(ValueError, L.sort, key=cmp_to_key(mutating_cmp))
166            memorywaster = [memorywaster]
167
168#==============================================================================
169
170class TestDecorateSortUndecorate(unittest.TestCase):
171
172    def test_decorated(self):
173        data = 'The quick Brown fox Jumped over The lazy Dog'.split()
174        copy = data[:]
175        random.shuffle(data)
176        data.sort(key=str.lower)
177        def my_cmp(x, y):
178            xlower, ylower = x.lower(), y.lower()
179            return (xlower > ylower) - (xlower < ylower)
180        copy.sort(key=cmp_to_key(my_cmp))
181
182    def test_baddecorator(self):
183        data = 'The quick Brown fox Jumped over The lazy Dog'.split()
184        self.assertRaises(TypeError, data.sort, key=lambda x,y: 0)
185
186    def test_stability(self):
187        data = [(random.randrange(100), i) for i in range(200)]
188        copy = data[:]
189        data.sort(key=lambda t: t[0])   # sort on the random first field
190        copy.sort()                     # sort using both fields
191        self.assertEqual(data, copy)    # should get the same result
192
193    def test_key_with_exception(self):
194        # Verify that the wrapper has been removed
195        data = list(range(-2, 2))
196        dup = data[:]
197        self.assertRaises(ZeroDivisionError, data.sort, key=lambda x: 1/x)
198        self.assertEqual(data, dup)
199
200    def test_key_with_mutation(self):
201        data = list(range(10))
202        def k(x):
203            del data[:]
204            data[:] = range(20)
205            return x
206        self.assertRaises(ValueError, data.sort, key=k)
207
208    def test_key_with_mutating_del(self):
209        data = list(range(10))
210        class SortKiller(object):
211            def __init__(self, x):
212                pass
213            def __del__(self):
214                del data[:]
215                data[:] = range(20)
216            def __lt__(self, other):
217                return id(self) < id(other)
218        self.assertRaises(ValueError, data.sort, key=SortKiller)
219
220    def test_key_with_mutating_del_and_exception(self):
221        data = list(range(10))
222        ## dup = data[:]
223        class SortKiller(object):
224            def __init__(self, x):
225                if x > 2:
226                    raise RuntimeError
227            def __del__(self):
228                del data[:]
229                data[:] = list(range(20))
230        self.assertRaises(RuntimeError, data.sort, key=SortKiller)
231        ## major honking subtlety: we *can't* do:
232        ##
233        ## self.assertEqual(data, dup)
234        ##
235        ## because there is a reference to a SortKiller in the
236        ## traceback and by the time it dies we're outside the call to
237        ## .sort() and so the list protection gimmicks are out of
238        ## date (this cost some brain cells to figure out...).
239
240    def test_reverse(self):
241        data = list(range(100))
242        random.shuffle(data)
243        data.sort(reverse=True)
244        self.assertEqual(data, list(range(99,-1,-1)))
245
246    def test_reverse_stability(self):
247        data = [(random.randrange(100), i) for i in range(200)]
248        copy1 = data[:]
249        copy2 = data[:]
250        def my_cmp(x, y):
251            x0, y0 = x[0], y[0]
252            return (x0 > y0) - (x0 < y0)
253        def my_cmp_reversed(x, y):
254            x0, y0 = x[0], y[0]
255            return (y0 > x0) - (y0 < x0)
256        data.sort(key=cmp_to_key(my_cmp), reverse=True)
257        copy1.sort(key=cmp_to_key(my_cmp_reversed))
258        self.assertEqual(data, copy1)
259        copy2.sort(key=lambda x: x[0], reverse=True)
260        self.assertEqual(data, copy2)
261
262#==============================================================================
263
264if __name__ == "__main__":
265    unittest.main()
266