1/* 2 * Copyright (c) 2013 The WebRTC project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11#include "webrtc/modules/audio_processing/transient/transient_detector.h" 12 13#include <assert.h> 14#include <float.h> 15#include <math.h> 16#include <string.h> 17 18#include "webrtc/modules/audio_processing/transient/common.h" 19#include "webrtc/modules/audio_processing/transient/daubechies_8_wavelet_coeffs.h" 20#include "webrtc/modules/audio_processing/transient/moving_moments.h" 21#include "webrtc/modules/audio_processing/transient/wpd_tree.h" 22 23namespace webrtc { 24 25static const int kTransientLengthMs = 30; 26static const int kChunksAtStartupLeftToDelete = 27 kTransientLengthMs / ts::kChunkSizeMs; 28static const float kDetectThreshold = 16.f; 29 30TransientDetector::TransientDetector(int sample_rate_hz) 31 : samples_per_chunk_(sample_rate_hz * ts::kChunkSizeMs / 1000), 32 last_first_moment_(), 33 last_second_moment_(), 34 chunks_at_startup_left_to_delete_(kChunksAtStartupLeftToDelete), 35 reference_energy_(1.f), 36 using_reference_(false) { 37 assert(sample_rate_hz == ts::kSampleRate8kHz || 38 sample_rate_hz == ts::kSampleRate16kHz || 39 sample_rate_hz == ts::kSampleRate32kHz || 40 sample_rate_hz == ts::kSampleRate48kHz); 41 int samples_per_transient = sample_rate_hz * kTransientLengthMs / 1000; 42 // Adjustment to avoid data loss while downsampling, making 43 // |samples_per_chunk_| and |samples_per_transient| always divisible by 44 // |kLeaves|. 45 samples_per_chunk_ -= samples_per_chunk_ % kLeaves; 46 samples_per_transient -= samples_per_transient % kLeaves; 47 48 tree_leaves_data_length_ = samples_per_chunk_ / kLeaves; 49 wpd_tree_.reset(new WPDTree(samples_per_chunk_, 50 kDaubechies8HighPassCoefficients, 51 kDaubechies8LowPassCoefficients, 52 kDaubechies8CoefficientsLength, 53 kLevels)); 54 for (size_t i = 0; i < kLeaves; ++i) { 55 moving_moments_[i].reset( 56 new MovingMoments(samples_per_transient / kLeaves)); 57 } 58 59 first_moments_.reset(new float[tree_leaves_data_length_]); 60 second_moments_.reset(new float[tree_leaves_data_length_]); 61 62 for (int i = 0; i < kChunksAtStartupLeftToDelete; ++i) { 63 previous_results_.push_back(0.f); 64 } 65} 66 67TransientDetector::~TransientDetector() {} 68 69float TransientDetector::Detect(const float* data, 70 size_t data_length, 71 const float* reference_data, 72 size_t reference_length) { 73 assert(data && data_length == samples_per_chunk_); 74 75 // TODO(aluebs): Check if these errors can logically happen and if not assert 76 // on them. 77 if (wpd_tree_->Update(data, samples_per_chunk_) != 0) { 78 return -1.f; 79 } 80 81 float result = 0.f; 82 83 for (size_t i = 0; i < kLeaves; ++i) { 84 WPDNode* leaf = wpd_tree_->NodeAt(kLevels, i); 85 86 moving_moments_[i]->CalculateMoments(leaf->data(), 87 tree_leaves_data_length_, 88 first_moments_.get(), 89 second_moments_.get()); 90 91 // Add value delayed (Use the last moments from the last call to Detect). 92 float unbiased_data = leaf->data()[0] - last_first_moment_[i]; 93 result += 94 unbiased_data * unbiased_data / (last_second_moment_[i] + FLT_MIN); 95 96 // Add new values. 97 for (size_t j = 1; j < tree_leaves_data_length_; ++j) { 98 unbiased_data = leaf->data()[j] - first_moments_[j - 1]; 99 result += 100 unbiased_data * unbiased_data / (second_moments_[j - 1] + FLT_MIN); 101 } 102 103 last_first_moment_[i] = first_moments_[tree_leaves_data_length_ - 1]; 104 last_second_moment_[i] = second_moments_[tree_leaves_data_length_ - 1]; 105 } 106 107 result /= tree_leaves_data_length_; 108 109 result *= ReferenceDetectionValue(reference_data, reference_length); 110 111 if (chunks_at_startup_left_to_delete_ > 0) { 112 chunks_at_startup_left_to_delete_--; 113 result = 0.f; 114 } 115 116 if (result >= kDetectThreshold) { 117 result = 1.f; 118 } else { 119 // Get proportional value. 120 // Proportion achieved with a squared raised cosine function with domain 121 // [0, kDetectThreshold) and image [0, 1), it's always increasing. 122 const float horizontal_scaling = ts::kPi / kDetectThreshold; 123 const float kHorizontalShift = ts::kPi; 124 const float kVerticalScaling = 0.5f; 125 const float kVerticalShift = 1.f; 126 127 result = (cos(result * horizontal_scaling + kHorizontalShift) 128 + kVerticalShift) * kVerticalScaling; 129 result *= result; 130 } 131 132 previous_results_.pop_front(); 133 previous_results_.push_back(result); 134 135 // In the current implementation we return the max of the current result and 136 // the previous results, so the high results have a width equals to 137 // |transient_length|. 138 return *std::max_element(previous_results_.begin(), previous_results_.end()); 139} 140 141// Looks for the highest slope and compares it with the previous ones. 142// An exponential transformation takes this to the [0, 1] range. This value is 143// multiplied by the detection result to avoid false positives. 144float TransientDetector::ReferenceDetectionValue(const float* data, 145 size_t length) { 146 if (data == NULL) { 147 using_reference_ = false; 148 return 1.f; 149 } 150 static const float kEnergyRatioThreshold = 0.2f; 151 static const float kReferenceNonLinearity = 20.f; 152 static const float kMemory = 0.99f; 153 float reference_energy = 0.f; 154 for (size_t i = 1; i < length; ++i) { 155 reference_energy += data[i] * data[i]; 156 } 157 if (reference_energy == 0.f) { 158 using_reference_ = false; 159 return 1.f; 160 } 161 assert(reference_energy_ != 0); 162 float result = 1.f / (1.f + exp(kReferenceNonLinearity * 163 (kEnergyRatioThreshold - 164 reference_energy / reference_energy_))); 165 reference_energy_ = 166 kMemory * reference_energy_ + (1.f - kMemory) * reference_energy; 167 168 using_reference_ = true; 169 170 return result; 171} 172 173} // namespace webrtc 174