1/*
2 *  Copyright (c) 2012 The WebM project authors. All Rights Reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include <arm_neon.h>
12
13#include "vp8/encoder/denoising.h"
14#include "vpx_mem/vpx_mem.h"
15#include "./vp8_rtcd.h"
16
17/*
18 * The filter function was modified to reduce the computational complexity.
19 *
20 * Step 1:
21 *  Instead of applying tap coefficients for each pixel, we calculated the
22 *  pixel adjustments vs. pixel diff value ahead of time.
23 *     adjustment = filtered_value - current_raw
24 *                = (filter_coefficient * diff + 128) >> 8
25 *  where
26 *     filter_coefficient = (255 << 8) / (256 + ((abs_diff * 330) >> 3));
27 *     filter_coefficient += filter_coefficient /
28 *                           (3 + motion_magnitude_adjustment);
29 *     filter_coefficient is clamped to 0 ~ 255.
30 *
31 * Step 2:
32 *  The adjustment vs. diff curve becomes flat very quick when diff increases.
33 *  This allowed us to use only several levels to approximate the curve without
34 *  changing the filtering algorithm too much.
35 *  The adjustments were further corrected by checking the motion magnitude.
36 *  The levels used are:
37 *      diff          level       adjustment w/o       adjustment w/
38 *                               motion correction    motion correction
39 *      [-255, -16]     3              -6                   -7
40 *      [-15, -8]       2              -4                   -5
41 *      [-7, -4]        1              -3                   -4
42 *      [-3, 3]         0              diff                 diff
43 *      [4, 7]          1               3                    4
44 *      [8, 15]         2               4                    5
45 *      [16, 255]       3               6                    7
46 */
47
48int vp8_denoiser_filter_neon(unsigned char *mc_running_avg_y,
49                             int mc_running_avg_y_stride,
50                             unsigned char *running_avg_y,
51                             int running_avg_y_stride, unsigned char *sig,
52                             int sig_stride, unsigned int motion_magnitude,
53                             int increase_denoising) {
54  /* If motion_magnitude is small, making the denoiser more aggressive by
55   * increasing the adjustment for each level, level1 adjustment is
56   * increased, the deltas stay the same.
57   */
58  int shift_inc =
59      (increase_denoising && motion_magnitude <= MOTION_MAGNITUDE_THRESHOLD)
60          ? 1
61          : 0;
62  const uint8x16_t v_level1_adjustment = vmovq_n_u8(
63      (motion_magnitude <= MOTION_MAGNITUDE_THRESHOLD) ? 4 + shift_inc : 3);
64  const uint8x16_t v_delta_level_1_and_2 = vdupq_n_u8(1);
65  const uint8x16_t v_delta_level_2_and_3 = vdupq_n_u8(2);
66  const uint8x16_t v_level1_threshold = vmovq_n_u8(4 + shift_inc);
67  const uint8x16_t v_level2_threshold = vdupq_n_u8(8);
68  const uint8x16_t v_level3_threshold = vdupq_n_u8(16);
69  int64x2_t v_sum_diff_total = vdupq_n_s64(0);
70
71  /* Go over lines. */
72  int r;
73  for (r = 0; r < 16; ++r) {
74    /* Load inputs. */
75    const uint8x16_t v_sig = vld1q_u8(sig);
76    const uint8x16_t v_mc_running_avg_y = vld1q_u8(mc_running_avg_y);
77
78    /* Calculate absolute difference and sign masks. */
79    const uint8x16_t v_abs_diff = vabdq_u8(v_sig, v_mc_running_avg_y);
80    const uint8x16_t v_diff_pos_mask = vcltq_u8(v_sig, v_mc_running_avg_y);
81    const uint8x16_t v_diff_neg_mask = vcgtq_u8(v_sig, v_mc_running_avg_y);
82
83    /* Figure out which level that put us in. */
84    const uint8x16_t v_level1_mask = vcleq_u8(v_level1_threshold, v_abs_diff);
85    const uint8x16_t v_level2_mask = vcleq_u8(v_level2_threshold, v_abs_diff);
86    const uint8x16_t v_level3_mask = vcleq_u8(v_level3_threshold, v_abs_diff);
87
88    /* Calculate absolute adjustments for level 1, 2 and 3. */
89    const uint8x16_t v_level2_adjustment =
90        vandq_u8(v_level2_mask, v_delta_level_1_and_2);
91    const uint8x16_t v_level3_adjustment =
92        vandq_u8(v_level3_mask, v_delta_level_2_and_3);
93    const uint8x16_t v_level1and2_adjustment =
94        vaddq_u8(v_level1_adjustment, v_level2_adjustment);
95    const uint8x16_t v_level1and2and3_adjustment =
96        vaddq_u8(v_level1and2_adjustment, v_level3_adjustment);
97
98    /* Figure adjustment absolute value by selecting between the absolute
99     * difference if in level0 or the value for level 1, 2 and 3.
100     */
101    const uint8x16_t v_abs_adjustment =
102        vbslq_u8(v_level1_mask, v_level1and2and3_adjustment, v_abs_diff);
103
104    /* Calculate positive and negative adjustments. Apply them to the signal
105     * and accumulate them. Adjustments are less than eight and the maximum
106     * sum of them (7 * 16) can fit in a signed char.
107     */
108    const uint8x16_t v_pos_adjustment =
109        vandq_u8(v_diff_pos_mask, v_abs_adjustment);
110    const uint8x16_t v_neg_adjustment =
111        vandq_u8(v_diff_neg_mask, v_abs_adjustment);
112
113    uint8x16_t v_running_avg_y = vqaddq_u8(v_sig, v_pos_adjustment);
114    v_running_avg_y = vqsubq_u8(v_running_avg_y, v_neg_adjustment);
115
116    /* Store results. */
117    vst1q_u8(running_avg_y, v_running_avg_y);
118
119    /* Sum all the accumulators to have the sum of all pixel differences
120     * for this macroblock.
121     */
122    {
123      const int8x16_t v_sum_diff =
124          vqsubq_s8(vreinterpretq_s8_u8(v_pos_adjustment),
125                    vreinterpretq_s8_u8(v_neg_adjustment));
126
127      const int16x8_t fe_dc_ba_98_76_54_32_10 = vpaddlq_s8(v_sum_diff);
128
129      const int32x4_t fedc_ba98_7654_3210 =
130          vpaddlq_s16(fe_dc_ba_98_76_54_32_10);
131
132      const int64x2_t fedcba98_76543210 = vpaddlq_s32(fedc_ba98_7654_3210);
133
134      v_sum_diff_total = vqaddq_s64(v_sum_diff_total, fedcba98_76543210);
135    }
136
137    /* Update pointers for next iteration. */
138    sig += sig_stride;
139    mc_running_avg_y += mc_running_avg_y_stride;
140    running_avg_y += running_avg_y_stride;
141  }
142
143  /* Too much adjustments => copy block. */
144  {
145    int64x1_t x = vqadd_s64(vget_high_s64(v_sum_diff_total),
146                            vget_low_s64(v_sum_diff_total));
147    int sum_diff = vget_lane_s32(vabs_s32(vreinterpret_s32_s64(x)), 0);
148    int sum_diff_thresh = SUM_DIFF_THRESHOLD;
149
150    if (increase_denoising) sum_diff_thresh = SUM_DIFF_THRESHOLD_HIGH;
151    if (sum_diff > sum_diff_thresh) {
152      // Before returning to copy the block (i.e., apply no denoising),
153      // checK if we can still apply some (weaker) temporal filtering to
154      // this block, that would otherwise not be denoised at all. Simplest
155      // is to apply an additional adjustment to running_avg_y to bring it
156      // closer to sig. The adjustment is capped by a maximum delta, and
157      // chosen such that in most cases the resulting sum_diff will be
158      // within the accceptable range given by sum_diff_thresh.
159
160      // The delta is set by the excess of absolute pixel diff over the
161      // threshold.
162      int delta = ((sum_diff - sum_diff_thresh) >> 8) + 1;
163      // Only apply the adjustment for max delta up to 3.
164      if (delta < 4) {
165        const uint8x16_t k_delta = vmovq_n_u8(delta);
166        sig -= sig_stride * 16;
167        mc_running_avg_y -= mc_running_avg_y_stride * 16;
168        running_avg_y -= running_avg_y_stride * 16;
169        for (r = 0; r < 16; ++r) {
170          uint8x16_t v_running_avg_y = vld1q_u8(running_avg_y);
171          const uint8x16_t v_sig = vld1q_u8(sig);
172          const uint8x16_t v_mc_running_avg_y = vld1q_u8(mc_running_avg_y);
173
174          /* Calculate absolute difference and sign masks. */
175          const uint8x16_t v_abs_diff = vabdq_u8(v_sig, v_mc_running_avg_y);
176          const uint8x16_t v_diff_pos_mask =
177              vcltq_u8(v_sig, v_mc_running_avg_y);
178          const uint8x16_t v_diff_neg_mask =
179              vcgtq_u8(v_sig, v_mc_running_avg_y);
180          // Clamp absolute difference to delta to get the adjustment.
181          const uint8x16_t v_abs_adjustment = vminq_u8(v_abs_diff, (k_delta));
182
183          const uint8x16_t v_pos_adjustment =
184              vandq_u8(v_diff_pos_mask, v_abs_adjustment);
185          const uint8x16_t v_neg_adjustment =
186              vandq_u8(v_diff_neg_mask, v_abs_adjustment);
187
188          v_running_avg_y = vqsubq_u8(v_running_avg_y, v_pos_adjustment);
189          v_running_avg_y = vqaddq_u8(v_running_avg_y, v_neg_adjustment);
190
191          /* Store results. */
192          vst1q_u8(running_avg_y, v_running_avg_y);
193
194          {
195            const int8x16_t v_sum_diff =
196                vqsubq_s8(vreinterpretq_s8_u8(v_neg_adjustment),
197                          vreinterpretq_s8_u8(v_pos_adjustment));
198
199            const int16x8_t fe_dc_ba_98_76_54_32_10 = vpaddlq_s8(v_sum_diff);
200            const int32x4_t fedc_ba98_7654_3210 =
201                vpaddlq_s16(fe_dc_ba_98_76_54_32_10);
202            const int64x2_t fedcba98_76543210 =
203                vpaddlq_s32(fedc_ba98_7654_3210);
204
205            v_sum_diff_total = vqaddq_s64(v_sum_diff_total, fedcba98_76543210);
206          }
207          /* Update pointers for next iteration. */
208          sig += sig_stride;
209          mc_running_avg_y += mc_running_avg_y_stride;
210          running_avg_y += running_avg_y_stride;
211        }
212        {
213          // Update the sum of all pixel differences of this MB.
214          x = vqadd_s64(vget_high_s64(v_sum_diff_total),
215                        vget_low_s64(v_sum_diff_total));
216          sum_diff = vget_lane_s32(vabs_s32(vreinterpret_s32_s64(x)), 0);
217
218          if (sum_diff > sum_diff_thresh) {
219            return COPY_BLOCK;
220          }
221        }
222      } else {
223        return COPY_BLOCK;
224      }
225    }
226  }
227
228  /* Tell above level that block was filtered. */
229  running_avg_y -= running_avg_y_stride * 16;
230  sig -= sig_stride * 16;
231
232  vp8_copy_mem16x16(running_avg_y, running_avg_y_stride, sig, sig_stride);
233
234  return FILTER_BLOCK;
235}
236
237int vp8_denoiser_filter_uv_neon(unsigned char *mc_running_avg,
238                                int mc_running_avg_stride,
239                                unsigned char *running_avg,
240                                int running_avg_stride, unsigned char *sig,
241                                int sig_stride, unsigned int motion_magnitude,
242                                int increase_denoising) {
243  /* If motion_magnitude is small, making the denoiser more aggressive by
244   * increasing the adjustment for each level, level1 adjustment is
245   * increased, the deltas stay the same.
246   */
247  int shift_inc =
248      (increase_denoising && motion_magnitude <= MOTION_MAGNITUDE_THRESHOLD_UV)
249          ? 1
250          : 0;
251  const uint8x16_t v_level1_adjustment = vmovq_n_u8(
252      (motion_magnitude <= MOTION_MAGNITUDE_THRESHOLD_UV) ? 4 + shift_inc : 3);
253
254  const uint8x16_t v_delta_level_1_and_2 = vdupq_n_u8(1);
255  const uint8x16_t v_delta_level_2_and_3 = vdupq_n_u8(2);
256  const uint8x16_t v_level1_threshold = vmovq_n_u8(4 + shift_inc);
257  const uint8x16_t v_level2_threshold = vdupq_n_u8(8);
258  const uint8x16_t v_level3_threshold = vdupq_n_u8(16);
259  int64x2_t v_sum_diff_total = vdupq_n_s64(0);
260  int r;
261
262  {
263    uint16x4_t v_sum_block = vdup_n_u16(0);
264
265    // Avoid denoising color signal if its close to average level.
266    for (r = 0; r < 8; ++r) {
267      const uint8x8_t v_sig = vld1_u8(sig);
268      const uint16x4_t _76_54_32_10 = vpaddl_u8(v_sig);
269      v_sum_block = vqadd_u16(v_sum_block, _76_54_32_10);
270      sig += sig_stride;
271    }
272    sig -= sig_stride * 8;
273    {
274      const uint32x2_t _7654_3210 = vpaddl_u16(v_sum_block);
275      const uint64x1_t _76543210 = vpaddl_u32(_7654_3210);
276      const int sum_block = vget_lane_s32(vreinterpret_s32_u64(_76543210), 0);
277      if (abs(sum_block - (128 * 8 * 8)) < SUM_DIFF_FROM_AVG_THRESH_UV) {
278        return COPY_BLOCK;
279      }
280    }
281  }
282
283  /* Go over lines. */
284  for (r = 0; r < 4; ++r) {
285    /* Load inputs. */
286    const uint8x8_t v_sig_lo = vld1_u8(sig);
287    const uint8x8_t v_sig_hi = vld1_u8(&sig[sig_stride]);
288    const uint8x16_t v_sig = vcombine_u8(v_sig_lo, v_sig_hi);
289    const uint8x8_t v_mc_running_avg_lo = vld1_u8(mc_running_avg);
290    const uint8x8_t v_mc_running_avg_hi =
291        vld1_u8(&mc_running_avg[mc_running_avg_stride]);
292    const uint8x16_t v_mc_running_avg =
293        vcombine_u8(v_mc_running_avg_lo, v_mc_running_avg_hi);
294    /* Calculate absolute difference and sign masks. */
295    const uint8x16_t v_abs_diff = vabdq_u8(v_sig, v_mc_running_avg);
296    const uint8x16_t v_diff_pos_mask = vcltq_u8(v_sig, v_mc_running_avg);
297    const uint8x16_t v_diff_neg_mask = vcgtq_u8(v_sig, v_mc_running_avg);
298
299    /* Figure out which level that put us in. */
300    const uint8x16_t v_level1_mask = vcleq_u8(v_level1_threshold, v_abs_diff);
301    const uint8x16_t v_level2_mask = vcleq_u8(v_level2_threshold, v_abs_diff);
302    const uint8x16_t v_level3_mask = vcleq_u8(v_level3_threshold, v_abs_diff);
303
304    /* Calculate absolute adjustments for level 1, 2 and 3. */
305    const uint8x16_t v_level2_adjustment =
306        vandq_u8(v_level2_mask, v_delta_level_1_and_2);
307    const uint8x16_t v_level3_adjustment =
308        vandq_u8(v_level3_mask, v_delta_level_2_and_3);
309    const uint8x16_t v_level1and2_adjustment =
310        vaddq_u8(v_level1_adjustment, v_level2_adjustment);
311    const uint8x16_t v_level1and2and3_adjustment =
312        vaddq_u8(v_level1and2_adjustment, v_level3_adjustment);
313
314    /* Figure adjustment absolute value by selecting between the absolute
315     * difference if in level0 or the value for level 1, 2 and 3.
316     */
317    const uint8x16_t v_abs_adjustment =
318        vbslq_u8(v_level1_mask, v_level1and2and3_adjustment, v_abs_diff);
319
320    /* Calculate positive and negative adjustments. Apply them to the signal
321     * and accumulate them. Adjustments are less than eight and the maximum
322     * sum of them (7 * 16) can fit in a signed char.
323     */
324    const uint8x16_t v_pos_adjustment =
325        vandq_u8(v_diff_pos_mask, v_abs_adjustment);
326    const uint8x16_t v_neg_adjustment =
327        vandq_u8(v_diff_neg_mask, v_abs_adjustment);
328
329    uint8x16_t v_running_avg = vqaddq_u8(v_sig, v_pos_adjustment);
330    v_running_avg = vqsubq_u8(v_running_avg, v_neg_adjustment);
331
332    /* Store results. */
333    vst1_u8(running_avg, vget_low_u8(v_running_avg));
334    vst1_u8(&running_avg[running_avg_stride], vget_high_u8(v_running_avg));
335
336    /* Sum all the accumulators to have the sum of all pixel differences
337     * for this macroblock.
338     */
339    {
340      const int8x16_t v_sum_diff =
341          vqsubq_s8(vreinterpretq_s8_u8(v_pos_adjustment),
342                    vreinterpretq_s8_u8(v_neg_adjustment));
343
344      const int16x8_t fe_dc_ba_98_76_54_32_10 = vpaddlq_s8(v_sum_diff);
345
346      const int32x4_t fedc_ba98_7654_3210 =
347          vpaddlq_s16(fe_dc_ba_98_76_54_32_10);
348
349      const int64x2_t fedcba98_76543210 = vpaddlq_s32(fedc_ba98_7654_3210);
350
351      v_sum_diff_total = vqaddq_s64(v_sum_diff_total, fedcba98_76543210);
352    }
353
354    /* Update pointers for next iteration. */
355    sig += sig_stride * 2;
356    mc_running_avg += mc_running_avg_stride * 2;
357    running_avg += running_avg_stride * 2;
358  }
359
360  /* Too much adjustments => copy block. */
361  {
362    int64x1_t x = vqadd_s64(vget_high_s64(v_sum_diff_total),
363                            vget_low_s64(v_sum_diff_total));
364    int sum_diff = vget_lane_s32(vabs_s32(vreinterpret_s32_s64(x)), 0);
365    int sum_diff_thresh = SUM_DIFF_THRESHOLD_UV;
366    if (increase_denoising) sum_diff_thresh = SUM_DIFF_THRESHOLD_HIGH_UV;
367    if (sum_diff > sum_diff_thresh) {
368      // Before returning to copy the block (i.e., apply no denoising),
369      // checK if we can still apply some (weaker) temporal filtering to
370      // this block, that would otherwise not be denoised at all. Simplest
371      // is to apply an additional adjustment to running_avg_y to bring it
372      // closer to sig. The adjustment is capped by a maximum delta, and
373      // chosen such that in most cases the resulting sum_diff will be
374      // within the accceptable range given by sum_diff_thresh.
375
376      // The delta is set by the excess of absolute pixel diff over the
377      // threshold.
378      int delta = ((sum_diff - sum_diff_thresh) >> 8) + 1;
379      // Only apply the adjustment for max delta up to 3.
380      if (delta < 4) {
381        const uint8x16_t k_delta = vmovq_n_u8(delta);
382        sig -= sig_stride * 8;
383        mc_running_avg -= mc_running_avg_stride * 8;
384        running_avg -= running_avg_stride * 8;
385        for (r = 0; r < 4; ++r) {
386          const uint8x8_t v_sig_lo = vld1_u8(sig);
387          const uint8x8_t v_sig_hi = vld1_u8(&sig[sig_stride]);
388          const uint8x16_t v_sig = vcombine_u8(v_sig_lo, v_sig_hi);
389          const uint8x8_t v_mc_running_avg_lo = vld1_u8(mc_running_avg);
390          const uint8x8_t v_mc_running_avg_hi =
391              vld1_u8(&mc_running_avg[mc_running_avg_stride]);
392          const uint8x16_t v_mc_running_avg =
393              vcombine_u8(v_mc_running_avg_lo, v_mc_running_avg_hi);
394          /* Calculate absolute difference and sign masks. */
395          const uint8x16_t v_abs_diff = vabdq_u8(v_sig, v_mc_running_avg);
396          const uint8x16_t v_diff_pos_mask = vcltq_u8(v_sig, v_mc_running_avg);
397          const uint8x16_t v_diff_neg_mask = vcgtq_u8(v_sig, v_mc_running_avg);
398          // Clamp absolute difference to delta to get the adjustment.
399          const uint8x16_t v_abs_adjustment = vminq_u8(v_abs_diff, (k_delta));
400
401          const uint8x16_t v_pos_adjustment =
402              vandq_u8(v_diff_pos_mask, v_abs_adjustment);
403          const uint8x16_t v_neg_adjustment =
404              vandq_u8(v_diff_neg_mask, v_abs_adjustment);
405          const uint8x8_t v_running_avg_lo = vld1_u8(running_avg);
406          const uint8x8_t v_running_avg_hi =
407              vld1_u8(&running_avg[running_avg_stride]);
408          uint8x16_t v_running_avg =
409              vcombine_u8(v_running_avg_lo, v_running_avg_hi);
410
411          v_running_avg = vqsubq_u8(v_running_avg, v_pos_adjustment);
412          v_running_avg = vqaddq_u8(v_running_avg, v_neg_adjustment);
413
414          /* Store results. */
415          vst1_u8(running_avg, vget_low_u8(v_running_avg));
416          vst1_u8(&running_avg[running_avg_stride],
417                  vget_high_u8(v_running_avg));
418
419          {
420            const int8x16_t v_sum_diff =
421                vqsubq_s8(vreinterpretq_s8_u8(v_neg_adjustment),
422                          vreinterpretq_s8_u8(v_pos_adjustment));
423
424            const int16x8_t fe_dc_ba_98_76_54_32_10 = vpaddlq_s8(v_sum_diff);
425            const int32x4_t fedc_ba98_7654_3210 =
426                vpaddlq_s16(fe_dc_ba_98_76_54_32_10);
427            const int64x2_t fedcba98_76543210 =
428                vpaddlq_s32(fedc_ba98_7654_3210);
429
430            v_sum_diff_total = vqaddq_s64(v_sum_diff_total, fedcba98_76543210);
431          }
432          /* Update pointers for next iteration. */
433          sig += sig_stride * 2;
434          mc_running_avg += mc_running_avg_stride * 2;
435          running_avg += running_avg_stride * 2;
436        }
437        {
438          // Update the sum of all pixel differences of this MB.
439          x = vqadd_s64(vget_high_s64(v_sum_diff_total),
440                        vget_low_s64(v_sum_diff_total));
441          sum_diff = vget_lane_s32(vabs_s32(vreinterpret_s32_s64(x)), 0);
442
443          if (sum_diff > sum_diff_thresh) {
444            return COPY_BLOCK;
445          }
446        }
447      } else {
448        return COPY_BLOCK;
449      }
450    }
451  }
452
453  /* Tell above level that block was filtered. */
454  running_avg -= running_avg_stride * 8;
455  sig -= sig_stride * 8;
456
457  vp8_copy_mem8x8(running_avg, running_avg_stride, sig, sig_stride);
458
459  return FILTER_BLOCK;
460}
461