1/*
2 *  Copyright (c) 2016 The WebM project authors. All Rights Reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include <arm_neon.h>
12#include <assert.h>
13
14#include "./vpx_dsp_rtcd.h"
15#include "vpx/vpx_integer.h"
16#include "vpx_dsp/arm/transpose_neon.h"
17
18extern const int16_t vpx_rv[];
19
20static uint8x8_t average_k_out(const uint8x8_t a2, const uint8x8_t a1,
21                               const uint8x8_t v0, const uint8x8_t b1,
22                               const uint8x8_t b2) {
23  const uint8x8_t k1 = vrhadd_u8(a2, a1);
24  const uint8x8_t k2 = vrhadd_u8(b2, b1);
25  const uint8x8_t k3 = vrhadd_u8(k1, k2);
26  return vrhadd_u8(k3, v0);
27}
28
29static uint8x8_t generate_mask(const uint8x8_t a2, const uint8x8_t a1,
30                               const uint8x8_t v0, const uint8x8_t b1,
31                               const uint8x8_t b2, const uint8x8_t filter) {
32  const uint8x8_t a2_v0 = vabd_u8(a2, v0);
33  const uint8x8_t a1_v0 = vabd_u8(a1, v0);
34  const uint8x8_t b1_v0 = vabd_u8(b1, v0);
35  const uint8x8_t b2_v0 = vabd_u8(b2, v0);
36
37  uint8x8_t max = vmax_u8(a2_v0, a1_v0);
38  max = vmax_u8(b1_v0, max);
39  max = vmax_u8(b2_v0, max);
40  return vclt_u8(max, filter);
41}
42
43static uint8x8_t generate_output(const uint8x8_t a2, const uint8x8_t a1,
44                                 const uint8x8_t v0, const uint8x8_t b1,
45                                 const uint8x8_t b2, const uint8x8_t filter) {
46  const uint8x8_t k_out = average_k_out(a2, a1, v0, b1, b2);
47  const uint8x8_t mask = generate_mask(a2, a1, v0, b1, b2, filter);
48
49  return vbsl_u8(mask, k_out, v0);
50}
51
52// Same functions but for uint8x16_t.
53static uint8x16_t average_k_outq(const uint8x16_t a2, const uint8x16_t a1,
54                                 const uint8x16_t v0, const uint8x16_t b1,
55                                 const uint8x16_t b2) {
56  const uint8x16_t k1 = vrhaddq_u8(a2, a1);
57  const uint8x16_t k2 = vrhaddq_u8(b2, b1);
58  const uint8x16_t k3 = vrhaddq_u8(k1, k2);
59  return vrhaddq_u8(k3, v0);
60}
61
62static uint8x16_t generate_maskq(const uint8x16_t a2, const uint8x16_t a1,
63                                 const uint8x16_t v0, const uint8x16_t b1,
64                                 const uint8x16_t b2, const uint8x16_t filter) {
65  const uint8x16_t a2_v0 = vabdq_u8(a2, v0);
66  const uint8x16_t a1_v0 = vabdq_u8(a1, v0);
67  const uint8x16_t b1_v0 = vabdq_u8(b1, v0);
68  const uint8x16_t b2_v0 = vabdq_u8(b2, v0);
69
70  uint8x16_t max = vmaxq_u8(a2_v0, a1_v0);
71  max = vmaxq_u8(b1_v0, max);
72  max = vmaxq_u8(b2_v0, max);
73  return vcltq_u8(max, filter);
74}
75
76static uint8x16_t generate_outputq(const uint8x16_t a2, const uint8x16_t a1,
77                                   const uint8x16_t v0, const uint8x16_t b1,
78                                   const uint8x16_t b2,
79                                   const uint8x16_t filter) {
80  const uint8x16_t k_out = average_k_outq(a2, a1, v0, b1, b2);
81  const uint8x16_t mask = generate_maskq(a2, a1, v0, b1, b2, filter);
82
83  return vbslq_u8(mask, k_out, v0);
84}
85
86void vpx_post_proc_down_and_across_mb_row_neon(uint8_t *src_ptr,
87                                               uint8_t *dst_ptr, int src_stride,
88                                               int dst_stride, int cols,
89                                               uint8_t *f, int size) {
90  uint8_t *src, *dst;
91  int row;
92  int col;
93
94  // Process a stripe of macroblocks. The stripe will be a multiple of 16 (for
95  // Y) or 8 (for U/V) wide (cols) and the height (size) will be 16 (for Y) or 8
96  // (for U/V).
97  assert((size == 8 || size == 16) && cols % 8 == 0);
98
99  // While columns of length 16 can be processed, load them.
100  for (col = 0; col < cols - 8; col += 16) {
101    uint8x16_t a0, a1, a2, a3, a4, a5, a6, a7;
102    src = src_ptr - 2 * src_stride;
103    dst = dst_ptr;
104
105    a0 = vld1q_u8(src);
106    src += src_stride;
107    a1 = vld1q_u8(src);
108    src += src_stride;
109    a2 = vld1q_u8(src);
110    src += src_stride;
111    a3 = vld1q_u8(src);
112    src += src_stride;
113
114    for (row = 0; row < size; row += 4) {
115      uint8x16_t v_out_0, v_out_1, v_out_2, v_out_3;
116      const uint8x16_t filterq = vld1q_u8(f + col);
117
118      a4 = vld1q_u8(src);
119      src += src_stride;
120      a5 = vld1q_u8(src);
121      src += src_stride;
122      a6 = vld1q_u8(src);
123      src += src_stride;
124      a7 = vld1q_u8(src);
125      src += src_stride;
126
127      v_out_0 = generate_outputq(a0, a1, a2, a3, a4, filterq);
128      v_out_1 = generate_outputq(a1, a2, a3, a4, a5, filterq);
129      v_out_2 = generate_outputq(a2, a3, a4, a5, a6, filterq);
130      v_out_3 = generate_outputq(a3, a4, a5, a6, a7, filterq);
131
132      vst1q_u8(dst, v_out_0);
133      dst += dst_stride;
134      vst1q_u8(dst, v_out_1);
135      dst += dst_stride;
136      vst1q_u8(dst, v_out_2);
137      dst += dst_stride;
138      vst1q_u8(dst, v_out_3);
139      dst += dst_stride;
140
141      // Rotate over to the next slot.
142      a0 = a4;
143      a1 = a5;
144      a2 = a6;
145      a3 = a7;
146    }
147
148    src_ptr += 16;
149    dst_ptr += 16;
150  }
151
152  // Clean up any left over column of length 8.
153  if (col != cols) {
154    uint8x8_t a0, a1, a2, a3, a4, a5, a6, a7;
155    src = src_ptr - 2 * src_stride;
156    dst = dst_ptr;
157
158    a0 = vld1_u8(src);
159    src += src_stride;
160    a1 = vld1_u8(src);
161    src += src_stride;
162    a2 = vld1_u8(src);
163    src += src_stride;
164    a3 = vld1_u8(src);
165    src += src_stride;
166
167    for (row = 0; row < size; row += 4) {
168      uint8x8_t v_out_0, v_out_1, v_out_2, v_out_3;
169      const uint8x8_t filter = vld1_u8(f + col);
170
171      a4 = vld1_u8(src);
172      src += src_stride;
173      a5 = vld1_u8(src);
174      src += src_stride;
175      a6 = vld1_u8(src);
176      src += src_stride;
177      a7 = vld1_u8(src);
178      src += src_stride;
179
180      v_out_0 = generate_output(a0, a1, a2, a3, a4, filter);
181      v_out_1 = generate_output(a1, a2, a3, a4, a5, filter);
182      v_out_2 = generate_output(a2, a3, a4, a5, a6, filter);
183      v_out_3 = generate_output(a3, a4, a5, a6, a7, filter);
184
185      vst1_u8(dst, v_out_0);
186      dst += dst_stride;
187      vst1_u8(dst, v_out_1);
188      dst += dst_stride;
189      vst1_u8(dst, v_out_2);
190      dst += dst_stride;
191      vst1_u8(dst, v_out_3);
192      dst += dst_stride;
193
194      // Rotate over to the next slot.
195      a0 = a4;
196      a1 = a5;
197      a2 = a6;
198      a3 = a7;
199    }
200
201    // Not strictly necessary but makes resetting dst_ptr easier.
202    dst_ptr += 8;
203  }
204
205  dst_ptr -= cols;
206
207  for (row = 0; row < size; row += 8) {
208    uint8x8_t a0, a1, a2, a3;
209    uint8x8_t b0, b1, b2, b3, b4, b5, b6, b7;
210
211    src = dst_ptr;
212    dst = dst_ptr;
213
214    // Load 8 values, transpose 4 of them, and discard 2 because they will be
215    // reloaded later.
216    load_and_transpose_u8_4x8(src, dst_stride, &a0, &a1, &a2, &a3);
217    a3 = a1;
218    a2 = a1 = a0;  // Extend left border.
219
220    src += 2;
221
222    for (col = 0; col < cols; col += 8) {
223      uint8x8_t v_out_0, v_out_1, v_out_2, v_out_3, v_out_4, v_out_5, v_out_6,
224          v_out_7;
225      // Although the filter is meant to be applied vertically and is instead
226      // being applied horizontally here it's OK because it's set in blocks of 8
227      // (or 16).
228      const uint8x8_t filter = vld1_u8(f + col);
229
230      load_and_transpose_u8_8x8(src, dst_stride, &b0, &b1, &b2, &b3, &b4, &b5,
231                                &b6, &b7);
232
233      if (col + 8 == cols) {
234        // Last row. Extend border (b5).
235        b6 = b7 = b5;
236      }
237
238      v_out_0 = generate_output(a0, a1, a2, a3, b0, filter);
239      v_out_1 = generate_output(a1, a2, a3, b0, b1, filter);
240      v_out_2 = generate_output(a2, a3, b0, b1, b2, filter);
241      v_out_3 = generate_output(a3, b0, b1, b2, b3, filter);
242      v_out_4 = generate_output(b0, b1, b2, b3, b4, filter);
243      v_out_5 = generate_output(b1, b2, b3, b4, b5, filter);
244      v_out_6 = generate_output(b2, b3, b4, b5, b6, filter);
245      v_out_7 = generate_output(b3, b4, b5, b6, b7, filter);
246
247      transpose_and_store_u8_8x8(dst, dst_stride, v_out_0, v_out_1, v_out_2,
248                                 v_out_3, v_out_4, v_out_5, v_out_6, v_out_7);
249
250      a0 = b4;
251      a1 = b5;
252      a2 = b6;
253      a3 = b7;
254
255      src += 8;
256      dst += 8;
257    }
258
259    dst_ptr += 8 * dst_stride;
260  }
261}
262
263// sum += x;
264// sumsq += x * y;
265static void accumulate_sum_sumsq(const int16x4_t x, const int32x4_t xy,
266                                 int16x4_t *const sum, int32x4_t *const sumsq) {
267  const int16x4_t zero = vdup_n_s16(0);
268  const int32x4_t zeroq = vdupq_n_s32(0);
269
270  // Add in the first set because vext doesn't work with '0'.
271  *sum = vadd_s16(*sum, x);
272  *sumsq = vaddq_s32(*sumsq, xy);
273
274  // Shift x and xy to the right and sum. vext requires an immediate.
275  *sum = vadd_s16(*sum, vext_s16(zero, x, 1));
276  *sumsq = vaddq_s32(*sumsq, vextq_s32(zeroq, xy, 1));
277
278  *sum = vadd_s16(*sum, vext_s16(zero, x, 2));
279  *sumsq = vaddq_s32(*sumsq, vextq_s32(zeroq, xy, 2));
280
281  *sum = vadd_s16(*sum, vext_s16(zero, x, 3));
282  *sumsq = vaddq_s32(*sumsq, vextq_s32(zeroq, xy, 3));
283}
284
285// Generate mask based on (sumsq * 15 - sum * sum < flimit)
286static uint16x4_t calculate_mask(const int16x4_t sum, const int32x4_t sumsq,
287                                 const int32x4_t f, const int32x4_t fifteen) {
288  const int32x4_t a = vmulq_s32(sumsq, fifteen);
289  const int32x4_t b = vmlsl_s16(a, sum, sum);
290  const uint32x4_t mask32 = vcltq_s32(b, f);
291  return vmovn_u32(mask32);
292}
293
294static uint8x8_t combine_mask(const int16x4_t sum_low, const int16x4_t sum_high,
295                              const int32x4_t sumsq_low,
296                              const int32x4_t sumsq_high, const int32x4_t f) {
297  const int32x4_t fifteen = vdupq_n_s32(15);
298  const uint16x4_t mask16_low = calculate_mask(sum_low, sumsq_low, f, fifteen);
299  const uint16x4_t mask16_high =
300      calculate_mask(sum_high, sumsq_high, f, fifteen);
301  return vmovn_u16(vcombine_u16(mask16_low, mask16_high));
302}
303
304// Apply filter of (8 + sum + s[c]) >> 4.
305static uint8x8_t filter_pixels(const int16x8_t sum, const uint8x8_t s) {
306  const int16x8_t s16 = vreinterpretq_s16_u16(vmovl_u8(s));
307  const int16x8_t sum_s = vaddq_s16(sum, s16);
308
309  return vqrshrun_n_s16(sum_s, 4);
310}
311
312void vpx_mbpost_proc_across_ip_neon(uint8_t *src, int pitch, int rows, int cols,
313                                    int flimit) {
314  int row, col;
315  const int32x4_t f = vdupq_n_s32(flimit);
316
317  assert(cols % 8 == 0);
318
319  for (row = 0; row < rows; ++row) {
320    // Sum the first 8 elements, which are extended from s[0].
321    // sumsq gets primed with +16.
322    int sumsq = src[0] * src[0] * 9 + 16;
323    int sum = src[0] * 9;
324
325    uint8x8_t left_context, s, right_context;
326    int16x4_t sum_low, sum_high;
327    int32x4_t sumsq_low, sumsq_high;
328
329    // Sum (+square) the next 6 elements.
330    // Skip [0] because it's included above.
331    for (col = 1; col <= 6; ++col) {
332      sumsq += src[col] * src[col];
333      sum += src[col];
334    }
335
336    // Prime the sums. Later the loop uses the _high values to prime the new
337    // vectors.
338    sumsq_high = vdupq_n_s32(sumsq);
339    sum_high = vdup_n_s16(sum);
340
341    // Manually extend the left border.
342    left_context = vdup_n_u8(src[0]);
343
344    for (col = 0; col < cols; col += 8) {
345      uint8x8_t mask, output;
346      int16x8_t x, y;
347      int32x4_t xy_low, xy_high;
348
349      s = vld1_u8(src + col);
350
351      if (col + 8 == cols) {
352        // Last row. Extend border.
353        right_context = vdup_n_u8(src[col + 7]);
354      } else {
355        right_context = vld1_u8(src + col + 7);
356      }
357
358      x = vreinterpretq_s16_u16(vsubl_u8(right_context, left_context));
359      y = vreinterpretq_s16_u16(vaddl_u8(right_context, left_context));
360      xy_low = vmull_s16(vget_low_s16(x), vget_low_s16(y));
361      xy_high = vmull_s16(vget_high_s16(x), vget_high_s16(y));
362
363      // Catch up to the last sum'd value.
364      sum_low = vdup_lane_s16(sum_high, 3);
365      sumsq_low = vdupq_lane_s32(vget_high_s32(sumsq_high), 1);
366
367      accumulate_sum_sumsq(vget_low_s16(x), xy_low, &sum_low, &sumsq_low);
368
369      // Need to do this sequentially because we need the max value from
370      // sum_low.
371      sum_high = vdup_lane_s16(sum_low, 3);
372      sumsq_high = vdupq_lane_s32(vget_high_s32(sumsq_low), 1);
373
374      accumulate_sum_sumsq(vget_high_s16(x), xy_high, &sum_high, &sumsq_high);
375
376      mask = combine_mask(sum_low, sum_high, sumsq_low, sumsq_high, f);
377
378      output = filter_pixels(vcombine_s16(sum_low, sum_high), s);
379      output = vbsl_u8(mask, output, s);
380
381      vst1_u8(src + col, output);
382
383      left_context = s;
384    }
385
386    src += pitch;
387  }
388}
389
390// Apply filter of (vpx_rv + sum + s[c]) >> 4.
391static uint8x8_t filter_pixels_rv(const int16x8_t sum, const uint8x8_t s,
392                                  const int16x8_t rv) {
393  const int16x8_t s16 = vreinterpretq_s16_u16(vmovl_u8(s));
394  const int16x8_t sum_s = vaddq_s16(sum, s16);
395  const int16x8_t rounded = vaddq_s16(sum_s, rv);
396
397  return vqshrun_n_s16(rounded, 4);
398}
399
400void vpx_mbpost_proc_down_neon(uint8_t *dst, int pitch, int rows, int cols,
401                               int flimit) {
402  int row, col, i;
403  const int32x4_t f = vdupq_n_s32(flimit);
404  uint8x8_t below_context = vdup_n_u8(0);
405
406  // 8 columns are processed at a time.
407  // If rows is less than 8 the bottom border extension fails.
408  assert(cols % 8 == 0);
409  assert(rows >= 8);
410
411  // Load and keep the first 8 values in memory. Process a vertical stripe that
412  // is 8 wide.
413  for (col = 0; col < cols; col += 8) {
414    uint8x8_t s, above_context[8];
415    int16x8_t sum, sum_tmp;
416    int32x4_t sumsq_low, sumsq_high;
417
418    // Load and extend the top border.
419    s = vld1_u8(dst);
420    for (i = 0; i < 8; i++) {
421      above_context[i] = s;
422    }
423
424    sum_tmp = vreinterpretq_s16_u16(vmovl_u8(s));
425
426    // sum * 9
427    sum = vmulq_n_s16(sum_tmp, 9);
428
429    // (sum * 9) * sum == sum * sum * 9
430    sumsq_low = vmull_s16(vget_low_s16(sum), vget_low_s16(sum_tmp));
431    sumsq_high = vmull_s16(vget_high_s16(sum), vget_high_s16(sum_tmp));
432
433    // Load and discard the next 6 values to prime sum and sumsq.
434    for (i = 1; i <= 6; ++i) {
435      const uint8x8_t a = vld1_u8(dst + i * pitch);
436      const int16x8_t b = vreinterpretq_s16_u16(vmovl_u8(a));
437      sum = vaddq_s16(sum, b);
438
439      sumsq_low = vmlal_s16(sumsq_low, vget_low_s16(b), vget_low_s16(b));
440      sumsq_high = vmlal_s16(sumsq_high, vget_high_s16(b), vget_high_s16(b));
441    }
442
443    for (row = 0; row < rows; ++row) {
444      uint8x8_t mask, output;
445      int16x8_t x, y;
446      int32x4_t xy_low, xy_high;
447
448      s = vld1_u8(dst + row * pitch);
449
450      // Extend the bottom border.
451      if (row + 7 < rows) {
452        below_context = vld1_u8(dst + (row + 7) * pitch);
453      }
454
455      x = vreinterpretq_s16_u16(vsubl_u8(below_context, above_context[0]));
456      y = vreinterpretq_s16_u16(vaddl_u8(below_context, above_context[0]));
457      xy_low = vmull_s16(vget_low_s16(x), vget_low_s16(y));
458      xy_high = vmull_s16(vget_high_s16(x), vget_high_s16(y));
459
460      sum = vaddq_s16(sum, x);
461
462      sumsq_low = vaddq_s32(sumsq_low, xy_low);
463      sumsq_high = vaddq_s32(sumsq_high, xy_high);
464
465      mask = combine_mask(vget_low_s16(sum), vget_high_s16(sum), sumsq_low,
466                          sumsq_high, f);
467
468      output = filter_pixels_rv(sum, s, vld1q_s16(vpx_rv + (row & 127)));
469      output = vbsl_u8(mask, output, s);
470
471      vst1_u8(dst + row * pitch, output);
472
473      above_context[0] = above_context[1];
474      above_context[1] = above_context[2];
475      above_context[2] = above_context[3];
476      above_context[3] = above_context[4];
477      above_context[4] = above_context[5];
478      above_context[5] = above_context[6];
479      above_context[6] = above_context[7];
480      above_context[7] = s;
481    }
482
483    dst += 8;
484  }
485}
486