1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for third_party.tensorflow.contrib.ffmpeg.decode_audio_op.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os.path 22 23import six 24 25from tensorflow.contrib import ffmpeg 26from tensorflow.python.framework import dtypes 27from tensorflow.python.ops import array_ops 28from tensorflow.python.platform import resource_loader 29from tensorflow.python.platform import test 30 31 32class DecodeAudioOpTest(test.TestCase): 33 34 def _loadFileAndTest(self, filename, file_format, duration_sec, 35 samples_per_second, channel_count, 36 samples_per_second_tensor=None, feed_dict=None, 37 stream=None): 38 """Loads an audio file and validates the output tensor. 39 40 Args: 41 filename: The filename of the input file. 42 file_format: The format of the input file. 43 duration_sec: The duration of the audio contained in the file in seconds. 44 samples_per_second: The desired sample rate in the output tensor. 45 channel_count: The desired channel count in the output tensor. 46 samples_per_second_tensor: The value to pass to the corresponding 47 parameter in the instantiated `decode_audio` op. If not 48 provided, will default to a constant value of 49 `samples_per_second`. Useful for providing a placeholder. 50 feed_dict: Used when evaluating the `decode_audio` op. If not 51 provided, will be empty. Useful when providing a placeholder for 52 `samples_per_second_tensor`. 53 stream: A string specifying which stream from the content file 54 should be decoded. The default value is '' which leaves the 55 decision to ffmpeg. 56 """ 57 if samples_per_second_tensor is None: 58 samples_per_second_tensor = samples_per_second 59 with self.test_session(): 60 path = os.path.join(resource_loader.get_data_files_path(), 'testdata', 61 filename) 62 with open(path, 'rb') as f: 63 contents = f.read() 64 65 audio_op = ffmpeg.decode_audio( 66 contents, 67 file_format=file_format, 68 samples_per_second=samples_per_second_tensor, 69 channel_count=channel_count, stream=stream) 70 audio = audio_op.eval(feed_dict=feed_dict or {}) 71 self.assertEqual(len(audio.shape), 2) 72 self.assertNear( 73 duration_sec * samples_per_second, 74 audio.shape[0], 75 # Duration should be specified within 10%: 76 0.1 * audio.shape[0]) 77 self.assertEqual(audio.shape[1], channel_count) 78 79 def testStreamIdentifier(self): 80 # mono_16khz_mp3_32khz_aac.mp4 was generated from: 81 # ffmpeg -i tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3.mp4 \ 82 # -i tensorflow/contrib/ffmpeg/testdata/mono_32khz_aac.mp4 \ 83 # -strict -2 -map 0:a -map 1:a \ 84 # tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3_32khz_aac.mp4 85 self._loadFileAndTest('mono_16khz_mp3_32khz_aac.mp4', 'mp4', 2.77, 20000, 86 1, stream='0') 87 self._loadFileAndTest('mono_16khz_mp3_32khz_aac.mp4', 'mp4', 2.77, 20000, 88 1, stream='1') 89 90 def testMonoMp3(self): 91 self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 1) 92 self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 2) 93 94 def testMonoMp4Mp3Codec(self): 95 # mp3 compressed audio streams in mp4 container. 96 self._loadFileAndTest('mono_16khz_mp3.mp4', 'mp4', 2.77, 20000, 1) 97 self._loadFileAndTest('mono_16khz_mp3.mp4', 'mp4', 2.77, 20000, 2) 98 99 def testMonoMp4AacCodec(self): 100 # aac compressed audio streams in mp4 container. 101 self._loadFileAndTest('mono_32khz_aac.mp4', 'mp4', 2.77, 20000, 1) 102 self._loadFileAndTest('mono_32khz_aac.mp4', 'mp4', 2.77, 20000, 2) 103 104 def testStereoMp3(self): 105 self._loadFileAndTest('stereo_48khz.mp3', 'mp3', 0.79, 50000, 1) 106 self._loadFileAndTest('stereo_48khz.mp3', 'mp3', 0.79, 20000, 2) 107 108 def testStereoMp4Mp3Codec(self): 109 # mp3 compressed audio streams in mp4 container. 110 self._loadFileAndTest('stereo_48khz_mp3.mp4', 'mp4', 0.79, 50000, 1) 111 self._loadFileAndTest('stereo_48khz_mp3.mp4', 'mp4', 0.79, 20000, 2) 112 113 def testStereoMp4AacCodec(self): 114 # aac compressed audio streams in mp4 container. 115 self._loadFileAndTest('stereo_48khz_aac.mp4', 'mp4', 0.79, 50000, 1) 116 self._loadFileAndTest('stereo_48khz_aac.mp4', 'mp4', 0.79, 20000, 2) 117 118 def testMonoWav(self): 119 self._loadFileAndTest('mono_10khz.wav', 'wav', 0.57, 5000, 1) 120 self._loadFileAndTest('mono_10khz.wav', 'wav', 0.57, 10000, 4) 121 122 def testOgg(self): 123 self._loadFileAndTest('mono_10khz.ogg', 'ogg', 0.57, 10000, 1) 124 125 def testInvalidFile(self): 126 with self.test_session(): 127 contents = 'invalid file' 128 audio_op = ffmpeg.decode_audio( 129 contents, 130 file_format='wav', 131 samples_per_second=10000, 132 channel_count=2) 133 audio = audio_op.eval() 134 self.assertEqual(audio.shape, (0, 0)) 135 136 def testSampleRatePlaceholder(self): 137 placeholder = array_ops.placeholder(dtypes.int32) 138 self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 1, 139 samples_per_second_tensor=placeholder, 140 feed_dict={placeholder: 20000}) 141 142 def testSampleRateBadType(self): 143 placeholder = array_ops.placeholder(dtypes.float32) 144 with self.assertRaises(TypeError): 145 self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000.0, 1, 146 samples_per_second_tensor=placeholder, 147 feed_dict={placeholder: 20000.0}) 148 149 def testSampleRateBadValue_Zero(self): 150 placeholder = array_ops.placeholder(dtypes.int32) 151 with six.assertRaisesRegex(self, Exception, 152 r'samples_per_second must be positive'): 153 self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000.0, 1, 154 samples_per_second_tensor=placeholder, 155 feed_dict={placeholder: 0}) 156 157 def testSampleRateBadValue_Negative(self): 158 placeholder = array_ops.placeholder(dtypes.int32) 159 with six.assertRaisesRegex(self, Exception, 160 r'samples_per_second must be positive'): 161 self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000.0, 1, 162 samples_per_second_tensor=placeholder, 163 feed_dict={placeholder: -2}) 164 165 def testInvalidFileFormat(self): 166 with six.assertRaisesRegex(self, Exception, 167 r'file_format must be one of'): 168 self._loadFileAndTest('mono_16khz.mp3', 'docx', 0.57, 20000, 1) 169 170 def testStaticShapeInference_ConstantChannelCount(self): 171 with self.test_session(): 172 audio_op = ffmpeg.decode_audio(b'~~~ wave ~~~', 173 file_format='wav', 174 samples_per_second=44100, 175 channel_count=2) 176 self.assertEqual([None, 2], audio_op.shape.as_list()) 177 178 def testStaticShapeInference_NonConstantChannelCount(self): 179 with self.test_session(): 180 channel_count = array_ops.placeholder(dtypes.int32) 181 audio_op = ffmpeg.decode_audio(b'~~~ wave ~~~', 182 file_format='wav', 183 samples_per_second=44100, 184 channel_count=channel_count) 185 self.assertEqual([None, None], audio_op.shape.as_list()) 186 187 def testStaticShapeInference_ZeroChannelCountInvalid(self): 188 with self.test_session(): 189 with six.assertRaisesRegex(self, Exception, 190 r'channel_count must be positive'): 191 ffmpeg.decode_audio(b'~~~ wave ~~~', 192 file_format='wav', 193 samples_per_second=44100, 194 channel_count=0) 195 196 def testStaticShapeInference_NegativeChannelCountInvalid(self): 197 with self.test_session(): 198 with six.assertRaisesRegex(self, Exception, 199 r'channel_count must be positive'): 200 ffmpeg.decode_audio(b'~~~ wave ~~~', 201 file_format='wav', 202 samples_per_second=44100, 203 channel_count=-2) 204 205 206if __name__ == '__main__': 207 test.main() 208