1"""Stuff to parse WAVE files.
2
3Usage.
4
5Reading WAVE files:
6      f = wave.open(file, 'r')
7where file is either the name of a file or an open file pointer.
8The open file pointer must have methods read(), seek(), and close().
9When the setpos() and rewind() methods are not used, the seek()
10method is not  necessary.
11
12This returns an instance of a class with the following public methods:
13      getnchannels()  -- returns number of audio channels (1 for
14                         mono, 2 for stereo)
15      getsampwidth()  -- returns sample width in bytes
16      getframerate()  -- returns sampling frequency
17      getnframes()    -- returns number of audio frames
18      getcomptype()   -- returns compression type ('NONE' for linear samples)
19      getcompname()   -- returns human-readable version of
20                         compression type ('not compressed' linear samples)
21      getparams()     -- returns a tuple consisting of all of the
22                         above in the above order
23      getmarkers()    -- returns None (for compatibility with the
24                         aifc module)
25      getmark(id)     -- raises an error since the mark does not
26                         exist (for compatibility with the aifc module)
27      readframes(n)   -- returns at most n frames of audio
28      rewind()        -- rewind to the beginning of the audio stream
29      setpos(pos)     -- seek to the specified position
30      tell()          -- return the current position
31      close()         -- close the instance (make it unusable)
32The position returned by tell() and the position given to setpos()
33are compatible and have nothing to do with the actual position in the
34file.
35The close() method is called automatically when the class instance
36is destroyed.
37
38Writing WAVE files:
39      f = wave.open(file, 'w')
40where file is either the name of a file or an open file pointer.
41The open file pointer must have methods write(), tell(), seek(), and
42close().
43
44This returns an instance of a class with the following public methods:
45      setnchannels(n) -- set the number of channels
46      setsampwidth(n) -- set the sample width
47      setframerate(n) -- set the frame rate
48      setnframes(n)   -- set the number of frames
49      setcomptype(type, name)
50                      -- set the compression type and the
51                         human-readable compression type
52      setparams(tuple)
53                      -- set all parameters at once
54      tell()          -- return current position in output file
55      writeframesraw(data)
56                      -- write audio frames without pathing up the
57                         file header
58      writeframes(data)
59                      -- write audio frames and patch up the file header
60      close()         -- patch up the file header and close the
61                         output file
62You should set the parameters before the first writeframesraw or
63writeframes.  The total number of frames does not need to be set,
64but when it is set to the correct value, the header does not have to
65be patched up.
66It is best to first set all parameters, perhaps possibly the
67compression type, and then write audio frames using writeframesraw.
68When all frames have been written, either call writeframes('') or
69close() to patch up the sizes in the header.
70The close() method is called automatically when the class instance
71is destroyed.
72"""
73
74import __builtin__
75
76__all__ = ["open", "openfp", "Error"]
77
78class Error(Exception):
79    pass
80
81WAVE_FORMAT_PCM = 0x0001
82
83_array_fmts = None, 'b', 'h', None, 'i'
84
85import struct
86import sys
87from chunk import Chunk
88
89def _byteswap3(data):
90    ba = bytearray(data)
91    ba[::3] = data[2::3]
92    ba[2::3] = data[::3]
93    return bytes(ba)
94
95class Wave_read:
96    """Variables used in this class:
97
98    These variables are available to the user though appropriate
99    methods of this class:
100    _file -- the open file with methods read(), close(), and seek()
101              set through the __init__() method
102    _nchannels -- the number of audio channels
103              available through the getnchannels() method
104    _nframes -- the number of audio frames
105              available through the getnframes() method
106    _sampwidth -- the number of bytes per audio sample
107              available through the getsampwidth() method
108    _framerate -- the sampling frequency
109              available through the getframerate() method
110    _comptype -- the AIFF-C compression type ('NONE' if AIFF)
111              available through the getcomptype() method
112    _compname -- the human-readable AIFF-C compression type
113              available through the getcomptype() method
114    _soundpos -- the position in the audio stream
115              available through the tell() method, set through the
116              setpos() method
117
118    These variables are used internally only:
119    _fmt_chunk_read -- 1 iff the FMT chunk has been read
120    _data_seek_needed -- 1 iff positioned correctly in audio
121              file for readframes()
122    _data_chunk -- instantiation of a chunk class for the DATA chunk
123    _framesize -- size of one frame in the file
124    """
125
126    def initfp(self, file):
127        self._convert = None
128        self._soundpos = 0
129        self._file = Chunk(file, bigendian = 0)
130        if self._file.getname() != 'RIFF':
131            raise Error, 'file does not start with RIFF id'
132        if self._file.read(4) != 'WAVE':
133            raise Error, 'not a WAVE file'
134        self._fmt_chunk_read = 0
135        self._data_chunk = None
136        while 1:
137            self._data_seek_needed = 1
138            try:
139                chunk = Chunk(self._file, bigendian = 0)
140            except EOFError:
141                break
142            chunkname = chunk.getname()
143            if chunkname == 'fmt ':
144                self._read_fmt_chunk(chunk)
145                self._fmt_chunk_read = 1
146            elif chunkname == 'data':
147                if not self._fmt_chunk_read:
148                    raise Error, 'data chunk before fmt chunk'
149                self._data_chunk = chunk
150                self._nframes = chunk.chunksize // self._framesize
151                self._data_seek_needed = 0
152                break
153            chunk.skip()
154        if not self._fmt_chunk_read or not self._data_chunk:
155            raise Error, 'fmt chunk and/or data chunk missing'
156
157    def __init__(self, f):
158        self._i_opened_the_file = None
159        if isinstance(f, basestring):
160            f = __builtin__.open(f, 'rb')
161            self._i_opened_the_file = f
162        # else, assume it is an open file object already
163        try:
164            self.initfp(f)
165        except:
166            if self._i_opened_the_file:
167                f.close()
168            raise
169
170    def __del__(self):
171        self.close()
172    #
173    # User visible methods.
174    #
175    def getfp(self):
176        return self._file
177
178    def rewind(self):
179        self._data_seek_needed = 1
180        self._soundpos = 0
181
182    def close(self):
183        self._file = None
184        file = self._i_opened_the_file
185        if file:
186            self._i_opened_the_file = None
187            file.close()
188
189    def tell(self):
190        return self._soundpos
191
192    def getnchannels(self):
193        return self._nchannels
194
195    def getnframes(self):
196        return self._nframes
197
198    def getsampwidth(self):
199        return self._sampwidth
200
201    def getframerate(self):
202        return self._framerate
203
204    def getcomptype(self):
205        return self._comptype
206
207    def getcompname(self):
208        return self._compname
209
210    def getparams(self):
211        return self.getnchannels(), self.getsampwidth(), \
212               self.getframerate(), self.getnframes(), \
213               self.getcomptype(), self.getcompname()
214
215    def getmarkers(self):
216        return None
217
218    def getmark(self, id):
219        raise Error, 'no marks'
220
221    def setpos(self, pos):
222        if pos < 0 or pos > self._nframes:
223            raise Error, 'position not in range'
224        self._soundpos = pos
225        self._data_seek_needed = 1
226
227    def readframes(self, nframes):
228        if self._data_seek_needed:
229            self._data_chunk.seek(0, 0)
230            pos = self._soundpos * self._framesize
231            if pos:
232                self._data_chunk.seek(pos, 0)
233            self._data_seek_needed = 0
234        if nframes == 0:
235            return ''
236        if self._sampwidth in (2, 4) and sys.byteorder == 'big':
237            # unfortunately the fromfile() method does not take
238            # something that only looks like a file object, so
239            # we have to reach into the innards of the chunk object
240            import array
241            chunk = self._data_chunk
242            data = array.array(_array_fmts[self._sampwidth])
243            assert data.itemsize == self._sampwidth
244            nitems = nframes * self._nchannels
245            if nitems * self._sampwidth > chunk.chunksize - chunk.size_read:
246                nitems = (chunk.chunksize - chunk.size_read) // self._sampwidth
247            data.fromfile(chunk.file.file, nitems)
248            # "tell" data chunk how much was read
249            chunk.size_read = chunk.size_read + nitems * self._sampwidth
250            # do the same for the outermost chunk
251            chunk = chunk.file
252            chunk.size_read = chunk.size_read + nitems * self._sampwidth
253            data.byteswap()
254            data = data.tostring()
255        else:
256            data = self._data_chunk.read(nframes * self._framesize)
257            if self._sampwidth == 3 and sys.byteorder == 'big':
258                data = _byteswap3(data)
259        if self._convert and data:
260            data = self._convert(data)
261        self._soundpos = self._soundpos + len(data) // (self._nchannels * self._sampwidth)
262        return data
263
264    #
265    # Internal methods.
266    #
267
268    def _read_fmt_chunk(self, chunk):
269        wFormatTag, self._nchannels, self._framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack('<HHLLH', chunk.read(14))
270        if wFormatTag == WAVE_FORMAT_PCM:
271            sampwidth = struct.unpack('<H', chunk.read(2))[0]
272            self._sampwidth = (sampwidth + 7) // 8
273        else:
274            raise Error, 'unknown format: %r' % (wFormatTag,)
275        self._framesize = self._nchannels * self._sampwidth
276        self._comptype = 'NONE'
277        self._compname = 'not compressed'
278
279class Wave_write:
280    """Variables used in this class:
281
282    These variables are user settable through appropriate methods
283    of this class:
284    _file -- the open file with methods write(), close(), tell(), seek()
285              set through the __init__() method
286    _comptype -- the AIFF-C compression type ('NONE' in AIFF)
287              set through the setcomptype() or setparams() method
288    _compname -- the human-readable AIFF-C compression type
289              set through the setcomptype() or setparams() method
290    _nchannels -- the number of audio channels
291              set through the setnchannels() or setparams() method
292    _sampwidth -- the number of bytes per audio sample
293              set through the setsampwidth() or setparams() method
294    _framerate -- the sampling frequency
295              set through the setframerate() or setparams() method
296    _nframes -- the number of audio frames written to the header
297              set through the setnframes() or setparams() method
298
299    These variables are used internally only:
300    _datalength -- the size of the audio samples written to the header
301    _nframeswritten -- the number of frames actually written
302    _datawritten -- the size of the audio samples actually written
303    """
304
305    def __init__(self, f):
306        self._i_opened_the_file = None
307        if isinstance(f, basestring):
308            f = __builtin__.open(f, 'wb')
309            self._i_opened_the_file = f
310        try:
311            self.initfp(f)
312        except:
313            if self._i_opened_the_file:
314                f.close()
315            raise
316
317    def initfp(self, file):
318        self._file = file
319        self._convert = None
320        self._nchannels = 0
321        self._sampwidth = 0
322        self._framerate = 0
323        self._nframes = 0
324        self._nframeswritten = 0
325        self._datawritten = 0
326        self._datalength = 0
327        self._headerwritten = False
328
329    def __del__(self):
330        self.close()
331
332    #
333    # User visible methods.
334    #
335    def setnchannels(self, nchannels):
336        if self._datawritten:
337            raise Error, 'cannot change parameters after starting to write'
338        if nchannels < 1:
339            raise Error, 'bad # of channels'
340        self._nchannels = nchannels
341
342    def getnchannels(self):
343        if not self._nchannels:
344            raise Error, 'number of channels not set'
345        return self._nchannels
346
347    def setsampwidth(self, sampwidth):
348        if self._datawritten:
349            raise Error, 'cannot change parameters after starting to write'
350        if sampwidth < 1 or sampwidth > 4:
351            raise Error, 'bad sample width'
352        self._sampwidth = sampwidth
353
354    def getsampwidth(self):
355        if not self._sampwidth:
356            raise Error, 'sample width not set'
357        return self._sampwidth
358
359    def setframerate(self, framerate):
360        if self._datawritten:
361            raise Error, 'cannot change parameters after starting to write'
362        if framerate <= 0:
363            raise Error, 'bad frame rate'
364        self._framerate = framerate
365
366    def getframerate(self):
367        if not self._framerate:
368            raise Error, 'frame rate not set'
369        return self._framerate
370
371    def setnframes(self, nframes):
372        if self._datawritten:
373            raise Error, 'cannot change parameters after starting to write'
374        self._nframes = nframes
375
376    def getnframes(self):
377        return self._nframeswritten
378
379    def setcomptype(self, comptype, compname):
380        if self._datawritten:
381            raise Error, 'cannot change parameters after starting to write'
382        if comptype not in ('NONE',):
383            raise Error, 'unsupported compression type'
384        self._comptype = comptype
385        self._compname = compname
386
387    def getcomptype(self):
388        return self._comptype
389
390    def getcompname(self):
391        return self._compname
392
393    def setparams(self, params):
394        nchannels, sampwidth, framerate, nframes, comptype, compname = params
395        if self._datawritten:
396            raise Error, 'cannot change parameters after starting to write'
397        self.setnchannels(nchannels)
398        self.setsampwidth(sampwidth)
399        self.setframerate(framerate)
400        self.setnframes(nframes)
401        self.setcomptype(comptype, compname)
402
403    def getparams(self):
404        if not self._nchannels or not self._sampwidth or not self._framerate:
405            raise Error, 'not all parameters set'
406        return self._nchannels, self._sampwidth, self._framerate, \
407              self._nframes, self._comptype, self._compname
408
409    def setmark(self, id, pos, name):
410        raise Error, 'setmark() not supported'
411
412    def getmark(self, id):
413        raise Error, 'no marks'
414
415    def getmarkers(self):
416        return None
417
418    def tell(self):
419        return self._nframeswritten
420
421    def writeframesraw(self, data):
422        self._ensure_header_written(len(data))
423        nframes = len(data) // (self._sampwidth * self._nchannels)
424        if self._convert:
425            data = self._convert(data)
426        if self._sampwidth in (2, 4) and sys.byteorder == 'big':
427            import array
428            a = array.array(_array_fmts[self._sampwidth])
429            a.fromstring(data)
430            data = a
431            assert data.itemsize == self._sampwidth
432            data.byteswap()
433            data.tofile(self._file)
434            self._datawritten = self._datawritten + len(data) * self._sampwidth
435        else:
436            if self._sampwidth == 3 and sys.byteorder == 'big':
437                data = _byteswap3(data)
438            self._file.write(data)
439            self._datawritten = self._datawritten + len(data)
440        self._nframeswritten = self._nframeswritten + nframes
441
442    def writeframes(self, data):
443        self.writeframesraw(data)
444        if self._datalength != self._datawritten:
445            self._patchheader()
446
447    def close(self):
448        try:
449            if self._file:
450                self._ensure_header_written(0)
451                if self._datalength != self._datawritten:
452                    self._patchheader()
453                self._file.flush()
454        finally:
455            self._file = None
456            file = self._i_opened_the_file
457            if file:
458                self._i_opened_the_file = None
459                file.close()
460
461    #
462    # Internal methods.
463    #
464
465    def _ensure_header_written(self, datasize):
466        if not self._headerwritten:
467            if not self._nchannels:
468                raise Error, '# channels not specified'
469            if not self._sampwidth:
470                raise Error, 'sample width not specified'
471            if not self._framerate:
472                raise Error, 'sampling rate not specified'
473            self._write_header(datasize)
474
475    def _write_header(self, initlength):
476        assert not self._headerwritten
477        self._file.write('RIFF')
478        if not self._nframes:
479            self._nframes = initlength / (self._nchannels * self._sampwidth)
480        self._datalength = self._nframes * self._nchannels * self._sampwidth
481        self._form_length_pos = self._file.tell()
482        self._file.write(struct.pack('<L4s4sLHHLLHH4s',
483            36 + self._datalength, 'WAVE', 'fmt ', 16,
484            WAVE_FORMAT_PCM, self._nchannels, self._framerate,
485            self._nchannels * self._framerate * self._sampwidth,
486            self._nchannels * self._sampwidth,
487            self._sampwidth * 8, 'data'))
488        self._data_length_pos = self._file.tell()
489        self._file.write(struct.pack('<L', self._datalength))
490        self._headerwritten = True
491
492    def _patchheader(self):
493        assert self._headerwritten
494        if self._datawritten == self._datalength:
495            return
496        curpos = self._file.tell()
497        self._file.seek(self._form_length_pos, 0)
498        self._file.write(struct.pack('<L', 36 + self._datawritten))
499        self._file.seek(self._data_length_pos, 0)
500        self._file.write(struct.pack('<L', self._datawritten))
501        self._file.seek(curpos, 0)
502        self._datalength = self._datawritten
503
504def open(f, mode=None):
505    if mode is None:
506        if hasattr(f, 'mode'):
507            mode = f.mode
508        else:
509            mode = 'rb'
510    if mode in ('r', 'rb'):
511        return Wave_read(f)
512    elif mode in ('w', 'wb'):
513        return Wave_write(f)
514    else:
515        raise Error, "mode must be 'r', 'rb', 'w', or 'wb'"
516
517openfp = open # B/W compatibility
518