1/*
2 *  Copyright (c) 2012 The WebRTC project authors. All Rights Reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "webrtc/modules/audio_processing/noise_suppression_impl.h"
12
13#include "webrtc/modules/audio_processing/audio_buffer.h"
14#if defined(WEBRTC_NS_FLOAT)
15#include "webrtc/modules/audio_processing/ns/noise_suppression.h"
16#define NS_CREATE WebRtcNs_Create
17#define NS_FREE WebRtcNs_Free
18#define NS_INIT WebRtcNs_Init
19#define NS_SET_POLICY WebRtcNs_set_policy
20typedef NsHandle NsState;
21#elif defined(WEBRTC_NS_FIXED)
22#include "webrtc/modules/audio_processing/ns/noise_suppression_x.h"
23#define NS_CREATE WebRtcNsx_Create
24#define NS_FREE WebRtcNsx_Free
25#define NS_INIT WebRtcNsx_Init
26#define NS_SET_POLICY WebRtcNsx_set_policy
27typedef NsxHandle NsState;
28#endif
29
30namespace webrtc {
31class NoiseSuppressionImpl::Suppressor {
32 public:
33  explicit Suppressor(int sample_rate_hz) {
34    state_ = NS_CREATE();
35    RTC_CHECK(state_);
36    int error = NS_INIT(state_, sample_rate_hz);
37    RTC_DCHECK_EQ(0, error);
38  }
39  ~Suppressor() {
40    NS_FREE(state_);
41  }
42  NsState* state() { return state_; }
43 private:
44  NsState* state_ = nullptr;
45  RTC_DISALLOW_IMPLICIT_CONSTRUCTORS(Suppressor);
46};
47
48NoiseSuppressionImpl::NoiseSuppressionImpl(rtc::CriticalSection* crit)
49    : crit_(crit) {
50  RTC_DCHECK(crit);
51}
52
53NoiseSuppressionImpl::~NoiseSuppressionImpl() {}
54
55void NoiseSuppressionImpl::Initialize(size_t channels, int sample_rate_hz) {
56  rtc::CritScope cs(crit_);
57  channels_ = channels;
58  sample_rate_hz_ = sample_rate_hz;
59  std::vector<rtc::scoped_ptr<Suppressor>> new_suppressors;
60  if (enabled_) {
61    new_suppressors.resize(channels);
62    for (size_t i = 0; i < channels; i++) {
63      new_suppressors[i].reset(new Suppressor(sample_rate_hz));
64    }
65  }
66  suppressors_.swap(new_suppressors);
67  set_level(level_);
68}
69
70void NoiseSuppressionImpl::AnalyzeCaptureAudio(AudioBuffer* audio) {
71  RTC_DCHECK(audio);
72#if defined(WEBRTC_NS_FLOAT)
73  rtc::CritScope cs(crit_);
74  if (!enabled_) {
75    return;
76  }
77
78  RTC_DCHECK_GE(160u, audio->num_frames_per_band());
79  RTC_DCHECK_EQ(suppressors_.size(), audio->num_channels());
80  for (size_t i = 0; i < suppressors_.size(); i++) {
81    WebRtcNs_Analyze(suppressors_[i]->state(),
82                     audio->split_bands_const_f(i)[kBand0To8kHz]);
83  }
84#endif
85}
86
87void NoiseSuppressionImpl::ProcessCaptureAudio(AudioBuffer* audio) {
88  RTC_DCHECK(audio);
89  rtc::CritScope cs(crit_);
90  if (!enabled_) {
91    return;
92  }
93
94  RTC_DCHECK_GE(160u, audio->num_frames_per_band());
95  RTC_DCHECK_EQ(suppressors_.size(), audio->num_channels());
96  for (size_t i = 0; i < suppressors_.size(); i++) {
97#if defined(WEBRTC_NS_FLOAT)
98    WebRtcNs_Process(suppressors_[i]->state(),
99                     audio->split_bands_const_f(i),
100                     audio->num_bands(),
101                     audio->split_bands_f(i));
102#elif defined(WEBRTC_NS_FIXED)
103    WebRtcNsx_Process(suppressors_[i]->state(),
104                      audio->split_bands_const(i),
105                      audio->num_bands(),
106                      audio->split_bands(i));
107#endif
108  }
109}
110
111int NoiseSuppressionImpl::Enable(bool enable) {
112  rtc::CritScope cs(crit_);
113  if (enabled_ != enable) {
114    enabled_ = enable;
115    Initialize(channels_, sample_rate_hz_);
116  }
117  return AudioProcessing::kNoError;
118}
119
120bool NoiseSuppressionImpl::is_enabled() const {
121  rtc::CritScope cs(crit_);
122  return enabled_;
123}
124
125int NoiseSuppressionImpl::set_level(Level level) {
126  int policy = 1;
127  switch (level) {
128    case NoiseSuppression::kLow:
129      policy = 0;
130      break;
131    case NoiseSuppression::kModerate:
132      policy = 1;
133      break;
134    case NoiseSuppression::kHigh:
135      policy = 2;
136      break;
137    case NoiseSuppression::kVeryHigh:
138      policy = 3;
139      break;
140    default:
141      RTC_NOTREACHED();
142  }
143  rtc::CritScope cs(crit_);
144  level_ = level;
145  for (auto& suppressor : suppressors_) {
146    int error = NS_SET_POLICY(suppressor->state(), policy);
147    RTC_DCHECK_EQ(0, error);
148  }
149  return AudioProcessing::kNoError;
150}
151
152NoiseSuppression::Level NoiseSuppressionImpl::level() const {
153  rtc::CritScope cs(crit_);
154  return level_;
155}
156
157float NoiseSuppressionImpl::speech_probability() const {
158  rtc::CritScope cs(crit_);
159#if defined(WEBRTC_NS_FLOAT)
160  float probability_average = 0.0f;
161  for (auto& suppressor : suppressors_) {
162    probability_average +=
163        WebRtcNs_prior_speech_probability(suppressor->state());
164  }
165  if (!suppressors_.empty()) {
166    probability_average /= suppressors_.size();
167  }
168  return probability_average;
169#elif defined(WEBRTC_NS_FIXED)
170  // TODO(peah): Returning error code as a float! Remove this.
171  // Currently not available for the fixed point implementation.
172  return AudioProcessing::kUnsupportedFunctionError;
173#endif
174}
175}  // namespace webrtc
176