test_index.py revision ec1a0b3abe08fb9a3952e8f48231cda1f6d9b1f3
1import unittest
2from test import test_support
3import operator
4from sys import maxint
5maxsize = test_support.MAX_Py_ssize_t
6minsize = -maxsize-1
7
8class oldstyle:
9    def __index__(self):
10        return self.ind
11
12class newstyle(object):
13    def __index__(self):
14        return self.ind
15
16class TrapInt(int):
17    def __index__(self):
18        return self
19
20class TrapLong(long):
21    def __index__(self):
22        return self
23
24class BaseTestCase(unittest.TestCase):
25    def setUp(self):
26        self.o = oldstyle()
27        self.n = newstyle()
28
29    def test_basic(self):
30        self.o.ind = -2
31        self.n.ind = 2
32        self.assertEqual(operator.index(self.o), -2)
33        self.assertEqual(operator.index(self.n), 2)
34
35    def test_slice(self):
36        self.o.ind = 1
37        self.n.ind = 2
38        slc = slice(self.o, self.o, self.o)
39        check_slc = slice(1, 1, 1)
40        self.assertEqual(slc.indices(self.o), check_slc.indices(1))
41        slc = slice(self.n, self.n, self.n)
42        check_slc = slice(2, 2, 2)
43        self.assertEqual(slc.indices(self.n), check_slc.indices(2))
44
45    def test_wrappers(self):
46        self.o.ind = 4
47        self.n.ind = 5
48        self.assertEqual(6 .__index__(), 6)
49        self.assertEqual(-7L.__index__(), -7)
50        self.assertEqual(self.o.__index__(), 4)
51        self.assertEqual(self.n.__index__(), 5)
52        self.assertEqual(True.__index__(), 1)
53        self.assertEqual(False.__index__(), 0)
54
55    def test_subclasses(self):
56        r = range(10)
57        self.assertEqual(r[TrapInt(5):TrapInt(10)], r[5:10])
58        self.assertEqual(r[TrapLong(5):TrapLong(10)], r[5:10])
59        self.assertEqual(slice(TrapInt()).indices(0), (0,0,1))
60        self.assertEqual(slice(TrapLong(0)).indices(0), (0,0,1))
61
62    def test_error(self):
63        self.o.ind = 'dumb'
64        self.n.ind = 'bad'
65        self.assertRaises(TypeError, operator.index, self.o)
66        self.assertRaises(TypeError, operator.index, self.n)
67        self.assertRaises(TypeError, slice(self.o).indices, 0)
68        self.assertRaises(TypeError, slice(self.n).indices, 0)
69
70
71class SeqTestCase(unittest.TestCase):
72    # This test case isn't run directly. It just defines common tests
73    # to the different sequence types below
74    def setUp(self):
75        self.o = oldstyle()
76        self.n = newstyle()
77        self.o2 = oldstyle()
78        self.n2 = newstyle()
79
80    def test_index(self):
81        self.o.ind = -2
82        self.n.ind = 2
83        self.assertEqual(self.seq[self.n], self.seq[2])
84        self.assertEqual(self.seq[self.o], self.seq[-2])
85
86    def test_slice(self):
87        self.o.ind = 1
88        self.o2.ind = 3
89        self.n.ind = 2
90        self.n2.ind = 4
91        self.assertEqual(self.seq[self.o:self.o2], self.seq[1:3])
92        self.assertEqual(self.seq[self.n:self.n2], self.seq[2:4])
93
94    def test_slice_bug7532(self):
95        seqlen = len(self.seq)
96        self.o.ind = int(seqlen * 1.5)
97        self.n.ind = seqlen + 2
98        self.assertEqual(self.seq[self.o:], self.seq[0:0])
99        self.assertEqual(self.seq[:self.o], self.seq)
100        self.assertEqual(self.seq[self.n:], self.seq[0:0])
101        self.assertEqual(self.seq[:self.n], self.seq)
102        if isinstance(self.seq, ClassicSeq):
103            return
104        # These tests fail for ClassicSeq (see bug #7532)
105        self.o2.ind = -seqlen - 2
106        self.n2.ind = -int(seqlen * 1.5)
107        self.assertEqual(self.seq[self.o2:], self.seq)
108        self.assertEqual(self.seq[:self.o2], self.seq[0:0])
109        self.assertEqual(self.seq[self.n2:], self.seq)
110        self.assertEqual(self.seq[:self.n2], self.seq[0:0])
111
112    def test_repeat(self):
113        self.o.ind = 3
114        self.n.ind = 2
115        self.assertEqual(self.seq * self.o, self.seq * 3)
116        self.assertEqual(self.seq * self.n, self.seq * 2)
117        self.assertEqual(self.o * self.seq, self.seq * 3)
118        self.assertEqual(self.n * self.seq, self.seq * 2)
119
120    def test_wrappers(self):
121        self.o.ind = 4
122        self.n.ind = 5
123        self.assertEqual(self.seq.__getitem__(self.o), self.seq[4])
124        self.assertEqual(self.seq.__mul__(self.o), self.seq * 4)
125        self.assertEqual(self.seq.__rmul__(self.o), self.seq * 4)
126        self.assertEqual(self.seq.__getitem__(self.n), self.seq[5])
127        self.assertEqual(self.seq.__mul__(self.n), self.seq * 5)
128        self.assertEqual(self.seq.__rmul__(self.n), self.seq * 5)
129
130    def test_subclasses(self):
131        self.assertEqual(self.seq[TrapInt()], self.seq[0])
132        self.assertEqual(self.seq[TrapLong()], self.seq[0])
133
134    def test_error(self):
135        self.o.ind = 'dumb'
136        self.n.ind = 'bad'
137        indexobj = lambda x, obj: obj.seq[x]
138        self.assertRaises(TypeError, indexobj, self.o, self)
139        self.assertRaises(TypeError, indexobj, self.n, self)
140        sliceobj = lambda x, obj: obj.seq[x:]
141        self.assertRaises(TypeError, sliceobj, self.o, self)
142        self.assertRaises(TypeError, sliceobj, self.n, self)
143
144
145class ListTestCase(SeqTestCase):
146    seq = [0,10,20,30,40,50]
147
148    def test_setdelitem(self):
149        self.o.ind = -2
150        self.n.ind = 2
151        lst = list('ab!cdefghi!j')
152        del lst[self.o]
153        del lst[self.n]
154        lst[self.o] = 'X'
155        lst[self.n] = 'Y'
156        self.assertEqual(lst, list('abYdefghXj'))
157
158        lst = [5, 6, 7, 8, 9, 10, 11]
159        lst.__setitem__(self.n, "here")
160        self.assertEqual(lst, [5, 6, "here", 8, 9, 10, 11])
161        lst.__delitem__(self.n)
162        self.assertEqual(lst, [5, 6, 8, 9, 10, 11])
163
164    def test_inplace_repeat(self):
165        self.o.ind = 2
166        self.n.ind = 3
167        lst = [6, 4]
168        lst *= self.o
169        self.assertEqual(lst, [6, 4, 6, 4])
170        lst *= self.n
171        self.assertEqual(lst, [6, 4, 6, 4] * 3)
172
173        lst = [5, 6, 7, 8, 9, 11]
174        l2 = lst.__imul__(self.n)
175        self.assertIs(l2, lst)
176        self.assertEqual(lst, [5, 6, 7, 8, 9, 11] * 3)
177
178
179class _BaseSeq:
180
181    def __init__(self, iterable):
182        self._list = list(iterable)
183
184    def __repr__(self):
185        return repr(self._list)
186
187    def __eq__(self, other):
188        return self._list == other
189
190    def __len__(self):
191        return len(self._list)
192
193    def __mul__(self, n):
194        return self.__class__(self._list*n)
195    __rmul__ = __mul__
196
197    def __getitem__(self, index):
198        return self._list[index]
199
200
201class _GetSliceMixin:
202
203    def __getslice__(self, i, j):
204        return self._list.__getslice__(i, j)
205
206
207class ClassicSeq(_BaseSeq): pass
208class NewSeq(_BaseSeq, object): pass
209class ClassicSeqDeprecated(_GetSliceMixin, ClassicSeq): pass
210class NewSeqDeprecated(_GetSliceMixin, NewSeq): pass
211
212
213class TupleTestCase(SeqTestCase):
214    seq = (0,10,20,30,40,50)
215
216class StringTestCase(SeqTestCase):
217    seq = "this is a test"
218
219class ByteArrayTestCase(SeqTestCase):
220    seq = bytearray("this is a test")
221
222class UnicodeTestCase(SeqTestCase):
223    seq = u"this is a test"
224
225class ClassicSeqTestCase(SeqTestCase):
226    seq = ClassicSeq((0,10,20,30,40,50))
227
228class NewSeqTestCase(SeqTestCase):
229    seq = NewSeq((0,10,20,30,40,50))
230
231class ClassicSeqDeprecatedTestCase(SeqTestCase):
232    seq = ClassicSeqDeprecated((0,10,20,30,40,50))
233
234class NewSeqDeprecatedTestCase(SeqTestCase):
235    seq = NewSeqDeprecated((0,10,20,30,40,50))
236
237
238class XRangeTestCase(unittest.TestCase):
239
240    def test_xrange(self):
241        n = newstyle()
242        n.ind = 5
243        self.assertEqual(xrange(1, 20)[n], 6)
244        self.assertEqual(xrange(1, 20).__getitem__(n), 6)
245
246class OverflowTestCase(unittest.TestCase):
247
248    def setUp(self):
249        self.pos = 2**100
250        self.neg = -self.pos
251
252    def test_large_longs(self):
253        self.assertEqual(self.pos.__index__(), self.pos)
254        self.assertEqual(self.neg.__index__(), self.neg)
255
256    def _getitem_helper(self, base):
257        class GetItem(base):
258            def __len__(self):
259                return maxint # cannot return long here
260            def __getitem__(self, key):
261                return key
262        x = GetItem()
263        self.assertEqual(x[self.pos], self.pos)
264        self.assertEqual(x[self.neg], self.neg)
265        self.assertEqual(x[self.neg:self.pos].indices(maxsize),
266                         (0, maxsize, 1))
267        self.assertEqual(x[self.neg:self.pos:1].indices(maxsize),
268                         (0, maxsize, 1))
269
270    def _getslice_helper_deprecated(self, base):
271        class GetItem(base):
272            def __len__(self):
273                return maxint # cannot return long here
274            def __getitem__(self, key):
275                return key
276            def __getslice__(self, i, j):
277                return i, j
278        x = GetItem()
279        self.assertEqual(x[self.pos], self.pos)
280        self.assertEqual(x[self.neg], self.neg)
281        self.assertEqual(x[self.neg:self.pos], (maxint+minsize, maxsize))
282        self.assertEqual(x[self.neg:self.pos:1].indices(maxsize),
283                         (0, maxsize, 1))
284
285    def test_getitem(self):
286        self._getitem_helper(object)
287        with test_support.check_py3k_warnings():
288            self._getslice_helper_deprecated(object)
289
290    def test_getitem_classic(self):
291        class Empty: pass
292        # XXX This test fails (see bug #7532)
293        #self._getitem_helper(Empty)
294        with test_support.check_py3k_warnings():
295            self._getslice_helper_deprecated(Empty)
296
297    def test_sequence_repeat(self):
298        self.assertRaises(OverflowError, lambda: "a" * self.pos)
299        self.assertRaises(OverflowError, lambda: "a" * self.neg)
300
301
302def test_main():
303    test_support.run_unittest(
304        BaseTestCase,
305        ListTestCase,
306        TupleTestCase,
307        ByteArrayTestCase,
308        StringTestCase,
309        UnicodeTestCase,
310        ClassicSeqTestCase,
311        NewSeqTestCase,
312        XRangeTestCase,
313        OverflowTestCase,
314    )
315    with test_support.check_py3k_warnings():
316        test_support.run_unittest(
317            ClassicSeqDeprecatedTestCase,
318            NewSeqDeprecatedTestCase,
319        )
320
321
322if __name__ == "__main__":
323    test_main()
324