decode_audio_op.cc revision fcc9a6ed272d6599d38ae59ae215cff786ad1bea
1// Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15 16#include <stdlib.h> 17 18#include <cstdio> 19#include <set> 20 21#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h" 22#include "tensorflow/core/framework/op.h" 23#include "tensorflow/core/framework/op_kernel.h" 24#include "tensorflow/core/lib/io/path.h" 25#include "tensorflow/core/lib/strings/str_util.h" 26#include "tensorflow/core/lib/strings/strcat.h" 27#include "tensorflow/core/platform/env.h" 28#include "tensorflow/core/platform/logging.h" 29 30namespace tensorflow { 31namespace ffmpeg { 32namespace { 33 34// The complete set of audio file formats that are supported by the op. These 35// strings are defined by FFmpeg and documented here: 36// https://www.ffmpeg.org/ffmpeg-formats.html 37const char* kValidFileFormats[] = {"mp3", "ogg", "wav"}; 38 39// Writes binary data to a file. 40Status WriteFile(const string& filename, tensorflow::StringPiece contents) { 41 Env& env = *Env::Default(); 42 std::unique_ptr<WritableFile> file; 43 TF_RETURN_IF_ERROR(env.NewWritableFile(filename, &file)); 44 TF_RETURN_IF_ERROR(file->Append(contents)); 45 TF_RETURN_IF_ERROR(file->Close()); 46 return Status::OK(); 47} 48 49// Cleans up a file on destruction. 50class FileDeleter { 51 public: 52 explicit FileDeleter(const string& filename) : filename_(filename) {} 53 ~FileDeleter() { 54 Env& env = *Env::Default(); 55 env.DeleteFile(filename_); 56 } 57 58 private: 59 const string filename_; 60}; 61 62} // namespace 63 64class DecodeAudioOp : public OpKernel { 65 public: 66 explicit DecodeAudioOp(OpKernelConstruction* context) 67 : OpKernel(context) { 68 OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_)); 69 file_format_ = str_util::Lowercase(file_format_); 70 const std::set<string> valid_file_formats( 71 kValidFileFormats, 72 kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats)); 73 OP_REQUIRES(context, valid_file_formats.count(file_format_) == 1, 74 errors::InvalidArgument( 75 "file_format arg must be in {", 76 str_util::Join(valid_file_formats, ", "), "}.")); 77 78 OP_REQUIRES_OK( 79 context, context->GetAttr("samples_per_second", &samples_per_second_)); 80 OP_REQUIRES(context, samples_per_second_ > 0, 81 errors::InvalidArgument("samples_per_second must be > 0.")); 82 83 OP_REQUIRES_OK( 84 context, context->GetAttr("channel_count", &channel_count_)); 85 OP_REQUIRES(context, channel_count_ > 0, 86 errors::InvalidArgument("channel_count must be > 0.")); 87 } 88 89 void Compute(OpKernelContext* context) override { 90 // Get and verify the input data. 91 OP_REQUIRES( 92 context, context->num_inputs() == 1, 93 errors::InvalidArgument("DecodeAudio requires exactly one input.")); 94 const Tensor& contents = context->input(0); 95 OP_REQUIRES( 96 context, TensorShapeUtils::IsScalar(contents.shape()), 97 errors::InvalidArgument("contents must be scalar but got shape ", 98 contents.shape().DebugString())); 99 100 // Write the input data to a temp file. 101 const tensorflow::StringPiece file_contents = contents.scalar<string>()(); 102 const string input_filename = GetTempFilename(file_format_); 103 OP_REQUIRES_OK(context, WriteFile(input_filename, file_contents)); 104 FileDeleter deleter(input_filename); 105 106 // Run FFmpeg on the data and verify results. 107 std::vector<float> output_samples; 108 Status result = 109 ffmpeg::ReadAudioFile(input_filename, file_format_, samples_per_second_, 110 channel_count_, &output_samples); 111 if (result.code() == error::Code::NOT_FOUND) { 112 OP_REQUIRES( 113 context, result.ok(), 114 errors::Unavailable("FFmpeg must be installed to run this op. FFmpeg " 115 "can be found at http://www.ffmpeg.org.")); 116 } else if (result.code() == error::UNKNOWN) { 117 LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message() 118 << "'. Returning empty tensor."; 119 Tensor* output = nullptr; 120 OP_REQUIRES_OK( 121 context, context->allocate_output(0, TensorShape({0, 0}), &output)); 122 return; 123 } else { 124 OP_REQUIRES_OK(context, result); 125 } 126 OP_REQUIRES( 127 context, !output_samples.empty(), 128 errors::Unknown("No output created by FFmpeg.")); 129 OP_REQUIRES( 130 context, output_samples.size() % channel_count_ == 0, 131 errors::Unknown("FFmpeg created non-integer number of audio frames.")); 132 133 // Copy the output data to the output Tensor. 134 Tensor* output = nullptr; 135 const int64 frame_count = output_samples.size() / channel_count_; 136 OP_REQUIRES_OK( 137 context, context->allocate_output( 138 0, TensorShape({frame_count, channel_count_}), &output)); 139 auto matrix = output->tensor<float, 2>(); 140 for (int32 frame = 0; frame < frame_count; ++frame) { 141 for (int32 channel = 0; channel < channel_count_; ++channel) { 142 matrix(frame, channel) = 143 output_samples[frame * channel_count_ + channel]; 144 } 145 } 146 } 147 148 private: 149 string file_format_; 150 int32 samples_per_second_; 151 int32 channel_count_; 152}; 153 154REGISTER_KERNEL_BUILDER(Name("DecodeAudio").Device(DEVICE_CPU), DecodeAudioOp); 155 156REGISTER_OP("DecodeAudio") 157 .Input("contents: string") 158 .Output("sampled_audio: float") 159 .Attr("file_format: string") 160 .Attr("samples_per_second: int") 161 .Attr("channel_count: int") 162 .Doc(R"doc( 163Processes the contents of an audio file into a tensor using FFmpeg to decode 164the file. 165 166One row of the tensor is created for each channel in the audio file. Each 167channel contains audio samples starting at the beginning of the audio and 168having `1/samples_per_second` time between them. If the `channel_count` is 169different from the contents of the file, channels will be merged or created. 170 171contents: The binary audio file contents. 172sampled_audio: A rank 2 tensor containing all tracks of the audio. Dimension 0 173 is time and dimension 1 is the channel. If ffmpeg fails to decode the audio 174 then an empty tensor will be returned. 175file_format: A string describing the audio file format. This can be "wav" or 176 "mp3". 177samples_per_second: The number of samples per second that the audio should have. 178channel_count: The number of channels of audio to read. 179)doc"); 180 181} // namespace ffmpeg 182} // namespace tensorflow 183