1/* 2 * Copyright (c) 2016 The WebM project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11#include "third_party/googletest/src/include/gtest/gtest.h" 12 13#include "./vp9_rtcd.h" 14#include "test/acm_random.h" 15#include "test/buffer.h" 16#include "test/register_state_check.h" 17#include "vpx_ports/vpx_timer.h" 18 19namespace { 20 21using ::libvpx_test::ACMRandom; 22using ::libvpx_test::Buffer; 23 24typedef void (*TemporalFilterFunc)(const uint8_t *a, unsigned int stride, 25 const uint8_t *b, unsigned int w, 26 unsigned int h, int filter_strength, 27 int filter_weight, unsigned int *accumulator, 28 uint16_t *count); 29 30// Calculate the difference between 'a' and 'b', sum in blocks of 9, and apply 31// filter based on strength and weight. Store the resulting filter amount in 32// 'count' and apply it to 'b' and store it in 'accumulator'. 33void reference_filter(const Buffer<uint8_t> &a, const Buffer<uint8_t> &b, int w, 34 int h, int filter_strength, int filter_weight, 35 Buffer<unsigned int> *accumulator, 36 Buffer<uint16_t> *count) { 37 Buffer<int> diff_sq = Buffer<int>(w, h, 0); 38 diff_sq.Set(0); 39 40 int rounding = 0; 41 if (filter_strength > 0) { 42 rounding = 1 << (filter_strength - 1); 43 } 44 45 // Calculate all the differences. Avoids re-calculating a bunch of extra 46 // values. 47 for (int height = 0; height < h; ++height) { 48 for (int width = 0; width < w; ++width) { 49 int diff = a.TopLeftPixel()[height * a.stride() + width] - 50 b.TopLeftPixel()[height * b.stride() + width]; 51 diff_sq.TopLeftPixel()[height * diff_sq.stride() + width] = diff * diff; 52 } 53 } 54 55 // For any given point, sum the neighboring values and calculate the 56 // modifier. 57 for (int height = 0; height < h; ++height) { 58 for (int width = 0; width < w; ++width) { 59 // Determine how many values are being summed. 60 int summed_values = 9; 61 62 if (height == 0 || height == (h - 1)) { 63 summed_values -= 3; 64 } 65 66 if (width == 0 || width == (w - 1)) { 67 if (summed_values == 6) { // corner 68 summed_values -= 2; 69 } else { 70 summed_values -= 3; 71 } 72 } 73 74 // Sum the diff_sq of the surrounding values. 75 int sum = 0; 76 for (int idy = -1; idy <= 1; ++idy) { 77 for (int idx = -1; idx <= 1; ++idx) { 78 const int y = height + idy; 79 const int x = width + idx; 80 81 // If inside the border. 82 if (y >= 0 && y < h && x >= 0 && x < w) { 83 sum += diff_sq.TopLeftPixel()[y * diff_sq.stride() + x]; 84 } 85 } 86 } 87 88 sum *= 3; 89 sum /= summed_values; 90 sum += rounding; 91 sum >>= filter_strength; 92 93 // Clamp the value and invert it. 94 if (sum > 16) sum = 16; 95 sum = 16 - sum; 96 97 sum *= filter_weight; 98 99 count->TopLeftPixel()[height * count->stride() + width] += sum; 100 accumulator->TopLeftPixel()[height * accumulator->stride() + width] += 101 sum * b.TopLeftPixel()[height * b.stride() + width]; 102 } 103 } 104} 105 106class TemporalFilterTest : public ::testing::TestWithParam<TemporalFilterFunc> { 107 public: 108 virtual void SetUp() { 109 filter_func_ = GetParam(); 110 rnd_.Reset(ACMRandom::DeterministicSeed()); 111 } 112 113 protected: 114 TemporalFilterFunc filter_func_; 115 ACMRandom rnd_; 116}; 117 118TEST_P(TemporalFilterTest, SizeCombinations) { 119 // Depending on subsampling this function may be called with values of 8 or 16 120 // for width and height, in any combination. 121 Buffer<uint8_t> a = Buffer<uint8_t>(16, 16, 8); 122 123 const int filter_weight = 2; 124 const int filter_strength = 6; 125 126 for (int width = 8; width <= 16; width += 8) { 127 for (int height = 8; height <= 16; height += 8) { 128 // The second buffer must not have any border. 129 Buffer<uint8_t> b = Buffer<uint8_t>(width, height, 0); 130 Buffer<unsigned int> accum_ref = Buffer<unsigned int>(width, height, 0); 131 Buffer<unsigned int> accum_chk = Buffer<unsigned int>(width, height, 0); 132 Buffer<uint16_t> count_ref = Buffer<uint16_t>(width, height, 0); 133 Buffer<uint16_t> count_chk = Buffer<uint16_t>(width, height, 0); 134 135 a.Set(&rnd_, &ACMRandom::Rand8); 136 b.Set(&rnd_, &ACMRandom::Rand8); 137 138 accum_ref.Set(rnd_.Rand8()); 139 accum_chk.CopyFrom(accum_ref); 140 count_ref.Set(rnd_.Rand8()); 141 count_chk.CopyFrom(count_ref); 142 reference_filter(a, b, width, height, filter_strength, filter_weight, 143 &accum_ref, &count_ref); 144 ASM_REGISTER_STATE_CHECK( 145 filter_func_(a.TopLeftPixel(), a.stride(), b.TopLeftPixel(), width, 146 height, filter_strength, filter_weight, 147 accum_chk.TopLeftPixel(), count_chk.TopLeftPixel())); 148 EXPECT_TRUE(accum_chk.CheckValues(accum_ref)); 149 EXPECT_TRUE(count_chk.CheckValues(count_ref)); 150 if (HasFailure()) { 151 printf("Width: %d Height: %d\n", width, height); 152 count_chk.PrintDifference(count_ref); 153 accum_chk.PrintDifference(accum_ref); 154 return; 155 } 156 } 157 } 158} 159 160TEST_P(TemporalFilterTest, CompareReferenceRandom) { 161 for (int width = 8; width <= 16; width += 8) { 162 for (int height = 8; height <= 16; height += 8) { 163 Buffer<uint8_t> a = Buffer<uint8_t>(width, height, 8); 164 // The second buffer must not have any border. 165 Buffer<uint8_t> b = Buffer<uint8_t>(width, height, 0); 166 Buffer<unsigned int> accum_ref = Buffer<unsigned int>(width, height, 0); 167 Buffer<unsigned int> accum_chk = Buffer<unsigned int>(width, height, 0); 168 Buffer<uint16_t> count_ref = Buffer<uint16_t>(width, height, 0); 169 Buffer<uint16_t> count_chk = Buffer<uint16_t>(width, height, 0); 170 171 for (int filter_strength = 0; filter_strength <= 6; ++filter_strength) { 172 for (int filter_weight = 0; filter_weight <= 2; ++filter_weight) { 173 for (int repeat = 0; repeat < 10; ++repeat) { 174 a.Set(&rnd_, &ACMRandom::Rand8); 175 b.Set(&rnd_, &ACMRandom::Rand8); 176 177 accum_ref.Set(rnd_.Rand8()); 178 accum_chk.CopyFrom(accum_ref); 179 count_ref.Set(rnd_.Rand8()); 180 count_chk.CopyFrom(count_ref); 181 reference_filter(a, b, width, height, filter_strength, 182 filter_weight, &accum_ref, &count_ref); 183 ASM_REGISTER_STATE_CHECK(filter_func_( 184 a.TopLeftPixel(), a.stride(), b.TopLeftPixel(), width, height, 185 filter_strength, filter_weight, accum_chk.TopLeftPixel(), 186 count_chk.TopLeftPixel())); 187 EXPECT_TRUE(accum_chk.CheckValues(accum_ref)); 188 EXPECT_TRUE(count_chk.CheckValues(count_ref)); 189 if (HasFailure()) { 190 printf("Weight: %d Strength: %d\n", filter_weight, 191 filter_strength); 192 count_chk.PrintDifference(count_ref); 193 accum_chk.PrintDifference(accum_ref); 194 return; 195 } 196 } 197 } 198 } 199 } 200 } 201} 202 203TEST_P(TemporalFilterTest, DISABLED_Speed) { 204 Buffer<uint8_t> a = Buffer<uint8_t>(16, 16, 8); 205 206 const int filter_weight = 2; 207 const int filter_strength = 6; 208 209 for (int width = 8; width <= 16; width += 8) { 210 for (int height = 8; height <= 16; height += 8) { 211 // The second buffer must not have any border. 212 Buffer<uint8_t> b = Buffer<uint8_t>(width, height, 0); 213 Buffer<unsigned int> accum_ref = Buffer<unsigned int>(width, height, 0); 214 Buffer<unsigned int> accum_chk = Buffer<unsigned int>(width, height, 0); 215 Buffer<uint16_t> count_ref = Buffer<uint16_t>(width, height, 0); 216 Buffer<uint16_t> count_chk = Buffer<uint16_t>(width, height, 0); 217 218 a.Set(&rnd_, &ACMRandom::Rand8); 219 b.Set(&rnd_, &ACMRandom::Rand8); 220 221 accum_chk.Set(0); 222 count_chk.Set(0); 223 224 vpx_usec_timer timer; 225 vpx_usec_timer_start(&timer); 226 for (int i = 0; i < 10000; ++i) { 227 filter_func_(a.TopLeftPixel(), a.stride(), b.TopLeftPixel(), width, 228 height, filter_strength, filter_weight, 229 accum_chk.TopLeftPixel(), count_chk.TopLeftPixel()); 230 } 231 vpx_usec_timer_mark(&timer); 232 const int elapsed_time = static_cast<int>(vpx_usec_timer_elapsed(&timer)); 233 printf("Temporal filter %dx%d time: %5d us\n", width, height, 234 elapsed_time); 235 } 236 } 237} 238 239INSTANTIATE_TEST_CASE_P(C, TemporalFilterTest, 240 ::testing::Values(&vp9_temporal_filter_apply_c)); 241 242#if HAVE_SSE4_1 243INSTANTIATE_TEST_CASE_P(SSE4_1, TemporalFilterTest, 244 ::testing::Values(&vp9_temporal_filter_apply_sse4_1)); 245#endif // HAVE_SSE4_1 246} // namespace 247