1import unittest
2
3class Empty:
4    def __repr__(self):
5        return '<Empty>'
6
7class Cmp:
8    def __init__(self,arg):
9        self.arg = arg
10
11    def __repr__(self):
12        return '<Cmp %s>' % self.arg
13
14    def __eq__(self, other):
15        return self.arg == other
16
17class Anything:
18    def __eq__(self, other):
19        return True
20
21    def __ne__(self, other):
22        return False
23
24class ComparisonTest(unittest.TestCase):
25    set1 = [2, 2.0, 2, 2+0j, Cmp(2.0)]
26    set2 = [[1], (3,), None, Empty()]
27    candidates = set1 + set2
28
29    def test_comparisons(self):
30        for a in self.candidates:
31            for b in self.candidates:
32                if ((a in self.set1) and (b in self.set1)) or a is b:
33                    self.assertEqual(a, b)
34                else:
35                    self.assertNotEqual(a, b)
36
37    def test_id_comparisons(self):
38        # Ensure default comparison compares id() of args
39        L = []
40        for i in range(10):
41            L.insert(len(L)//2, Empty())
42        for a in L:
43            for b in L:
44                self.assertEqual(a == b, id(a) == id(b),
45                                 'a=%r, b=%r' % (a, b))
46
47    def test_ne_defaults_to_not_eq(self):
48        a = Cmp(1)
49        b = Cmp(1)
50        c = Cmp(2)
51        self.assertIs(a == b, True)
52        self.assertIs(a != b, False)
53        self.assertIs(a != c, True)
54
55    def test_ne_high_priority(self):
56        """object.__ne__() should allow reflected __ne__() to be tried"""
57        calls = []
58        class Left:
59            # Inherits object.__ne__()
60            def __eq__(*args):
61                calls.append('Left.__eq__')
62                return NotImplemented
63        class Right:
64            def __eq__(*args):
65                calls.append('Right.__eq__')
66                return NotImplemented
67            def __ne__(*args):
68                calls.append('Right.__ne__')
69                return NotImplemented
70        Left() != Right()
71        self.assertSequenceEqual(calls, ['Left.__eq__', 'Right.__ne__'])
72
73    def test_ne_low_priority(self):
74        """object.__ne__() should not invoke reflected __eq__()"""
75        calls = []
76        class Base:
77            # Inherits object.__ne__()
78            def __eq__(*args):
79                calls.append('Base.__eq__')
80                return NotImplemented
81        class Derived(Base):  # Subclassing forces higher priority
82            def __eq__(*args):
83                calls.append('Derived.__eq__')
84                return NotImplemented
85            def __ne__(*args):
86                calls.append('Derived.__ne__')
87                return NotImplemented
88        Base() != Derived()
89        self.assertSequenceEqual(calls, ['Derived.__ne__', 'Base.__eq__'])
90
91    def test_other_delegation(self):
92        """No default delegation between operations except __ne__()"""
93        ops = (
94            ('__eq__', lambda a, b: a == b),
95            ('__lt__', lambda a, b: a < b),
96            ('__le__', lambda a, b: a <= b),
97            ('__gt__', lambda a, b: a > b),
98            ('__ge__', lambda a, b: a >= b),
99        )
100        for name, func in ops:
101            with self.subTest(name):
102                def unexpected(*args):
103                    self.fail('Unexpected operator method called')
104                class C:
105                    __ne__ = unexpected
106                for other, _ in ops:
107                    if other != name:
108                        setattr(C, other, unexpected)
109                if name == '__eq__':
110                    self.assertIs(func(C(), object()), False)
111                else:
112                    self.assertRaises(TypeError, func, C(), object())
113
114    def test_issue_1393(self):
115        x = lambda: None
116        self.assertEqual(x, Anything())
117        self.assertEqual(Anything(), x)
118        y = object()
119        self.assertEqual(y, Anything())
120        self.assertEqual(Anything(), y)
121
122
123if __name__ == '__main__':
124    unittest.main()
125