1/*
2 *  Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "webrtc/modules/audio_processing/transient/transient_suppressor.h"
12
13#include <math.h>
14#include <string.h>
15#include <cmath>
16#include <complex>
17#include <deque>
18#include <set>
19
20#include "webrtc/base/scoped_ptr.h"
21#include "webrtc/common_audio/fft4g.h"
22#include "webrtc/common_audio/include/audio_util.h"
23#include "webrtc/common_audio/signal_processing/include/signal_processing_library.h"
24#include "webrtc/modules/audio_processing/transient/common.h"
25#include "webrtc/modules/audio_processing/transient/transient_detector.h"
26#include "webrtc/modules/audio_processing/ns/windows_private.h"
27#include "webrtc/system_wrappers/include/logging.h"
28#include "webrtc/typedefs.h"
29
30namespace webrtc {
31
32static const float kMeanIIRCoefficient = 0.5f;
33static const float kVoiceThreshold = 0.02f;
34
35// TODO(aluebs): Check if these values work also for 48kHz.
36static const size_t kMinVoiceBin = 3;
37static const size_t kMaxVoiceBin = 60;
38
39namespace {
40
41float ComplexMagnitude(float a, float b) {
42  return std::abs(a) + std::abs(b);
43}
44
45}  // namespace
46
47TransientSuppressor::TransientSuppressor()
48    : data_length_(0),
49      detection_length_(0),
50      analysis_length_(0),
51      buffer_delay_(0),
52      complex_analysis_length_(0),
53      num_channels_(0),
54      window_(NULL),
55      detector_smoothed_(0.f),
56      keypress_counter_(0),
57      chunks_since_keypress_(0),
58      detection_enabled_(false),
59      suppression_enabled_(false),
60      use_hard_restoration_(false),
61      chunks_since_voice_change_(0),
62      seed_(182),
63      using_reference_(false) {
64}
65
66TransientSuppressor::~TransientSuppressor() {}
67
68int TransientSuppressor::Initialize(int sample_rate_hz,
69                                    int detection_rate_hz,
70                                    int num_channels) {
71  switch (sample_rate_hz) {
72    case ts::kSampleRate8kHz:
73      analysis_length_ = 128u;
74      window_ = kBlocks80w128;
75      break;
76    case ts::kSampleRate16kHz:
77      analysis_length_ = 256u;
78      window_ = kBlocks160w256;
79      break;
80    case ts::kSampleRate32kHz:
81      analysis_length_ = 512u;
82      window_ = kBlocks320w512;
83      break;
84    case ts::kSampleRate48kHz:
85      analysis_length_ = 1024u;
86      window_ = kBlocks480w1024;
87      break;
88    default:
89      return -1;
90  }
91  if (detection_rate_hz != ts::kSampleRate8kHz &&
92      detection_rate_hz != ts::kSampleRate16kHz &&
93      detection_rate_hz != ts::kSampleRate32kHz &&
94      detection_rate_hz != ts::kSampleRate48kHz) {
95    return -1;
96  }
97  if (num_channels <= 0) {
98    return -1;
99  }
100
101  detector_.reset(new TransientDetector(detection_rate_hz));
102  data_length_ = sample_rate_hz * ts::kChunkSizeMs / 1000;
103  if (data_length_ > analysis_length_) {
104    assert(false);
105    return -1;
106  }
107  buffer_delay_ = analysis_length_ - data_length_;
108
109  complex_analysis_length_ = analysis_length_ / 2 + 1;
110  assert(complex_analysis_length_ >= kMaxVoiceBin);
111  num_channels_ = num_channels;
112  in_buffer_.reset(new float[analysis_length_ * num_channels_]);
113  memset(in_buffer_.get(),
114         0,
115         analysis_length_ * num_channels_ * sizeof(in_buffer_[0]));
116  detection_length_ = detection_rate_hz * ts::kChunkSizeMs / 1000;
117  detection_buffer_.reset(new float[detection_length_]);
118  memset(detection_buffer_.get(),
119         0,
120         detection_length_ * sizeof(detection_buffer_[0]));
121  out_buffer_.reset(new float[analysis_length_ * num_channels_]);
122  memset(out_buffer_.get(),
123         0,
124         analysis_length_ * num_channels_ * sizeof(out_buffer_[0]));
125  // ip[0] must be zero to trigger initialization using rdft().
126  size_t ip_length = 2 + sqrtf(analysis_length_);
127  ip_.reset(new size_t[ip_length]());
128  memset(ip_.get(), 0, ip_length * sizeof(ip_[0]));
129  wfft_.reset(new float[complex_analysis_length_ - 1]);
130  memset(wfft_.get(), 0, (complex_analysis_length_ - 1) * sizeof(wfft_[0]));
131  spectral_mean_.reset(new float[complex_analysis_length_ * num_channels_]);
132  memset(spectral_mean_.get(),
133         0,
134         complex_analysis_length_ * num_channels_ * sizeof(spectral_mean_[0]));
135  fft_buffer_.reset(new float[analysis_length_ + 2]);
136  memset(fft_buffer_.get(), 0, (analysis_length_ + 2) * sizeof(fft_buffer_[0]));
137  magnitudes_.reset(new float[complex_analysis_length_]);
138  memset(magnitudes_.get(),
139         0,
140         complex_analysis_length_ * sizeof(magnitudes_[0]));
141  mean_factor_.reset(new float[complex_analysis_length_]);
142
143  static const float kFactorHeight = 10.f;
144  static const float kLowSlope = 1.f;
145  static const float kHighSlope = 0.3f;
146  for (size_t i = 0; i < complex_analysis_length_; ++i) {
147    mean_factor_[i] =
148        kFactorHeight /
149            (1.f + exp(kLowSlope * static_cast<int>(i - kMinVoiceBin))) +
150        kFactorHeight /
151            (1.f + exp(kHighSlope * static_cast<int>(kMaxVoiceBin - i)));
152  }
153  detector_smoothed_ = 0.f;
154  keypress_counter_ = 0;
155  chunks_since_keypress_ = 0;
156  detection_enabled_ = false;
157  suppression_enabled_ = false;
158  use_hard_restoration_ = false;
159  chunks_since_voice_change_ = 0;
160  seed_ = 182;
161  using_reference_ = false;
162  return 0;
163}
164
165int TransientSuppressor::Suppress(float* data,
166                                  size_t data_length,
167                                  int num_channels,
168                                  const float* detection_data,
169                                  size_t detection_length,
170                                  const float* reference_data,
171                                  size_t reference_length,
172                                  float voice_probability,
173                                  bool key_pressed) {
174  if (!data || data_length != data_length_ || num_channels != num_channels_ ||
175      detection_length != detection_length_ || voice_probability < 0 ||
176      voice_probability > 1) {
177    return -1;
178  }
179
180  UpdateKeypress(key_pressed);
181  UpdateBuffers(data);
182
183  int result = 0;
184  if (detection_enabled_) {
185    UpdateRestoration(voice_probability);
186
187    if (!detection_data) {
188      // Use the input data  of the first channel if special detection data is
189      // not supplied.
190      detection_data = &in_buffer_[buffer_delay_];
191    }
192
193    float detector_result = detector_->Detect(
194        detection_data, detection_length, reference_data, reference_length);
195    if (detector_result < 0) {
196      return -1;
197    }
198
199    using_reference_ = detector_->using_reference();
200
201    // |detector_smoothed_| follows the |detector_result| when this last one is
202    // increasing, but has an exponential decaying tail to be able to suppress
203    // the ringing of keyclicks.
204    float smooth_factor = using_reference_ ? 0.6 : 0.1;
205    detector_smoothed_ = detector_result >= detector_smoothed_
206                             ? detector_result
207                             : smooth_factor * detector_smoothed_ +
208                                   (1 - smooth_factor) * detector_result;
209
210    for (int i = 0; i < num_channels_; ++i) {
211      Suppress(&in_buffer_[i * analysis_length_],
212               &spectral_mean_[i * complex_analysis_length_],
213               &out_buffer_[i * analysis_length_]);
214    }
215  }
216
217  // If the suppression isn't enabled, we use the in buffer to delay the signal
218  // appropriately. This also gives time for the out buffer to be refreshed with
219  // new data between detection and suppression getting enabled.
220  for (int i = 0; i < num_channels_; ++i) {
221    memcpy(&data[i * data_length_],
222           suppression_enabled_ ? &out_buffer_[i * analysis_length_]
223                                : &in_buffer_[i * analysis_length_],
224           data_length_ * sizeof(*data));
225  }
226  return result;
227}
228
229// This should only be called when detection is enabled. UpdateBuffers() must
230// have been called. At return, |out_buffer_| will be filled with the
231// processed output.
232void TransientSuppressor::Suppress(float* in_ptr,
233                                   float* spectral_mean,
234                                   float* out_ptr) {
235  // Go to frequency domain.
236  for (size_t i = 0; i < analysis_length_; ++i) {
237    // TODO(aluebs): Rename windows
238    fft_buffer_[i] = in_ptr[i] * window_[i];
239  }
240
241  WebRtc_rdft(analysis_length_, 1, fft_buffer_.get(), ip_.get(), wfft_.get());
242
243  // Since WebRtc_rdft puts R[n/2] in fft_buffer_[1], we move it to the end
244  // for convenience.
245  fft_buffer_[analysis_length_] = fft_buffer_[1];
246  fft_buffer_[analysis_length_ + 1] = 0.f;
247  fft_buffer_[1] = 0.f;
248
249  for (size_t i = 0; i < complex_analysis_length_; ++i) {
250    magnitudes_[i] = ComplexMagnitude(fft_buffer_[i * 2],
251                                      fft_buffer_[i * 2 + 1]);
252  }
253  // Restore audio if necessary.
254  if (suppression_enabled_) {
255    if (use_hard_restoration_) {
256      HardRestoration(spectral_mean);
257    } else {
258      SoftRestoration(spectral_mean);
259    }
260  }
261
262  // Update the spectral mean.
263  for (size_t i = 0; i < complex_analysis_length_; ++i) {
264    spectral_mean[i] = (1 - kMeanIIRCoefficient) * spectral_mean[i] +
265                       kMeanIIRCoefficient * magnitudes_[i];
266  }
267
268  // Back to time domain.
269  // Put R[n/2] back in fft_buffer_[1].
270  fft_buffer_[1] = fft_buffer_[analysis_length_];
271
272  WebRtc_rdft(analysis_length_,
273              -1,
274              fft_buffer_.get(),
275              ip_.get(),
276              wfft_.get());
277  const float fft_scaling = 2.f / analysis_length_;
278
279  for (size_t i = 0; i < analysis_length_; ++i) {
280    out_ptr[i] += fft_buffer_[i] * window_[i] * fft_scaling;
281  }
282}
283
284void TransientSuppressor::UpdateKeypress(bool key_pressed) {
285  const int kKeypressPenalty = 1000 / ts::kChunkSizeMs;
286  const int kIsTypingThreshold = 1000 / ts::kChunkSizeMs;
287  const int kChunksUntilNotTyping = 4000 / ts::kChunkSizeMs;  // 4 seconds.
288
289  if (key_pressed) {
290    keypress_counter_ += kKeypressPenalty;
291    chunks_since_keypress_ = 0;
292    detection_enabled_ = true;
293  }
294  keypress_counter_ = std::max(0, keypress_counter_ - 1);
295
296  if (keypress_counter_ > kIsTypingThreshold) {
297    if (!suppression_enabled_) {
298      LOG(LS_INFO) << "[ts] Transient suppression is now enabled.";
299    }
300    suppression_enabled_ = true;
301    keypress_counter_ = 0;
302  }
303
304  if (detection_enabled_ &&
305      ++chunks_since_keypress_ > kChunksUntilNotTyping) {
306    if (suppression_enabled_) {
307      LOG(LS_INFO) << "[ts] Transient suppression is now disabled.";
308    }
309    detection_enabled_ = false;
310    suppression_enabled_ = false;
311    keypress_counter_ = 0;
312  }
313}
314
315void TransientSuppressor::UpdateRestoration(float voice_probability) {
316  const int kHardRestorationOffsetDelay = 3;
317  const int kHardRestorationOnsetDelay = 80;
318
319  bool not_voiced = voice_probability < kVoiceThreshold;
320
321  if (not_voiced == use_hard_restoration_) {
322    chunks_since_voice_change_ = 0;
323  } else {
324    ++chunks_since_voice_change_;
325
326    if ((use_hard_restoration_ &&
327         chunks_since_voice_change_ > kHardRestorationOffsetDelay) ||
328        (!use_hard_restoration_ &&
329         chunks_since_voice_change_ > kHardRestorationOnsetDelay)) {
330      use_hard_restoration_ = not_voiced;
331      chunks_since_voice_change_ = 0;
332    }
333  }
334}
335
336// Shift buffers to make way for new data. Must be called after
337// |detection_enabled_| is updated by UpdateKeypress().
338void TransientSuppressor::UpdateBuffers(float* data) {
339  // TODO(aluebs): Change to ring buffer.
340  memmove(in_buffer_.get(),
341          &in_buffer_[data_length_],
342          (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
343              sizeof(in_buffer_[0]));
344  // Copy new chunk to buffer.
345  for (int i = 0; i < num_channels_; ++i) {
346    memcpy(&in_buffer_[buffer_delay_ + i * analysis_length_],
347           &data[i * data_length_],
348           data_length_ * sizeof(*data));
349  }
350  if (detection_enabled_) {
351    // Shift previous chunk in out buffer.
352    memmove(out_buffer_.get(),
353            &out_buffer_[data_length_],
354            (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
355                sizeof(out_buffer_[0]));
356    // Initialize new chunk in out buffer.
357    for (int i = 0; i < num_channels_; ++i) {
358      memset(&out_buffer_[buffer_delay_ + i * analysis_length_],
359             0,
360             data_length_ * sizeof(out_buffer_[0]));
361    }
362  }
363}
364
365// Restores the unvoiced signal if a click is present.
366// Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
367// the spectral mean. The attenuation depends on |detector_smoothed_|.
368// If a restoration takes place, the |magnitudes_| are updated to the new value.
369void TransientSuppressor::HardRestoration(float* spectral_mean) {
370  const float detector_result =
371      1.f - pow(1.f - detector_smoothed_, using_reference_ ? 200.f : 50.f);
372  // To restore, we get the peaks in the spectrum. If higher than the previous
373  // spectral mean we adjust them.
374  for (size_t i = 0; i < complex_analysis_length_; ++i) {
375    if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0) {
376      // RandU() generates values on [0, int16::max()]
377      const float phase = 2 * ts::kPi * WebRtcSpl_RandU(&seed_) /
378          std::numeric_limits<int16_t>::max();
379      const float scaled_mean = detector_result * spectral_mean[i];
380
381      fft_buffer_[i * 2] = (1 - detector_result) * fft_buffer_[i * 2] +
382                           scaled_mean * cosf(phase);
383      fft_buffer_[i * 2 + 1] = (1 - detector_result) * fft_buffer_[i * 2 + 1] +
384                               scaled_mean * sinf(phase);
385      magnitudes_[i] = magnitudes_[i] -
386                       detector_result * (magnitudes_[i] - spectral_mean[i]);
387    }
388  }
389}
390
391// Restores the voiced signal if a click is present.
392// Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
393// the spectral mean and that is lower than some function of the current block
394// frequency mean. The attenuation depends on |detector_smoothed_|.
395// If a restoration takes place, the |magnitudes_| are updated to the new value.
396void TransientSuppressor::SoftRestoration(float* spectral_mean) {
397  // Get the spectral magnitude mean of the current block.
398  float block_frequency_mean = 0;
399  for (size_t i = kMinVoiceBin; i < kMaxVoiceBin; ++i) {
400    block_frequency_mean += magnitudes_[i];
401  }
402  block_frequency_mean /= (kMaxVoiceBin - kMinVoiceBin);
403
404  // To restore, we get the peaks in the spectrum. If higher than the
405  // previous spectral mean and lower than a factor of the block mean
406  // we adjust them. The factor is a double sigmoid that has a minimum in the
407  // voice frequency range (300Hz - 3kHz).
408  for (size_t i = 0; i < complex_analysis_length_; ++i) {
409    if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0 &&
410        (using_reference_ ||
411         magnitudes_[i] < block_frequency_mean * mean_factor_[i])) {
412      const float new_magnitude =
413          magnitudes_[i] -
414          detector_smoothed_ * (magnitudes_[i] - spectral_mean[i]);
415      const float magnitude_ratio = new_magnitude / magnitudes_[i];
416
417      fft_buffer_[i * 2] *= magnitude_ratio;
418      fft_buffer_[i * 2 + 1] *= magnitude_ratio;
419      magnitudes_[i] = new_magnitude;
420    }
421  }
422}
423
424}  // namespace webrtc
425