1from test.test_support import findfile, TESTFN, unlink
2import unittest
3import array
4import io
5import pickle
6import sys
7import base64
8
9class UnseekableIO(file):
10    def tell(self):
11        raise io.UnsupportedOperation
12
13    def seek(self, *args, **kwargs):
14        raise io.UnsupportedOperation
15
16def fromhex(s):
17    return base64.b16decode(s.replace(' ', ''))
18
19def byteswap2(data):
20    a = array.array('h')
21    a.fromstring(data)
22    a.byteswap()
23    return a.tostring()
24
25def byteswap3(data):
26    ba = bytearray(data)
27    ba[::3] = data[2::3]
28    ba[2::3] = data[::3]
29    return bytes(ba)
30
31def byteswap4(data):
32    a = array.array('i')
33    a.fromstring(data)
34    a.byteswap()
35    return a.tostring()
36
37
38class AudioTests:
39    close_fd = False
40
41    def setUp(self):
42        self.f = self.fout = None
43
44    def tearDown(self):
45        if self.f is not None:
46            self.f.close()
47        if self.fout is not None:
48            self.fout.close()
49        unlink(TESTFN)
50
51    def check_params(self, f, nchannels, sampwidth, framerate, nframes,
52                     comptype, compname):
53        self.assertEqual(f.getnchannels(), nchannels)
54        self.assertEqual(f.getsampwidth(), sampwidth)
55        self.assertEqual(f.getframerate(), framerate)
56        self.assertEqual(f.getnframes(), nframes)
57        self.assertEqual(f.getcomptype(), comptype)
58        self.assertEqual(f.getcompname(), compname)
59
60        params = f.getparams()
61        self.assertEqual(params,
62                (nchannels, sampwidth, framerate, nframes, comptype, compname))
63
64        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
65            dump = pickle.dumps(params, proto)
66            self.assertEqual(pickle.loads(dump), params)
67
68
69class AudioWriteTests(AudioTests):
70
71    def create_file(self, testfile):
72        f = self.fout = self.module.open(testfile, 'wb')
73        f.setnchannels(self.nchannels)
74        f.setsampwidth(self.sampwidth)
75        f.setframerate(self.framerate)
76        f.setcomptype(self.comptype, self.compname)
77        return f
78
79    def check_file(self, testfile, nframes, frames):
80        f = self.module.open(testfile, 'rb')
81        try:
82            self.assertEqual(f.getnchannels(), self.nchannels)
83            self.assertEqual(f.getsampwidth(), self.sampwidth)
84            self.assertEqual(f.getframerate(), self.framerate)
85            self.assertEqual(f.getnframes(), nframes)
86            self.assertEqual(f.readframes(nframes), frames)
87        finally:
88            f.close()
89
90    def test_write_params(self):
91        f = self.create_file(TESTFN)
92        f.setnframes(self.nframes)
93        f.writeframes(self.frames)
94        self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
95                          self.nframes, self.comptype, self.compname)
96        f.close()
97
98    def test_write(self):
99        f = self.create_file(TESTFN)
100        f.setnframes(self.nframes)
101        f.writeframes(self.frames)
102        f.close()
103
104        self.check_file(TESTFN, self.nframes, self.frames)
105
106    def test_incompleted_write(self):
107        with open(TESTFN, 'wb') as testfile:
108            testfile.write(b'ababagalamaga')
109            f = self.create_file(testfile)
110            f.setnframes(self.nframes + 1)
111            f.writeframes(self.frames)
112            f.close()
113
114        with open(TESTFN, 'rb') as testfile:
115            self.assertEqual(testfile.read(13), b'ababagalamaga')
116            self.check_file(testfile, self.nframes, self.frames)
117
118    def test_multiple_writes(self):
119        with open(TESTFN, 'wb') as testfile:
120            testfile.write(b'ababagalamaga')
121            f = self.create_file(testfile)
122            f.setnframes(self.nframes)
123            framesize = self.nchannels * self.sampwidth
124            f.writeframes(self.frames[:-framesize])
125            f.writeframes(self.frames[-framesize:])
126            f.close()
127
128        with open(TESTFN, 'rb') as testfile:
129            self.assertEqual(testfile.read(13), b'ababagalamaga')
130            self.check_file(testfile, self.nframes, self.frames)
131
132    def test_overflowed_write(self):
133        with open(TESTFN, 'wb') as testfile:
134            testfile.write(b'ababagalamaga')
135            f = self.create_file(testfile)
136            f.setnframes(self.nframes - 1)
137            f.writeframes(self.frames)
138            f.close()
139
140        with open(TESTFN, 'rb') as testfile:
141            self.assertEqual(testfile.read(13), b'ababagalamaga')
142            self.check_file(testfile, self.nframes, self.frames)
143
144    def test_unseekable_read(self):
145        f = self.create_file(TESTFN)
146        f.setnframes(self.nframes)
147        f.writeframes(self.frames)
148        f.close()
149
150        with UnseekableIO(TESTFN, 'rb') as testfile:
151            self.check_file(testfile, self.nframes, self.frames)
152
153    def test_unseekable_write(self):
154        with UnseekableIO(TESTFN, 'wb') as testfile:
155            f = self.create_file(testfile)
156            f.setnframes(self.nframes)
157            f.writeframes(self.frames)
158            f.close()
159            self.fout = None
160
161        self.check_file(TESTFN, self.nframes, self.frames)
162
163    def test_unseekable_incompleted_write(self):
164        with UnseekableIO(TESTFN, 'wb') as testfile:
165            testfile.write(b'ababagalamaga')
166            f = self.create_file(testfile)
167            f.setnframes(self.nframes + 1)
168            try:
169                f.writeframes(self.frames)
170            except IOError:
171                pass
172            try:
173                f.close()
174            except IOError:
175                pass
176
177        with open(TESTFN, 'rb') as testfile:
178            self.assertEqual(testfile.read(13), b'ababagalamaga')
179            self.check_file(testfile, self.nframes + 1, self.frames)
180
181    def test_unseekable_overflowed_write(self):
182        with UnseekableIO(TESTFN, 'wb') as testfile:
183            testfile.write(b'ababagalamaga')
184            f = self.create_file(testfile)
185            f.setnframes(self.nframes - 1)
186            try:
187                f.writeframes(self.frames)
188            except IOError:
189                pass
190            try:
191                f.close()
192            except IOError:
193                pass
194
195        with open(TESTFN, 'rb') as testfile:
196            self.assertEqual(testfile.read(13), b'ababagalamaga')
197            framesize = self.nchannels * self.sampwidth
198            self.check_file(testfile, self.nframes - 1, self.frames[:-framesize])
199
200
201class AudioTestsWithSourceFile(AudioTests):
202
203    @classmethod
204    def setUpClass(cls):
205        cls.sndfilepath = findfile(cls.sndfilename, subdir='audiodata')
206
207    def test_read_params(self):
208        f = self.f = self.module.open(self.sndfilepath)
209        #self.assertEqual(f.getfp().name, self.sndfilepath)
210        self.check_params(f, self.nchannels, self.sampwidth, self.framerate,
211                          self.sndfilenframes, self.comptype, self.compname)
212
213    def test_close(self):
214        with open(self.sndfilepath, 'rb') as testfile:
215            f = self.f = self.module.open(testfile)
216            self.assertFalse(testfile.closed)
217            f.close()
218            self.assertEqual(testfile.closed, self.close_fd)
219        with open(TESTFN, 'wb') as testfile:
220            fout = self.fout = self.module.open(testfile, 'wb')
221            self.assertFalse(testfile.closed)
222            with self.assertRaises(self.module.Error):
223                fout.close()
224            self.assertEqual(testfile.closed, self.close_fd)
225            fout.close() # do nothing
226
227    def test_read(self):
228        framesize = self.nchannels * self.sampwidth
229        chunk1 = self.frames[:2 * framesize]
230        chunk2 = self.frames[2 * framesize: 4 * framesize]
231        f = self.f = self.module.open(self.sndfilepath)
232        self.assertEqual(f.readframes(0), b'')
233        self.assertEqual(f.tell(), 0)
234        self.assertEqual(f.readframes(2), chunk1)
235        f.rewind()
236        pos0 = f.tell()
237        self.assertEqual(pos0, 0)
238        self.assertEqual(f.readframes(2), chunk1)
239        pos2 = f.tell()
240        self.assertEqual(pos2, 2)
241        self.assertEqual(f.readframes(2), chunk2)
242        f.setpos(pos2)
243        self.assertEqual(f.readframes(2), chunk2)
244        f.setpos(pos0)
245        self.assertEqual(f.readframes(2), chunk1)
246        with self.assertRaises(self.module.Error):
247            f.setpos(-1)
248        with self.assertRaises(self.module.Error):
249            f.setpos(f.getnframes() + 1)
250
251    def test_copy(self):
252        f = self.f = self.module.open(self.sndfilepath)
253        fout = self.fout = self.module.open(TESTFN, 'wb')
254        fout.setparams(f.getparams())
255        i = 0
256        n = f.getnframes()
257        while n > 0:
258            i += 1
259            fout.writeframes(f.readframes(i))
260            n -= i
261        fout.close()
262        fout = self.fout = self.module.open(TESTFN, 'rb')
263        f.rewind()
264        self.assertEqual(f.getparams(), fout.getparams())
265        self.assertEqual(f.readframes(f.getnframes()),
266                         fout.readframes(fout.getnframes()))
267
268    def test_read_not_from_start(self):
269        with open(TESTFN, 'wb') as testfile:
270            testfile.write(b'ababagalamaga')
271            with open(self.sndfilepath, 'rb') as f:
272                testfile.write(f.read())
273
274        with open(TESTFN, 'rb') as testfile:
275            self.assertEqual(testfile.read(13), b'ababagalamaga')
276            f = self.module.open(testfile, 'rb')
277            try:
278                self.assertEqual(f.getnchannels(), self.nchannels)
279                self.assertEqual(f.getsampwidth(), self.sampwidth)
280                self.assertEqual(f.getframerate(), self.framerate)
281                self.assertEqual(f.getnframes(), self.sndfilenframes)
282                self.assertEqual(f.readframes(self.nframes), self.frames)
283            finally:
284                f.close()
285