1"""Unit tests for the memoryview
2
3XXX We need more tests! Some tests are in test_bytes
4"""
5
6import unittest
7import sys
8import gc
9import weakref
10import array
11from test import test_support
12import io
13
14
15class AbstractMemoryTests:
16    source_bytes = b"abcdef"
17
18    @property
19    def _source(self):
20        return self.source_bytes
21
22    @property
23    def _types(self):
24        return filter(None, [self.ro_type, self.rw_type])
25
26    def check_getitem_with_type(self, tp):
27        item = self.getitem_type
28        b = tp(self._source)
29        oldrefcount = sys.getrefcount(b)
30        m = self._view(b)
31        self.assertEqual(m[0], item(b"a"))
32        self.assertIsInstance(m[0], bytes)
33        self.assertEqual(m[5], item(b"f"))
34        self.assertEqual(m[-1], item(b"f"))
35        self.assertEqual(m[-6], item(b"a"))
36        # Bounds checking
37        self.assertRaises(IndexError, lambda: m[6])
38        self.assertRaises(IndexError, lambda: m[-7])
39        self.assertRaises(IndexError, lambda: m[sys.maxsize])
40        self.assertRaises(IndexError, lambda: m[-sys.maxsize])
41        # Type checking
42        self.assertRaises(TypeError, lambda: m[None])
43        self.assertRaises(TypeError, lambda: m[0.0])
44        self.assertRaises(TypeError, lambda: m["a"])
45        m = None
46        self.assertEqual(sys.getrefcount(b), oldrefcount)
47
48    def test_getitem(self):
49        for tp in self._types:
50            self.check_getitem_with_type(tp)
51
52    def test_iter(self):
53        for tp in self._types:
54            b = tp(self._source)
55            m = self._view(b)
56            self.assertEqual(list(m), [m[i] for i in range(len(m))])
57
58    def test_repr(self):
59        for tp in self._types:
60            b = tp(self._source)
61            m = self._view(b)
62            self.assertIsInstance(m.__repr__(), str)
63
64    def test_setitem_readonly(self):
65        if not self.ro_type:
66            return
67        b = self.ro_type(self._source)
68        oldrefcount = sys.getrefcount(b)
69        m = self._view(b)
70        def setitem(value):
71            m[0] = value
72        self.assertRaises(TypeError, setitem, b"a")
73        self.assertRaises(TypeError, setitem, 65)
74        self.assertRaises(TypeError, setitem, memoryview(b"a"))
75        m = None
76        self.assertEqual(sys.getrefcount(b), oldrefcount)
77
78    def test_setitem_writable(self):
79        if not self.rw_type:
80            return
81        tp = self.rw_type
82        b = self.rw_type(self._source)
83        oldrefcount = sys.getrefcount(b)
84        m = self._view(b)
85        m[0] = tp(b"0")
86        self._check_contents(tp, b, b"0bcdef")
87        m[1:3] = tp(b"12")
88        self._check_contents(tp, b, b"012def")
89        m[1:1] = tp(b"")
90        self._check_contents(tp, b, b"012def")
91        m[:] = tp(b"abcdef")
92        self._check_contents(tp, b, b"abcdef")
93
94        # Overlapping copies of a view into itself
95        m[0:3] = m[2:5]
96        self._check_contents(tp, b, b"cdedef")
97        m[:] = tp(b"abcdef")
98        m[2:5] = m[0:3]
99        self._check_contents(tp, b, b"ababcf")
100
101        def setitem(key, value):
102            m[key] = tp(value)
103        # Bounds checking
104        self.assertRaises(IndexError, setitem, 6, b"a")
105        self.assertRaises(IndexError, setitem, -7, b"a")
106        self.assertRaises(IndexError, setitem, sys.maxsize, b"a")
107        self.assertRaises(IndexError, setitem, -sys.maxsize, b"a")
108        # Wrong index/slice types
109        self.assertRaises(TypeError, setitem, 0.0, b"a")
110        self.assertRaises(TypeError, setitem, (0,), b"a")
111        self.assertRaises(TypeError, setitem, "a", b"a")
112        # Trying to resize the memory object
113        self.assertRaises(ValueError, setitem, 0, b"")
114        self.assertRaises(ValueError, setitem, 0, b"ab")
115        self.assertRaises(ValueError, setitem, slice(1,1), b"a")
116        self.assertRaises(ValueError, setitem, slice(0,2), b"a")
117
118        m = None
119        self.assertEqual(sys.getrefcount(b), oldrefcount)
120
121    def test_delitem(self):
122        for tp in self._types:
123            b = tp(self._source)
124            m = self._view(b)
125            with self.assertRaises(TypeError):
126                del m[1]
127            with self.assertRaises(TypeError):
128                del m[1:4]
129
130    def test_tobytes(self):
131        for tp in self._types:
132            m = self._view(tp(self._source))
133            b = m.tobytes()
134            # This calls self.getitem_type() on each separate byte of b"abcdef"
135            expected = b"".join(
136                self.getitem_type(c) for c in b"abcdef")
137            self.assertEqual(b, expected)
138            self.assertIsInstance(b, bytes)
139
140    def test_tolist(self):
141        for tp in self._types:
142            m = self._view(tp(self._source))
143            l = m.tolist()
144            self.assertEqual(l, map(ord, b"abcdef"))
145
146    def test_compare(self):
147        # memoryviews can compare for equality with other objects
148        # having the buffer interface.
149        for tp in self._types:
150            m = self._view(tp(self._source))
151            for tp_comp in self._types:
152                self.assertTrue(m == tp_comp(b"abcdef"))
153                self.assertFalse(m != tp_comp(b"abcdef"))
154                self.assertFalse(m == tp_comp(b"abcde"))
155                self.assertTrue(m != tp_comp(b"abcde"))
156                self.assertFalse(m == tp_comp(b"abcde1"))
157                self.assertTrue(m != tp_comp(b"abcde1"))
158            self.assertTrue(m == m)
159            self.assertTrue(m == m[:])
160            self.assertTrue(m[0:6] == m[:])
161            self.assertFalse(m[0:5] == m)
162
163            # Comparison with objects which don't support the buffer API
164            self.assertFalse(m == u"abcdef")
165            self.assertTrue(m != u"abcdef")
166            self.assertFalse(u"abcdef" == m)
167            self.assertTrue(u"abcdef" != m)
168
169            # Unordered comparisons are unimplemented, and therefore give
170            # arbitrary results (they raise a TypeError in py3k)
171
172    def check_attributes_with_type(self, tp):
173        m = self._view(tp(self._source))
174        self.assertEqual(m.format, self.format)
175        self.assertIsInstance(m.format, str)
176        self.assertEqual(m.itemsize, self.itemsize)
177        self.assertEqual(m.ndim, 1)
178        self.assertEqual(m.shape, (6,))
179        self.assertEqual(len(m), 6)
180        self.assertEqual(m.strides, (self.itemsize,))
181        self.assertEqual(m.suboffsets, None)
182        return m
183
184    def test_attributes_readonly(self):
185        if not self.ro_type:
186            return
187        m = self.check_attributes_with_type(self.ro_type)
188        self.assertEqual(m.readonly, True)
189
190    def test_attributes_writable(self):
191        if not self.rw_type:
192            return
193        m = self.check_attributes_with_type(self.rw_type)
194        self.assertEqual(m.readonly, False)
195
196    # Disabled: unicode uses the old buffer API in 2.x
197
198    #def test_getbuffer(self):
199        ## Test PyObject_GetBuffer() on a memoryview object.
200        #for tp in self._types:
201            #b = tp(self._source)
202            #oldrefcount = sys.getrefcount(b)
203            #m = self._view(b)
204            #oldviewrefcount = sys.getrefcount(m)
205            #s = unicode(m, "utf-8")
206            #self._check_contents(tp, b, s.encode("utf-8"))
207            #self.assertEqual(sys.getrefcount(m), oldviewrefcount)
208            #m = None
209            #self.assertEqual(sys.getrefcount(b), oldrefcount)
210
211    def test_gc(self):
212        for tp in self._types:
213            if not isinstance(tp, type):
214                # If tp is a factory rather than a plain type, skip
215                continue
216
217            class MySource(tp):
218                pass
219            class MyObject:
220                pass
221
222            # Create a reference cycle through a memoryview object
223            b = MySource(tp(b'abc'))
224            m = self._view(b)
225            o = MyObject()
226            b.m = m
227            b.o = o
228            wr = weakref.ref(o)
229            b = m = o = None
230            # The cycle must be broken
231            gc.collect()
232            self.assertTrue(wr() is None, wr())
233
234    def test_writable_readonly(self):
235        # Issue #10451: memoryview incorrectly exposes a readonly
236        # buffer as writable causing a segfault if using mmap
237        tp = self.ro_type
238        if tp is None:
239            return
240        b = tp(self._source)
241        m = self._view(b)
242        i = io.BytesIO(b'ZZZZ')
243        self.assertRaises(TypeError, i.readinto, m)
244
245# Variations on source objects for the buffer: bytes-like objects, then arrays
246# with itemsize > 1.
247# NOTE: support for multi-dimensional objects is unimplemented.
248
249class BaseBytesMemoryTests(AbstractMemoryTests):
250    ro_type = bytes
251    rw_type = bytearray
252    getitem_type = bytes
253    itemsize = 1
254    format = 'B'
255
256# Disabled: array.array() does not support the new buffer API in 2.x
257
258#class BaseArrayMemoryTests(AbstractMemoryTests):
259    #ro_type = None
260    #rw_type = lambda self, b: array.array('i', map(ord, b))
261    #getitem_type = lambda self, b: array.array('i', map(ord, b)).tostring()
262    #itemsize = array.array('i').itemsize
263    #format = 'i'
264
265    #def test_getbuffer(self):
266        ## XXX Test should be adapted for non-byte buffers
267        #pass
268
269    #def test_tolist(self):
270        ## XXX NotImplementedError: tolist() only supports byte views
271        #pass
272
273
274# Variations on indirection levels: memoryview, slice of memoryview,
275# slice of slice of memoryview.
276# This is important to test allocation subtleties.
277
278class BaseMemoryviewTests:
279    def _view(self, obj):
280        return memoryview(obj)
281
282    def _check_contents(self, tp, obj, contents):
283        self.assertEqual(obj, tp(contents))
284
285class BaseMemorySliceTests:
286    source_bytes = b"XabcdefY"
287
288    def _view(self, obj):
289        m = memoryview(obj)
290        return m[1:7]
291
292    def _check_contents(self, tp, obj, contents):
293        self.assertEqual(obj[1:7], tp(contents))
294
295    def test_refs(self):
296        for tp in self._types:
297            m = memoryview(tp(self._source))
298            oldrefcount = sys.getrefcount(m)
299            m[1:2]
300            self.assertEqual(sys.getrefcount(m), oldrefcount)
301
302class BaseMemorySliceSliceTests:
303    source_bytes = b"XabcdefY"
304
305    def _view(self, obj):
306        m = memoryview(obj)
307        return m[:7][1:]
308
309    def _check_contents(self, tp, obj, contents):
310        self.assertEqual(obj[1:7], tp(contents))
311
312
313# Concrete test classes
314
315class BytesMemoryviewTest(unittest.TestCase,
316    BaseMemoryviewTests, BaseBytesMemoryTests):
317
318    def test_constructor(self):
319        for tp in self._types:
320            ob = tp(self._source)
321            self.assertTrue(memoryview(ob))
322            self.assertTrue(memoryview(object=ob))
323            self.assertRaises(TypeError, memoryview)
324            self.assertRaises(TypeError, memoryview, ob, ob)
325            self.assertRaises(TypeError, memoryview, argument=ob)
326            self.assertRaises(TypeError, memoryview, ob, argument=True)
327
328#class ArrayMemoryviewTest(unittest.TestCase,
329    #BaseMemoryviewTests, BaseArrayMemoryTests):
330
331    #def test_array_assign(self):
332        ## Issue #4569: segfault when mutating a memoryview with itemsize != 1
333        #a = array.array('i', range(10))
334        #m = memoryview(a)
335        #new_a = array.array('i', range(9, -1, -1))
336        #m[:] = new_a
337        #self.assertEqual(a, new_a)
338
339
340class BytesMemorySliceTest(unittest.TestCase,
341    BaseMemorySliceTests, BaseBytesMemoryTests):
342    pass
343
344#class ArrayMemorySliceTest(unittest.TestCase,
345    #BaseMemorySliceTests, BaseArrayMemoryTests):
346    #pass
347
348class BytesMemorySliceSliceTest(unittest.TestCase,
349    BaseMemorySliceSliceTests, BaseBytesMemoryTests):
350    pass
351
352#class ArrayMemorySliceSliceTest(unittest.TestCase,
353    #BaseMemorySliceSliceTests, BaseArrayMemoryTests):
354    #pass
355
356
357def test_main():
358    test_support.run_unittest(__name__)
359
360if __name__ == "__main__":
361    test_main()
362