1"""Unit tests for the memoryview
2
3   Some tests are in test_bytes. Many tests that require _testbuffer.ndarray
4   are in test_buffer.
5"""
6
7import unittest
8import test.support
9import sys
10import gc
11import weakref
12import array
13import io
14import copy
15import pickle
16
17
18class AbstractMemoryTests:
19    source_bytes = b"abcdef"
20
21    @property
22    def _source(self):
23        return self.source_bytes
24
25    @property
26    def _types(self):
27        return filter(None, [self.ro_type, self.rw_type])
28
29    def check_getitem_with_type(self, tp):
30        b = tp(self._source)
31        oldrefcount = sys.getrefcount(b)
32        m = self._view(b)
33        self.assertEqual(m[0], ord(b"a"))
34        self.assertIsInstance(m[0], int)
35        self.assertEqual(m[5], ord(b"f"))
36        self.assertEqual(m[-1], ord(b"f"))
37        self.assertEqual(m[-6], ord(b"a"))
38        # Bounds checking
39        self.assertRaises(IndexError, lambda: m[6])
40        self.assertRaises(IndexError, lambda: m[-7])
41        self.assertRaises(IndexError, lambda: m[sys.maxsize])
42        self.assertRaises(IndexError, lambda: m[-sys.maxsize])
43        # Type checking
44        self.assertRaises(TypeError, lambda: m[None])
45        self.assertRaises(TypeError, lambda: m[0.0])
46        self.assertRaises(TypeError, lambda: m["a"])
47        m = None
48        self.assertEqual(sys.getrefcount(b), oldrefcount)
49
50    def test_getitem(self):
51        for tp in self._types:
52            self.check_getitem_with_type(tp)
53
54    def test_iter(self):
55        for tp in self._types:
56            b = tp(self._source)
57            m = self._view(b)
58            self.assertEqual(list(m), [m[i] for i in range(len(m))])
59
60    def test_setitem_readonly(self):
61        if not self.ro_type:
62            self.skipTest("no read-only type to test")
63        b = self.ro_type(self._source)
64        oldrefcount = sys.getrefcount(b)
65        m = self._view(b)
66        def setitem(value):
67            m[0] = value
68        self.assertRaises(TypeError, setitem, b"a")
69        self.assertRaises(TypeError, setitem, 65)
70        self.assertRaises(TypeError, setitem, memoryview(b"a"))
71        m = None
72        self.assertEqual(sys.getrefcount(b), oldrefcount)
73
74    def test_setitem_writable(self):
75        if not self.rw_type:
76            self.skipTest("no writable type to test")
77        tp = self.rw_type
78        b = self.rw_type(self._source)
79        oldrefcount = sys.getrefcount(b)
80        m = self._view(b)
81        m[0] = ord(b'1')
82        self._check_contents(tp, b, b"1bcdef")
83        m[0:1] = tp(b"0")
84        self._check_contents(tp, b, b"0bcdef")
85        m[1:3] = tp(b"12")
86        self._check_contents(tp, b, b"012def")
87        m[1:1] = tp(b"")
88        self._check_contents(tp, b, b"012def")
89        m[:] = tp(b"abcdef")
90        self._check_contents(tp, b, b"abcdef")
91
92        # Overlapping copies of a view into itself
93        m[0:3] = m[2:5]
94        self._check_contents(tp, b, b"cdedef")
95        m[:] = tp(b"abcdef")
96        m[2:5] = m[0:3]
97        self._check_contents(tp, b, b"ababcf")
98
99        def setitem(key, value):
100            m[key] = tp(value)
101        # Bounds checking
102        self.assertRaises(IndexError, setitem, 6, b"a")
103        self.assertRaises(IndexError, setitem, -7, b"a")
104        self.assertRaises(IndexError, setitem, sys.maxsize, b"a")
105        self.assertRaises(IndexError, setitem, -sys.maxsize, b"a")
106        # Wrong index/slice types
107        self.assertRaises(TypeError, setitem, 0.0, b"a")
108        self.assertRaises(TypeError, setitem, (0,), b"a")
109        self.assertRaises(TypeError, setitem, (slice(0,1,1), 0), b"a")
110        self.assertRaises(TypeError, setitem, (0, slice(0,1,1)), b"a")
111        self.assertRaises(TypeError, setitem, (0,), b"a")
112        self.assertRaises(TypeError, setitem, "a", b"a")
113        # Not implemented: multidimensional slices
114        slices = (slice(0,1,1), slice(0,1,2))
115        self.assertRaises(NotImplementedError, setitem, slices, b"a")
116        # Trying to resize the memory object
117        exc = ValueError if m.format == 'c' else TypeError
118        self.assertRaises(exc, setitem, 0, b"")
119        self.assertRaises(exc, setitem, 0, b"ab")
120        self.assertRaises(ValueError, setitem, slice(1,1), b"a")
121        self.assertRaises(ValueError, setitem, slice(0,2), b"a")
122
123        m = None
124        self.assertEqual(sys.getrefcount(b), oldrefcount)
125
126    def test_delitem(self):
127        for tp in self._types:
128            b = tp(self._source)
129            m = self._view(b)
130            with self.assertRaises(TypeError):
131                del m[1]
132            with self.assertRaises(TypeError):
133                del m[1:4]
134
135    def test_tobytes(self):
136        for tp in self._types:
137            m = self._view(tp(self._source))
138            b = m.tobytes()
139            # This calls self.getitem_type() on each separate byte of b"abcdef"
140            expected = b"".join(
141                self.getitem_type(bytes([c])) for c in b"abcdef")
142            self.assertEqual(b, expected)
143            self.assertIsInstance(b, bytes)
144
145    def test_tolist(self):
146        for tp in self._types:
147            m = self._view(tp(self._source))
148            l = m.tolist()
149            self.assertEqual(l, list(b"abcdef"))
150
151    def test_compare(self):
152        # memoryviews can compare for equality with other objects
153        # having the buffer interface.
154        for tp in self._types:
155            m = self._view(tp(self._source))
156            for tp_comp in self._types:
157                self.assertTrue(m == tp_comp(b"abcdef"))
158                self.assertFalse(m != tp_comp(b"abcdef"))
159                self.assertFalse(m == tp_comp(b"abcde"))
160                self.assertTrue(m != tp_comp(b"abcde"))
161                self.assertFalse(m == tp_comp(b"abcde1"))
162                self.assertTrue(m != tp_comp(b"abcde1"))
163            self.assertTrue(m == m)
164            self.assertTrue(m == m[:])
165            self.assertTrue(m[0:6] == m[:])
166            self.assertFalse(m[0:5] == m)
167
168            # Comparison with objects which don't support the buffer API
169            self.assertFalse(m == "abcdef")
170            self.assertTrue(m != "abcdef")
171            self.assertFalse("abcdef" == m)
172            self.assertTrue("abcdef" != m)
173
174            # Unordered comparisons
175            for c in (m, b"abcdef"):
176                self.assertRaises(TypeError, lambda: m < c)
177                self.assertRaises(TypeError, lambda: c <= m)
178                self.assertRaises(TypeError, lambda: m >= c)
179                self.assertRaises(TypeError, lambda: c > m)
180
181    def check_attributes_with_type(self, tp):
182        m = self._view(tp(self._source))
183        self.assertEqual(m.format, self.format)
184        self.assertEqual(m.itemsize, self.itemsize)
185        self.assertEqual(m.ndim, 1)
186        self.assertEqual(m.shape, (6,))
187        self.assertEqual(len(m), 6)
188        self.assertEqual(m.strides, (self.itemsize,))
189        self.assertEqual(m.suboffsets, ())
190        return m
191
192    def test_attributes_readonly(self):
193        if not self.ro_type:
194            self.skipTest("no read-only type to test")
195        m = self.check_attributes_with_type(self.ro_type)
196        self.assertEqual(m.readonly, True)
197
198    def test_attributes_writable(self):
199        if not self.rw_type:
200            self.skipTest("no writable type to test")
201        m = self.check_attributes_with_type(self.rw_type)
202        self.assertEqual(m.readonly, False)
203
204    def test_getbuffer(self):
205        # Test PyObject_GetBuffer() on a memoryview object.
206        for tp in self._types:
207            b = tp(self._source)
208            oldrefcount = sys.getrefcount(b)
209            m = self._view(b)
210            oldviewrefcount = sys.getrefcount(m)
211            s = str(m, "utf-8")
212            self._check_contents(tp, b, s.encode("utf-8"))
213            self.assertEqual(sys.getrefcount(m), oldviewrefcount)
214            m = None
215            self.assertEqual(sys.getrefcount(b), oldrefcount)
216
217    def test_gc(self):
218        for tp in self._types:
219            if not isinstance(tp, type):
220                # If tp is a factory rather than a plain type, skip
221                continue
222
223            class MyView():
224                def __init__(self, base):
225                    self.m = memoryview(base)
226            class MySource(tp):
227                pass
228            class MyObject:
229                pass
230
231            # Create a reference cycle through a memoryview object.
232            # This exercises mbuf_clear().
233            b = MySource(tp(b'abc'))
234            m = self._view(b)
235            o = MyObject()
236            b.m = m
237            b.o = o
238            wr = weakref.ref(o)
239            b = m = o = None
240            # The cycle must be broken
241            gc.collect()
242            self.assertTrue(wr() is None, wr())
243
244            # This exercises memory_clear().
245            m = MyView(tp(b'abc'))
246            o = MyObject()
247            m.x = m
248            m.o = o
249            wr = weakref.ref(o)
250            m = o = None
251            # The cycle must be broken
252            gc.collect()
253            self.assertTrue(wr() is None, wr())
254
255    def _check_released(self, m, tp):
256        check = self.assertRaisesRegex(ValueError, "released")
257        with check: bytes(m)
258        with check: m.tobytes()
259        with check: m.tolist()
260        with check: m[0]
261        with check: m[0] = b'x'
262        with check: len(m)
263        with check: m.format
264        with check: m.itemsize
265        with check: m.ndim
266        with check: m.readonly
267        with check: m.shape
268        with check: m.strides
269        with check:
270            with m:
271                pass
272        # str() and repr() still function
273        self.assertIn("released memory", str(m))
274        self.assertIn("released memory", repr(m))
275        self.assertEqual(m, m)
276        self.assertNotEqual(m, memoryview(tp(self._source)))
277        self.assertNotEqual(m, tp(self._source))
278
279    def test_contextmanager(self):
280        for tp in self._types:
281            b = tp(self._source)
282            m = self._view(b)
283            with m as cm:
284                self.assertIs(cm, m)
285            self._check_released(m, tp)
286            m = self._view(b)
287            # Can release explicitly inside the context manager
288            with m:
289                m.release()
290
291    def test_release(self):
292        for tp in self._types:
293            b = tp(self._source)
294            m = self._view(b)
295            m.release()
296            self._check_released(m, tp)
297            # Can be called a second time (it's a no-op)
298            m.release()
299            self._check_released(m, tp)
300
301    def test_writable_readonly(self):
302        # Issue #10451: memoryview incorrectly exposes a readonly
303        # buffer as writable causing a segfault if using mmap
304        tp = self.ro_type
305        if tp is None:
306            self.skipTest("no read-only type to test")
307        b = tp(self._source)
308        m = self._view(b)
309        i = io.BytesIO(b'ZZZZ')
310        self.assertRaises(TypeError, i.readinto, m)
311
312    def test_getbuf_fail(self):
313        self.assertRaises(TypeError, self._view, {})
314
315    def test_hash(self):
316        # Memoryviews of readonly (hashable) types are hashable, and they
317        # hash as hash(obj.tobytes()).
318        tp = self.ro_type
319        if tp is None:
320            self.skipTest("no read-only type to test")
321        b = tp(self._source)
322        m = self._view(b)
323        self.assertEqual(hash(m), hash(b"abcdef"))
324        # Releasing the memoryview keeps the stored hash value (as with weakrefs)
325        m.release()
326        self.assertEqual(hash(m), hash(b"abcdef"))
327        # Hashing a memoryview for the first time after it is released
328        # results in an error (as with weakrefs).
329        m = self._view(b)
330        m.release()
331        self.assertRaises(ValueError, hash, m)
332
333    def test_hash_writable(self):
334        # Memoryviews of writable types are unhashable
335        tp = self.rw_type
336        if tp is None:
337            self.skipTest("no writable type to test")
338        b = tp(self._source)
339        m = self._view(b)
340        self.assertRaises(ValueError, hash, m)
341
342    def test_weakref(self):
343        # Check memoryviews are weakrefable
344        for tp in self._types:
345            b = tp(self._source)
346            m = self._view(b)
347            L = []
348            def callback(wr, b=b):
349                L.append(b)
350            wr = weakref.ref(m, callback)
351            self.assertIs(wr(), m)
352            del m
353            test.support.gc_collect()
354            self.assertIs(wr(), None)
355            self.assertIs(L[0], b)
356
357    def test_reversed(self):
358        for tp in self._types:
359            b = tp(self._source)
360            m = self._view(b)
361            aslist = list(reversed(m.tolist()))
362            self.assertEqual(list(reversed(m)), aslist)
363            self.assertEqual(list(reversed(m)), list(m[::-1]))
364
365    def test_issue22668(self):
366        a = array.array('H', [256, 256, 256, 256])
367        x = memoryview(a)
368        m = x.cast('B')
369        b = m.cast('H')
370        c = b[0:2]
371        d = memoryview(b)
372
373        del b
374
375        self.assertEqual(c[0], 256)
376        self.assertEqual(d[0], 256)
377        self.assertEqual(c.format, "H")
378        self.assertEqual(d.format, "H")
379
380        _ = m.cast('I')
381        self.assertEqual(c[0], 256)
382        self.assertEqual(d[0], 256)
383        self.assertEqual(c.format, "H")
384        self.assertEqual(d.format, "H")
385
386
387# Variations on source objects for the buffer: bytes-like objects, then arrays
388# with itemsize > 1.
389# NOTE: support for multi-dimensional objects is unimplemented.
390
391class BaseBytesMemoryTests(AbstractMemoryTests):
392    ro_type = bytes
393    rw_type = bytearray
394    getitem_type = bytes
395    itemsize = 1
396    format = 'B'
397
398class BaseArrayMemoryTests(AbstractMemoryTests):
399    ro_type = None
400    rw_type = lambda self, b: array.array('i', list(b))
401    getitem_type = lambda self, b: array.array('i', list(b)).tobytes()
402    itemsize = array.array('i').itemsize
403    format = 'i'
404
405    @unittest.skip('XXX test should be adapted for non-byte buffers')
406    def test_getbuffer(self):
407        pass
408
409    @unittest.skip('XXX NotImplementedError: tolist() only supports byte views')
410    def test_tolist(self):
411        pass
412
413
414# Variations on indirection levels: memoryview, slice of memoryview,
415# slice of slice of memoryview.
416# This is important to test allocation subtleties.
417
418class BaseMemoryviewTests:
419    def _view(self, obj):
420        return memoryview(obj)
421
422    def _check_contents(self, tp, obj, contents):
423        self.assertEqual(obj, tp(contents))
424
425class BaseMemorySliceTests:
426    source_bytes = b"XabcdefY"
427
428    def _view(self, obj):
429        m = memoryview(obj)
430        return m[1:7]
431
432    def _check_contents(self, tp, obj, contents):
433        self.assertEqual(obj[1:7], tp(contents))
434
435    def test_refs(self):
436        for tp in self._types:
437            m = memoryview(tp(self._source))
438            oldrefcount = sys.getrefcount(m)
439            m[1:2]
440            self.assertEqual(sys.getrefcount(m), oldrefcount)
441
442class BaseMemorySliceSliceTests:
443    source_bytes = b"XabcdefY"
444
445    def _view(self, obj):
446        m = memoryview(obj)
447        return m[:7][1:]
448
449    def _check_contents(self, tp, obj, contents):
450        self.assertEqual(obj[1:7], tp(contents))
451
452
453# Concrete test classes
454
455class BytesMemoryviewTest(unittest.TestCase,
456    BaseMemoryviewTests, BaseBytesMemoryTests):
457
458    def test_constructor(self):
459        for tp in self._types:
460            ob = tp(self._source)
461            self.assertTrue(memoryview(ob))
462            self.assertTrue(memoryview(object=ob))
463            self.assertRaises(TypeError, memoryview)
464            self.assertRaises(TypeError, memoryview, ob, ob)
465            self.assertRaises(TypeError, memoryview, argument=ob)
466            self.assertRaises(TypeError, memoryview, ob, argument=True)
467
468class ArrayMemoryviewTest(unittest.TestCase,
469    BaseMemoryviewTests, BaseArrayMemoryTests):
470
471    def test_array_assign(self):
472        # Issue #4569: segfault when mutating a memoryview with itemsize != 1
473        a = array.array('i', range(10))
474        m = memoryview(a)
475        new_a = array.array('i', range(9, -1, -1))
476        m[:] = new_a
477        self.assertEqual(a, new_a)
478
479
480class BytesMemorySliceTest(unittest.TestCase,
481    BaseMemorySliceTests, BaseBytesMemoryTests):
482    pass
483
484class ArrayMemorySliceTest(unittest.TestCase,
485    BaseMemorySliceTests, BaseArrayMemoryTests):
486    pass
487
488class BytesMemorySliceSliceTest(unittest.TestCase,
489    BaseMemorySliceSliceTests, BaseBytesMemoryTests):
490    pass
491
492class ArrayMemorySliceSliceTest(unittest.TestCase,
493    BaseMemorySliceSliceTests, BaseArrayMemoryTests):
494    pass
495
496
497class OtherTest(unittest.TestCase):
498    def test_ctypes_cast(self):
499        # Issue 15944: Allow all source formats when casting to bytes.
500        ctypes = test.support.import_module("ctypes")
501        p6 = bytes(ctypes.c_double(0.6))
502
503        d = ctypes.c_double()
504        m = memoryview(d).cast("B")
505        m[:2] = p6[:2]
506        m[2:] = p6[2:]
507        self.assertEqual(d.value, 0.6)
508
509        for format in "Bbc":
510            with self.subTest(format):
511                d = ctypes.c_double()
512                m = memoryview(d).cast(format)
513                m[:2] = memoryview(p6).cast(format)[:2]
514                m[2:] = memoryview(p6).cast(format)[2:]
515                self.assertEqual(d.value, 0.6)
516
517    def test_memoryview_hex(self):
518        # Issue #9951: memoryview.hex() segfaults with non-contiguous buffers.
519        x = b'0' * 200000
520        m1 = memoryview(x)
521        m2 = m1[::-1]
522        self.assertEqual(m2.hex(), '30' * 200000)
523
524    def test_copy(self):
525        m = memoryview(b'abc')
526        with self.assertRaises(TypeError):
527            copy.copy(m)
528
529    def test_pickle(self):
530        m = memoryview(b'abc')
531        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
532            with self.assertRaises(TypeError):
533                pickle.dumps(m, proto)
534
535
536if __name__ == "__main__":
537    unittest.main()
538