1# Tests for rich comparisons
2
3import unittest
4from test import test_support
5
6import operator
7
8class Number:
9
10    def __init__(self, x):
11        self.x = x
12
13    def __lt__(self, other):
14        return self.x < other
15
16    def __le__(self, other):
17        return self.x <= other
18
19    def __eq__(self, other):
20        return self.x == other
21
22    def __ne__(self, other):
23        return self.x != other
24
25    def __gt__(self, other):
26        return self.x > other
27
28    def __ge__(self, other):
29        return self.x >= other
30
31    def __cmp__(self, other):
32        raise test_support.TestFailed, "Number.__cmp__() should not be called"
33
34    def __repr__(self):
35        return "Number(%r)" % (self.x, )
36
37class Vector:
38
39    def __init__(self, data):
40        self.data = data
41
42    def __len__(self):
43        return len(self.data)
44
45    def __getitem__(self, i):
46        return self.data[i]
47
48    def __setitem__(self, i, v):
49        self.data[i] = v
50
51    __hash__ = None # Vectors cannot be hashed
52
53    def __nonzero__(self):
54        raise TypeError, "Vectors cannot be used in Boolean contexts"
55
56    def __cmp__(self, other):
57        raise test_support.TestFailed, "Vector.__cmp__() should not be called"
58
59    def __repr__(self):
60        return "Vector(%r)" % (self.data, )
61
62    def __lt__(self, other):
63        return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
64
65    def __le__(self, other):
66        return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
67
68    def __eq__(self, other):
69        return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
70
71    def __ne__(self, other):
72        return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
73
74    def __gt__(self, other):
75        return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
76
77    def __ge__(self, other):
78        return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
79
80    def __cast(self, other):
81        if isinstance(other, Vector):
82            other = other.data
83        if len(self.data) != len(other):
84            raise ValueError, "Cannot compare vectors of different length"
85        return other
86
87opmap = {
88    "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
89    "le": (lambda a,b: a<=b, operator.le, operator.__le__),
90    "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
91    "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
92    "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
93    "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
94}
95
96class VectorTest(unittest.TestCase):
97
98    def checkfail(self, error, opname, *args):
99        for op in opmap[opname]:
100            self.assertRaises(error, op, *args)
101
102    def checkequal(self, opname, a, b, expres):
103        for op in opmap[opname]:
104            realres = op(a, b)
105            # can't use assertEqual(realres, expres) here
106            self.assertEqual(len(realres), len(expres))
107            for i in xrange(len(realres)):
108                # results are bool, so we can use "is" here
109                self.assertTrue(realres[i] is expres[i])
110
111    def test_mixed(self):
112        # check that comparisons involving Vector objects
113        # which return rich results (i.e. Vectors with itemwise
114        # comparison results) work
115        a = Vector(range(2))
116        b = Vector(range(3))
117        # all comparisons should fail for different length
118        for opname in opmap:
119            self.checkfail(ValueError, opname, a, b)
120
121        a = range(5)
122        b = 5 * [2]
123        # try mixed arguments (but not (a, b) as that won't return a bool vector)
124        args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
125        for (a, b) in args:
126            self.checkequal("lt", a, b, [True,  True,  False, False, False])
127            self.checkequal("le", a, b, [True,  True,  True,  False, False])
128            self.checkequal("eq", a, b, [False, False, True,  False, False])
129            self.checkequal("ne", a, b, [True,  True,  False, True,  True ])
130            self.checkequal("gt", a, b, [False, False, False, True,  True ])
131            self.checkequal("ge", a, b, [False, False, True,  True,  True ])
132
133            for ops in opmap.itervalues():
134                for op in ops:
135                    # calls __nonzero__, which should fail
136                    self.assertRaises(TypeError, bool, op(a, b))
137
138class NumberTest(unittest.TestCase):
139
140    def test_basic(self):
141        # Check that comparisons involving Number objects
142        # give the same results give as comparing the
143        # corresponding ints
144        for a in xrange(3):
145            for b in xrange(3):
146                for typea in (int, Number):
147                    for typeb in (int, Number):
148                        if typea==typeb==int:
149                            continue # the combination int, int is useless
150                        ta = typea(a)
151                        tb = typeb(b)
152                        for ops in opmap.itervalues():
153                            for op in ops:
154                                realoutcome = op(a, b)
155                                testoutcome = op(ta, tb)
156                                self.assertEqual(realoutcome, testoutcome)
157
158    def checkvalue(self, opname, a, b, expres):
159        for typea in (int, Number):
160            for typeb in (int, Number):
161                ta = typea(a)
162                tb = typeb(b)
163                for op in opmap[opname]:
164                    realres = op(ta, tb)
165                    realres = getattr(realres, "x", realres)
166                    self.assertTrue(realres is expres)
167
168    def test_values(self):
169        # check all operators and all comparison results
170        self.checkvalue("lt", 0, 0, False)
171        self.checkvalue("le", 0, 0, True )
172        self.checkvalue("eq", 0, 0, True )
173        self.checkvalue("ne", 0, 0, False)
174        self.checkvalue("gt", 0, 0, False)
175        self.checkvalue("ge", 0, 0, True )
176
177        self.checkvalue("lt", 0, 1, True )
178        self.checkvalue("le", 0, 1, True )
179        self.checkvalue("eq", 0, 1, False)
180        self.checkvalue("ne", 0, 1, True )
181        self.checkvalue("gt", 0, 1, False)
182        self.checkvalue("ge", 0, 1, False)
183
184        self.checkvalue("lt", 1, 0, False)
185        self.checkvalue("le", 1, 0, False)
186        self.checkvalue("eq", 1, 0, False)
187        self.checkvalue("ne", 1, 0, True )
188        self.checkvalue("gt", 1, 0, True )
189        self.checkvalue("ge", 1, 0, True )
190
191class MiscTest(unittest.TestCase):
192
193    def test_misbehavin(self):
194        class Misb:
195            def __lt__(self_, other): return 0
196            def __gt__(self_, other): return 0
197            def __eq__(self_, other): return 0
198            def __le__(self_, other): self.fail("This shouldn't happen")
199            def __ge__(self_, other): self.fail("This shouldn't happen")
200            def __ne__(self_, other): self.fail("This shouldn't happen")
201            def __cmp__(self_, other): raise RuntimeError, "expected"
202        a = Misb()
203        b = Misb()
204        self.assertEqual(a<b, 0)
205        self.assertEqual(a==b, 0)
206        self.assertEqual(a>b, 0)
207        self.assertRaises(RuntimeError, cmp, a, b)
208
209    def test_not(self):
210        # Check that exceptions in __nonzero__ are properly
211        # propagated by the not operator
212        import operator
213        class Exc(Exception):
214            pass
215        class Bad:
216            def __nonzero__(self):
217                raise Exc
218
219        def do(bad):
220            not bad
221
222        for func in (do, operator.not_):
223            self.assertRaises(Exc, func, Bad())
224
225    def test_recursion(self):
226        # Check that comparison for recursive objects fails gracefully
227        from UserList import UserList
228        a = UserList()
229        b = UserList()
230        a.append(b)
231        b.append(a)
232        self.assertRaises(RuntimeError, operator.eq, a, b)
233        self.assertRaises(RuntimeError, operator.ne, a, b)
234        self.assertRaises(RuntimeError, operator.lt, a, b)
235        self.assertRaises(RuntimeError, operator.le, a, b)
236        self.assertRaises(RuntimeError, operator.gt, a, b)
237        self.assertRaises(RuntimeError, operator.ge, a, b)
238
239        b.append(17)
240        # Even recursive lists of different lengths are different,
241        # but they cannot be ordered
242        self.assertTrue(not (a == b))
243        self.assertTrue(a != b)
244        self.assertRaises(RuntimeError, operator.lt, a, b)
245        self.assertRaises(RuntimeError, operator.le, a, b)
246        self.assertRaises(RuntimeError, operator.gt, a, b)
247        self.assertRaises(RuntimeError, operator.ge, a, b)
248        a.append(17)
249        self.assertRaises(RuntimeError, operator.eq, a, b)
250        self.assertRaises(RuntimeError, operator.ne, a, b)
251        a.insert(0, 11)
252        b.insert(0, 12)
253        self.assertTrue(not (a == b))
254        self.assertTrue(a != b)
255        self.assertTrue(a < b)
256
257class DictTest(unittest.TestCase):
258
259    def test_dicts(self):
260        # Verify that __eq__ and __ne__ work for dicts even if the keys and
261        # values don't support anything other than __eq__ and __ne__ (and
262        # __hash__).  Complex numbers are a fine example of that.
263        import random
264        imag1a = {}
265        for i in range(50):
266            imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
267        items = imag1a.items()
268        random.shuffle(items)
269        imag1b = {}
270        for k, v in items:
271            imag1b[k] = v
272        imag2 = imag1b.copy()
273        imag2[k] = v + 1.0
274        self.assertTrue(imag1a == imag1a)
275        self.assertTrue(imag1a == imag1b)
276        self.assertTrue(imag2 == imag2)
277        self.assertTrue(imag1a != imag2)
278        for opname in ("lt", "le", "gt", "ge"):
279            for op in opmap[opname]:
280                self.assertRaises(TypeError, op, imag1a, imag2)
281
282class ListTest(unittest.TestCase):
283
284    def test_coverage(self):
285        # exercise all comparisons for lists
286        x = [42]
287        self.assertIs(x<x, False)
288        self.assertIs(x<=x, True)
289        self.assertIs(x==x, True)
290        self.assertIs(x!=x, False)
291        self.assertIs(x>x, False)
292        self.assertIs(x>=x, True)
293        y = [42, 42]
294        self.assertIs(x<y, True)
295        self.assertIs(x<=y, True)
296        self.assertIs(x==y, False)
297        self.assertIs(x!=y, True)
298        self.assertIs(x>y, False)
299        self.assertIs(x>=y, False)
300
301    def test_badentry(self):
302        # make sure that exceptions for item comparison are properly
303        # propagated in list comparisons
304        class Exc(Exception):
305            pass
306        class Bad:
307            def __eq__(self, other):
308                raise Exc
309
310        x = [Bad()]
311        y = [Bad()]
312
313        for op in opmap["eq"]:
314            self.assertRaises(Exc, op, x, y)
315
316    def test_goodentry(self):
317        # This test exercises the final call to PyObject_RichCompare()
318        # in Objects/listobject.c::list_richcompare()
319        class Good:
320            def __lt__(self, other):
321                return True
322
323        x = [Good()]
324        y = [Good()]
325
326        for op in opmap["lt"]:
327            self.assertIs(op(x, y), True)
328
329def test_main():
330    test_support.run_unittest(VectorTest, NumberTest, MiscTest, ListTest)
331    with test_support.check_py3k_warnings(("dict inequality comparisons "
332                                             "not supported in 3.x",
333                                             DeprecationWarning)):
334        test_support.run_unittest(DictTest)
335
336
337if __name__ == "__main__":
338    test_main()
339