decode_audio_op.cc revision a824c3c87226345de45432c329c2b21e17854d0e
1// Copyright 2016 Google Inc. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15 16#include <stdlib.h> 17 18#include <cstdio> 19#include <set> 20 21#include "tensorflow/contrib/ffmpeg/default/ffmpeg_lib.h" 22#include "tensorflow/core/framework/op.h" 23#include "tensorflow/core/framework/op_kernel.h" 24#include "tensorflow/core/lib/io/path.h" 25#include "tensorflow/core/lib/strings/str_util.h" 26#include "tensorflow/core/lib/strings/strcat.h" 27#include "tensorflow/core/platform/env.h" 28 29namespace tensorflow { 30namespace ffmpeg { 31namespace { 32 33// The complete set of audio file formats that are supported by the op. These 34// strings are defined by FFmpeg and documented here: 35// https://www.ffmpeg.org/ffmpeg-formats.html 36const char* kValidFileFormats[] = {"mp3", "ogg", "wav"}; 37 38// Writes binary data to a file. 39Status WriteFile(const string& filename, tensorflow::StringPiece contents) { 40 Env& env = *Env::Default(); 41 WritableFile* file = nullptr; 42 TF_RETURN_IF_ERROR(env.NewWritableFile(filename, &file)); 43 std::unique_ptr<WritableFile> file_deleter(file); 44 TF_RETURN_IF_ERROR(file->Append(contents)); 45 TF_RETURN_IF_ERROR(file->Close()); 46 return Status::OK(); 47} 48 49// Cleans up a file on destruction. 50class FileDeleter { 51 public: 52 explicit FileDeleter(const string& filename) : filename_(filename) {} 53 ~FileDeleter() { 54 Env& env = *Env::Default(); 55 env.DeleteFile(filename_); 56 } 57 58 private: 59 const string filename_; 60}; 61 62} // namespace 63 64class DecodeAudioOp : public OpKernel { 65 public: 66 explicit DecodeAudioOp(OpKernelConstruction* context) 67 : OpKernel(context) { 68 OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_)); 69 file_format_ = str_util::Lowercase(file_format_); 70 const std::set<string> valid_file_formats( 71 kValidFileFormats, 72 kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats)); 73 OP_REQUIRES(context, valid_file_formats.count(file_format_) == 1, 74 errors::InvalidArgument( 75 "file_format arg must be in {", 76 str_util::Join(valid_file_formats, ", "), "}.")); 77 78 OP_REQUIRES_OK( 79 context, context->GetAttr("samples_per_second", &samples_per_second_)); 80 OP_REQUIRES(context, samples_per_second_ > 0, 81 errors::InvalidArgument("samples_per_second must be > 0.")); 82 83 OP_REQUIRES_OK( 84 context, context->GetAttr("channel_count", &channel_count_)); 85 OP_REQUIRES(context, channel_count_ > 0, 86 errors::InvalidArgument("channel_count must be > 0.")); 87 } 88 89 void Compute(OpKernelContext* context) override { 90 // Get and verify the input data. 91 OP_REQUIRES( 92 context, context->num_inputs() == 1, 93 errors::InvalidArgument("DecodeAudio requires exactly one input.")); 94 const Tensor& contents = context->input(0); 95 OP_REQUIRES( 96 context, TensorShapeUtils::IsScalar(contents.shape()), 97 errors::InvalidArgument("contents must be scalar but got shape ", 98 contents.shape().DebugString())); 99 100 // Write the input data to a temp file. 101 const tensorflow::StringPiece file_contents = contents.scalar<string>()(); 102 const string input_filename = GetTempFilename(file_format_); 103 OP_REQUIRES_OK(context, WriteFile(input_filename, file_contents)); 104 FileDeleter deleter(input_filename); 105 106 // Run FFmpeg on the data and verify results. 107 std::vector<float> output_samples; 108 Status result = 109 ffmpeg::ReadAudioFile(input_filename, file_format_, samples_per_second_, 110 channel_count_, &output_samples); 111 if (result.code() == error::Code::NOT_FOUND) { 112 OP_REQUIRES( 113 context, result.ok(), 114 errors::Unavailable("FFmpeg must be installed to run this op. FFmpeg " 115 "can be found at http://www.ffmpeg.org.")); 116 } else { 117 OP_REQUIRES_OK(context, result); 118 } 119 OP_REQUIRES( 120 context, !output_samples.empty(), 121 errors::Unknown("No output created by FFmpeg.")); 122 OP_REQUIRES( 123 context, output_samples.size() % channel_count_ == 0, 124 errors::Unknown("FFmpeg created non-integer number of audio frames.")); 125 126 // Copy the output data to the output Tensor. 127 Tensor* output = nullptr; 128 const int64 frame_count = output_samples.size() / channel_count_; 129 OP_REQUIRES_OK( 130 context, context->allocate_output( 131 0, TensorShape({frame_count, channel_count_}), &output)); 132 auto matrix = output->tensor<float, 2>(); 133 for (int32 frame = 0; frame < frame_count; ++frame) { 134 for (int32 channel = 0; channel < channel_count_; ++channel) { 135 matrix(frame, channel) = 136 output_samples[frame * channel_count_ + channel]; 137 } 138 } 139 } 140 141 private: 142 string file_format_; 143 int32 samples_per_second_; 144 int32 channel_count_; 145}; 146 147REGISTER_KERNEL_BUILDER(Name("DecodeAudio").Device(DEVICE_CPU), DecodeAudioOp); 148 149REGISTER_OP("DecodeAudio") 150 .Input("contents: string") 151 .Output("sampled_audio: float") 152 .Attr("file_format: string") 153 .Attr("samples_per_second: int") 154 .Attr("channel_count: int") 155 .Doc(R"doc( 156Processes the contents of an audio file into a tensor using FFmpeg to decode 157the file. 158 159One row of the tensor is created for each channel in the audio file. Each 160channel contains audio samples starting at the beginning of the audio and 161having `1/samples_per_second` time between them. If the `channel_count` is 162different from the contents of the file, channels will be merged or created. 163 164contents: The binary audio file contents. 165sampled_audio: A rank 2 tensor containing all tracks of the audio. Dimension 0 166 is time and dimension 1 is the channel. 167file_format: A string describing the audio file format. This can be "wav" or 168 "mp3". 169samples_per_second: The number of samples per second that the audio should have. 170channel_count: The number of channels of audio to read. 171)doc"); 172 173} // namespace ffmpeg 174} // namespace tensorflow 175