1/*
2 *  Copyright (c) 2016 The WebM project authors. All Rights Reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "third_party/googletest/src/include/gtest/gtest.h"
12
13#include "./vp9_rtcd.h"
14#include "test/acm_random.h"
15#include "test/buffer.h"
16#include "test/register_state_check.h"
17#include "vpx_ports/vpx_timer.h"
18
19namespace {
20
21using ::libvpx_test::ACMRandom;
22using ::libvpx_test::Buffer;
23
24typedef void (*TemporalFilterFunc)(const uint8_t *a, unsigned int stride,
25                                   const uint8_t *b, unsigned int w,
26                                   unsigned int h, int filter_strength,
27                                   int filter_weight, unsigned int *accumulator,
28                                   uint16_t *count);
29
30// Calculate the difference between 'a' and 'b', sum in blocks of 9, and apply
31// filter based on strength and weight. Store the resulting filter amount in
32// 'count' and apply it to 'b' and store it in 'accumulator'.
33void reference_filter(const Buffer<uint8_t> &a, const Buffer<uint8_t> &b, int w,
34                      int h, int filter_strength, int filter_weight,
35                      Buffer<unsigned int> *accumulator,
36                      Buffer<uint16_t> *count) {
37  Buffer<int> diff_sq = Buffer<int>(w, h, 0);
38  diff_sq.Set(0);
39
40  int rounding = 0;
41  if (filter_strength > 0) {
42    rounding = 1 << (filter_strength - 1);
43  }
44
45  // Calculate all the differences. Avoids re-calculating a bunch of extra
46  // values.
47  for (int height = 0; height < h; ++height) {
48    for (int width = 0; width < w; ++width) {
49      int diff = a.TopLeftPixel()[height * a.stride() + width] -
50                 b.TopLeftPixel()[height * b.stride() + width];
51      diff_sq.TopLeftPixel()[height * diff_sq.stride() + width] = diff * diff;
52    }
53  }
54
55  // For any given point, sum the neighboring values and calculate the
56  // modifier.
57  for (int height = 0; height < h; ++height) {
58    for (int width = 0; width < w; ++width) {
59      // Determine how many values are being summed.
60      int summed_values = 9;
61
62      if (height == 0 || height == (h - 1)) {
63        summed_values -= 3;
64      }
65
66      if (width == 0 || width == (w - 1)) {
67        if (summed_values == 6) {  // corner
68          summed_values -= 2;
69        } else {
70          summed_values -= 3;
71        }
72      }
73
74      // Sum the diff_sq of the surrounding values.
75      int sum = 0;
76      for (int idy = -1; idy <= 1; ++idy) {
77        for (int idx = -1; idx <= 1; ++idx) {
78          const int y = height + idy;
79          const int x = width + idx;
80
81          // If inside the border.
82          if (y >= 0 && y < h && x >= 0 && x < w) {
83            sum += diff_sq.TopLeftPixel()[y * diff_sq.stride() + x];
84          }
85        }
86      }
87
88      sum *= 3;
89      sum /= summed_values;
90      sum += rounding;
91      sum >>= filter_strength;
92
93      // Clamp the value and invert it.
94      if (sum > 16) sum = 16;
95      sum = 16 - sum;
96
97      sum *= filter_weight;
98
99      count->TopLeftPixel()[height * count->stride() + width] += sum;
100      accumulator->TopLeftPixel()[height * accumulator->stride() + width] +=
101          sum * b.TopLeftPixel()[height * b.stride() + width];
102    }
103  }
104}
105
106class TemporalFilterTest : public ::testing::TestWithParam<TemporalFilterFunc> {
107 public:
108  virtual void SetUp() {
109    filter_func_ = GetParam();
110    rnd_.Reset(ACMRandom::DeterministicSeed());
111  }
112
113 protected:
114  TemporalFilterFunc filter_func_;
115  ACMRandom rnd_;
116};
117
118TEST_P(TemporalFilterTest, SizeCombinations) {
119  // Depending on subsampling this function may be called with values of 8 or 16
120  // for width and height, in any combination.
121  Buffer<uint8_t> a = Buffer<uint8_t>(16, 16, 8);
122
123  const int filter_weight = 2;
124  const int filter_strength = 6;
125
126  for (int width = 8; width <= 16; width += 8) {
127    for (int height = 8; height <= 16; height += 8) {
128      // The second buffer must not have any border.
129      Buffer<uint8_t> b = Buffer<uint8_t>(width, height, 0);
130      Buffer<unsigned int> accum_ref = Buffer<unsigned int>(width, height, 0);
131      Buffer<unsigned int> accum_chk = Buffer<unsigned int>(width, height, 0);
132      Buffer<uint16_t> count_ref = Buffer<uint16_t>(width, height, 0);
133      Buffer<uint16_t> count_chk = Buffer<uint16_t>(width, height, 0);
134
135      a.Set(&rnd_, &ACMRandom::Rand8);
136      b.Set(&rnd_, &ACMRandom::Rand8);
137
138      accum_ref.Set(rnd_.Rand8());
139      accum_chk.CopyFrom(accum_ref);
140      count_ref.Set(rnd_.Rand8());
141      count_chk.CopyFrom(count_ref);
142      reference_filter(a, b, width, height, filter_strength, filter_weight,
143                       &accum_ref, &count_ref);
144      ASM_REGISTER_STATE_CHECK(
145          filter_func_(a.TopLeftPixel(), a.stride(), b.TopLeftPixel(), width,
146                       height, filter_strength, filter_weight,
147                       accum_chk.TopLeftPixel(), count_chk.TopLeftPixel()));
148      EXPECT_TRUE(accum_chk.CheckValues(accum_ref));
149      EXPECT_TRUE(count_chk.CheckValues(count_ref));
150      if (HasFailure()) {
151        printf("Width: %d Height: %d\n", width, height);
152        count_chk.PrintDifference(count_ref);
153        accum_chk.PrintDifference(accum_ref);
154        return;
155      }
156    }
157  }
158}
159
160TEST_P(TemporalFilterTest, CompareReferenceRandom) {
161  for (int width = 8; width <= 16; width += 8) {
162    for (int height = 8; height <= 16; height += 8) {
163      Buffer<uint8_t> a = Buffer<uint8_t>(width, height, 8);
164      // The second buffer must not have any border.
165      Buffer<uint8_t> b = Buffer<uint8_t>(width, height, 0);
166      Buffer<unsigned int> accum_ref = Buffer<unsigned int>(width, height, 0);
167      Buffer<unsigned int> accum_chk = Buffer<unsigned int>(width, height, 0);
168      Buffer<uint16_t> count_ref = Buffer<uint16_t>(width, height, 0);
169      Buffer<uint16_t> count_chk = Buffer<uint16_t>(width, height, 0);
170
171      for (int filter_strength = 0; filter_strength <= 6; ++filter_strength) {
172        for (int filter_weight = 0; filter_weight <= 2; ++filter_weight) {
173          for (int repeat = 0; repeat < 10; ++repeat) {
174            a.Set(&rnd_, &ACMRandom::Rand8);
175            b.Set(&rnd_, &ACMRandom::Rand8);
176
177            accum_ref.Set(rnd_.Rand8());
178            accum_chk.CopyFrom(accum_ref);
179            count_ref.Set(rnd_.Rand8());
180            count_chk.CopyFrom(count_ref);
181            reference_filter(a, b, width, height, filter_strength,
182                             filter_weight, &accum_ref, &count_ref);
183            ASM_REGISTER_STATE_CHECK(filter_func_(
184                a.TopLeftPixel(), a.stride(), b.TopLeftPixel(), width, height,
185                filter_strength, filter_weight, accum_chk.TopLeftPixel(),
186                count_chk.TopLeftPixel()));
187            EXPECT_TRUE(accum_chk.CheckValues(accum_ref));
188            EXPECT_TRUE(count_chk.CheckValues(count_ref));
189            if (HasFailure()) {
190              printf("Weight: %d Strength: %d\n", filter_weight,
191                     filter_strength);
192              count_chk.PrintDifference(count_ref);
193              accum_chk.PrintDifference(accum_ref);
194              return;
195            }
196          }
197        }
198      }
199    }
200  }
201}
202
203TEST_P(TemporalFilterTest, DISABLED_Speed) {
204  Buffer<uint8_t> a = Buffer<uint8_t>(16, 16, 8);
205
206  const int filter_weight = 2;
207  const int filter_strength = 6;
208
209  for (int width = 8; width <= 16; width += 8) {
210    for (int height = 8; height <= 16; height += 8) {
211      // The second buffer must not have any border.
212      Buffer<uint8_t> b = Buffer<uint8_t>(width, height, 0);
213      Buffer<unsigned int> accum_ref = Buffer<unsigned int>(width, height, 0);
214      Buffer<unsigned int> accum_chk = Buffer<unsigned int>(width, height, 0);
215      Buffer<uint16_t> count_ref = Buffer<uint16_t>(width, height, 0);
216      Buffer<uint16_t> count_chk = Buffer<uint16_t>(width, height, 0);
217
218      a.Set(&rnd_, &ACMRandom::Rand8);
219      b.Set(&rnd_, &ACMRandom::Rand8);
220
221      accum_chk.Set(0);
222      count_chk.Set(0);
223
224      vpx_usec_timer timer;
225      vpx_usec_timer_start(&timer);
226      for (int i = 0; i < 10000; ++i) {
227        filter_func_(a.TopLeftPixel(), a.stride(), b.TopLeftPixel(), width,
228                     height, filter_strength, filter_weight,
229                     accum_chk.TopLeftPixel(), count_chk.TopLeftPixel());
230      }
231      vpx_usec_timer_mark(&timer);
232      const int elapsed_time = static_cast<int>(vpx_usec_timer_elapsed(&timer));
233      printf("Temporal filter %dx%d time: %5d us\n", width, height,
234             elapsed_time);
235    }
236  }
237}
238
239INSTANTIATE_TEST_CASE_P(C, TemporalFilterTest,
240                        ::testing::Values(&vp9_temporal_filter_apply_c));
241
242#if HAVE_SSE4_1
243INSTANTIATE_TEST_CASE_P(SSE4_1, TemporalFilterTest,
244                        ::testing::Values(&vp9_temporal_filter_apply_sse4_1));
245#endif  // HAVE_SSE4_1
246}  // namespace
247