1/*
2 *  Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "webrtc/modules/audio_processing/transient/wpd_node.h"
12
13#include <assert.h>
14#include <math.h>
15#include <string.h>
16
17#include "webrtc/base/scoped_ptr.h"
18#include "webrtc/common_audio/fir_filter.h"
19#include "webrtc/modules/audio_processing/transient/dyadic_decimator.h"
20
21namespace webrtc {
22
23WPDNode::WPDNode(size_t length,
24                 const float* coefficients,
25                 size_t coefficients_length)
26    : // The data buffer has parent data length to be able to contain and filter
27      // it.
28      data_(new float[2 * length + 1]),
29      length_(length),
30      filter_(FIRFilter::Create(coefficients,
31                                coefficients_length,
32                                2 * length + 1)) {
33  assert(length > 0 && coefficients && coefficients_length > 0);
34  memset(data_.get(), 0.f, (2 * length + 1) * sizeof(data_[0]));
35}
36
37WPDNode::~WPDNode() {}
38
39int WPDNode::Update(const float* parent_data, size_t parent_data_length) {
40  if (!parent_data || (parent_data_length / 2) != length_) {
41    return -1;
42  }
43
44  // Filter data.
45  filter_->Filter(parent_data, parent_data_length, data_.get());
46
47  // Decimate data.
48  const bool kOddSequence = true;
49  size_t output_samples = DyadicDecimate(
50      data_.get(), parent_data_length, kOddSequence, data_.get(), length_);
51  if (output_samples != length_) {
52    return -1;
53  }
54
55  // Get abs to all values.
56  for (size_t i = 0; i < length_; ++i) {
57    data_[i] = fabs(data_[i]);
58  }
59
60  return 0;
61}
62
63int WPDNode::set_data(const float* new_data, size_t length) {
64  if (!new_data || length != length_) {
65    return -1;
66  }
67  memcpy(data_.get(), new_data, length * sizeof(data_[0]));
68  return 0;
69}
70
71}  // namespace webrtc
72