1/*
2 *  Copyright (c) 2012 The WebM project authors. All Rights Reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include <immintrin.h>  // AVX2
12
13void vp9_get16x16var_avx2(const unsigned char *src_ptr,
14                          int source_stride,
15                          const unsigned char *ref_ptr,
16                          int recon_stride,
17                          unsigned int *SSE,
18                          int *Sum) {
19    __m256i src, src_expand_low, src_expand_high, ref, ref_expand_low;
20    __m256i ref_expand_high, madd_low, madd_high;
21    unsigned int i, src_2strides, ref_2strides;
22    __m256i zero_reg = _mm256_set1_epi16(0);
23    __m256i sum_ref_src = _mm256_set1_epi16(0);
24    __m256i madd_ref_src = _mm256_set1_epi16(0);
25
26    // processing two strides in a 256 bit register reducing the number
27    // of loop stride by half (comparing to the sse2 code)
28    src_2strides = source_stride << 1;
29    ref_2strides = recon_stride << 1;
30    for (i = 0; i < 8; i++) {
31        src = _mm256_castsi128_si256(
32              _mm_loadu_si128((__m128i const *) (src_ptr)));
33        src = _mm256_inserti128_si256(src,
34              _mm_loadu_si128((__m128i const *)(src_ptr+source_stride)), 1);
35
36        ref =_mm256_castsi128_si256(
37             _mm_loadu_si128((__m128i const *) (ref_ptr)));
38        ref = _mm256_inserti128_si256(ref,
39              _mm_loadu_si128((__m128i const *)(ref_ptr+recon_stride)), 1);
40
41        // expanding to 16 bit each lane
42        src_expand_low = _mm256_unpacklo_epi8(src, zero_reg);
43        src_expand_high = _mm256_unpackhi_epi8(src, zero_reg);
44
45        ref_expand_low = _mm256_unpacklo_epi8(ref, zero_reg);
46        ref_expand_high = _mm256_unpackhi_epi8(ref, zero_reg);
47
48        // src-ref
49        src_expand_low = _mm256_sub_epi16(src_expand_low, ref_expand_low);
50        src_expand_high = _mm256_sub_epi16(src_expand_high, ref_expand_high);
51
52        // madd low (src - ref)
53        madd_low = _mm256_madd_epi16(src_expand_low, src_expand_low);
54
55        // add high to low
56        src_expand_low = _mm256_add_epi16(src_expand_low, src_expand_high);
57
58        // madd high (src - ref)
59        madd_high = _mm256_madd_epi16(src_expand_high, src_expand_high);
60
61        sum_ref_src = _mm256_add_epi16(sum_ref_src, src_expand_low);
62
63        // add high to low
64        madd_ref_src = _mm256_add_epi32(madd_ref_src,
65                       _mm256_add_epi32(madd_low, madd_high));
66
67        src_ptr+= src_2strides;
68        ref_ptr+= ref_2strides;
69    }
70
71    {
72        __m128i sum_res, madd_res;
73        __m128i expand_sum_low, expand_sum_high, expand_sum;
74        __m128i expand_madd_low, expand_madd_high, expand_madd;
75        __m128i ex_expand_sum_low, ex_expand_sum_high, ex_expand_sum;
76
77        // extract the low lane and add it to the high lane
78        sum_res = _mm_add_epi16(_mm256_castsi256_si128(sum_ref_src),
79                                _mm256_extractf128_si256(sum_ref_src, 1));
80
81        madd_res = _mm_add_epi32(_mm256_castsi256_si128(madd_ref_src),
82                                 _mm256_extractf128_si256(madd_ref_src, 1));
83
84        // padding each 2 bytes with another 2 zeroed bytes
85        expand_sum_low = _mm_unpacklo_epi16(_mm256_castsi256_si128(zero_reg),
86                                            sum_res);
87        expand_sum_high = _mm_unpackhi_epi16(_mm256_castsi256_si128(zero_reg),
88                                             sum_res);
89
90        // shifting the sign 16 bits right
91        expand_sum_low = _mm_srai_epi32(expand_sum_low, 16);
92        expand_sum_high = _mm_srai_epi32(expand_sum_high, 16);
93
94        expand_sum = _mm_add_epi32(expand_sum_low, expand_sum_high);
95
96        // expand each 32 bits of the madd result to 64 bits
97        expand_madd_low = _mm_unpacklo_epi32(madd_res,
98                          _mm256_castsi256_si128(zero_reg));
99        expand_madd_high = _mm_unpackhi_epi32(madd_res,
100                           _mm256_castsi256_si128(zero_reg));
101
102        expand_madd = _mm_add_epi32(expand_madd_low, expand_madd_high);
103
104        ex_expand_sum_low = _mm_unpacklo_epi32(expand_sum,
105                            _mm256_castsi256_si128(zero_reg));
106        ex_expand_sum_high = _mm_unpackhi_epi32(expand_sum,
107                             _mm256_castsi256_si128(zero_reg));
108
109        ex_expand_sum = _mm_add_epi32(ex_expand_sum_low, ex_expand_sum_high);
110
111        // shift 8 bytes eight
112        madd_res = _mm_srli_si128(expand_madd, 8);
113        sum_res = _mm_srli_si128(ex_expand_sum, 8);
114
115        madd_res = _mm_add_epi32(madd_res, expand_madd);
116        sum_res = _mm_add_epi32(sum_res, ex_expand_sum);
117
118        *((int*)SSE)= _mm_cvtsi128_si32(madd_res);
119
120        *((int*)Sum)= _mm_cvtsi128_si32(sum_res);
121    }
122}
123
124void vp9_get32x32var_avx2(const unsigned char *src_ptr,
125                          int source_stride,
126                          const unsigned char *ref_ptr,
127                          int recon_stride,
128                          unsigned int *SSE,
129                          int *Sum) {
130    __m256i src, src_expand_low, src_expand_high, ref, ref_expand_low;
131    __m256i ref_expand_high, madd_low, madd_high;
132    unsigned int i;
133    __m256i zero_reg = _mm256_set1_epi16(0);
134    __m256i sum_ref_src = _mm256_set1_epi16(0);
135    __m256i madd_ref_src = _mm256_set1_epi16(0);
136
137    // processing 32 elements in parallel
138    for (i = 0; i < 16; i++) {
139       src = _mm256_loadu_si256((__m256i const *) (src_ptr));
140
141       ref = _mm256_loadu_si256((__m256i const *) (ref_ptr));
142
143       // expanding to 16 bit each lane
144       src_expand_low = _mm256_unpacklo_epi8(src, zero_reg);
145       src_expand_high = _mm256_unpackhi_epi8(src, zero_reg);
146
147       ref_expand_low = _mm256_unpacklo_epi8(ref, zero_reg);
148       ref_expand_high = _mm256_unpackhi_epi8(ref, zero_reg);
149
150       // src-ref
151       src_expand_low = _mm256_sub_epi16(src_expand_low, ref_expand_low);
152       src_expand_high = _mm256_sub_epi16(src_expand_high, ref_expand_high);
153
154       // madd low (src - ref)
155       madd_low = _mm256_madd_epi16(src_expand_low, src_expand_low);
156
157       // add high to low
158       src_expand_low = _mm256_add_epi16(src_expand_low, src_expand_high);
159
160       // madd high (src - ref)
161       madd_high = _mm256_madd_epi16(src_expand_high, src_expand_high);
162
163       sum_ref_src = _mm256_add_epi16(sum_ref_src, src_expand_low);
164
165       // add high to low
166       madd_ref_src = _mm256_add_epi32(madd_ref_src,
167                      _mm256_add_epi32(madd_low, madd_high));
168
169       src_ptr+= source_stride;
170       ref_ptr+= recon_stride;
171    }
172
173    {
174      __m256i expand_sum_low, expand_sum_high, expand_sum;
175      __m256i expand_madd_low, expand_madd_high, expand_madd;
176      __m256i ex_expand_sum_low, ex_expand_sum_high, ex_expand_sum;
177
178      // padding each 2 bytes with another 2 zeroed bytes
179      expand_sum_low = _mm256_unpacklo_epi16(zero_reg, sum_ref_src);
180      expand_sum_high = _mm256_unpackhi_epi16(zero_reg, sum_ref_src);
181
182      // shifting the sign 16 bits right
183      expand_sum_low = _mm256_srai_epi32(expand_sum_low, 16);
184      expand_sum_high = _mm256_srai_epi32(expand_sum_high, 16);
185
186      expand_sum = _mm256_add_epi32(expand_sum_low, expand_sum_high);
187
188      // expand each 32 bits of the madd result to 64 bits
189      expand_madd_low = _mm256_unpacklo_epi32(madd_ref_src, zero_reg);
190      expand_madd_high = _mm256_unpackhi_epi32(madd_ref_src, zero_reg);
191
192      expand_madd = _mm256_add_epi32(expand_madd_low, expand_madd_high);
193
194      ex_expand_sum_low = _mm256_unpacklo_epi32(expand_sum, zero_reg);
195      ex_expand_sum_high = _mm256_unpackhi_epi32(expand_sum, zero_reg);
196
197      ex_expand_sum = _mm256_add_epi32(ex_expand_sum_low, ex_expand_sum_high);
198
199      // shift 8 bytes eight
200      madd_ref_src = _mm256_srli_si256(expand_madd, 8);
201      sum_ref_src = _mm256_srli_si256(ex_expand_sum, 8);
202
203      madd_ref_src = _mm256_add_epi32(madd_ref_src, expand_madd);
204      sum_ref_src = _mm256_add_epi32(sum_ref_src, ex_expand_sum);
205
206      // extract the low lane and the high lane and add the results
207      *((int*)SSE)= _mm_cvtsi128_si32(_mm256_castsi256_si128(madd_ref_src)) +
208      _mm_cvtsi128_si32(_mm256_extractf128_si256(madd_ref_src, 1));
209
210      *((int*)Sum)= _mm_cvtsi128_si32(_mm256_castsi256_si128(sum_ref_src)) +
211      _mm_cvtsi128_si32(_mm256_extractf128_si256(sum_ref_src, 1));
212    }
213}
214