1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5// MSVC++ requires this to be set before any other includes to get M_PI.
6#define _USE_MATH_DEFINES
7
8#include "media/filters/wsola_internals.h"
9
10#include <algorithm>
11#include <cmath>
12#include <limits>
13
14#include "base/logging.h"
15#include "base/memory/scoped_ptr.h"
16#include "media/base/audio_bus.h"
17
18namespace media {
19
20namespace internal {
21
22bool InInterval(int n, Interval q) {
23  return n >= q.first && n <= q.second;
24}
25
26float MultiChannelSimilarityMeasure(const float* dot_prod_a_b,
27                                    const float* energy_a,
28                                    const float* energy_b,
29                                    int channels) {
30  const float kEpsilon = 1e-12f;
31  float similarity_measure = 0.0f;
32  for (int n = 0; n < channels; ++n) {
33    similarity_measure += dot_prod_a_b[n] / sqrt(energy_a[n] * energy_b[n] +
34                                                 kEpsilon);
35  }
36  return similarity_measure;
37}
38
39void MultiChannelDotProduct(const AudioBus* a,
40                            int frame_offset_a,
41                            const AudioBus* b,
42                            int frame_offset_b,
43                            int num_frames,
44                            float* dot_product) {
45  DCHECK_EQ(a->channels(), b->channels());
46  DCHECK_GE(frame_offset_a, 0);
47  DCHECK_GE(frame_offset_b, 0);
48  DCHECK_LE(frame_offset_a + num_frames, a->frames());
49  DCHECK_LE(frame_offset_b + num_frames, b->frames());
50
51  memset(dot_product, 0, sizeof(*dot_product) * a->channels());
52  for (int k = 0; k < a->channels(); ++k) {
53    const float* ch_a = a->channel(k) + frame_offset_a;
54    const float* ch_b = b->channel(k) + frame_offset_b;
55    for (int n = 0; n < num_frames; ++n) {
56      dot_product[k] += *ch_a++ * *ch_b++;
57    }
58  }
59}
60
61void MultiChannelMovingBlockEnergies(const AudioBus* input,
62                                     int frames_per_block,
63                                     float* energy) {
64  int num_blocks = input->frames() - (frames_per_block - 1);
65  int channels = input->channels();
66
67  for (int k = 0; k < input->channels(); ++k) {
68    const float* input_channel = input->channel(k);
69
70    energy[k] = 0;
71
72    // First block of channel |k|.
73    for (int m = 0; m < frames_per_block; ++m) {
74      energy[k] += input_channel[m] * input_channel[m];
75    }
76
77    const float* slide_out = input_channel;
78    const float* slide_in = input_channel + frames_per_block;
79    for (int n = 1; n < num_blocks; ++n, ++slide_in, ++slide_out) {
80      energy[k + n * channels] = energy[k + (n - 1) * channels] - *slide_out *
81          *slide_out + *slide_in * *slide_in;
82    }
83  }
84}
85
86// Fit the curve f(x) = a * x^2 + b * x + c such that
87//   f(-1) = y[0]
88//   f(0) = y[1]
89//   f(1) = y[2]
90// and return the maximum, assuming that y[0] <= y[1] >= y[2].
91void QuadraticInterpolation(const float* y_values,
92                            float* extremum,
93                            float* extremum_value) {
94  float a = 0.5f * (y_values[2] + y_values[0]) - y_values[1];
95  float b = 0.5f * (y_values[2] - y_values[0]);
96  float c = y_values[1];
97
98  if (a == 0.f) {
99    // The coordinates are colinear (within floating-point error).
100    *extremum = 0;
101    *extremum_value = y_values[1];
102  } else {
103    *extremum = -b / (2.f * a);
104    *extremum_value = a * (*extremum) * (*extremum) + b * (*extremum) + c;
105  }
106}
107
108int DecimatedSearch(int decimation,
109                    Interval exclude_interval,
110                    const AudioBus* target_block,
111                    const AudioBus* search_segment,
112                    const float* energy_target_block,
113                    const float* energy_candidate_blocks) {
114  int channels = search_segment->channels();
115  int block_size = target_block->frames();
116  int num_candidate_blocks = search_segment->frames() - (block_size - 1);
117  scoped_ptr<float[]> dot_prod(new float[channels]);
118  float similarity[3];  // Three elements for cubic interpolation.
119
120  int n = 0;
121  MultiChannelDotProduct(target_block, 0, search_segment, n, block_size,
122                         dot_prod.get());
123  similarity[0] = MultiChannelSimilarityMeasure(
124      dot_prod.get(), energy_target_block,
125      &energy_candidate_blocks[n * channels], channels);
126
127  // Set the starting point as optimal point.
128  float best_similarity = similarity[0];
129  int optimal_index = 0;
130
131  n += decimation;
132  if (n >= num_candidate_blocks) {
133    return 0;
134  }
135
136  MultiChannelDotProduct(target_block, 0, search_segment, n, block_size,
137                         dot_prod.get());
138  similarity[1] = MultiChannelSimilarityMeasure(
139      dot_prod.get(), energy_target_block,
140      &energy_candidate_blocks[n * channels], channels);
141
142  n += decimation;
143  if (n >= num_candidate_blocks) {
144    // We cannot do any more sampling. Compare these two values and return the
145    // optimal index.
146    return similarity[1] > similarity[0] ? decimation : 0;
147  }
148
149  for (; n < num_candidate_blocks; n += decimation) {
150    MultiChannelDotProduct(target_block, 0, search_segment, n, block_size,
151                           dot_prod.get());
152
153    similarity[2] = MultiChannelSimilarityMeasure(
154        dot_prod.get(), energy_target_block,
155        &energy_candidate_blocks[n * channels], channels);
156
157    if ((similarity[1] > similarity[0] && similarity[1] >= similarity[2]) ||
158        (similarity[1] >= similarity[0] && similarity[1] > similarity[2])) {
159      // A local maximum is found. Do a cubic interpolation for a better
160      // estimate of candidate maximum.
161      float normalized_candidate_index;
162      float candidate_similarity;
163      QuadraticInterpolation(similarity, &normalized_candidate_index,
164                             &candidate_similarity);
165
166      int candidate_index = n - decimation + static_cast<int>(
167          normalized_candidate_index * decimation +  0.5f);
168      if (candidate_similarity > best_similarity &&
169          !InInterval(candidate_index, exclude_interval)) {
170        optimal_index = candidate_index;
171        best_similarity = candidate_similarity;
172      }
173    } else if (n + decimation >= num_candidate_blocks &&
174               similarity[2] > best_similarity &&
175               !InInterval(n, exclude_interval)) {
176      // If this is the end-point and has a better similarity-measure than
177      // optimal, then we accept it as optimal point.
178      optimal_index = n;
179      best_similarity = similarity[2];
180    }
181    memmove(similarity, &similarity[1], 2 * sizeof(*similarity));
182  }
183  return optimal_index;
184}
185
186int FullSearch(int low_limit,
187               int high_limit,
188               Interval exclude_interval,
189               const AudioBus* target_block,
190               const AudioBus* search_block,
191               const float* energy_target_block,
192               const float* energy_candidate_blocks) {
193  int channels = search_block->channels();
194  int block_size = target_block->frames();
195  scoped_ptr<float[]> dot_prod(new float[channels]);
196
197  float best_similarity = std::numeric_limits<float>::min();
198  int optimal_index = 0;
199
200  for (int n = low_limit; n <= high_limit; ++n) {
201    if (InInterval(n, exclude_interval)) {
202      continue;
203    }
204    MultiChannelDotProduct(target_block, 0, search_block, n, block_size,
205                           dot_prod.get());
206
207    float similarity = MultiChannelSimilarityMeasure(
208        dot_prod.get(), energy_target_block,
209        &energy_candidate_blocks[n * channels], channels);
210
211    if (similarity > best_similarity) {
212      best_similarity = similarity;
213      optimal_index = n;
214    }
215  }
216
217  return optimal_index;
218}
219
220int OptimalIndex(const AudioBus* search_block,
221                 const AudioBus* target_block,
222                 Interval exclude_interval) {
223  int channels = search_block->channels();
224  DCHECK_EQ(channels, target_block->channels());
225  int target_size = target_block->frames();
226  int num_candidate_blocks = search_block->frames() - (target_size - 1);
227
228  // This is a compromise between complexity reduction and search accuracy. I
229  // don't have a proof that down sample of order 5 is optimal. One can compute
230  // a decimation factor that minimizes complexity given the size of
231  // |search_block| and |target_block|. However, my experiments show the rate of
232  // missing the optimal index is significant. This value is chosen
233  // heuristically based on experiments.
234  const int kSearchDecimation = 5;
235
236  scoped_ptr<float[]> energy_target_block(new float[channels]);
237  scoped_ptr<float[]> energy_candidate_blocks(
238      new float[channels * num_candidate_blocks]);
239
240  // Energy of all candid frames.
241  MultiChannelMovingBlockEnergies(search_block, target_size,
242                                  energy_candidate_blocks.get());
243
244  // Energy of target frame.
245  MultiChannelDotProduct(target_block, 0, target_block, 0,
246                         target_size, energy_target_block.get());
247
248  int optimal_index = DecimatedSearch(kSearchDecimation,
249                                      exclude_interval, target_block,
250                                      search_block, energy_target_block.get(),
251                                      energy_candidate_blocks.get());
252
253  int lim_low = std::max(0, optimal_index - kSearchDecimation);
254  int lim_high = std::min(num_candidate_blocks - 1,
255                          optimal_index + kSearchDecimation);
256  return FullSearch(lim_low, lim_high, exclude_interval, target_block,
257                    search_block, energy_target_block.get(),
258                    energy_candidate_blocks.get());
259}
260
261void GetSymmetricHanningWindow(int window_length, float* window) {
262  const float scale = 2.0f * M_PI / window_length;
263  for (int n = 0; n < window_length; ++n)
264    window[n] = 0.5f * (1.0f - cosf(n * scale));
265}
266
267}  // namespace internal
268
269}  // namespace media
270
271