1from test.test_support import have_unicode, run_unittest
2import unittest
3
4
5class base_set:
6    def __init__(self, el):
7        self.el = el
8
9class set(base_set):
10    def __contains__(self, el):
11        return self.el == el
12
13class seq(base_set):
14    def __getitem__(self, n):
15        return [self.el][n]
16
17
18class TestContains(unittest.TestCase):
19    def test_common_tests(self):
20        a = base_set(1)
21        b = set(1)
22        c = seq(1)
23        self.assertIn(1, b)
24        self.assertNotIn(0, b)
25        self.assertIn(1, c)
26        self.assertNotIn(0, c)
27        self.assertRaises(TypeError, lambda: 1 in a)
28        self.assertRaises(TypeError, lambda: 1 not in a)
29
30        # test char in string
31        self.assertIn('c', 'abc')
32        self.assertNotIn('d', 'abc')
33
34        self.assertIn('', '')
35        self.assertIn('', 'abc')
36
37        self.assertRaises(TypeError, lambda: None in 'abc')
38
39    if have_unicode:
40        def test_char_in_unicode(self):
41            self.assertIn('c', unicode('abc'))
42            self.assertNotIn('d', unicode('abc'))
43
44            self.assertIn('', unicode(''))
45            self.assertIn(unicode(''), '')
46            self.assertIn(unicode(''), unicode(''))
47            self.assertIn('', unicode('abc'))
48            self.assertIn(unicode(''), 'abc')
49            self.assertIn(unicode(''), unicode('abc'))
50
51            self.assertRaises(TypeError, lambda: None in unicode('abc'))
52
53            # test Unicode char in Unicode
54            self.assertIn(unicode('c'), unicode('abc'))
55            self.assertNotIn(unicode('d'), unicode('abc'))
56
57            # test Unicode char in string
58            self.assertIn(unicode('c'), 'abc')
59            self.assertNotIn(unicode('d'), 'abc')
60
61    def test_builtin_sequence_types(self):
62        # a collection of tests on builtin sequence types
63        a = range(10)
64        for i in a:
65            self.assertIn(i, a)
66        self.assertNotIn(16, a)
67        self.assertNotIn(a, a)
68
69        a = tuple(a)
70        for i in a:
71            self.assertIn(i, a)
72        self.assertNotIn(16, a)
73        self.assertNotIn(a, a)
74
75        class Deviant1:
76            """Behaves strangely when compared
77
78            This class is designed to make sure that the contains code
79            works when the list is modified during the check.
80            """
81            aList = range(15)
82            def __cmp__(self, other):
83                if other == 12:
84                    self.aList.remove(12)
85                    self.aList.remove(13)
86                    self.aList.remove(14)
87                return 1
88
89        self.assertNotIn(Deviant1(), Deviant1.aList)
90
91        class Deviant2:
92            """Behaves strangely when compared
93
94            This class raises an exception during comparison.  That in
95            turn causes the comparison to fail with a TypeError.
96            """
97            def __cmp__(self, other):
98                if other == 4:
99                    raise RuntimeError, "gotcha"
100
101        try:
102            self.assertNotIn(Deviant2(), a)
103        except TypeError:
104            pass
105
106
107def test_main():
108    run_unittest(TestContains)
109
110if __name__ == '__main__':
111    test_main()
112