1// Copyright 2015 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// output_sse.h: optimized SSE4.2 specializations of the templates in output.h.
16
17#ifndef GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
18#define GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
19
20#include "output.h"
21
22#include <smmintrin.h>
23
24namespace gemmlowp {
25
26template <>
27struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
28                                 RegBufferInt32<4>> {
29  typedef RegBufferInt32<4> InputType;
30  typedef RegBufferUint8<4> OutputType;
31
32  typedef OutputStageSaturatingCastToUint8 OutputStage;
33
34  OutputStageEvalBufferImpl(const OutputStage&) {}
35
36  OutputType Eval(InputType input) const {
37    OutputType output;
38    __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[0]);
39    __m128i res_8 = _mm_packus_epi16(res_16, res_16);
40    output.reg[0] = _mm_cvtsi128_si32(res_8);
41    return output;
42  }
43};
44
45template <>
46struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
47                                 RegBufferInt32<8>> {
48  typedef RegBufferInt32<8> InputType;
49  typedef RegBufferUint8<8> OutputType;
50
51  typedef OutputStageSaturatingCastToUint8 OutputStage;
52
53  OutputStageEvalBufferImpl(const OutputStage&) {}
54
55  OutputType Eval(InputType input) const {
56    OutputType output;
57    __m128i res_16 = _mm_packs_epi32(input.reg[0], input.reg[1]);
58    __m128i res_8 = _mm_packus_epi16(res_16, res_16);
59    output.reg[0] = _mm_extract_epi32(res_8, 0);
60    output.reg[1] = _mm_extract_epi32(res_8, 1);
61    return output;
62  }
63};
64
65template <>
66struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
67                                 RegBufferInt32<16>> {
68  typedef RegBufferInt32<16> InputType;
69  typedef RegBufferUint8<16> OutputType;
70
71  typedef OutputStageSaturatingCastToUint8 OutputStage;
72
73  OutputStageEvalBufferImpl(const OutputStage&) {}
74
75  OutputType Eval(InputType input) const {
76    OutputType output;
77    __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]);
78    __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]);
79    output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1);
80    return output;
81  }
82};
83
84template <>
85struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
86                                 RegBufferInt32<32>> {
87  typedef RegBufferInt32<32> InputType;
88  typedef RegBufferUint8<32> OutputType;
89
90  typedef OutputStageSaturatingCastToUint8 OutputStage;
91
92  OutputStageEvalBufferImpl(const OutputStage&) {}
93
94  OutputType Eval(InputType input) const {
95    OutputType output;
96    __m128i res_16_0 = _mm_packs_epi32(input.reg[0], input.reg[1]);
97    __m128i res_16_1 = _mm_packs_epi32(input.reg[2], input.reg[3]);
98    output.reg[0] = _mm_packus_epi16(res_16_0, res_16_1);
99    __m128i res_16_2 = _mm_packs_epi32(input.reg[4], input.reg[5]);
100    __m128i res_16_3 = _mm_packs_epi32(input.reg[6], input.reg[7]);
101    output.reg[1] = _mm_packus_epi16(res_16_2, res_16_3);
102    return output;
103  }
104};
105
106template <typename DstType>
107struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
108  static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
109                  int col) {
110    if (DstType::kOrder == MapOrder::ColMajor) {
111      StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
112    } else {
113      *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
114      *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
115      *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
116      *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
117    }
118  }
119};
120
121template <typename DstType>
122struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
123  static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
124                  int col) {
125    if (DstType::kOrder == MapOrder::ColMajor) {
126      StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
127      StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
128    } else {
129      *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
130      *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
131      *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
132      *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
133      *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]);
134      *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]);
135      *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]);
136      *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]);
137    }
138  }
139};
140
141inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
142  __m128i t0 = _mm_unpacklo_epi32(src.buf.reg[0], src.buf.reg[1]);
143  __m128i t1 = _mm_unpacklo_epi32(src.buf.reg[2], src.buf.reg[3]);
144  __m128i t2 = _mm_unpackhi_epi32(src.buf.reg[0], src.buf.reg[1]);
145  __m128i t3 = _mm_unpackhi_epi32(src.buf.reg[2], src.buf.reg[3]);
146
147  RegBlockInt32<4, 4> result;
148  result.buf.reg[0] = _mm_unpacklo_epi64(t0, t1);
149  result.buf.reg[1] = _mm_unpackhi_epi64(t0, t1);
150  result.buf.reg[2] = _mm_unpacklo_epi64(t2, t3);
151  result.buf.reg[3] = _mm_unpackhi_epi64(t2, t3);
152  return result;
153}
154
155template <typename DstType>
156struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
157  static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row,
158                  int col) {
159    if (DstType::kOrder == MapOrder::ColMajor) {
160      for (int i = 0; i < 4; i++) {
161        StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]);
162      }
163    } else {
164      const auto transpose = Transpose(src);
165      for (int i = 0; i < 4; i++) {
166        StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]);
167      }
168    }
169  }
170};
171
172template <typename DstType>
173struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
174  static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
175                  int col) {
176    if (DstType::kOrder == MapOrder::ColMajor) {
177      for (int i = 0; i < 4; i++) {
178        StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
179        StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
180      }
181    } else {
182      RegBlockInt32<4, 4> top;
183      top.buf.reg[0] = src.buf.reg[0];
184      top.buf.reg[1] = src.buf.reg[2];
185      top.buf.reg[2] = src.buf.reg[4];
186      top.buf.reg[3] = src.buf.reg[6];
187      const auto transpose_top = Transpose(top);
188      for (int i = 0; i < 4; i++) {
189        StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]);
190      }
191      RegBlockInt32<4, 4> bottom;
192      bottom.buf.reg[0] = src.buf.reg[1];
193      bottom.buf.reg[1] = src.buf.reg[3];
194      bottom.buf.reg[2] = src.buf.reg[5];
195      bottom.buf.reg[3] = src.buf.reg[7];
196      const auto transpose_bottom = Transpose(bottom);
197      for (int i = 0; i < 4; i++) {
198        StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]);
199      }
200    }
201  }
202};
203
204template <typename DstType>
205struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
206  static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
207                  int col) {
208    if (DstType::kOrder == MapOrder::ColMajor) {
209      for (int i = 0; i < 8; i++) {
210        StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
211        StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
212      }
213    } else {
214      RegBlockInt32<4, 4> top_left;
215      top_left.buf.reg[0] = src.buf.reg[0];
216      top_left.buf.reg[1] = src.buf.reg[2];
217      top_left.buf.reg[2] = src.buf.reg[4];
218      top_left.buf.reg[3] = src.buf.reg[6];
219      const auto transpose_top_left = Transpose(top_left);
220      for (int i = 0; i < 4; i++) {
221        StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]);
222      }
223      RegBlockInt32<4, 4> bottom_left;
224      bottom_left.buf.reg[0] = src.buf.reg[1];
225      bottom_left.buf.reg[1] = src.buf.reg[3];
226      bottom_left.buf.reg[2] = src.buf.reg[5];
227      bottom_left.buf.reg[3] = src.buf.reg[7];
228      const auto transpose_bottom_left = Transpose(bottom_left);
229      for (int i = 0; i < 4; i++) {
230        StoreInt32x4(dst->data(row + 4 + i, col),
231                     transpose_bottom_left.buf.reg[i]);
232      }
233      RegBlockInt32<4, 4> top_right;
234      top_right.buf.reg[0] = src.buf.reg[8];
235      top_right.buf.reg[1] = src.buf.reg[10];
236      top_right.buf.reg[2] = src.buf.reg[12];
237      top_right.buf.reg[3] = src.buf.reg[14];
238      const auto transpose_top_right = Transpose(top_right);
239      for (int i = 0; i < 4; i++) {
240        StoreInt32x4(dst->data(row + i, col + 4),
241                     transpose_top_right.buf.reg[i]);
242      }
243      RegBlockInt32<4, 4> bottom_right;
244      bottom_right.buf.reg[0] = src.buf.reg[9];
245      bottom_right.buf.reg[1] = src.buf.reg[11];
246      bottom_right.buf.reg[2] = src.buf.reg[13];
247      bottom_right.buf.reg[3] = src.buf.reg[15];
248      const auto transpose_bottom_right = Transpose(bottom_right);
249      for (int i = 0; i < 4; i++) {
250        StoreInt32x4(dst->data(row + 4 + i, col + 4),
251                     transpose_bottom_right.buf.reg[i]);
252      }
253    }
254  }
255};
256
257template <typename DstType>
258struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
259  static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
260                  int col) {
261    if (DstType::kOrder == MapOrder::ColMajor) {
262      *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]);
263      *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]);
264      *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]);
265      *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]);
266    } else {
267      StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
268    }
269  }
270};
271
272template <typename DstType>
273struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
274  static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
275                  int col) {
276    const std::uint32_t src_reg = src.buf.reg[0];
277    for (int i = 0; i < 4; i++) {
278      *dst->data(row + i, col) = (src_reg >> (8 * i));
279    }
280  }
281};
282
283template <typename DstType>
284struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> {
285  static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row,
286                  int col) {
287    for (int i = 0; i < 4; i++) {
288      *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i));
289    }
290    for (int i = 0; i < 4; i++) {
291      *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i));
292    }
293  }
294};
295
296template <typename DstType>
297struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> {
298  static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row,
299                  int col) {
300    for (int i = 0; i < 4; i++) {
301      *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
302    }
303  }
304};
305
306template <typename DstType>
307struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> {
308  static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row,
309                  int col) {
310    std::uint8_t buf[16];
311    StoreUint8x16(buf, src.buf.reg[0]);
312    for (int c = 0; c < 4; c++) {
313      for (int r = 0; r < 4; r++) {
314        *dst->data(row + r, col + c) = buf[r + 4 * c];
315      }
316    }
317  }
318};
319
320template <typename DstType>
321struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
322  static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
323                  int col) {
324    std::uint8_t buf[32];
325    StoreUint8x16(buf, src.buf.reg[0]);
326    StoreUint8x16(buf + 16, src.buf.reg[1]);
327    for (int c = 0; c < 4; c++) {
328      for (int r = 0; r < 8; r++) {
329        *dst->data(row + r, col + c) = buf[r + 8 * c];
330      }
331    }
332  }
333};
334
335template <typename DstType>
336struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
337  static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
338                  int col) {
339    std::uint8_t buf[64];
340    StoreUint8x16(buf, src.buf.reg[0]);
341    StoreUint8x16(buf + 16, src.buf.reg[1]);
342    StoreUint8x16(buf + 32, src.buf.reg[2]);
343    StoreUint8x16(buf + 48, src.buf.reg[3]);
344    for (int c = 0; c < 8; c++) {
345      for (int r = 0; r < 8; r++) {
346        *dst->data(row + r, col + c) = buf[r + 8 * c];
347      }
348    }
349  }
350};
351
352}  // namespace gemmlowp
353
354#endif  // GEMMLOWP_INTERNAL_OUTPUT_SSE_H_
355