1// Copyright 2015 Google Inc. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// output_sse.h: optimized SSE4.2 specializations of the templates in output.h. 16 17#ifndef GEMMLOWP_INTERNAL_OUTPUT_SSE_H_ 18#define GEMMLOWP_INTERNAL_OUTPUT_SSE_H_ 19 20#include "output.h" 21 22#include <smmintrin.h> 23 24namespace gemmlowp { 25 26template <> 27struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 28 RegBufferInt32<4>> { 29 typedef RegBufferInt32<4> InputType; 30 typedef RegBufferUint8<4> OutputType; 31 32 typedef OutputStageSaturatingCastToUint8 OutputStage; 33 34 OutputStageEvalBufferImpl(const OutputStage&) {} 35 36 OutputType Eval(InputType input) const { 37 OutputType output; 38 __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[0]); 39 __m128i res_8 = _mm_packus_epi16(res_16, res_16); 40 output.reg[0] = _mm_cvtsi128_si32(res_8); 41 return output; 42 } 43}; 44 45template <> 46struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 47 RegBufferInt32<8>> { 48 typedef RegBufferInt32<8> InputType; 49 typedef RegBufferUint8<8> OutputType; 50 51 typedef OutputStageSaturatingCastToUint8 OutputStage; 52 53 OutputStageEvalBufferImpl(const OutputStage&) {} 54 55 OutputType Eval(InputType input) const { 56 OutputType output; 57 __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[1]); 58 __m128i res_8 = _mm_packus_epi16(res_16, res_16); 59 output.reg[0] = _mm_extract_epi32(res_8, 0); 60 output.reg[1] = _mm_extract_epi32(res_8, 1); 61 return output; 62 } 63}; 64 65template <> 66struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 67 RegBufferInt32<16>> { 68 typedef RegBufferInt32<16> InputType; 69 typedef RegBufferUint8<16> OutputType; 70 71 typedef OutputStageSaturatingCastToUint8 OutputStage; 72 73 OutputStageEvalBufferImpl(const OutputStage&) {} 74 75 OutputType Eval(InputType input) const { 76 OutputType output; 77 __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]); 78 __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]); 79 output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1); 80 return output; 81 } 82}; 83 84template <> 85struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 86 RegBufferInt32<32>> { 87 typedef RegBufferInt32<32> InputType; 88 typedef RegBufferUint8<32> OutputType; 89 90 typedef OutputStageSaturatingCastToUint8 OutputStage; 91 92 OutputStageEvalBufferImpl(const OutputStage&) {} 93 94 OutputType Eval(InputType input) const { 95 OutputType output; 96 __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]); 97 __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]); 98 output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1); 99 __m128i res_16_2 = _mm_packs_epi32(input.reg[4], input.reg[5]); 100 __m128i res_16_3 = _mm_packs_epi32(input.reg[6], input.reg[7]); 101 output.reg[1] = _mm_packus_epi16(res_16_2, res_16_3); 102 return output; 103 } 104}; 105 106template <typename DstType> 107struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> { 108 static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row, 109 int col) { 110 if (DstType::kOrder == MapOrder::ColMajor) { 111 StoreInt32x4(dst->data(row, col), src.buf.reg[0]); 112 } else { 113 *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]); 114 *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]); 115 *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]); 116 *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]); 117 } 118 } 119}; 120 121template <typename DstType> 122struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> { 123 static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row, 124 int col) { 125 if (DstType::kOrder == MapOrder::ColMajor) { 126 StoreInt32x4(dst->data(row, col), src.buf.reg[0]); 127 StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]); 128 } else { 129 *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]); 130 *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]); 131 *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]); 132 *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]); 133 *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]); 134 *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]); 135 *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]); 136 *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]); 137 } 138 } 139}; 140 141inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) { 142 __m128i t0 = _mm_unpacklo_epi32(src.buf.reg[0], src.buf.reg[1]); 143 __m128i t1 = _mm_unpacklo_epi32(src.buf.reg[2], src.buf.reg[3]); 144 __m128i t2 = _mm_unpackhi_epi32(src.buf.reg[0], src.buf.reg[1]); 145 __m128i t3 = _mm_unpackhi_epi32(src.buf.reg[2], src.buf.reg[3]); 146 147 RegBlockInt32<4, 4> result; 148 result.buf.reg[0] = _mm_unpacklo_epi64(t0, t1); 149 result.buf.reg[1] = _mm_unpackhi_epi64(t0, t1); 150 result.buf.reg[2] = _mm_unpacklo_epi64(t2, t3); 151 result.buf.reg[3] = _mm_unpackhi_epi64(t2, t3); 152 return result; 153} 154 155template <typename DstType> 156struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> { 157 static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row, 158 int col) { 159 if (DstType::kOrder == MapOrder::ColMajor) { 160 for (int i = 0; i < 4; i++) { 161 StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]); 162 } 163 } else { 164 const auto transpose = Transpose(src); 165 for (int i = 0; i < 4; i++) { 166 StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]); 167 } 168 } 169 } 170}; 171 172template <typename DstType> 173struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> { 174 static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row, 175 int col) { 176 if (DstType::kOrder == MapOrder::ColMajor) { 177 for (int i = 0; i < 4; i++) { 178 StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]); 179 StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]); 180 } 181 } else { 182 RegBlockInt32<4, 4> top; 183 top.buf.reg[0] = src.buf.reg[0]; 184 top.buf.reg[1] = src.buf.reg[2]; 185 top.buf.reg[2] = src.buf.reg[4]; 186 top.buf.reg[3] = src.buf.reg[6]; 187 const auto transpose_top = Transpose(top); 188 for (int i = 0; i < 4; i++) { 189 StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]); 190 } 191 RegBlockInt32<4, 4> bottom; 192 bottom.buf.reg[0] = src.buf.reg[1]; 193 bottom.buf.reg[1] = src.buf.reg[3]; 194 bottom.buf.reg[2] = src.buf.reg[5]; 195 bottom.buf.reg[3] = src.buf.reg[7]; 196 const auto transpose_bottom = Transpose(bottom); 197 for (int i = 0; i < 4; i++) { 198 StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]); 199 } 200 } 201 } 202}; 203 204template <typename DstType> 205struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> { 206 static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row, 207 int col) { 208 if (DstType::kOrder == MapOrder::ColMajor) { 209 for (int i = 0; i < 8; i++) { 210 StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]); 211 StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]); 212 } 213 } else { 214 RegBlockInt32<4, 4> top_left; 215 top_left.buf.reg[0] = src.buf.reg[0]; 216 top_left.buf.reg[1] = src.buf.reg[2]; 217 top_left.buf.reg[2] = src.buf.reg[4]; 218 top_left.buf.reg[3] = src.buf.reg[6]; 219 const auto transpose_top_left = Transpose(top_left); 220 for (int i = 0; i < 4; i++) { 221 StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]); 222 } 223 RegBlockInt32<4, 4> bottom_left; 224 bottom_left.buf.reg[0] = src.buf.reg[1]; 225 bottom_left.buf.reg[1] = src.buf.reg[3]; 226 bottom_left.buf.reg[2] = src.buf.reg[5]; 227 bottom_left.buf.reg[3] = src.buf.reg[7]; 228 const auto transpose_bottom_left = Transpose(bottom_left); 229 for (int i = 0; i < 4; i++) { 230 StoreInt32x4(dst->data(row + 4 + i, col), 231 transpose_bottom_left.buf.reg[i]); 232 } 233 RegBlockInt32<4, 4> top_right; 234 top_right.buf.reg[0] = src.buf.reg[8]; 235 top_right.buf.reg[1] = src.buf.reg[10]; 236 top_right.buf.reg[2] = src.buf.reg[12]; 237 top_right.buf.reg[3] = src.buf.reg[14]; 238 const auto transpose_top_right = Transpose(top_right); 239 for (int i = 0; i < 4; i++) { 240 StoreInt32x4(dst->data(row + i, col + 4), 241 transpose_top_right.buf.reg[i]); 242 } 243 RegBlockInt32<4, 4> bottom_right; 244 bottom_right.buf.reg[0] = src.buf.reg[9]; 245 bottom_right.buf.reg[1] = src.buf.reg[11]; 246 bottom_right.buf.reg[2] = src.buf.reg[13]; 247 bottom_right.buf.reg[3] = src.buf.reg[15]; 248 const auto transpose_bottom_right = Transpose(bottom_right); 249 for (int i = 0; i < 4; i++) { 250 StoreInt32x4(dst->data(row + 4 + i, col + 4), 251 transpose_bottom_right.buf.reg[i]); 252 } 253 } 254 } 255}; 256 257template <typename DstType> 258struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> { 259 static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row, 260 int col) { 261 if (DstType::kOrder == MapOrder::ColMajor) { 262 *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]); 263 *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]); 264 *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]); 265 *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]); 266 } else { 267 StoreInt32x4(dst->data(row, col), src.buf.reg[0]); 268 } 269 } 270}; 271 272template <typename DstType> 273struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> { 274 static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row, 275 int col) { 276 const std::uint32_t src_reg = src.buf.reg[0]; 277 for (int i = 0; i < 4; i++) { 278 *dst->data(row + i, col) = (src_reg >> (8 * i)); 279 } 280 } 281}; 282 283template <typename DstType> 284struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> { 285 static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row, 286 int col) { 287 for (int i = 0; i < 4; i++) { 288 *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i)); 289 } 290 for (int i = 0; i < 4; i++) { 291 *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i)); 292 } 293 } 294}; 295 296template <typename DstType> 297struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> { 298 static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row, 299 int col) { 300 for (int i = 0; i < 4; i++) { 301 *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i)); 302 } 303 } 304}; 305 306template <typename DstType> 307struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> { 308 static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row, 309 int col) { 310 std::uint8_t buf[16]; 311 StoreUint8x16(buf, src.buf.reg[0]); 312 for (int c = 0; c < 4; c++) { 313 for (int r = 0; r < 4; r++) { 314 *dst->data(row + r, col + c) = buf[r + 4 * c]; 315 } 316 } 317 } 318}; 319 320template <typename DstType> 321struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> { 322 static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row, 323 int col) { 324 std::uint8_t buf[32]; 325 StoreUint8x16(buf, src.buf.reg[0]); 326 StoreUint8x16(buf + 16, src.buf.reg[1]); 327 for (int c = 0; c < 4; c++) { 328 for (int r = 0; r < 8; r++) { 329 *dst->data(row + r, col + c) = buf[r + 8 * c]; 330 } 331 } 332 } 333}; 334 335template <typename DstType> 336struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> { 337 static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row, 338 int col) { 339 std::uint8_t buf[64]; 340 StoreUint8x16(buf, src.buf.reg[0]); 341 StoreUint8x16(buf + 16, src.buf.reg[1]); 342 StoreUint8x16(buf + 32, src.buf.reg[2]); 343 StoreUint8x16(buf + 48, src.buf.reg[3]); 344 for (int c = 0; c < 8; c++) { 345 for (int r = 0; r < 8; r++) { 346 *dst->data(row + r, col + c) = buf[r + 8 * c]; 347 } 348 } 349 } 350}; 351 352} // namespace gemmlowp 353 354#endif // GEMMLOWP_INTERNAL_OUTPUT_SSE_H_ 355