1/*
2 *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "webrtc/modules/audio_coding/neteq/audio_classifier.h"
12
13#include <math.h>
14#include <stdio.h>
15#include <stdlib.h>
16#include <string.h>
17
18#include <string>
19#include <iostream>
20
21#include "webrtc/base/scoped_ptr.h"
22
23int main(int argc, char* argv[]) {
24  if (argc != 5) {
25    std::cout << "Usage: " << argv[0] <<
26        " channels output_type <input file name> <output file name> "
27        << std::endl << std::endl;
28    std::cout << "Where channels can be 1 (mono) or 2 (interleaved stereo),";
29    std::cout << " outputs can be 1 (classification (boolean)) or 2";
30    std::cout << " (classification and music probability (float)),"
31        << std::endl;
32    std::cout << "and the sampling frequency is assumed to be 48 kHz."
33        << std::endl;
34    return -1;
35  }
36
37  const int kFrameSizeSamples = 960;
38  int channels = atoi(argv[1]);
39  if (channels < 1 || channels > 2) {
40    std::cout << "Disallowed number of channels  " << channels << std::endl;
41    return -1;
42  }
43
44  int outputs = atoi(argv[2]);
45  if (outputs < 1 || outputs > 2) {
46    std::cout << "Disallowed number of outputs  " << outputs << std::endl;
47    return -1;
48  }
49
50  const int data_size = channels * kFrameSizeSamples;
51  rtc::scoped_ptr<int16_t[]> in(new int16_t[data_size]);
52
53  std::string input_filename = argv[3];
54  std::string output_filename = argv[4];
55
56  std::cout << "Input file: " << input_filename << std::endl;
57  std::cout << "Output file: " << output_filename << std::endl;
58
59  FILE* in_file = fopen(input_filename.c_str(), "rb");
60  if (!in_file) {
61    std::cout << "Cannot open input file " << input_filename << std::endl;
62    return -1;
63  }
64
65  FILE* out_file = fopen(output_filename.c_str(), "wb");
66  if (!out_file) {
67    std::cout << "Cannot open output file " << output_filename << std::endl;
68    return -1;
69  }
70
71  webrtc::AudioClassifier classifier;
72  int frame_counter = 0;
73  int music_counter = 0;
74  while (fread(in.get(), sizeof(*in.get()),
75               data_size, in_file) == (size_t) data_size) {
76    bool is_music = classifier.Analysis(in.get(), data_size, channels);
77    if (!fwrite(&is_music, sizeof(is_music), 1, out_file)) {
78       std::cout << "Error writing." << std::endl;
79       return -1;
80    }
81    if (is_music) {
82      music_counter++;
83    }
84    std::cout << "frame " << frame_counter << " decision " << is_music;
85    if (outputs == 2) {
86      float music_prob = classifier.music_probability();
87      if (!fwrite(&music_prob, sizeof(music_prob), 1, out_file)) {
88        std::cout << "Error writing." << std::endl;
89        return -1;
90      }
91      std::cout << " music prob " << music_prob;
92    }
93    std::cout << std::endl;
94    frame_counter++;
95  }
96  std::cout << frame_counter << " frames processed." << std::endl;
97  if (frame_counter > 0) {
98    float music_percentage = music_counter / static_cast<float>(frame_counter);
99    std::cout <<  music_percentage <<  " percent music." << std::endl;
100  }
101
102  fclose(in_file);
103  fclose(out_file);
104  return 0;
105}
106