1/*
2 *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#ifndef WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_NONLINEAR_BEAMFORMER_H_
12#define WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_NONLINEAR_BEAMFORMER_H_
13
14// MSVC++ requires this to be set before any other includes to get M_PI.
15#define _USE_MATH_DEFINES
16
17#include <math.h>
18#include <vector>
19
20#include "webrtc/common_audio/lapped_transform.h"
21#include "webrtc/common_audio/channel_buffer.h"
22#include "webrtc/modules/audio_processing/beamformer/beamformer.h"
23#include "webrtc/modules/audio_processing/beamformer/complex_matrix.h"
24#include "webrtc/system_wrappers/include/scoped_vector.h"
25
26namespace webrtc {
27
28// Enhances sound sources coming directly in front of a uniform linear array
29// and suppresses sound sources coming from all other directions. Operates on
30// multichannel signals and produces single-channel output.
31//
32// The implemented nonlinear postfilter algorithm taken from "A Robust Nonlinear
33// Beamforming Postprocessor" by Bastiaan Kleijn.
34class NonlinearBeamformer
35  : public Beamformer<float>,
36    public LappedTransform::Callback {
37 public:
38  static const float kHalfBeamWidthRadians;
39
40  explicit NonlinearBeamformer(
41      const std::vector<Point>& array_geometry,
42      SphericalPointf target_direction =
43          SphericalPointf(static_cast<float>(M_PI) / 2.f, 0.f, 1.f));
44
45  // Sample rate corresponds to the lower band.
46  // Needs to be called before the NonlinearBeamformer can be used.
47  void Initialize(int chunk_size_ms, int sample_rate_hz) override;
48
49  // Process one time-domain chunk of audio. The audio is expected to be split
50  // into frequency bands inside the ChannelBuffer. The number of frames and
51  // channels must correspond to the constructor parameters. The same
52  // ChannelBuffer can be passed in as |input| and |output|.
53  void ProcessChunk(const ChannelBuffer<float>& input,
54                    ChannelBuffer<float>* output) override;
55
56  void AimAt(const SphericalPointf& target_direction) override;
57
58  bool IsInBeam(const SphericalPointf& spherical_point) override;
59
60  // After processing each block |is_target_present_| is set to true if the
61  // target signal es present and to false otherwise. This methods can be called
62  // to know if the data is target signal or interference and process it
63  // accordingly.
64  bool is_target_present() override { return is_target_present_; }
65
66 protected:
67  // Process one frequency-domain block of audio. This is where the fun
68  // happens. Implements LappedTransform::Callback.
69  void ProcessAudioBlock(const complex<float>* const* input,
70                         size_t num_input_channels,
71                         size_t num_freq_bins,
72                         size_t num_output_channels,
73                         complex<float>* const* output) override;
74
75 private:
76  FRIEND_TEST_ALL_PREFIXES(NonlinearBeamformerTest,
77                           InterfAnglesTakeAmbiguityIntoAccount);
78
79  typedef Matrix<float> MatrixF;
80  typedef ComplexMatrix<float> ComplexMatrixF;
81  typedef complex<float> complex_f;
82
83  void InitLowFrequencyCorrectionRanges();
84  void InitHighFrequencyCorrectionRanges();
85  void InitInterfAngles();
86  void InitDelaySumMasks();
87  void InitTargetCovMats();
88  void InitDiffuseCovMats();
89  void InitInterfCovMats();
90  void NormalizeCovMats();
91
92  // Calculates postfilter masks that minimize the mean squared error of our
93  // estimation of the desired signal.
94  float CalculatePostfilterMask(const ComplexMatrixF& interf_cov_mat,
95                                float rpsiw,
96                                float ratio_rxiw_rxim,
97                                float rmxi_r);
98
99  // Prevents the postfilter masks from degenerating too quickly (a cause of
100  // musical noise).
101  void ApplyMaskTimeSmoothing();
102  void ApplyMaskFrequencySmoothing();
103
104  // The postfilter masks are unreliable at low frequencies. Calculates a better
105  // mask by averaging mid-low frequency values.
106  void ApplyLowFrequencyCorrection();
107
108  // Postfilter masks are also unreliable at high frequencies. Average mid-high
109  // frequency masks to calculate a single mask per block which can be applied
110  // in the time-domain. Further, we average these block-masks over a chunk,
111  // resulting in one postfilter mask per audio chunk. This allows us to skip
112  // both transforming and blocking the high-frequency signal.
113  void ApplyHighFrequencyCorrection();
114
115  // Compute the means needed for the above frequency correction.
116  float MaskRangeMean(size_t start_bin, size_t end_bin);
117
118  // Applies both sets of masks to |input| and store in |output|.
119  void ApplyMasks(const complex_f* const* input, complex_f* const* output);
120
121  void EstimateTargetPresence();
122
123  static const size_t kFftSize = 256;
124  static const size_t kNumFreqBins = kFftSize / 2 + 1;
125
126  // Deals with the fft transform and blocking.
127  size_t chunk_length_;
128  rtc::scoped_ptr<LappedTransform> lapped_transform_;
129  float window_[kFftSize];
130
131  // Parameters exposed to the user.
132  const size_t num_input_channels_;
133  int sample_rate_hz_;
134
135  const std::vector<Point> array_geometry_;
136  // The normal direction of the array if it has one and it is in the xy-plane.
137  const rtc::Optional<Point> array_normal_;
138
139  // Minimum spacing between microphone pairs.
140  const float min_mic_spacing_;
141
142  // Calculated based on user-input and constants in the .cc file.
143  size_t low_mean_start_bin_;
144  size_t low_mean_end_bin_;
145  size_t high_mean_start_bin_;
146  size_t high_mean_end_bin_;
147
148  // Quickly varying mask updated every block.
149  float new_mask_[kNumFreqBins];
150  // Time smoothed mask.
151  float time_smooth_mask_[kNumFreqBins];
152  // Time and frequency smoothed mask.
153  float final_mask_[kNumFreqBins];
154
155  float target_angle_radians_;
156  // Angles of the interferer scenarios.
157  std::vector<float> interf_angles_radians_;
158  // The angle between the target and the interferer scenarios.
159  const float away_radians_;
160
161  // Array of length |kNumFreqBins|, Matrix of size |1| x |num_channels_|.
162  ComplexMatrixF delay_sum_masks_[kNumFreqBins];
163  ComplexMatrixF normalized_delay_sum_masks_[kNumFreqBins];
164
165  // Arrays of length |kNumFreqBins|, Matrix of size |num_input_channels_| x
166  // |num_input_channels_|.
167  ComplexMatrixF target_cov_mats_[kNumFreqBins];
168  ComplexMatrixF uniform_cov_mat_[kNumFreqBins];
169  // Array of length |kNumFreqBins|, Matrix of size |num_input_channels_| x
170  // |num_input_channels_|. ScopedVector has a size equal to the number of
171  // interferer scenarios.
172  ScopedVector<ComplexMatrixF> interf_cov_mats_[kNumFreqBins];
173
174  // Of length |kNumFreqBins|.
175  float wave_numbers_[kNumFreqBins];
176
177  // Preallocated for ProcessAudioBlock()
178  // Of length |kNumFreqBins|.
179  float rxiws_[kNumFreqBins];
180  // The vector has a size equal to the number of interferer scenarios.
181  std::vector<float> rpsiws_[kNumFreqBins];
182
183  // The microphone normalization factor.
184  ComplexMatrixF eig_m_;
185
186  // For processing the high-frequency input signal.
187  float high_pass_postfilter_mask_;
188
189  // True when the target signal is present.
190  bool is_target_present_;
191  // Number of blocks after which the data is considered interference if the
192  // mask does not pass |kMaskSignalThreshold|.
193  size_t hold_target_blocks_;
194  // Number of blocks since the last mask that passed |kMaskSignalThreshold|.
195  size_t interference_blocks_count_;
196};
197
198}  // namespace webrtc
199
200#endif  // WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_NONLINEAR_BEAMFORMER_H_
201