1/* 2 * Copyright (c) 2014 The WebM project authors. All Rights Reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11#include <arm_neon.h> 12#include <assert.h> 13 14#include "./vpx_dsp_rtcd.h" 15#include "./vpx_config.h" 16 17#include "vpx/vpx_integer.h" 18#include "vpx_dsp/arm/mem_neon.h" 19#include "vpx_dsp/arm/sum_neon.h" 20#include "vpx_ports/mem.h" 21 22// The variance helper functions use int16_t for sum. 8 values are accumulated 23// and then added (at which point they expand up to int32_t). To avoid overflow, 24// there can be no more than 32767 / 255 ~= 128 values accumulated in each 25// column. For a 32x32 buffer, this results in 32 / 8 = 4 values per row * 32 26// rows = 128. Asserts have been added to each function to warn against reaching 27// this limit. 28 29// Process a block of width 4 four rows at a time. 30static void variance_neon_w4x4(const uint8_t *a, int a_stride, const uint8_t *b, 31 int b_stride, int h, uint32_t *sse, int *sum) { 32 int i; 33 int16x8_t sum_s16 = vdupq_n_s16(0); 34 int32x4_t sse_lo_s32 = vdupq_n_s32(0); 35 int32x4_t sse_hi_s32 = vdupq_n_s32(0); 36 37 // Since width is only 4, sum_s16 only loads a half row per loop. 38 assert(h <= 256); 39 40 for (i = 0; i < h; i += 4) { 41 const uint8x16_t a_u8 = load_unaligned_u8q(a, a_stride); 42 const uint8x16_t b_u8 = load_unaligned_u8q(b, b_stride); 43 const uint16x8_t diff_lo_u16 = 44 vsubl_u8(vget_low_u8(a_u8), vget_low_u8(b_u8)); 45 const uint16x8_t diff_hi_u16 = 46 vsubl_u8(vget_high_u8(a_u8), vget_high_u8(b_u8)); 47 48 const int16x8_t diff_lo_s16 = vreinterpretq_s16_u16(diff_lo_u16); 49 const int16x8_t diff_hi_s16 = vreinterpretq_s16_u16(diff_hi_u16); 50 51 sum_s16 = vaddq_s16(sum_s16, diff_lo_s16); 52 sum_s16 = vaddq_s16(sum_s16, diff_hi_s16); 53 54 sse_lo_s32 = vmlal_s16(sse_lo_s32, vget_low_s16(diff_lo_s16), 55 vget_low_s16(diff_lo_s16)); 56 sse_lo_s32 = vmlal_s16(sse_lo_s32, vget_high_s16(diff_lo_s16), 57 vget_high_s16(diff_lo_s16)); 58 59 sse_hi_s32 = vmlal_s16(sse_hi_s32, vget_low_s16(diff_hi_s16), 60 vget_low_s16(diff_hi_s16)); 61 sse_hi_s32 = vmlal_s16(sse_hi_s32, vget_high_s16(diff_hi_s16), 62 vget_high_s16(diff_hi_s16)); 63 64 a += 4 * a_stride; 65 b += 4 * b_stride; 66 } 67 68 *sum = vget_lane_s32(horizontal_add_int16x8(sum_s16), 0); 69 *sse = vget_lane_u32(horizontal_add_uint32x4(vreinterpretq_u32_s32( 70 vaddq_s32(sse_lo_s32, sse_hi_s32))), 71 0); 72} 73 74// Process a block of any size where the width is divisible by 16. 75static void variance_neon_w16(const uint8_t *a, int a_stride, const uint8_t *b, 76 int b_stride, int w, int h, uint32_t *sse, 77 int *sum) { 78 int i, j; 79 int16x8_t sum_s16 = vdupq_n_s16(0); 80 int32x4_t sse_lo_s32 = vdupq_n_s32(0); 81 int32x4_t sse_hi_s32 = vdupq_n_s32(0); 82 83 // The loop loads 16 values at a time but doubles them up when accumulating 84 // into sum_s16. 85 assert(w / 8 * h <= 128); 86 87 for (i = 0; i < h; ++i) { 88 for (j = 0; j < w; j += 16) { 89 const uint8x16_t a_u8 = vld1q_u8(a + j); 90 const uint8x16_t b_u8 = vld1q_u8(b + j); 91 92 const uint16x8_t diff_lo_u16 = 93 vsubl_u8(vget_low_u8(a_u8), vget_low_u8(b_u8)); 94 const uint16x8_t diff_hi_u16 = 95 vsubl_u8(vget_high_u8(a_u8), vget_high_u8(b_u8)); 96 97 const int16x8_t diff_lo_s16 = vreinterpretq_s16_u16(diff_lo_u16); 98 const int16x8_t diff_hi_s16 = vreinterpretq_s16_u16(diff_hi_u16); 99 100 sum_s16 = vaddq_s16(sum_s16, diff_lo_s16); 101 sum_s16 = vaddq_s16(sum_s16, diff_hi_s16); 102 103 sse_lo_s32 = vmlal_s16(sse_lo_s32, vget_low_s16(diff_lo_s16), 104 vget_low_s16(diff_lo_s16)); 105 sse_lo_s32 = vmlal_s16(sse_lo_s32, vget_high_s16(diff_lo_s16), 106 vget_high_s16(diff_lo_s16)); 107 108 sse_hi_s32 = vmlal_s16(sse_hi_s32, vget_low_s16(diff_hi_s16), 109 vget_low_s16(diff_hi_s16)); 110 sse_hi_s32 = vmlal_s16(sse_hi_s32, vget_high_s16(diff_hi_s16), 111 vget_high_s16(diff_hi_s16)); 112 } 113 a += a_stride; 114 b += b_stride; 115 } 116 117 *sum = vget_lane_s32(horizontal_add_int16x8(sum_s16), 0); 118 *sse = vget_lane_u32(horizontal_add_uint32x4(vreinterpretq_u32_s32( 119 vaddq_s32(sse_lo_s32, sse_hi_s32))), 120 0); 121} 122 123// Process a block of width 8 two rows at a time. 124static void variance_neon_w8x2(const uint8_t *a, int a_stride, const uint8_t *b, 125 int b_stride, int h, uint32_t *sse, int *sum) { 126 int i = 0; 127 int16x8_t sum_s16 = vdupq_n_s16(0); 128 int32x4_t sse_lo_s32 = vdupq_n_s32(0); 129 int32x4_t sse_hi_s32 = vdupq_n_s32(0); 130 131 // Each column has it's own accumulator entry in sum_s16. 132 assert(h <= 128); 133 134 do { 135 const uint8x8_t a_0_u8 = vld1_u8(a); 136 const uint8x8_t a_1_u8 = vld1_u8(a + a_stride); 137 const uint8x8_t b_0_u8 = vld1_u8(b); 138 const uint8x8_t b_1_u8 = vld1_u8(b + b_stride); 139 const uint16x8_t diff_0_u16 = vsubl_u8(a_0_u8, b_0_u8); 140 const uint16x8_t diff_1_u16 = vsubl_u8(a_1_u8, b_1_u8); 141 const int16x8_t diff_0_s16 = vreinterpretq_s16_u16(diff_0_u16); 142 const int16x8_t diff_1_s16 = vreinterpretq_s16_u16(diff_1_u16); 143 sum_s16 = vaddq_s16(sum_s16, diff_0_s16); 144 sum_s16 = vaddq_s16(sum_s16, diff_1_s16); 145 sse_lo_s32 = vmlal_s16(sse_lo_s32, vget_low_s16(diff_0_s16), 146 vget_low_s16(diff_0_s16)); 147 sse_lo_s32 = vmlal_s16(sse_lo_s32, vget_low_s16(diff_1_s16), 148 vget_low_s16(diff_1_s16)); 149 sse_hi_s32 = vmlal_s16(sse_hi_s32, vget_high_s16(diff_0_s16), 150 vget_high_s16(diff_0_s16)); 151 sse_hi_s32 = vmlal_s16(sse_hi_s32, vget_high_s16(diff_1_s16), 152 vget_high_s16(diff_1_s16)); 153 a += a_stride + a_stride; 154 b += b_stride + b_stride; 155 i += 2; 156 } while (i < h); 157 158 *sum = vget_lane_s32(horizontal_add_int16x8(sum_s16), 0); 159 *sse = vget_lane_u32(horizontal_add_uint32x4(vreinterpretq_u32_s32( 160 vaddq_s32(sse_lo_s32, sse_hi_s32))), 161 0); 162} 163 164void vpx_get8x8var_neon(const uint8_t *a, int a_stride, const uint8_t *b, 165 int b_stride, unsigned int *sse, int *sum) { 166 variance_neon_w8x2(a, a_stride, b, b_stride, 8, sse, sum); 167} 168 169void vpx_get16x16var_neon(const uint8_t *a, int a_stride, const uint8_t *b, 170 int b_stride, unsigned int *sse, int *sum) { 171 variance_neon_w16(a, a_stride, b, b_stride, 16, 16, sse, sum); 172} 173 174#define varianceNxM(n, m, shift) \ 175 unsigned int vpx_variance##n##x##m##_neon(const uint8_t *a, int a_stride, \ 176 const uint8_t *b, int b_stride, \ 177 unsigned int *sse) { \ 178 int sum; \ 179 if (n == 4) \ 180 variance_neon_w4x4(a, a_stride, b, b_stride, m, sse, &sum); \ 181 else if (n == 8) \ 182 variance_neon_w8x2(a, a_stride, b, b_stride, m, sse, &sum); \ 183 else \ 184 variance_neon_w16(a, a_stride, b, b_stride, n, m, sse, &sum); \ 185 if (n * m < 16 * 16) \ 186 return *sse - ((sum * sum) >> shift); \ 187 else \ 188 return *sse - (uint32_t)(((int64_t)sum * sum) >> shift); \ 189 } 190 191varianceNxM(4, 4, 4); 192varianceNxM(4, 8, 5); 193varianceNxM(8, 4, 5); 194varianceNxM(8, 8, 6); 195varianceNxM(8, 16, 7); 196varianceNxM(16, 8, 7); 197varianceNxM(16, 16, 8); 198varianceNxM(16, 32, 9); 199varianceNxM(32, 16, 9); 200varianceNxM(32, 32, 10); 201 202unsigned int vpx_variance32x64_neon(const uint8_t *a, int a_stride, 203 const uint8_t *b, int b_stride, 204 unsigned int *sse) { 205 int sum1, sum2; 206 uint32_t sse1, sse2; 207 variance_neon_w16(a, a_stride, b, b_stride, 32, 32, &sse1, &sum1); 208 variance_neon_w16(a + (32 * a_stride), a_stride, b + (32 * b_stride), 209 b_stride, 32, 32, &sse2, &sum2); 210 *sse = sse1 + sse2; 211 sum1 += sum2; 212 return *sse - (unsigned int)(((int64_t)sum1 * sum1) >> 11); 213} 214 215unsigned int vpx_variance64x32_neon(const uint8_t *a, int a_stride, 216 const uint8_t *b, int b_stride, 217 unsigned int *sse) { 218 int sum1, sum2; 219 uint32_t sse1, sse2; 220 variance_neon_w16(a, a_stride, b, b_stride, 64, 16, &sse1, &sum1); 221 variance_neon_w16(a + (16 * a_stride), a_stride, b + (16 * b_stride), 222 b_stride, 64, 16, &sse2, &sum2); 223 *sse = sse1 + sse2; 224 sum1 += sum2; 225 return *sse - (unsigned int)(((int64_t)sum1 * sum1) >> 11); 226} 227 228unsigned int vpx_variance64x64_neon(const uint8_t *a, int a_stride, 229 const uint8_t *b, int b_stride, 230 unsigned int *sse) { 231 int sum1, sum2; 232 uint32_t sse1, sse2; 233 234 variance_neon_w16(a, a_stride, b, b_stride, 64, 16, &sse1, &sum1); 235 variance_neon_w16(a + (16 * a_stride), a_stride, b + (16 * b_stride), 236 b_stride, 64, 16, &sse2, &sum2); 237 sse1 += sse2; 238 sum1 += sum2; 239 240 variance_neon_w16(a + (16 * 2 * a_stride), a_stride, b + (16 * 2 * b_stride), 241 b_stride, 64, 16, &sse2, &sum2); 242 sse1 += sse2; 243 sum1 += sum2; 244 245 variance_neon_w16(a + (16 * 3 * a_stride), a_stride, b + (16 * 3 * b_stride), 246 b_stride, 64, 16, &sse2, &sum2); 247 *sse = sse1 + sse2; 248 sum1 += sum2; 249 return *sse - (unsigned int)(((int64_t)sum1 * sum1) >> 12); 250} 251 252unsigned int vpx_mse16x16_neon(const unsigned char *src_ptr, int source_stride, 253 const unsigned char *ref_ptr, int recon_stride, 254 unsigned int *sse) { 255 int i; 256 int16x4_t d22s16, d23s16, d24s16, d25s16, d26s16, d27s16, d28s16, d29s16; 257 int64x1_t d0s64; 258 uint8x16_t q0u8, q1u8, q2u8, q3u8; 259 int32x4_t q7s32, q8s32, q9s32, q10s32; 260 uint16x8_t q11u16, q12u16, q13u16, q14u16; 261 int64x2_t q1s64; 262 263 q7s32 = vdupq_n_s32(0); 264 q8s32 = vdupq_n_s32(0); 265 q9s32 = vdupq_n_s32(0); 266 q10s32 = vdupq_n_s32(0); 267 268 for (i = 0; i < 8; i++) { // mse16x16_neon_loop 269 q0u8 = vld1q_u8(src_ptr); 270 src_ptr += source_stride; 271 q1u8 = vld1q_u8(src_ptr); 272 src_ptr += source_stride; 273 q2u8 = vld1q_u8(ref_ptr); 274 ref_ptr += recon_stride; 275 q3u8 = vld1q_u8(ref_ptr); 276 ref_ptr += recon_stride; 277 278 q11u16 = vsubl_u8(vget_low_u8(q0u8), vget_low_u8(q2u8)); 279 q12u16 = vsubl_u8(vget_high_u8(q0u8), vget_high_u8(q2u8)); 280 q13u16 = vsubl_u8(vget_low_u8(q1u8), vget_low_u8(q3u8)); 281 q14u16 = vsubl_u8(vget_high_u8(q1u8), vget_high_u8(q3u8)); 282 283 d22s16 = vreinterpret_s16_u16(vget_low_u16(q11u16)); 284 d23s16 = vreinterpret_s16_u16(vget_high_u16(q11u16)); 285 q7s32 = vmlal_s16(q7s32, d22s16, d22s16); 286 q8s32 = vmlal_s16(q8s32, d23s16, d23s16); 287 288 d24s16 = vreinterpret_s16_u16(vget_low_u16(q12u16)); 289 d25s16 = vreinterpret_s16_u16(vget_high_u16(q12u16)); 290 q9s32 = vmlal_s16(q9s32, d24s16, d24s16); 291 q10s32 = vmlal_s16(q10s32, d25s16, d25s16); 292 293 d26s16 = vreinterpret_s16_u16(vget_low_u16(q13u16)); 294 d27s16 = vreinterpret_s16_u16(vget_high_u16(q13u16)); 295 q7s32 = vmlal_s16(q7s32, d26s16, d26s16); 296 q8s32 = vmlal_s16(q8s32, d27s16, d27s16); 297 298 d28s16 = vreinterpret_s16_u16(vget_low_u16(q14u16)); 299 d29s16 = vreinterpret_s16_u16(vget_high_u16(q14u16)); 300 q9s32 = vmlal_s16(q9s32, d28s16, d28s16); 301 q10s32 = vmlal_s16(q10s32, d29s16, d29s16); 302 } 303 304 q7s32 = vaddq_s32(q7s32, q8s32); 305 q9s32 = vaddq_s32(q9s32, q10s32); 306 q10s32 = vaddq_s32(q7s32, q9s32); 307 308 q1s64 = vpaddlq_s32(q10s32); 309 d0s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64)); 310 311 vst1_lane_u32((uint32_t *)sse, vreinterpret_u32_s64(d0s64), 0); 312 return vget_lane_u32(vreinterpret_u32_s64(d0s64), 0); 313} 314 315unsigned int vpx_get4x4sse_cs_neon(const unsigned char *src_ptr, 316 int source_stride, 317 const unsigned char *ref_ptr, 318 int recon_stride) { 319 int16x4_t d22s16, d24s16, d26s16, d28s16; 320 int64x1_t d0s64; 321 uint8x8_t d0u8, d1u8, d2u8, d3u8, d4u8, d5u8, d6u8, d7u8; 322 int32x4_t q7s32, q8s32, q9s32, q10s32; 323 uint16x8_t q11u16, q12u16, q13u16, q14u16; 324 int64x2_t q1s64; 325 326 d0u8 = vld1_u8(src_ptr); 327 src_ptr += source_stride; 328 d4u8 = vld1_u8(ref_ptr); 329 ref_ptr += recon_stride; 330 d1u8 = vld1_u8(src_ptr); 331 src_ptr += source_stride; 332 d5u8 = vld1_u8(ref_ptr); 333 ref_ptr += recon_stride; 334 d2u8 = vld1_u8(src_ptr); 335 src_ptr += source_stride; 336 d6u8 = vld1_u8(ref_ptr); 337 ref_ptr += recon_stride; 338 d3u8 = vld1_u8(src_ptr); 339 src_ptr += source_stride; 340 d7u8 = vld1_u8(ref_ptr); 341 ref_ptr += recon_stride; 342 343 q11u16 = vsubl_u8(d0u8, d4u8); 344 q12u16 = vsubl_u8(d1u8, d5u8); 345 q13u16 = vsubl_u8(d2u8, d6u8); 346 q14u16 = vsubl_u8(d3u8, d7u8); 347 348 d22s16 = vget_low_s16(vreinterpretq_s16_u16(q11u16)); 349 d24s16 = vget_low_s16(vreinterpretq_s16_u16(q12u16)); 350 d26s16 = vget_low_s16(vreinterpretq_s16_u16(q13u16)); 351 d28s16 = vget_low_s16(vreinterpretq_s16_u16(q14u16)); 352 353 q7s32 = vmull_s16(d22s16, d22s16); 354 q8s32 = vmull_s16(d24s16, d24s16); 355 q9s32 = vmull_s16(d26s16, d26s16); 356 q10s32 = vmull_s16(d28s16, d28s16); 357 358 q7s32 = vaddq_s32(q7s32, q8s32); 359 q9s32 = vaddq_s32(q9s32, q10s32); 360 q9s32 = vaddq_s32(q7s32, q9s32); 361 362 q1s64 = vpaddlq_s32(q9s32); 363 d0s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64)); 364 365 return vget_lane_u32(vreinterpret_u32_s64(d0s64), 0); 366} 367