1/* 2 * Copyright (c) 2016 The WebM project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11#include <arm_neon.h> 12#include <assert.h> 13 14#include "./vpx_dsp_rtcd.h" 15#include "vpx/vpx_integer.h" 16#include "vpx_dsp/arm/transpose_neon.h" 17 18extern const int16_t vpx_rv[]; 19 20static uint8x8_t average_k_out(const uint8x8_t a2, const uint8x8_t a1, 21 const uint8x8_t v0, const uint8x8_t b1, 22 const uint8x8_t b2) { 23 const uint8x8_t k1 = vrhadd_u8(a2, a1); 24 const uint8x8_t k2 = vrhadd_u8(b2, b1); 25 const uint8x8_t k3 = vrhadd_u8(k1, k2); 26 return vrhadd_u8(k3, v0); 27} 28 29static uint8x8_t generate_mask(const uint8x8_t a2, const uint8x8_t a1, 30 const uint8x8_t v0, const uint8x8_t b1, 31 const uint8x8_t b2, const uint8x8_t filter) { 32 const uint8x8_t a2_v0 = vabd_u8(a2, v0); 33 const uint8x8_t a1_v0 = vabd_u8(a1, v0); 34 const uint8x8_t b1_v0 = vabd_u8(b1, v0); 35 const uint8x8_t b2_v0 = vabd_u8(b2, v0); 36 37 uint8x8_t max = vmax_u8(a2_v0, a1_v0); 38 max = vmax_u8(b1_v0, max); 39 max = vmax_u8(b2_v0, max); 40 return vclt_u8(max, filter); 41} 42 43static uint8x8_t generate_output(const uint8x8_t a2, const uint8x8_t a1, 44 const uint8x8_t v0, const uint8x8_t b1, 45 const uint8x8_t b2, const uint8x8_t filter) { 46 const uint8x8_t k_out = average_k_out(a2, a1, v0, b1, b2); 47 const uint8x8_t mask = generate_mask(a2, a1, v0, b1, b2, filter); 48 49 return vbsl_u8(mask, k_out, v0); 50} 51 52// Same functions but for uint8x16_t. 53static uint8x16_t average_k_outq(const uint8x16_t a2, const uint8x16_t a1, 54 const uint8x16_t v0, const uint8x16_t b1, 55 const uint8x16_t b2) { 56 const uint8x16_t k1 = vrhaddq_u8(a2, a1); 57 const uint8x16_t k2 = vrhaddq_u8(b2, b1); 58 const uint8x16_t k3 = vrhaddq_u8(k1, k2); 59 return vrhaddq_u8(k3, v0); 60} 61 62static uint8x16_t generate_maskq(const uint8x16_t a2, const uint8x16_t a1, 63 const uint8x16_t v0, const uint8x16_t b1, 64 const uint8x16_t b2, const uint8x16_t filter) { 65 const uint8x16_t a2_v0 = vabdq_u8(a2, v0); 66 const uint8x16_t a1_v0 = vabdq_u8(a1, v0); 67 const uint8x16_t b1_v0 = vabdq_u8(b1, v0); 68 const uint8x16_t b2_v0 = vabdq_u8(b2, v0); 69 70 uint8x16_t max = vmaxq_u8(a2_v0, a1_v0); 71 max = vmaxq_u8(b1_v0, max); 72 max = vmaxq_u8(b2_v0, max); 73 return vcltq_u8(max, filter); 74} 75 76static uint8x16_t generate_outputq(const uint8x16_t a2, const uint8x16_t a1, 77 const uint8x16_t v0, const uint8x16_t b1, 78 const uint8x16_t b2, 79 const uint8x16_t filter) { 80 const uint8x16_t k_out = average_k_outq(a2, a1, v0, b1, b2); 81 const uint8x16_t mask = generate_maskq(a2, a1, v0, b1, b2, filter); 82 83 return vbslq_u8(mask, k_out, v0); 84} 85 86void vpx_post_proc_down_and_across_mb_row_neon(uint8_t *src_ptr, 87 uint8_t *dst_ptr, int src_stride, 88 int dst_stride, int cols, 89 uint8_t *f, int size) { 90 uint8_t *src, *dst; 91 int row; 92 int col; 93 94 // Process a stripe of macroblocks. The stripe will be a multiple of 16 (for 95 // Y) or 8 (for U/V) wide (cols) and the height (size) will be 16 (for Y) or 8 96 // (for U/V). 97 assert((size == 8 || size == 16) && cols % 8 == 0); 98 99 // While columns of length 16 can be processed, load them. 100 for (col = 0; col < cols - 8; col += 16) { 101 uint8x16_t a0, a1, a2, a3, a4, a5, a6, a7; 102 src = src_ptr - 2 * src_stride; 103 dst = dst_ptr; 104 105 a0 = vld1q_u8(src); 106 src += src_stride; 107 a1 = vld1q_u8(src); 108 src += src_stride; 109 a2 = vld1q_u8(src); 110 src += src_stride; 111 a3 = vld1q_u8(src); 112 src += src_stride; 113 114 for (row = 0; row < size; row += 4) { 115 uint8x16_t v_out_0, v_out_1, v_out_2, v_out_3; 116 const uint8x16_t filterq = vld1q_u8(f + col); 117 118 a4 = vld1q_u8(src); 119 src += src_stride; 120 a5 = vld1q_u8(src); 121 src += src_stride; 122 a6 = vld1q_u8(src); 123 src += src_stride; 124 a7 = vld1q_u8(src); 125 src += src_stride; 126 127 v_out_0 = generate_outputq(a0, a1, a2, a3, a4, filterq); 128 v_out_1 = generate_outputq(a1, a2, a3, a4, a5, filterq); 129 v_out_2 = generate_outputq(a2, a3, a4, a5, a6, filterq); 130 v_out_3 = generate_outputq(a3, a4, a5, a6, a7, filterq); 131 132 vst1q_u8(dst, v_out_0); 133 dst += dst_stride; 134 vst1q_u8(dst, v_out_1); 135 dst += dst_stride; 136 vst1q_u8(dst, v_out_2); 137 dst += dst_stride; 138 vst1q_u8(dst, v_out_3); 139 dst += dst_stride; 140 141 // Rotate over to the next slot. 142 a0 = a4; 143 a1 = a5; 144 a2 = a6; 145 a3 = a7; 146 } 147 148 src_ptr += 16; 149 dst_ptr += 16; 150 } 151 152 // Clean up any left over column of length 8. 153 if (col != cols) { 154 uint8x8_t a0, a1, a2, a3, a4, a5, a6, a7; 155 src = src_ptr - 2 * src_stride; 156 dst = dst_ptr; 157 158 a0 = vld1_u8(src); 159 src += src_stride; 160 a1 = vld1_u8(src); 161 src += src_stride; 162 a2 = vld1_u8(src); 163 src += src_stride; 164 a3 = vld1_u8(src); 165 src += src_stride; 166 167 for (row = 0; row < size; row += 4) { 168 uint8x8_t v_out_0, v_out_1, v_out_2, v_out_3; 169 const uint8x8_t filter = vld1_u8(f + col); 170 171 a4 = vld1_u8(src); 172 src += src_stride; 173 a5 = vld1_u8(src); 174 src += src_stride; 175 a6 = vld1_u8(src); 176 src += src_stride; 177 a7 = vld1_u8(src); 178 src += src_stride; 179 180 v_out_0 = generate_output(a0, a1, a2, a3, a4, filter); 181 v_out_1 = generate_output(a1, a2, a3, a4, a5, filter); 182 v_out_2 = generate_output(a2, a3, a4, a5, a6, filter); 183 v_out_3 = generate_output(a3, a4, a5, a6, a7, filter); 184 185 vst1_u8(dst, v_out_0); 186 dst += dst_stride; 187 vst1_u8(dst, v_out_1); 188 dst += dst_stride; 189 vst1_u8(dst, v_out_2); 190 dst += dst_stride; 191 vst1_u8(dst, v_out_3); 192 dst += dst_stride; 193 194 // Rotate over to the next slot. 195 a0 = a4; 196 a1 = a5; 197 a2 = a6; 198 a3 = a7; 199 } 200 201 // Not strictly necessary but makes resetting dst_ptr easier. 202 dst_ptr += 8; 203 } 204 205 dst_ptr -= cols; 206 207 for (row = 0; row < size; row += 8) { 208 uint8x8_t a0, a1, a2, a3; 209 uint8x8_t b0, b1, b2, b3, b4, b5, b6, b7; 210 211 src = dst_ptr; 212 dst = dst_ptr; 213 214 // Load 8 values, transpose 4 of them, and discard 2 because they will be 215 // reloaded later. 216 load_and_transpose_u8_4x8(src, dst_stride, &a0, &a1, &a2, &a3); 217 a3 = a1; 218 a2 = a1 = a0; // Extend left border. 219 220 src += 2; 221 222 for (col = 0; col < cols; col += 8) { 223 uint8x8_t v_out_0, v_out_1, v_out_2, v_out_3, v_out_4, v_out_5, v_out_6, 224 v_out_7; 225 // Although the filter is meant to be applied vertically and is instead 226 // being applied horizontally here it's OK because it's set in blocks of 8 227 // (or 16). 228 const uint8x8_t filter = vld1_u8(f + col); 229 230 load_and_transpose_u8_8x8(src, dst_stride, &b0, &b1, &b2, &b3, &b4, &b5, 231 &b6, &b7); 232 233 if (col + 8 == cols) { 234 // Last row. Extend border (b5). 235 b6 = b7 = b5; 236 } 237 238 v_out_0 = generate_output(a0, a1, a2, a3, b0, filter); 239 v_out_1 = generate_output(a1, a2, a3, b0, b1, filter); 240 v_out_2 = generate_output(a2, a3, b0, b1, b2, filter); 241 v_out_3 = generate_output(a3, b0, b1, b2, b3, filter); 242 v_out_4 = generate_output(b0, b1, b2, b3, b4, filter); 243 v_out_5 = generate_output(b1, b2, b3, b4, b5, filter); 244 v_out_6 = generate_output(b2, b3, b4, b5, b6, filter); 245 v_out_7 = generate_output(b3, b4, b5, b6, b7, filter); 246 247 transpose_and_store_u8_8x8(dst, dst_stride, v_out_0, v_out_1, v_out_2, 248 v_out_3, v_out_4, v_out_5, v_out_6, v_out_7); 249 250 a0 = b4; 251 a1 = b5; 252 a2 = b6; 253 a3 = b7; 254 255 src += 8; 256 dst += 8; 257 } 258 259 dst_ptr += 8 * dst_stride; 260 } 261} 262 263// sum += x; 264// sumsq += x * y; 265static void accumulate_sum_sumsq(const int16x4_t x, const int32x4_t xy, 266 int16x4_t *const sum, int32x4_t *const sumsq) { 267 const int16x4_t zero = vdup_n_s16(0); 268 const int32x4_t zeroq = vdupq_n_s32(0); 269 270 // Add in the first set because vext doesn't work with '0'. 271 *sum = vadd_s16(*sum, x); 272 *sumsq = vaddq_s32(*sumsq, xy); 273 274 // Shift x and xy to the right and sum. vext requires an immediate. 275 *sum = vadd_s16(*sum, vext_s16(zero, x, 1)); 276 *sumsq = vaddq_s32(*sumsq, vextq_s32(zeroq, xy, 1)); 277 278 *sum = vadd_s16(*sum, vext_s16(zero, x, 2)); 279 *sumsq = vaddq_s32(*sumsq, vextq_s32(zeroq, xy, 2)); 280 281 *sum = vadd_s16(*sum, vext_s16(zero, x, 3)); 282 *sumsq = vaddq_s32(*sumsq, vextq_s32(zeroq, xy, 3)); 283} 284 285// Generate mask based on (sumsq * 15 - sum * sum < flimit) 286static uint16x4_t calculate_mask(const int16x4_t sum, const int32x4_t sumsq, 287 const int32x4_t f, const int32x4_t fifteen) { 288 const int32x4_t a = vmulq_s32(sumsq, fifteen); 289 const int32x4_t b = vmlsl_s16(a, sum, sum); 290 const uint32x4_t mask32 = vcltq_s32(b, f); 291 return vmovn_u32(mask32); 292} 293 294static uint8x8_t combine_mask(const int16x4_t sum_low, const int16x4_t sum_high, 295 const int32x4_t sumsq_low, 296 const int32x4_t sumsq_high, const int32x4_t f) { 297 const int32x4_t fifteen = vdupq_n_s32(15); 298 const uint16x4_t mask16_low = calculate_mask(sum_low, sumsq_low, f, fifteen); 299 const uint16x4_t mask16_high = 300 calculate_mask(sum_high, sumsq_high, f, fifteen); 301 return vmovn_u16(vcombine_u16(mask16_low, mask16_high)); 302} 303 304// Apply filter of (8 + sum + s[c]) >> 4. 305static uint8x8_t filter_pixels(const int16x8_t sum, const uint8x8_t s) { 306 const int16x8_t s16 = vreinterpretq_s16_u16(vmovl_u8(s)); 307 const int16x8_t sum_s = vaddq_s16(sum, s16); 308 309 return vqrshrun_n_s16(sum_s, 4); 310} 311 312void vpx_mbpost_proc_across_ip_neon(uint8_t *src, int pitch, int rows, int cols, 313 int flimit) { 314 int row, col; 315 const int32x4_t f = vdupq_n_s32(flimit); 316 317 assert(cols % 8 == 0); 318 319 for (row = 0; row < rows; ++row) { 320 // Sum the first 8 elements, which are extended from s[0]. 321 // sumsq gets primed with +16. 322 int sumsq = src[0] * src[0] * 9 + 16; 323 int sum = src[0] * 9; 324 325 uint8x8_t left_context, s, right_context; 326 int16x4_t sum_low, sum_high; 327 int32x4_t sumsq_low, sumsq_high; 328 329 // Sum (+square) the next 6 elements. 330 // Skip [0] because it's included above. 331 for (col = 1; col <= 6; ++col) { 332 sumsq += src[col] * src[col]; 333 sum += src[col]; 334 } 335 336 // Prime the sums. Later the loop uses the _high values to prime the new 337 // vectors. 338 sumsq_high = vdupq_n_s32(sumsq); 339 sum_high = vdup_n_s16(sum); 340 341 // Manually extend the left border. 342 left_context = vdup_n_u8(src[0]); 343 344 for (col = 0; col < cols; col += 8) { 345 uint8x8_t mask, output; 346 int16x8_t x, y; 347 int32x4_t xy_low, xy_high; 348 349 s = vld1_u8(src + col); 350 351 if (col + 8 == cols) { 352 // Last row. Extend border. 353 right_context = vdup_n_u8(src[col + 7]); 354 } else { 355 right_context = vld1_u8(src + col + 7); 356 } 357 358 x = vreinterpretq_s16_u16(vsubl_u8(right_context, left_context)); 359 y = vreinterpretq_s16_u16(vaddl_u8(right_context, left_context)); 360 xy_low = vmull_s16(vget_low_s16(x), vget_low_s16(y)); 361 xy_high = vmull_s16(vget_high_s16(x), vget_high_s16(y)); 362 363 // Catch up to the last sum'd value. 364 sum_low = vdup_lane_s16(sum_high, 3); 365 sumsq_low = vdupq_lane_s32(vget_high_s32(sumsq_high), 1); 366 367 accumulate_sum_sumsq(vget_low_s16(x), xy_low, &sum_low, &sumsq_low); 368 369 // Need to do this sequentially because we need the max value from 370 // sum_low. 371 sum_high = vdup_lane_s16(sum_low, 3); 372 sumsq_high = vdupq_lane_s32(vget_high_s32(sumsq_low), 1); 373 374 accumulate_sum_sumsq(vget_high_s16(x), xy_high, &sum_high, &sumsq_high); 375 376 mask = combine_mask(sum_low, sum_high, sumsq_low, sumsq_high, f); 377 378 output = filter_pixels(vcombine_s16(sum_low, sum_high), s); 379 output = vbsl_u8(mask, output, s); 380 381 vst1_u8(src + col, output); 382 383 left_context = s; 384 } 385 386 src += pitch; 387 } 388} 389 390// Apply filter of (vpx_rv + sum + s[c]) >> 4. 391static uint8x8_t filter_pixels_rv(const int16x8_t sum, const uint8x8_t s, 392 const int16x8_t rv) { 393 const int16x8_t s16 = vreinterpretq_s16_u16(vmovl_u8(s)); 394 const int16x8_t sum_s = vaddq_s16(sum, s16); 395 const int16x8_t rounded = vaddq_s16(sum_s, rv); 396 397 return vqshrun_n_s16(rounded, 4); 398} 399 400void vpx_mbpost_proc_down_neon(uint8_t *dst, int pitch, int rows, int cols, 401 int flimit) { 402 int row, col, i; 403 const int32x4_t f = vdupq_n_s32(flimit); 404 uint8x8_t below_context = vdup_n_u8(0); 405 406 // 8 columns are processed at a time. 407 // If rows is less than 8 the bottom border extension fails. 408 assert(cols % 8 == 0); 409 assert(rows >= 8); 410 411 // Load and keep the first 8 values in memory. Process a vertical stripe that 412 // is 8 wide. 413 for (col = 0; col < cols; col += 8) { 414 uint8x8_t s, above_context[8]; 415 int16x8_t sum, sum_tmp; 416 int32x4_t sumsq_low, sumsq_high; 417 418 // Load and extend the top border. 419 s = vld1_u8(dst); 420 for (i = 0; i < 8; i++) { 421 above_context[i] = s; 422 } 423 424 sum_tmp = vreinterpretq_s16_u16(vmovl_u8(s)); 425 426 // sum * 9 427 sum = vmulq_n_s16(sum_tmp, 9); 428 429 // (sum * 9) * sum == sum * sum * 9 430 sumsq_low = vmull_s16(vget_low_s16(sum), vget_low_s16(sum_tmp)); 431 sumsq_high = vmull_s16(vget_high_s16(sum), vget_high_s16(sum_tmp)); 432 433 // Load and discard the next 6 values to prime sum and sumsq. 434 for (i = 1; i <= 6; ++i) { 435 const uint8x8_t a = vld1_u8(dst + i * pitch); 436 const int16x8_t b = vreinterpretq_s16_u16(vmovl_u8(a)); 437 sum = vaddq_s16(sum, b); 438 439 sumsq_low = vmlal_s16(sumsq_low, vget_low_s16(b), vget_low_s16(b)); 440 sumsq_high = vmlal_s16(sumsq_high, vget_high_s16(b), vget_high_s16(b)); 441 } 442 443 for (row = 0; row < rows; ++row) { 444 uint8x8_t mask, output; 445 int16x8_t x, y; 446 int32x4_t xy_low, xy_high; 447 448 s = vld1_u8(dst + row * pitch); 449 450 // Extend the bottom border. 451 if (row + 7 < rows) { 452 below_context = vld1_u8(dst + (row + 7) * pitch); 453 } 454 455 x = vreinterpretq_s16_u16(vsubl_u8(below_context, above_context[0])); 456 y = vreinterpretq_s16_u16(vaddl_u8(below_context, above_context[0])); 457 xy_low = vmull_s16(vget_low_s16(x), vget_low_s16(y)); 458 xy_high = vmull_s16(vget_high_s16(x), vget_high_s16(y)); 459 460 sum = vaddq_s16(sum, x); 461 462 sumsq_low = vaddq_s32(sumsq_low, xy_low); 463 sumsq_high = vaddq_s32(sumsq_high, xy_high); 464 465 mask = combine_mask(vget_low_s16(sum), vget_high_s16(sum), sumsq_low, 466 sumsq_high, f); 467 468 output = filter_pixels_rv(sum, s, vld1q_s16(vpx_rv + (row & 127))); 469 output = vbsl_u8(mask, output, s); 470 471 vst1_u8(dst + row * pitch, output); 472 473 above_context[0] = above_context[1]; 474 above_context[1] = above_context[2]; 475 above_context[2] = above_context[3]; 476 above_context[3] = above_context[4]; 477 above_context[4] = above_context[5]; 478 above_context[5] = above_context[6]; 479 above_context[6] = above_context[7]; 480 above_context[7] = s; 481 } 482 483 dst += 8; 484 } 485} 486