1import unittest
2import operator
3import sys
4import pickle
5
6from test import support
7
8class G:
9    'Sequence using __getitem__'
10    def __init__(self, seqn):
11        self.seqn = seqn
12    def __getitem__(self, i):
13        return self.seqn[i]
14
15class I:
16    'Sequence using iterator protocol'
17    def __init__(self, seqn):
18        self.seqn = seqn
19        self.i = 0
20    def __iter__(self):
21        return self
22    def __next__(self):
23        if self.i >= len(self.seqn): raise StopIteration
24        v = self.seqn[self.i]
25        self.i += 1
26        return v
27
28class Ig:
29    'Sequence using iterator protocol defined with a generator'
30    def __init__(self, seqn):
31        self.seqn = seqn
32        self.i = 0
33    def __iter__(self):
34        for val in self.seqn:
35            yield val
36
37class X:
38    'Missing __getitem__ and __iter__'
39    def __init__(self, seqn):
40        self.seqn = seqn
41        self.i = 0
42    def __next__(self):
43        if self.i >= len(self.seqn): raise StopIteration
44        v = self.seqn[self.i]
45        self.i += 1
46        return v
47
48class E:
49    'Test propagation of exceptions'
50    def __init__(self, seqn):
51        self.seqn = seqn
52        self.i = 0
53    def __iter__(self):
54        return self
55    def __next__(self):
56        3 // 0
57
58class N:
59    'Iterator missing __next__()'
60    def __init__(self, seqn):
61        self.seqn = seqn
62        self.i = 0
63    def __iter__(self):
64        return self
65
66class PickleTest:
67    # Helper to check picklability
68    def check_pickle(self, itorg, seq):
69        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
70            d = pickle.dumps(itorg, proto)
71            it = pickle.loads(d)
72            self.assertEqual(type(itorg), type(it))
73            self.assertEqual(list(it), seq)
74
75            it = pickle.loads(d)
76            try:
77                next(it)
78            except StopIteration:
79                self.assertFalse(seq[1:])
80                continue
81            d = pickle.dumps(it, proto)
82            it = pickle.loads(d)
83            self.assertEqual(list(it), seq[1:])
84
85class EnumerateTestCase(unittest.TestCase, PickleTest):
86
87    enum = enumerate
88    seq, res = 'abc', [(0,'a'), (1,'b'), (2,'c')]
89
90    def test_basicfunction(self):
91        self.assertEqual(type(self.enum(self.seq)), self.enum)
92        e = self.enum(self.seq)
93        self.assertEqual(iter(e), e)
94        self.assertEqual(list(self.enum(self.seq)), self.res)
95        self.enum.__doc__
96
97    def test_pickle(self):
98        self.check_pickle(self.enum(self.seq), self.res)
99
100    def test_getitemseqn(self):
101        self.assertEqual(list(self.enum(G(self.seq))), self.res)
102        e = self.enum(G(''))
103        self.assertRaises(StopIteration, next, e)
104
105    def test_iteratorseqn(self):
106        self.assertEqual(list(self.enum(I(self.seq))), self.res)
107        e = self.enum(I(''))
108        self.assertRaises(StopIteration, next, e)
109
110    def test_iteratorgenerator(self):
111        self.assertEqual(list(self.enum(Ig(self.seq))), self.res)
112        e = self.enum(Ig(''))
113        self.assertRaises(StopIteration, next, e)
114
115    def test_noniterable(self):
116        self.assertRaises(TypeError, self.enum, X(self.seq))
117
118    def test_illformediterable(self):
119        self.assertRaises(TypeError, self.enum, N(self.seq))
120
121    def test_exception_propagation(self):
122        self.assertRaises(ZeroDivisionError, list, self.enum(E(self.seq)))
123
124    def test_argumentcheck(self):
125        self.assertRaises(TypeError, self.enum) # no arguments
126        self.assertRaises(TypeError, self.enum, 1) # wrong type (not iterable)
127        self.assertRaises(TypeError, self.enum, 'abc', 'a') # wrong type
128        self.assertRaises(TypeError, self.enum, 'abc', 2, 3) # too many arguments
129
130    @support.cpython_only
131    def test_tuple_reuse(self):
132        # Tests an implementation detail where tuple is reused
133        # whenever nothing else holds a reference to it
134        self.assertEqual(len(set(map(id, list(enumerate(self.seq))))), len(self.seq))
135        self.assertEqual(len(set(map(id, enumerate(self.seq)))), min(1,len(self.seq)))
136
137class MyEnum(enumerate):
138    pass
139
140class SubclassTestCase(EnumerateTestCase):
141
142    enum = MyEnum
143
144class TestEmpty(EnumerateTestCase):
145
146    seq, res = '', []
147
148class TestBig(EnumerateTestCase):
149
150    seq = range(10,20000,2)
151    res = list(zip(range(20000), seq))
152
153class TestReversed(unittest.TestCase, PickleTest):
154
155    def test_simple(self):
156        class A:
157            def __getitem__(self, i):
158                if i < 5:
159                    return str(i)
160                raise StopIteration
161            def __len__(self):
162                return 5
163        for data in 'abc', range(5), tuple(enumerate('abc')), A(), range(1,17,5):
164            self.assertEqual(list(data)[::-1], list(reversed(data)))
165        self.assertRaises(TypeError, reversed, {})
166        # don't allow keyword arguments
167        self.assertRaises(TypeError, reversed, [], a=1)
168
169    def test_range_optimization(self):
170        x = range(1)
171        self.assertEqual(type(reversed(x)), type(iter(x)))
172
173    def test_len(self):
174        for s in ('hello', tuple('hello'), list('hello'), range(5)):
175            self.assertEqual(operator.length_hint(reversed(s)), len(s))
176            r = reversed(s)
177            list(r)
178            self.assertEqual(operator.length_hint(r), 0)
179        class SeqWithWeirdLen:
180            called = False
181            def __len__(self):
182                if not self.called:
183                    self.called = True
184                    return 10
185                raise ZeroDivisionError
186            def __getitem__(self, index):
187                return index
188        r = reversed(SeqWithWeirdLen())
189        self.assertRaises(ZeroDivisionError, operator.length_hint, r)
190
191
192    def test_gc(self):
193        class Seq:
194            def __len__(self):
195                return 10
196            def __getitem__(self, index):
197                return index
198        s = Seq()
199        r = reversed(s)
200        s.r = r
201
202    def test_args(self):
203        self.assertRaises(TypeError, reversed)
204        self.assertRaises(TypeError, reversed, [], 'extra')
205
206    @unittest.skipUnless(hasattr(sys, 'getrefcount'), 'test needs sys.getrefcount()')
207    def test_bug1229429(self):
208        # this bug was never in reversed, it was in
209        # PyObject_CallMethod, and reversed_new calls that sometimes.
210        def f():
211            pass
212        r = f.__reversed__ = object()
213        rc = sys.getrefcount(r)
214        for i in range(10):
215            try:
216                reversed(f)
217            except TypeError:
218                pass
219            else:
220                self.fail("non-callable __reversed__ didn't raise!")
221        self.assertEqual(rc, sys.getrefcount(r))
222
223    def test_objmethods(self):
224        # Objects must have __len__() and __getitem__() implemented.
225        class NoLen(object):
226            def __getitem__(self, i): return 1
227        nl = NoLen()
228        self.assertRaises(TypeError, reversed, nl)
229
230        class NoGetItem(object):
231            def __len__(self): return 2
232        ngi = NoGetItem()
233        self.assertRaises(TypeError, reversed, ngi)
234
235        class Blocked(object):
236            def __getitem__(self, i): return 1
237            def __len__(self): return 2
238            __reversed__ = None
239        b = Blocked()
240        self.assertRaises(TypeError, reversed, b)
241
242    def test_pickle(self):
243        for data in 'abc', range(5), tuple(enumerate('abc')), range(1,17,5):
244            self.check_pickle(reversed(data), list(data)[::-1])
245
246
247class EnumerateStartTestCase(EnumerateTestCase):
248
249    def test_basicfunction(self):
250        e = self.enum(self.seq)
251        self.assertEqual(iter(e), e)
252        self.assertEqual(list(self.enum(self.seq)), self.res)
253
254
255class TestStart(EnumerateStartTestCase):
256
257    enum = lambda self, i: enumerate(i, start=11)
258    seq, res = 'abc', [(11, 'a'), (12, 'b'), (13, 'c')]
259
260
261class TestLongStart(EnumerateStartTestCase):
262
263    enum = lambda self, i: enumerate(i, start=sys.maxsize+1)
264    seq, res = 'abc', [(sys.maxsize+1,'a'), (sys.maxsize+2,'b'),
265                       (sys.maxsize+3,'c')]
266
267
268if __name__ == "__main__":
269    unittest.main()
270