1// Copyright 2015 Google Inc. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// fixedpoint_SSE.h: optimized SSE specializations of the templates 16// in fixedpoint.h. 17 18#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_ 19#define GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_ 20 21#include <smmintrin.h> 22#include "fixedpoint.h" 23 24namespace gemmlowp { 25 26// SSE intrinsics are not finely typed: there is a single __m128i vector 27// type that does not distinguish between "int32x4" and "int16x8" use 28// cases, unlike the NEON equivalents. Because we had initially focused 29// on int32x4, we did not pay attention and specialized these fixedpoint 30// templates directly for __m128i hardcoding the int32x4 semantics, 31// not leaving room for int16x8 semantics. Amending that by adding a separate 32// data type, int16x8_m128i, that wraps __m128i while being a separate 33// type. 34struct int16x8_m128i { 35 int16x8_m128i() {} 36 explicit int16x8_m128i(__m128i w) : v(w) {} 37 ~int16x8_m128i() {} 38 39 __m128i v; 40}; 41 42template <> 43struct FixedPointRawTypeTraits<__m128i> { 44 typedef std::int32_t ScalarRawType; 45 static const int kLanes = 4; 46}; 47 48template <> 49struct FixedPointRawTypeTraits<int16x8_m128i> { 50 typedef std::int16_t ScalarRawType; 51 static const int kLanes = 8; 52}; 53 54template <> 55inline __m128i BitAnd(__m128i a, __m128i b) { 56 return _mm_and_si128(a, b); 57} 58 59template <> 60inline int16x8_m128i BitAnd(int16x8_m128i a, int16x8_m128i b) { 61 return int16x8_m128i(_mm_and_si128(a.v, b.v)); 62} 63 64template <> 65inline __m128i BitOr(__m128i a, __m128i b) { 66 return _mm_or_si128(a, b); 67} 68 69template <> 70inline int16x8_m128i BitOr(int16x8_m128i a, int16x8_m128i b) { 71 return int16x8_m128i(_mm_or_si128(a.v, b.v)); 72} 73 74template <> 75inline __m128i BitXor(__m128i a, __m128i b) { 76 return _mm_xor_si128(a, b); 77} 78 79template <> 80inline int16x8_m128i BitXor(int16x8_m128i a, int16x8_m128i b) { 81 return int16x8_m128i(_mm_xor_si128(a.v, b.v)); 82} 83 84template <> 85inline __m128i BitNot(__m128i a) { 86 return _mm_andnot_si128(a, _mm_set1_epi32(-1)); 87} 88 89template <> 90inline int16x8_m128i BitNot(int16x8_m128i a) { 91 return int16x8_m128i(_mm_andnot_si128(a.v, _mm_set1_epi16(-1))); 92} 93 94template <> 95inline __m128i Add(__m128i a, __m128i b) { 96 return _mm_add_epi32(a, b); 97} 98 99template <> 100inline int16x8_m128i Add(int16x8_m128i a, int16x8_m128i b) { 101 return int16x8_m128i(_mm_add_epi16(a.v, b.v)); 102} 103 104template <> 105inline __m128i Mul(__m128i a, __m128i b) { 106 return _mm_mullo_epi32(a, b); 107} 108 109template <> 110inline int16x8_m128i Mul(int16x8_m128i a, int16x8_m128i b) { 111 return int16x8_m128i(_mm_mullo_epi16(a.v, b.v)); 112} 113 114template <> 115inline __m128i Sub(__m128i a, __m128i b) { 116 return _mm_sub_epi32(a, b); 117} 118 119template <> 120inline int16x8_m128i Sub(int16x8_m128i a, int16x8_m128i b) { 121 return int16x8_m128i(_mm_sub_epi16(a.v, b.v)); 122} 123 124template <> 125inline __m128i Neg(__m128i a) { 126 return _mm_sign_epi32(a, _mm_set1_epi32(-1)); 127} 128 129template <> 130inline int16x8_m128i Neg(int16x8_m128i a) { 131 return int16x8_m128i(_mm_sign_epi16(a.v, _mm_set1_epi16(-1))); 132} 133 134template <> 135inline __m128i ShiftLeft(__m128i a, int offset) { 136 return _mm_slli_epi32(a, offset); 137} 138 139template <> 140inline int16x8_m128i ShiftLeft(int16x8_m128i a, int offset) { 141 return int16x8_m128i(_mm_slli_epi16(a.v, offset)); 142} 143 144template <> 145inline __m128i ShiftRight(__m128i a, int offset) { 146 return _mm_srai_epi32(a, offset); 147} 148 149template <> 150inline int16x8_m128i ShiftRight(int16x8_m128i a, int offset) { 151 return int16x8_m128i(_mm_srai_epi16(a.v, offset)); 152} 153 154template <> 155inline __m128i SelectUsingMask(__m128i if_mask, __m128i then_val, 156 __m128i else_val) { 157 // borrowed from Intel's arm_neon_sse.h header. 158 return _mm_or_si128(_mm_and_si128(if_mask, then_val), 159 _mm_andnot_si128(if_mask, else_val)); 160} 161 162template <> 163inline int16x8_m128i SelectUsingMask(int16x8_m128i if_mask, 164 int16x8_m128i then_val, 165 int16x8_m128i else_val) { 166 // borrowed from Intel's arm_neon_sse.h header. 167 return int16x8_m128i(SelectUsingMask(if_mask.v, then_val.v, else_val.v)); 168} 169 170template <> 171inline __m128i MaskIfEqual(__m128i a, __m128i b) { 172 return _mm_cmpeq_epi32(a, b); 173} 174 175template <> 176inline int16x8_m128i MaskIfEqual(int16x8_m128i a, int16x8_m128i b) { 177 return int16x8_m128i(_mm_cmpeq_epi16(a.v, b.v)); 178} 179 180template <> 181inline __m128i MaskIfNotEqual(__m128i a, __m128i b) { 182 return BitNot(MaskIfEqual(a, b)); 183} 184 185template <> 186inline int16x8_m128i MaskIfNotEqual(int16x8_m128i a, int16x8_m128i b) { 187 return BitNot(MaskIfEqual(a, b)); 188} 189 190template <> 191inline __m128i MaskIfZero(__m128i a) { 192 return MaskIfEqual(a, _mm_set1_epi32(0)); 193} 194 195template <> 196inline int16x8_m128i MaskIfZero(int16x8_m128i a) { 197 return MaskIfEqual(a, int16x8_m128i(_mm_set1_epi16(0))); 198} 199 200template <> 201inline __m128i MaskIfNonZero(__m128i a) { 202 return MaskIfNotEqual(a, _mm_set1_epi32(0)); 203} 204 205template <> 206inline int16x8_m128i MaskIfNonZero(int16x8_m128i a) { 207 return MaskIfNotEqual(a, int16x8_m128i(_mm_set1_epi16(0))); 208} 209 210template <> 211inline __m128i MaskIfGreaterThan(__m128i a, __m128i b) { 212 return _mm_cmpgt_epi32(a, b); 213} 214 215template <> 216inline int16x8_m128i MaskIfGreaterThan(int16x8_m128i a, int16x8_m128i b) { 217 return int16x8_m128i(_mm_cmpgt_epi16(a.v, b.v)); 218} 219 220template <> 221inline __m128i MaskIfLessThan(__m128i a, __m128i b) { 222 return _mm_cmplt_epi32(a, b); 223} 224 225template <> 226inline int16x8_m128i MaskIfLessThan(int16x8_m128i a, int16x8_m128i b) { 227 return int16x8_m128i(_mm_cmplt_epi16(a.v, b.v)); 228} 229 230template <> 231inline __m128i MaskIfGreaterThanOrEqual(__m128i a, __m128i b) { 232 return BitNot(MaskIfLessThan(a, b)); 233} 234 235template <> 236inline int16x8_m128i MaskIfGreaterThanOrEqual(int16x8_m128i a, 237 int16x8_m128i b) { 238 return BitNot(MaskIfLessThan(a, b)); 239} 240 241template <> 242inline __m128i MaskIfLessThanOrEqual(__m128i a, __m128i b) { 243 return BitNot(MaskIfGreaterThan(a, b)); 244} 245 246template <> 247inline int16x8_m128i MaskIfLessThanOrEqual(int16x8_m128i a, int16x8_m128i b) { 248 return BitNot(MaskIfGreaterThan(a, b)); 249} 250 251/* Assumptions: 252 - All and Any are used on masks. 253 - masks are all_ones for true lanes, all_zeroes otherwise. 254Hence, All means all 128bits set, and Any means any bit set. 255*/ 256 257template <> 258inline bool All(__m128i a) { 259 return _mm_testc_si128(a, a); 260} 261 262template <> 263inline bool All(int16x8_m128i a) { 264 return _mm_testc_si128(a.v, a.v); 265} 266 267template <> 268inline bool Any(__m128i a) { 269 return !_mm_testz_si128(a, a); 270} 271 272template <> 273inline bool Any(int16x8_m128i a) { 274 return !_mm_testz_si128(a.v, a.v); 275} 276 277template <> 278inline __m128i RoundingHalfSum(__m128i a, __m128i b) { 279 /* __m128i round_bit_mask, a_over_2, b_over_2, round_bit, sum; */ 280 /* We divide the inputs before the add to avoid the overflow and costly test 281 */ 282 /* of checking if an overflow occured on signed add */ 283 /* round_bit_mask = _mm_set1_epi32(1); */ 284 /* a_over_2 = _mm_srai_epi32(a, 1); */ 285 /* b_over_2 = _mm_srai_epi32(b, 1); */ 286 /* sum = Add(a_over_2, b_over_2); */ 287 /* round_bit = _mm_sign_epi32(BitAnd(BitOr(a,b), round_bit_mask), sum); */ 288 /* return Add(sum, round_bit); */ 289 290 /* Other possibility detecting overflow and xor the sign if an overflow 291 * happened*/ 292 __m128i one, sign_bit_mask, sum, rounded_half_sum, overflow, result; 293 one = _mm_set1_epi32(1); 294 sign_bit_mask = _mm_set1_epi32(0x80000000); 295 sum = Add(a, b); 296 rounded_half_sum = _mm_srai_epi32(Add(sum, one), 1); 297 overflow = 298 BitAnd(BitAnd(BitXor(a, rounded_half_sum), BitXor(b, rounded_half_sum)), 299 sign_bit_mask); 300 result = BitXor(rounded_half_sum, overflow); 301 return result; 302} 303 304template <> 305inline int16x8_m128i RoundingHalfSum(int16x8_m128i a, int16x8_m128i b) { 306 // Idea: go to unsigned to use _mm_avg_epu16, 307 // borrowed from Intel's arm_neon_sse.h header. 308 __m128i constant_neg_32768 = _mm_set1_epi16(-32768); 309 __m128i a_unsigned = _mm_sub_epi16(a.v, constant_neg_32768); 310 __m128i b_unsigned = _mm_sub_epi16(b.v, constant_neg_32768); 311 __m128i avg_unsigned = _mm_avg_epu16(a_unsigned, b_unsigned); 312 __m128i avg = _mm_add_epi16(avg_unsigned, constant_neg_32768); 313 return int16x8_m128i(avg); 314} 315 316template <> 317inline __m128i SaturatingRoundingDoublingHighMul(__m128i a, __m128i b) { 318 __m128i min, saturation_mask, a0_a2, a1_a3, b0_b2, b1_b3; 319 __m128i a0b0_a2b2, a1b1_a3b3, a0b0_a2b2_rounded, a1b1_a3b3_rounded; 320 __m128i a0b0_a2b2_rounded_2x, a1b1_a3b3_rounded_2x, result; 321 __m128i nudge; 322 323 // saturation only happen if a == b == INT_MIN 324 min = _mm_set1_epi32(std::numeric_limits<std::int32_t>::min()); 325 saturation_mask = BitAnd(MaskIfEqual(a, b), MaskIfEqual(a, min)); 326 327 // a = a0 | a1 | a2 | a3 328 // b = b0 | b1 | b2 | b3 329 a0_a2 = a; 330 a1_a3 = _mm_srli_si128(a, 4); 331 b0_b2 = b; 332 b1_b3 = _mm_srli_si128(b, 4); 333 334 a0b0_a2b2 = _mm_mul_epi32(a0_a2, b0_b2); 335 a1b1_a3b3 = _mm_mul_epi32(a1_a3, b1_b3); 336 337 // do the rounding and take into account that it will be doubled 338 nudge = _mm_set1_epi64x(1 << 30); 339 a0b0_a2b2_rounded = _mm_add_epi64(a0b0_a2b2, nudge); 340 a1b1_a3b3_rounded = _mm_add_epi64(a1b1_a3b3, nudge); 341 342 // do the doubling 343 a0b0_a2b2_rounded_2x = _mm_slli_epi64(a0b0_a2b2_rounded, 1); 344 a1b1_a3b3_rounded_2x = _mm_slli_epi64(a1b1_a3b3_rounded, 1); 345 346 // get the high part of the products 347 result = _mm_blend_epi16(_mm_srli_si128(a0b0_a2b2_rounded_2x, 4), 348 a1b1_a3b3_rounded_2x, 0xcc); 349 350 // saturate those which overflowed 351 return SelectUsingMask(saturation_mask, min, result); 352} 353 354template <> 355inline int16x8_m128i SaturatingRoundingDoublingHighMul(int16x8_m128i a, 356 int16x8_m128i b) { 357 // Idea: use _mm_mulhrs_epi16 then saturate with a bit-operation, 358 // borrowed from Intel's arm_neon_sse.h header. 359 __m128i result_unsaturated = _mm_mulhrs_epi16(a.v, b.v); 360 __m128i saturation_mask = 361 _mm_cmpeq_epi16(result_unsaturated, _mm_set1_epi16(0x8000)); 362 __m128i result = _mm_xor_si128(result_unsaturated, saturation_mask); 363 return int16x8_m128i(result); 364} 365 366template <> 367inline __m128i Dup<__m128i>(std::int32_t x) { 368 return _mm_set1_epi32(x); 369} 370 371template <> 372inline int16x8_m128i Dup<int16x8_m128i>(std::int16_t x) { 373 return int16x8_m128i(_mm_set1_epi16(x)); 374} 375 376// So far this is only needed for int16. 377template <> 378inline int16x8_m128i SaturatingAdd(int16x8_m128i a, int16x8_m128i b) { 379 return int16x8_m128i(_mm_adds_epi16(a.v, b.v)); 380} 381 382} // end namespace gemmlowp 383 384#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_SSE_H_ 385