decode_audio_op.cc revision fcc9a6ed272d6599d38ae59ae215cff786ad1bea
1// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15
16#include <stdlib.h>
17
18#include <cstdio>
19#include <set>
20
21#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
22#include "tensorflow/core/framework/op.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/lib/io/path.h"
25#include "tensorflow/core/lib/strings/str_util.h"
26#include "tensorflow/core/lib/strings/strcat.h"
27#include "tensorflow/core/platform/env.h"
28#include "tensorflow/core/platform/logging.h"
29
30namespace tensorflow {
31namespace ffmpeg {
32namespace {
33
34// The complete set of audio file formats that are supported by the op. These
35// strings are defined by FFmpeg and documented here:
36// https://www.ffmpeg.org/ffmpeg-formats.html
37const char* kValidFileFormats[] = {"mp3", "ogg", "wav"};
38
39// Writes binary data to a file.
40Status WriteFile(const string& filename, tensorflow::StringPiece contents) {
41  Env& env = *Env::Default();
42  std::unique_ptr<WritableFile> file;
43  TF_RETURN_IF_ERROR(env.NewWritableFile(filename, &file));
44  TF_RETURN_IF_ERROR(file->Append(contents));
45  TF_RETURN_IF_ERROR(file->Close());
46  return Status::OK();
47}
48
49// Cleans up a file on destruction.
50class FileDeleter {
51 public:
52  explicit FileDeleter(const string& filename) : filename_(filename) {}
53  ~FileDeleter() {
54    Env& env = *Env::Default();
55    env.DeleteFile(filename_);
56  }
57
58 private:
59  const string filename_;
60};
61
62}  // namespace
63
64class DecodeAudioOp : public OpKernel {
65 public:
66  explicit DecodeAudioOp(OpKernelConstruction* context)
67      : OpKernel(context) {
68    OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
69    file_format_ = str_util::Lowercase(file_format_);
70    const std::set<string> valid_file_formats(
71        kValidFileFormats,
72        kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
73    OP_REQUIRES(context, valid_file_formats.count(file_format_) == 1,
74                errors::InvalidArgument(
75                    "file_format arg must be in {",
76                    str_util::Join(valid_file_formats, ", "), "}."));
77
78    OP_REQUIRES_OK(
79        context, context->GetAttr("samples_per_second", &samples_per_second_));
80    OP_REQUIRES(context, samples_per_second_ > 0,
81                errors::InvalidArgument("samples_per_second must be > 0."));
82
83    OP_REQUIRES_OK(
84        context, context->GetAttr("channel_count", &channel_count_));
85    OP_REQUIRES(context, channel_count_ > 0,
86                errors::InvalidArgument("channel_count must be > 0."));
87  }
88
89  void Compute(OpKernelContext* context) override {
90    // Get and verify the input data.
91    OP_REQUIRES(
92        context, context->num_inputs() == 1,
93        errors::InvalidArgument("DecodeAudio requires exactly one input."));
94    const Tensor& contents = context->input(0);
95    OP_REQUIRES(
96        context, TensorShapeUtils::IsScalar(contents.shape()),
97        errors::InvalidArgument("contents must be scalar but got shape ",
98                                contents.shape().DebugString()));
99
100    // Write the input data to a temp file.
101    const tensorflow::StringPiece file_contents = contents.scalar<string>()();
102    const string input_filename = GetTempFilename(file_format_);
103    OP_REQUIRES_OK(context, WriteFile(input_filename, file_contents));
104    FileDeleter deleter(input_filename);
105
106    // Run FFmpeg on the data and verify results.
107    std::vector<float> output_samples;
108    Status result =
109        ffmpeg::ReadAudioFile(input_filename, file_format_, samples_per_second_,
110                              channel_count_, &output_samples);
111    if (result.code() == error::Code::NOT_FOUND) {
112      OP_REQUIRES(
113          context, result.ok(),
114          errors::Unavailable("FFmpeg must be installed to run this op. FFmpeg "
115                              "can be found at http://www.ffmpeg.org."));
116    } else if (result.code() == error::UNKNOWN) {
117      LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message()
118                 << "'. Returning empty tensor.";
119      Tensor* output = nullptr;
120      OP_REQUIRES_OK(
121          context, context->allocate_output(0, TensorShape({0, 0}), &output));
122      return;
123    } else {
124      OP_REQUIRES_OK(context, result);
125    }
126    OP_REQUIRES(
127        context, !output_samples.empty(),
128        errors::Unknown("No output created by FFmpeg."));
129    OP_REQUIRES(
130        context, output_samples.size() % channel_count_ == 0,
131        errors::Unknown("FFmpeg created non-integer number of audio frames."));
132
133    // Copy the output data to the output Tensor.
134    Tensor* output = nullptr;
135    const int64 frame_count = output_samples.size() / channel_count_;
136    OP_REQUIRES_OK(
137        context, context->allocate_output(
138            0, TensorShape({frame_count, channel_count_}), &output));
139    auto matrix = output->tensor<float, 2>();
140    for (int32 frame = 0; frame < frame_count; ++frame) {
141      for (int32 channel = 0; channel < channel_count_; ++channel) {
142        matrix(frame, channel) =
143            output_samples[frame * channel_count_ + channel];
144      }
145    }
146  }
147
148 private:
149  string file_format_;
150  int32 samples_per_second_;
151  int32 channel_count_;
152};
153
154REGISTER_KERNEL_BUILDER(Name("DecodeAudio").Device(DEVICE_CPU), DecodeAudioOp);
155
156REGISTER_OP("DecodeAudio")
157    .Input("contents: string")
158    .Output("sampled_audio: float")
159    .Attr("file_format: string")
160    .Attr("samples_per_second: int")
161    .Attr("channel_count: int")
162    .Doc(R"doc(
163Processes the contents of an audio file into a tensor using FFmpeg to decode
164the file.
165
166One row of the tensor is created for each channel in the audio file. Each
167channel contains audio samples starting at the beginning of the audio and
168having `1/samples_per_second` time between them. If the `channel_count` is
169different from the contents of the file, channels will be merged or created.
170
171contents: The binary audio file contents.
172sampled_audio: A rank 2 tensor containing all tracks of the audio. Dimension 0
173    is time and dimension 1 is the channel. If ffmpeg fails to decode the audio
174    then an empty tensor will be returned.
175file_format: A string describing the audio file format. This can be "wav" or
176    "mp3".
177samples_per_second: The number of samples per second that the audio should have.
178channel_count: The number of channels of audio to read.
179)doc");
180
181}  // namespace ffmpeg
182}  // namespace tensorflow
183