decode_audio_op.cc revision a824c3c87226345de45432c329c2b21e17854d0e
1// Copyright 2016 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15
16#include <stdlib.h>
17
18#include <cstdio>
19#include <set>
20
21#include "tensorflow/contrib/ffmpeg/default/ffmpeg_lib.h"
22#include "tensorflow/core/framework/op.h"
23#include "tensorflow/core/framework/op_kernel.h"
24#include "tensorflow/core/lib/io/path.h"
25#include "tensorflow/core/lib/strings/str_util.h"
26#include "tensorflow/core/lib/strings/strcat.h"
27#include "tensorflow/core/platform/env.h"
28
29namespace tensorflow {
30namespace ffmpeg {
31namespace {
32
33// The complete set of audio file formats that are supported by the op. These
34// strings are defined by FFmpeg and documented here:
35// https://www.ffmpeg.org/ffmpeg-formats.html
36const char* kValidFileFormats[] = {"mp3", "ogg", "wav"};
37
38// Writes binary data to a file.
39Status WriteFile(const string& filename, tensorflow::StringPiece contents) {
40  Env& env = *Env::Default();
41  WritableFile* file = nullptr;
42  TF_RETURN_IF_ERROR(env.NewWritableFile(filename, &file));
43  std::unique_ptr<WritableFile> file_deleter(file);
44  TF_RETURN_IF_ERROR(file->Append(contents));
45  TF_RETURN_IF_ERROR(file->Close());
46  return Status::OK();
47}
48
49// Cleans up a file on destruction.
50class FileDeleter {
51 public:
52  explicit FileDeleter(const string& filename) : filename_(filename) {}
53  ~FileDeleter() {
54    Env& env = *Env::Default();
55    env.DeleteFile(filename_);
56  }
57
58 private:
59  const string filename_;
60};
61
62}  // namespace
63
64class DecodeAudioOp : public OpKernel {
65 public:
66  explicit DecodeAudioOp(OpKernelConstruction* context)
67      : OpKernel(context) {
68    OP_REQUIRES_OK(context, context->GetAttr("file_format", &file_format_));
69    file_format_ = str_util::Lowercase(file_format_);
70    const std::set<string> valid_file_formats(
71        kValidFileFormats,
72        kValidFileFormats + TF_ARRAYSIZE(kValidFileFormats));
73    OP_REQUIRES(context, valid_file_formats.count(file_format_) == 1,
74                errors::InvalidArgument(
75                    "file_format arg must be in {",
76                    str_util::Join(valid_file_formats, ", "), "}."));
77
78    OP_REQUIRES_OK(
79        context, context->GetAttr("samples_per_second", &samples_per_second_));
80    OP_REQUIRES(context, samples_per_second_ > 0,
81                errors::InvalidArgument("samples_per_second must be > 0."));
82
83    OP_REQUIRES_OK(
84        context, context->GetAttr("channel_count", &channel_count_));
85    OP_REQUIRES(context, channel_count_ > 0,
86                errors::InvalidArgument("channel_count must be > 0."));
87  }
88
89  void Compute(OpKernelContext* context) override {
90    // Get and verify the input data.
91    OP_REQUIRES(
92        context, context->num_inputs() == 1,
93        errors::InvalidArgument("DecodeAudio requires exactly one input."));
94    const Tensor& contents = context->input(0);
95    OP_REQUIRES(
96        context, TensorShapeUtils::IsScalar(contents.shape()),
97        errors::InvalidArgument("contents must be scalar but got shape ",
98                                contents.shape().DebugString()));
99
100    // Write the input data to a temp file.
101    const tensorflow::StringPiece file_contents = contents.scalar<string>()();
102    const string input_filename = GetTempFilename(file_format_);
103    OP_REQUIRES_OK(context, WriteFile(input_filename, file_contents));
104    FileDeleter deleter(input_filename);
105
106    // Run FFmpeg on the data and verify results.
107    std::vector<float> output_samples;
108    Status result =
109        ffmpeg::ReadAudioFile(input_filename, file_format_, samples_per_second_,
110                              channel_count_, &output_samples);
111    if (result.code() == error::Code::NOT_FOUND) {
112      OP_REQUIRES(
113          context, result.ok(),
114          errors::Unavailable("FFmpeg must be installed to run this op. FFmpeg "
115                              "can be found at http://www.ffmpeg.org."));
116    } else {
117      OP_REQUIRES_OK(context, result);
118    }
119    OP_REQUIRES(
120        context, !output_samples.empty(),
121        errors::Unknown("No output created by FFmpeg."));
122    OP_REQUIRES(
123        context, output_samples.size() % channel_count_ == 0,
124        errors::Unknown("FFmpeg created non-integer number of audio frames."));
125
126    // Copy the output data to the output Tensor.
127    Tensor* output = nullptr;
128    const int64 frame_count = output_samples.size() / channel_count_;
129    OP_REQUIRES_OK(
130        context, context->allocate_output(
131            0, TensorShape({frame_count, channel_count_}), &output));
132    auto matrix = output->tensor<float, 2>();
133    for (int32 frame = 0; frame < frame_count; ++frame) {
134      for (int32 channel = 0; channel < channel_count_; ++channel) {
135        matrix(frame, channel) =
136            output_samples[frame * channel_count_ + channel];
137      }
138    }
139  }
140
141 private:
142  string file_format_;
143  int32 samples_per_second_;
144  int32 channel_count_;
145};
146
147REGISTER_KERNEL_BUILDER(Name("DecodeAudio").Device(DEVICE_CPU), DecodeAudioOp);
148
149REGISTER_OP("DecodeAudio")
150    .Input("contents: string")
151    .Output("sampled_audio: float")
152    .Attr("file_format: string")
153    .Attr("samples_per_second: int")
154    .Attr("channel_count: int")
155    .Doc(R"doc(
156Processes the contents of an audio file into a tensor using FFmpeg to decode
157the file.
158
159One row of the tensor is created for each channel in the audio file. Each
160channel contains audio samples starting at the beginning of the audio and
161having `1/samples_per_second` time between them. If the `channel_count` is
162different from the contents of the file, channels will be merged or created.
163
164contents: The binary audio file contents.
165sampled_audio: A rank 2 tensor containing all tracks of the audio. Dimension 0
166    is time and dimension 1 is the channel.
167file_format: A string describing the audio file format. This can be "wav" or
168    "mp3".
169samples_per_second: The number of samples per second that the audio should have.
170channel_count: The number of channels of audio to read.
171)doc");
172
173}  // namespace ffmpeg
174}  // namespace tensorflow
175