1# Tests for rich comparisons
2
3import unittest
4from test import support
5
6import operator
7
8class Number:
9
10    def __init__(self, x):
11        self.x = x
12
13    def __lt__(self, other):
14        return self.x < other
15
16    def __le__(self, other):
17        return self.x <= other
18
19    def __eq__(self, other):
20        return self.x == other
21
22    def __ne__(self, other):
23        return self.x != other
24
25    def __gt__(self, other):
26        return self.x > other
27
28    def __ge__(self, other):
29        return self.x >= other
30
31    def __cmp__(self, other):
32        raise support.TestFailed("Number.__cmp__() should not be called")
33
34    def __repr__(self):
35        return "Number(%r)" % (self.x, )
36
37class Vector:
38
39    def __init__(self, data):
40        self.data = data
41
42    def __len__(self):
43        return len(self.data)
44
45    def __getitem__(self, i):
46        return self.data[i]
47
48    def __setitem__(self, i, v):
49        self.data[i] = v
50
51    __hash__ = None # Vectors cannot be hashed
52
53    def __bool__(self):
54        raise TypeError("Vectors cannot be used in Boolean contexts")
55
56    def __cmp__(self, other):
57        raise support.TestFailed("Vector.__cmp__() should not be called")
58
59    def __repr__(self):
60        return "Vector(%r)" % (self.data, )
61
62    def __lt__(self, other):
63        return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
64
65    def __le__(self, other):
66        return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
67
68    def __eq__(self, other):
69        return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
70
71    def __ne__(self, other):
72        return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
73
74    def __gt__(self, other):
75        return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
76
77    def __ge__(self, other):
78        return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
79
80    def __cast(self, other):
81        if isinstance(other, Vector):
82            other = other.data
83        if len(self.data) != len(other):
84            raise ValueError("Cannot compare vectors of different length")
85        return other
86
87opmap = {
88    "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
89    "le": (lambda a,b: a<=b, operator.le, operator.__le__),
90    "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
91    "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
92    "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
93    "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
94}
95
96class VectorTest(unittest.TestCase):
97
98    def checkfail(self, error, opname, *args):
99        for op in opmap[opname]:
100            self.assertRaises(error, op, *args)
101
102    def checkequal(self, opname, a, b, expres):
103        for op in opmap[opname]:
104            realres = op(a, b)
105            # can't use assertEqual(realres, expres) here
106            self.assertEqual(len(realres), len(expres))
107            for i in range(len(realres)):
108                # results are bool, so we can use "is" here
109                self.assertTrue(realres[i] is expres[i])
110
111    def test_mixed(self):
112        # check that comparisons involving Vector objects
113        # which return rich results (i.e. Vectors with itemwise
114        # comparison results) work
115        a = Vector(range(2))
116        b = Vector(range(3))
117        # all comparisons should fail for different length
118        for opname in opmap:
119            self.checkfail(ValueError, opname, a, b)
120
121        a = list(range(5))
122        b = 5 * [2]
123        # try mixed arguments (but not (a, b) as that won't return a bool vector)
124        args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
125        for (a, b) in args:
126            self.checkequal("lt", a, b, [True,  True,  False, False, False])
127            self.checkequal("le", a, b, [True,  True,  True,  False, False])
128            self.checkequal("eq", a, b, [False, False, True,  False, False])
129            self.checkequal("ne", a, b, [True,  True,  False, True,  True ])
130            self.checkequal("gt", a, b, [False, False, False, True,  True ])
131            self.checkequal("ge", a, b, [False, False, True,  True,  True ])
132
133            for ops in opmap.values():
134                for op in ops:
135                    # calls __bool__, which should fail
136                    self.assertRaises(TypeError, bool, op(a, b))
137
138class NumberTest(unittest.TestCase):
139
140    def test_basic(self):
141        # Check that comparisons involving Number objects
142        # give the same results give as comparing the
143        # corresponding ints
144        for a in range(3):
145            for b in range(3):
146                for typea in (int, Number):
147                    for typeb in (int, Number):
148                        if typea==typeb==int:
149                            continue # the combination int, int is useless
150                        ta = typea(a)
151                        tb = typeb(b)
152                        for ops in opmap.values():
153                            for op in ops:
154                                realoutcome = op(a, b)
155                                testoutcome = op(ta, tb)
156                                self.assertEqual(realoutcome, testoutcome)
157
158    def checkvalue(self, opname, a, b, expres):
159        for typea in (int, Number):
160            for typeb in (int, Number):
161                ta = typea(a)
162                tb = typeb(b)
163                for op in opmap[opname]:
164                    realres = op(ta, tb)
165                    realres = getattr(realres, "x", realres)
166                    self.assertTrue(realres is expres)
167
168    def test_values(self):
169        # check all operators and all comparison results
170        self.checkvalue("lt", 0, 0, False)
171        self.checkvalue("le", 0, 0, True )
172        self.checkvalue("eq", 0, 0, True )
173        self.checkvalue("ne", 0, 0, False)
174        self.checkvalue("gt", 0, 0, False)
175        self.checkvalue("ge", 0, 0, True )
176
177        self.checkvalue("lt", 0, 1, True )
178        self.checkvalue("le", 0, 1, True )
179        self.checkvalue("eq", 0, 1, False)
180        self.checkvalue("ne", 0, 1, True )
181        self.checkvalue("gt", 0, 1, False)
182        self.checkvalue("ge", 0, 1, False)
183
184        self.checkvalue("lt", 1, 0, False)
185        self.checkvalue("le", 1, 0, False)
186        self.checkvalue("eq", 1, 0, False)
187        self.checkvalue("ne", 1, 0, True )
188        self.checkvalue("gt", 1, 0, True )
189        self.checkvalue("ge", 1, 0, True )
190
191class MiscTest(unittest.TestCase):
192
193    def test_misbehavin(self):
194        class Misb:
195            def __lt__(self_, other): return 0
196            def __gt__(self_, other): return 0
197            def __eq__(self_, other): return 0
198            def __le__(self_, other): self.fail("This shouldn't happen")
199            def __ge__(self_, other): self.fail("This shouldn't happen")
200            def __ne__(self_, other): self.fail("This shouldn't happen")
201        a = Misb()
202        b = Misb()
203        self.assertEqual(a<b, 0)
204        self.assertEqual(a==b, 0)
205        self.assertEqual(a>b, 0)
206
207    def test_not(self):
208        # Check that exceptions in __bool__ are properly
209        # propagated by the not operator
210        import operator
211        class Exc(Exception):
212            pass
213        class Bad:
214            def __bool__(self):
215                raise Exc
216
217        def do(bad):
218            not bad
219
220        for func in (do, operator.not_):
221            self.assertRaises(Exc, func, Bad())
222
223    @support.no_tracing
224    def test_recursion(self):
225        # Check that comparison for recursive objects fails gracefully
226        from collections import UserList
227        a = UserList()
228        b = UserList()
229        a.append(b)
230        b.append(a)
231        self.assertRaises(RecursionError, operator.eq, a, b)
232        self.assertRaises(RecursionError, operator.ne, a, b)
233        self.assertRaises(RecursionError, operator.lt, a, b)
234        self.assertRaises(RecursionError, operator.le, a, b)
235        self.assertRaises(RecursionError, operator.gt, a, b)
236        self.assertRaises(RecursionError, operator.ge, a, b)
237
238        b.append(17)
239        # Even recursive lists of different lengths are different,
240        # but they cannot be ordered
241        self.assertTrue(not (a == b))
242        self.assertTrue(a != b)
243        self.assertRaises(RecursionError, operator.lt, a, b)
244        self.assertRaises(RecursionError, operator.le, a, b)
245        self.assertRaises(RecursionError, operator.gt, a, b)
246        self.assertRaises(RecursionError, operator.ge, a, b)
247        a.append(17)
248        self.assertRaises(RecursionError, operator.eq, a, b)
249        self.assertRaises(RecursionError, operator.ne, a, b)
250        a.insert(0, 11)
251        b.insert(0, 12)
252        self.assertTrue(not (a == b))
253        self.assertTrue(a != b)
254        self.assertTrue(a < b)
255
256    def test_exception_message(self):
257        class Spam:
258            pass
259
260        tests = [
261            (lambda: 42 < None, r"'<' .* of 'int' and 'NoneType'"),
262            (lambda: None < 42, r"'<' .* of 'NoneType' and 'int'"),
263            (lambda: 42 > None, r"'>' .* of 'int' and 'NoneType'"),
264            (lambda: "foo" < None, r"'<' .* of 'str' and 'NoneType'"),
265            (lambda: "foo" >= 666, r"'>=' .* of 'str' and 'int'"),
266            (lambda: 42 <= None, r"'<=' .* of 'int' and 'NoneType'"),
267            (lambda: 42 >= None, r"'>=' .* of 'int' and 'NoneType'"),
268            (lambda: 42 < [], r"'<' .* of 'int' and 'list'"),
269            (lambda: () > [], r"'>' .* of 'tuple' and 'list'"),
270            (lambda: None >= None, r"'>=' .* of 'NoneType' and 'NoneType'"),
271            (lambda: Spam() < 42, r"'<' .* of 'Spam' and 'int'"),
272            (lambda: 42 < Spam(), r"'<' .* of 'int' and 'Spam'"),
273            (lambda: Spam() <= Spam(), r"'<=' .* of 'Spam' and 'Spam'"),
274        ]
275        for i, test in enumerate(tests):
276            with self.subTest(test=i):
277                with self.assertRaisesRegex(TypeError, test[1]):
278                    test[0]()
279
280
281class DictTest(unittest.TestCase):
282
283    def test_dicts(self):
284        # Verify that __eq__ and __ne__ work for dicts even if the keys and
285        # values don't support anything other than __eq__ and __ne__ (and
286        # __hash__).  Complex numbers are a fine example of that.
287        import random
288        imag1a = {}
289        for i in range(50):
290            imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
291        items = list(imag1a.items())
292        random.shuffle(items)
293        imag1b = {}
294        for k, v in items:
295            imag1b[k] = v
296        imag2 = imag1b.copy()
297        imag2[k] = v + 1.0
298        self.assertEqual(imag1a, imag1a)
299        self.assertEqual(imag1a, imag1b)
300        self.assertEqual(imag2, imag2)
301        self.assertTrue(imag1a != imag2)
302        for opname in ("lt", "le", "gt", "ge"):
303            for op in opmap[opname]:
304                self.assertRaises(TypeError, op, imag1a, imag2)
305
306class ListTest(unittest.TestCase):
307
308    def test_coverage(self):
309        # exercise all comparisons for lists
310        x = [42]
311        self.assertIs(x<x, False)
312        self.assertIs(x<=x, True)
313        self.assertIs(x==x, True)
314        self.assertIs(x!=x, False)
315        self.assertIs(x>x, False)
316        self.assertIs(x>=x, True)
317        y = [42, 42]
318        self.assertIs(x<y, True)
319        self.assertIs(x<=y, True)
320        self.assertIs(x==y, False)
321        self.assertIs(x!=y, True)
322        self.assertIs(x>y, False)
323        self.assertIs(x>=y, False)
324
325    def test_badentry(self):
326        # make sure that exceptions for item comparison are properly
327        # propagated in list comparisons
328        class Exc(Exception):
329            pass
330        class Bad:
331            def __eq__(self, other):
332                raise Exc
333
334        x = [Bad()]
335        y = [Bad()]
336
337        for op in opmap["eq"]:
338            self.assertRaises(Exc, op, x, y)
339
340    def test_goodentry(self):
341        # This test exercises the final call to PyObject_RichCompare()
342        # in Objects/listobject.c::list_richcompare()
343        class Good:
344            def __lt__(self, other):
345                return True
346
347        x = [Good()]
348        y = [Good()]
349
350        for op in opmap["lt"]:
351            self.assertIs(op(x, y), True)
352
353
354if __name__ == "__main__":
355    unittest.main()
356