1/*
2 *  Copyright (c) 2010 The WebM project authors. All Rights Reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include <immintrin.h>
12
13#include "./vpx_dsp_rtcd.h"
14#include "vpx_dsp/x86/convolve.h"
15#include "vpx_dsp/x86/convolve_avx2.h"
16#include "vpx_ports/mem.h"
17
18// filters for 16_h8
19DECLARE_ALIGNED(32, static const uint8_t, filt1_global_avx2[32]) = {
20  0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8,
21  0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8
22};
23
24DECLARE_ALIGNED(32, static const uint8_t, filt2_global_avx2[32]) = {
25  2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10,
26  2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10
27};
28
29DECLARE_ALIGNED(32, static const uint8_t, filt3_global_avx2[32]) = {
30  4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12,
31  4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12
32};
33
34DECLARE_ALIGNED(32, static const uint8_t, filt4_global_avx2[32]) = {
35  6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14,
36  6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14
37};
38
39static INLINE void vpx_filter_block1d16_h8_x_avx2(
40    const uint8_t *src_ptr, ptrdiff_t src_pixels_per_line, uint8_t *output_ptr,
41    ptrdiff_t output_pitch, uint32_t output_height, const int16_t *filter,
42    const int avg) {
43  __m128i outReg1, outReg2;
44  __m256i outReg32b1, outReg32b2;
45  unsigned int i;
46  ptrdiff_t src_stride, dst_stride;
47  __m256i f[4], filt[4], s[4];
48
49  shuffle_filter_avx2(filter, f);
50  filt[0] = _mm256_load_si256((__m256i const *)filt1_global_avx2);
51  filt[1] = _mm256_load_si256((__m256i const *)filt2_global_avx2);
52  filt[2] = _mm256_load_si256((__m256i const *)filt3_global_avx2);
53  filt[3] = _mm256_load_si256((__m256i const *)filt4_global_avx2);
54
55  // multiple the size of the source and destination stride by two
56  src_stride = src_pixels_per_line << 1;
57  dst_stride = output_pitch << 1;
58  for (i = output_height; i > 1; i -= 2) {
59    __m256i srcReg;
60
61    // load the 2 strides of source
62    srcReg =
63        _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(src_ptr - 3)));
64    srcReg = _mm256_inserti128_si256(
65        srcReg,
66        _mm_loadu_si128((const __m128i *)(src_ptr + src_pixels_per_line - 3)),
67        1);
68
69    // filter the source buffer
70    s[0] = _mm256_shuffle_epi8(srcReg, filt[0]);
71    s[1] = _mm256_shuffle_epi8(srcReg, filt[1]);
72    s[2] = _mm256_shuffle_epi8(srcReg, filt[2]);
73    s[3] = _mm256_shuffle_epi8(srcReg, filt[3]);
74    outReg32b1 = convolve8_16_avx2(s, f);
75
76    // reading 2 strides of the next 16 bytes
77    // (part of it was being read by earlier read)
78    srcReg =
79        _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)(src_ptr + 5)));
80    srcReg = _mm256_inserti128_si256(
81        srcReg,
82        _mm_loadu_si128((const __m128i *)(src_ptr + src_pixels_per_line + 5)),
83        1);
84
85    // filter the source buffer
86    s[0] = _mm256_shuffle_epi8(srcReg, filt[0]);
87    s[1] = _mm256_shuffle_epi8(srcReg, filt[1]);
88    s[2] = _mm256_shuffle_epi8(srcReg, filt[2]);
89    s[3] = _mm256_shuffle_epi8(srcReg, filt[3]);
90    outReg32b2 = convolve8_16_avx2(s, f);
91
92    // shrink to 8 bit each 16 bits, the low and high 64-bits of each lane
93    // contain the first and second convolve result respectively
94    outReg32b1 = _mm256_packus_epi16(outReg32b1, outReg32b2);
95
96    src_ptr += src_stride;
97
98    // average if necessary
99    outReg1 = _mm256_castsi256_si128(outReg32b1);
100    outReg2 = _mm256_extractf128_si256(outReg32b1, 1);
101    if (avg) {
102      outReg1 = _mm_avg_epu8(outReg1, _mm_load_si128((__m128i *)output_ptr));
103      outReg2 = _mm_avg_epu8(
104          outReg2, _mm_load_si128((__m128i *)(output_ptr + output_pitch)));
105    }
106
107    // save 16 bytes
108    _mm_store_si128((__m128i *)output_ptr, outReg1);
109
110    // save the next 16 bits
111    _mm_store_si128((__m128i *)(output_ptr + output_pitch), outReg2);
112
113    output_ptr += dst_stride;
114  }
115
116  // if the number of strides is odd.
117  // process only 16 bytes
118  if (i > 0) {
119    __m128i srcReg;
120
121    // load the first 16 bytes of the last row
122    srcReg = _mm_loadu_si128((const __m128i *)(src_ptr - 3));
123
124    // filter the source buffer
125    s[0] = _mm256_castsi128_si256(
126        _mm_shuffle_epi8(srcReg, _mm256_castsi256_si128(filt[0])));
127    s[1] = _mm256_castsi128_si256(
128        _mm_shuffle_epi8(srcReg, _mm256_castsi256_si128(filt[1])));
129    s[2] = _mm256_castsi128_si256(
130        _mm_shuffle_epi8(srcReg, _mm256_castsi256_si128(filt[2])));
131    s[3] = _mm256_castsi128_si256(
132        _mm_shuffle_epi8(srcReg, _mm256_castsi256_si128(filt[3])));
133    outReg1 = convolve8_8_avx2(s, f);
134
135    // reading the next 16 bytes
136    // (part of it was being read by earlier read)
137    srcReg = _mm_loadu_si128((const __m128i *)(src_ptr + 5));
138
139    // filter the source buffer
140    s[0] = _mm256_castsi128_si256(
141        _mm_shuffle_epi8(srcReg, _mm256_castsi256_si128(filt[0])));
142    s[1] = _mm256_castsi128_si256(
143        _mm_shuffle_epi8(srcReg, _mm256_castsi256_si128(filt[1])));
144    s[2] = _mm256_castsi128_si256(
145        _mm_shuffle_epi8(srcReg, _mm256_castsi256_si128(filt[2])));
146    s[3] = _mm256_castsi128_si256(
147        _mm_shuffle_epi8(srcReg, _mm256_castsi256_si128(filt[3])));
148    outReg2 = convolve8_8_avx2(s, f);
149
150    // shrink to 8 bit each 16 bits, the low and high 64-bits of each lane
151    // contain the first and second convolve result respectively
152    outReg1 = _mm_packus_epi16(outReg1, outReg2);
153
154    // average if necessary
155    if (avg) {
156      outReg1 = _mm_avg_epu8(outReg1, _mm_load_si128((__m128i *)output_ptr));
157    }
158
159    // save 16 bytes
160    _mm_store_si128((__m128i *)output_ptr, outReg1);
161  }
162}
163
164static void vpx_filter_block1d16_h8_avx2(
165    const uint8_t *src_ptr, ptrdiff_t src_stride, uint8_t *output_ptr,
166    ptrdiff_t dst_stride, uint32_t output_height, const int16_t *filter) {
167  vpx_filter_block1d16_h8_x_avx2(src_ptr, src_stride, output_ptr, dst_stride,
168                                 output_height, filter, 0);
169}
170
171static void vpx_filter_block1d16_h8_avg_avx2(
172    const uint8_t *src_ptr, ptrdiff_t src_stride, uint8_t *output_ptr,
173    ptrdiff_t dst_stride, uint32_t output_height, const int16_t *filter) {
174  vpx_filter_block1d16_h8_x_avx2(src_ptr, src_stride, output_ptr, dst_stride,
175                                 output_height, filter, 1);
176}
177
178static INLINE void vpx_filter_block1d16_v8_x_avx2(
179    const uint8_t *src_ptr, ptrdiff_t src_pitch, uint8_t *output_ptr,
180    ptrdiff_t out_pitch, uint32_t output_height, const int16_t *filter,
181    const int avg) {
182  __m128i outReg1, outReg2;
183  __m256i srcRegHead1;
184  unsigned int i;
185  ptrdiff_t src_stride, dst_stride;
186  __m256i f[4], s1[4], s2[4];
187
188  shuffle_filter_avx2(filter, f);
189
190  // multiple the size of the source and destination stride by two
191  src_stride = src_pitch << 1;
192  dst_stride = out_pitch << 1;
193
194  {
195    __m128i s[6];
196    __m256i s32b[6];
197
198    // load 16 bytes 7 times in stride of src_pitch
199    s[0] = _mm_loadu_si128((const __m128i *)(src_ptr + 0 * src_pitch));
200    s[1] = _mm_loadu_si128((const __m128i *)(src_ptr + 1 * src_pitch));
201    s[2] = _mm_loadu_si128((const __m128i *)(src_ptr + 2 * src_pitch));
202    s[3] = _mm_loadu_si128((const __m128i *)(src_ptr + 3 * src_pitch));
203    s[4] = _mm_loadu_si128((const __m128i *)(src_ptr + 4 * src_pitch));
204    s[5] = _mm_loadu_si128((const __m128i *)(src_ptr + 5 * src_pitch));
205    srcRegHead1 = _mm256_castsi128_si256(
206        _mm_loadu_si128((const __m128i *)(src_ptr + 6 * src_pitch)));
207
208    // have each consecutive loads on the same 256 register
209    s32b[0] = _mm256_inserti128_si256(_mm256_castsi128_si256(s[0]), s[1], 1);
210    s32b[1] = _mm256_inserti128_si256(_mm256_castsi128_si256(s[1]), s[2], 1);
211    s32b[2] = _mm256_inserti128_si256(_mm256_castsi128_si256(s[2]), s[3], 1);
212    s32b[3] = _mm256_inserti128_si256(_mm256_castsi128_si256(s[3]), s[4], 1);
213    s32b[4] = _mm256_inserti128_si256(_mm256_castsi128_si256(s[4]), s[5], 1);
214    s32b[5] = _mm256_inserti128_si256(_mm256_castsi128_si256(s[5]),
215                                      _mm256_castsi256_si128(srcRegHead1), 1);
216
217    // merge every two consecutive registers except the last one
218    // the first lanes contain values for filtering odd rows (1,3,5...) and
219    // the second lanes contain values for filtering even rows (2,4,6...)
220    s1[0] = _mm256_unpacklo_epi8(s32b[0], s32b[1]);
221    s2[0] = _mm256_unpackhi_epi8(s32b[0], s32b[1]);
222    s1[1] = _mm256_unpacklo_epi8(s32b[2], s32b[3]);
223    s2[1] = _mm256_unpackhi_epi8(s32b[2], s32b[3]);
224    s1[2] = _mm256_unpacklo_epi8(s32b[4], s32b[5]);
225    s2[2] = _mm256_unpackhi_epi8(s32b[4], s32b[5]);
226  }
227
228  for (i = output_height; i > 1; i -= 2) {
229    __m256i srcRegHead2, srcRegHead3;
230
231    // load the next 2 loads of 16 bytes and have every two
232    // consecutive loads in the same 256 bit register
233    srcRegHead2 = _mm256_castsi128_si256(
234        _mm_loadu_si128((const __m128i *)(src_ptr + 7 * src_pitch)));
235    srcRegHead1 = _mm256_inserti128_si256(
236        srcRegHead1, _mm256_castsi256_si128(srcRegHead2), 1);
237    srcRegHead3 = _mm256_castsi128_si256(
238        _mm_loadu_si128((const __m128i *)(src_ptr + 8 * src_pitch)));
239    srcRegHead2 = _mm256_inserti128_si256(
240        srcRegHead2, _mm256_castsi256_si128(srcRegHead3), 1);
241
242    // merge the two new consecutive registers
243    // the first lane contain values for filtering odd rows (1,3,5...) and
244    // the second lane contain values for filtering even rows (2,4,6...)
245    s1[3] = _mm256_unpacklo_epi8(srcRegHead1, srcRegHead2);
246    s2[3] = _mm256_unpackhi_epi8(srcRegHead1, srcRegHead2);
247
248    s1[0] = convolve8_16_avx2(s1, f);
249    s2[0] = convolve8_16_avx2(s2, f);
250
251    // shrink to 8 bit each 16 bits, the low and high 64-bits of each lane
252    // contain the first and second convolve result respectively
253    s1[0] = _mm256_packus_epi16(s1[0], s2[0]);
254
255    src_ptr += src_stride;
256
257    // average if necessary
258    outReg1 = _mm256_castsi256_si128(s1[0]);
259    outReg2 = _mm256_extractf128_si256(s1[0], 1);
260    if (avg) {
261      outReg1 = _mm_avg_epu8(outReg1, _mm_load_si128((__m128i *)output_ptr));
262      outReg2 = _mm_avg_epu8(
263          outReg2, _mm_load_si128((__m128i *)(output_ptr + out_pitch)));
264    }
265
266    // save 16 bytes
267    _mm_store_si128((__m128i *)output_ptr, outReg1);
268
269    // save the next 16 bits
270    _mm_store_si128((__m128i *)(output_ptr + out_pitch), outReg2);
271
272    output_ptr += dst_stride;
273
274    // shift down by two rows
275    s1[0] = s1[1];
276    s2[0] = s2[1];
277    s1[1] = s1[2];
278    s2[1] = s2[2];
279    s1[2] = s1[3];
280    s2[2] = s2[3];
281    srcRegHead1 = srcRegHead3;
282  }
283
284  // if the number of strides is odd.
285  // process only 16 bytes
286  if (i > 0) {
287    // load the last 16 bytes
288    const __m128i srcRegHead2 =
289        _mm_loadu_si128((const __m128i *)(src_ptr + src_pitch * 7));
290
291    // merge the last 2 results together
292    s1[0] = _mm256_castsi128_si256(
293        _mm_unpacklo_epi8(_mm256_castsi256_si128(srcRegHead1), srcRegHead2));
294    s2[0] = _mm256_castsi128_si256(
295        _mm_unpackhi_epi8(_mm256_castsi256_si128(srcRegHead1), srcRegHead2));
296
297    outReg1 = convolve8_8_avx2(s1, f);
298    outReg2 = convolve8_8_avx2(s2, f);
299
300    // shrink to 8 bit each 16 bits, the low and high 64-bits of each lane
301    // contain the first and second convolve result respectively
302    outReg1 = _mm_packus_epi16(outReg1, outReg2);
303
304    // average if necessary
305    if (avg) {
306      outReg1 = _mm_avg_epu8(outReg1, _mm_load_si128((__m128i *)output_ptr));
307    }
308
309    // save 16 bytes
310    _mm_store_si128((__m128i *)output_ptr, outReg1);
311  }
312}
313
314static void vpx_filter_block1d16_v8_avx2(const uint8_t *src_ptr,
315                                         ptrdiff_t src_stride, uint8_t *dst_ptr,
316                                         ptrdiff_t dst_stride, uint32_t height,
317                                         const int16_t *filter) {
318  vpx_filter_block1d16_v8_x_avx2(src_ptr, src_stride, dst_ptr, dst_stride,
319                                 height, filter, 0);
320}
321
322static void vpx_filter_block1d16_v8_avg_avx2(
323    const uint8_t *src_ptr, ptrdiff_t src_stride, uint8_t *dst_ptr,
324    ptrdiff_t dst_stride, uint32_t height, const int16_t *filter) {
325  vpx_filter_block1d16_v8_x_avx2(src_ptr, src_stride, dst_ptr, dst_stride,
326                                 height, filter, 1);
327}
328
329#if HAVE_AVX2 && HAVE_SSSE3
330filter8_1dfunction vpx_filter_block1d4_v8_ssse3;
331#if ARCH_X86_64
332filter8_1dfunction vpx_filter_block1d8_v8_intrin_ssse3;
333filter8_1dfunction vpx_filter_block1d8_h8_intrin_ssse3;
334filter8_1dfunction vpx_filter_block1d4_h8_intrin_ssse3;
335#define vpx_filter_block1d8_v8_avx2 vpx_filter_block1d8_v8_intrin_ssse3
336#define vpx_filter_block1d8_h8_avx2 vpx_filter_block1d8_h8_intrin_ssse3
337#define vpx_filter_block1d4_h8_avx2 vpx_filter_block1d4_h8_intrin_ssse3
338#else  // ARCH_X86
339filter8_1dfunction vpx_filter_block1d8_v8_ssse3;
340filter8_1dfunction vpx_filter_block1d8_h8_ssse3;
341filter8_1dfunction vpx_filter_block1d4_h8_ssse3;
342#define vpx_filter_block1d8_v8_avx2 vpx_filter_block1d8_v8_ssse3
343#define vpx_filter_block1d8_h8_avx2 vpx_filter_block1d8_h8_ssse3
344#define vpx_filter_block1d4_h8_avx2 vpx_filter_block1d4_h8_ssse3
345#endif  // ARCH_X86_64
346filter8_1dfunction vpx_filter_block1d8_v8_avg_ssse3;
347filter8_1dfunction vpx_filter_block1d8_h8_avg_ssse3;
348filter8_1dfunction vpx_filter_block1d4_v8_avg_ssse3;
349filter8_1dfunction vpx_filter_block1d4_h8_avg_ssse3;
350#define vpx_filter_block1d8_v8_avg_avx2 vpx_filter_block1d8_v8_avg_ssse3
351#define vpx_filter_block1d8_h8_avg_avx2 vpx_filter_block1d8_h8_avg_ssse3
352#define vpx_filter_block1d4_v8_avg_avx2 vpx_filter_block1d4_v8_avg_ssse3
353#define vpx_filter_block1d4_h8_avg_avx2 vpx_filter_block1d4_h8_avg_ssse3
354filter8_1dfunction vpx_filter_block1d16_v2_ssse3;
355filter8_1dfunction vpx_filter_block1d16_h2_ssse3;
356filter8_1dfunction vpx_filter_block1d8_v2_ssse3;
357filter8_1dfunction vpx_filter_block1d8_h2_ssse3;
358filter8_1dfunction vpx_filter_block1d4_v2_ssse3;
359filter8_1dfunction vpx_filter_block1d4_h2_ssse3;
360#define vpx_filter_block1d4_v8_avx2 vpx_filter_block1d4_v8_ssse3
361#define vpx_filter_block1d16_v2_avx2 vpx_filter_block1d16_v2_ssse3
362#define vpx_filter_block1d16_h2_avx2 vpx_filter_block1d16_h2_ssse3
363#define vpx_filter_block1d8_v2_avx2 vpx_filter_block1d8_v2_ssse3
364#define vpx_filter_block1d8_h2_avx2 vpx_filter_block1d8_h2_ssse3
365#define vpx_filter_block1d4_v2_avx2 vpx_filter_block1d4_v2_ssse3
366#define vpx_filter_block1d4_h2_avx2 vpx_filter_block1d4_h2_ssse3
367filter8_1dfunction vpx_filter_block1d16_v2_avg_ssse3;
368filter8_1dfunction vpx_filter_block1d16_h2_avg_ssse3;
369filter8_1dfunction vpx_filter_block1d8_v2_avg_ssse3;
370filter8_1dfunction vpx_filter_block1d8_h2_avg_ssse3;
371filter8_1dfunction vpx_filter_block1d4_v2_avg_ssse3;
372filter8_1dfunction vpx_filter_block1d4_h2_avg_ssse3;
373#define vpx_filter_block1d16_v2_avg_avx2 vpx_filter_block1d16_v2_avg_ssse3
374#define vpx_filter_block1d16_h2_avg_avx2 vpx_filter_block1d16_h2_avg_ssse3
375#define vpx_filter_block1d8_v2_avg_avx2 vpx_filter_block1d8_v2_avg_ssse3
376#define vpx_filter_block1d8_h2_avg_avx2 vpx_filter_block1d8_h2_avg_ssse3
377#define vpx_filter_block1d4_v2_avg_avx2 vpx_filter_block1d4_v2_avg_ssse3
378#define vpx_filter_block1d4_h2_avg_avx2 vpx_filter_block1d4_h2_avg_ssse3
379// void vpx_convolve8_horiz_avx2(const uint8_t *src, ptrdiff_t src_stride,
380//                                uint8_t *dst, ptrdiff_t dst_stride,
381//                                const InterpKernel *filter, int x0_q4,
382//                                int32_t x_step_q4, int y0_q4, int y_step_q4,
383//                                int w, int h);
384// void vpx_convolve8_vert_avx2(const uint8_t *src, ptrdiff_t src_stride,
385//                               uint8_t *dst, ptrdiff_t dst_stride,
386//                               const InterpKernel *filter, int x0_q4,
387//                               int32_t x_step_q4, int y0_q4, int y_step_q4,
388//                               int w, int h);
389// void vpx_convolve8_avg_horiz_avx2(const uint8_t *src, ptrdiff_t src_stride,
390//                                    uint8_t *dst, ptrdiff_t dst_stride,
391//                                    const InterpKernel *filter, int x0_q4,
392//                                    int32_t x_step_q4, int y0_q4,
393//                                    int y_step_q4, int w, int h);
394// void vpx_convolve8_avg_vert_avx2(const uint8_t *src, ptrdiff_t src_stride,
395//                                   uint8_t *dst, ptrdiff_t dst_stride,
396//                                   const InterpKernel *filter, int x0_q4,
397//                                   int32_t x_step_q4, int y0_q4,
398//                                   int y_step_q4, int w, int h);
399FUN_CONV_1D(horiz, x0_q4, x_step_q4, h, src, , avx2);
400FUN_CONV_1D(vert, y0_q4, y_step_q4, v, src - src_stride * 3, , avx2);
401FUN_CONV_1D(avg_horiz, x0_q4, x_step_q4, h, src, avg_, avx2);
402FUN_CONV_1D(avg_vert, y0_q4, y_step_q4, v, src - src_stride * 3, avg_, avx2);
403
404// void vpx_convolve8_avx2(const uint8_t *src, ptrdiff_t src_stride,
405//                          uint8_t *dst, ptrdiff_t dst_stride,
406//                          const InterpKernel *filter, int x0_q4,
407//                          int32_t x_step_q4, int y0_q4, int y_step_q4,
408//                          int w, int h);
409// void vpx_convolve8_avg_avx2(const uint8_t *src, ptrdiff_t src_stride,
410//                              uint8_t *dst, ptrdiff_t dst_stride,
411//                              const InterpKernel *filter, int x0_q4,
412//                              int32_t x_step_q4, int y0_q4, int y_step_q4,
413//                              int w, int h);
414FUN_CONV_2D(, avx2);
415FUN_CONV_2D(avg_, avx2);
416#endif  // HAVE_AX2 && HAVE_SSSE3
417