1// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// output_msa.h: optimized MSA specializations of the templates in output.h.
16
17#ifndef GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
18#define GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
19
20#include "output.h"
21
22#include <msa.h>
23
24namespace gemmlowp {
25
26template <>
27struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
28                                 RegBufferInt32<4>> {
29  typedef RegBufferInt32<4> InputType;
30  typedef RegBufferUint8<4> OutputType;
31
32  typedef OutputStageSaturatingCastToUint8 OutputStage;
33
34  OutputStageEvalBufferImpl(const OutputStage&) {}
35
36  OutputType Eval(InputType input) const {
37    OutputType output;
38    // Signed saturate each 32-bit element to 9 bits
39    // (this takes full care of non-negative elements).
40    v4i32 tmp = __builtin_msa_sat_s_w(input.reg[0], 8);
41    // Pack every 32-bit element into 16 bits.
42    tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
43        reinterpret_cast<v8i16>(tmp), reinterpret_cast<v8i16>(tmp)));
44    // Detect negative elements with arithmetic shift right (we
45    // get a 16-bit mask of all zeroes or all ones for every element).
46    v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp), 15);
47    // Zero out negative elements.
48    signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
49        reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp), 0));
50    // Pack every element into 8 bits.
51    tmp = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
52        reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
53    // Return 4 uint8_t elements as uint32_t.
54    output.reg[0] = __builtin_msa_copy_s_w(tmp, 0);
55    return output;
56  }
57};
58
59template <>
60struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
61                                 RegBufferInt32<8>> {
62  typedef RegBufferInt32<8> InputType;
63  typedef RegBufferUint8<8> OutputType;
64
65  typedef OutputStageSaturatingCastToUint8 OutputStage;
66
67  OutputStageEvalBufferImpl(const OutputStage&) {}
68
69  OutputType Eval(InputType input) const {
70    OutputType output;
71    // Signed saturate each 32-bit element to 9 bits
72    // (this takes full care of non-negative elements).
73    v4i32 tmp_lo = __builtin_msa_sat_s_w(input.reg[0], 8);
74    v4i32 tmp_hi = __builtin_msa_sat_s_w(input.reg[1], 8);
75    // Pack every 32-bit element into 16 bits,
76    // combining all 8 elements into one vector.
77    tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(
78        reinterpret_cast<v8i16>(tmp_hi), reinterpret_cast<v8i16>(tmp_lo)));
79    // Detect negative elements with arithmetic shift right (we
80    // get a 16-bit mask of all zeroes or all ones for every element).
81    v8i16 signs = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp_lo), 15);
82    // Zero out negative elements.
83    signs = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(
84        reinterpret_cast<v16u8>(signs), reinterpret_cast<v16u8>(tmp_lo), 0));
85    // Pack every element into 8 bits.
86    tmp_lo = reinterpret_cast<v4i32>(__builtin_msa_pckev_b(
87        reinterpret_cast<v16i8>(signs), reinterpret_cast<v16i8>(signs)));
88    // Return 8 uint8_t elements as 2 uint32_t's.
89    output.reg[0] = __builtin_msa_copy_s_w(tmp_lo, 0);
90    output.reg[1] = __builtin_msa_copy_s_w(tmp_lo, 1);
91    return output;
92  }
93};
94
95#define GEMMLOWP_MIPS_SAT_U8_16(out, in0, in1, in2, in3)                     \
96  {                                                                          \
97    v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 8);                              \
98    v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 8);                              \
99    v4i32 tmp2 = __builtin_msa_sat_s_w(in2, 8);                              \
100    v4i32 tmp3 = __builtin_msa_sat_s_w(in3, 8);                              \
101    tmp0 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(                    \
102        reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0)));      \
103    tmp2 = reinterpret_cast<v4i32>(__builtin_msa_pckev_h(                    \
104        reinterpret_cast<v8i16>(tmp3), reinterpret_cast<v8i16>(tmp2)));      \
105    v8i16 signs0 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp0), 15);  \
106    v8i16 signs1 = __builtin_msa_srai_h(reinterpret_cast<v8i16>(tmp2), 15);  \
107    signs0 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(                  \
108        reinterpret_cast<v16u8>(signs0), reinterpret_cast<v16u8>(tmp0), 0)); \
109    signs1 = reinterpret_cast<v8i16>(__builtin_msa_bseli_b(                  \
110        reinterpret_cast<v16u8>(signs1), reinterpret_cast<v16u8>(tmp2), 0)); \
111    signs0 = reinterpret_cast<v8i16>(__builtin_msa_pckev_b(                  \
112        reinterpret_cast<v16i8>(signs1), reinterpret_cast<v16i8>(signs0)));  \
113    out = reinterpret_cast<v16i8>(signs0);                                   \
114  }
115
116template <>
117struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
118                                 RegBufferInt32<16>> {
119  typedef RegBufferInt32<16> InputType;
120  typedef RegBufferUint8<16> OutputType;
121
122  typedef OutputStageSaturatingCastToUint8 OutputStage;
123
124  OutputStageEvalBufferImpl(const OutputStage&) {}
125
126  OutputType Eval(InputType input) const {
127    OutputType output;
128    GEMMLOWP_MIPS_SAT_U8_16(output.reg[0], input.reg[0], input.reg[1],
129                            input.reg[2], input.reg[3]);
130    return output;
131  }
132};
133
134template <>
135struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8,
136                                 RegBufferInt32<32>> {
137  typedef RegBufferInt32<32> InputType;
138  typedef RegBufferUint8<32> OutputType;
139
140  typedef OutputStageSaturatingCastToUint8 OutputStage;
141
142  OutputStageEvalBufferImpl(const OutputStage&) {}
143
144  OutputType Eval(InputType input) const {
145    OutputType output;
146    GEMMLOWP_MIPS_SAT_U8_16(output.reg[0], input.reg[0], input.reg[1],
147                            input.reg[2], input.reg[3]);
148    GEMMLOWP_MIPS_SAT_U8_16(output.reg[1], input.reg[4], input.reg[5],
149                            input.reg[6], input.reg[7]);
150    return output;
151  }
152};
153
154#undef GEMMLOWP_MIPS_SAT_U8_16
155
156template <>
157struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
158                                 RegBufferInt32<4>> {
159  typedef RegBufferInt32<4> InputType;
160  typedef RegBufferInt16<4> OutputType;
161
162  typedef OutputStageSaturatingCastToInt16 OutputStage;
163
164  OutputStageEvalBufferImpl(const OutputStage&) {}
165
166  OutputType Eval(InputType input) const {
167    OutputType output;
168    // Signed saturate each 32-bit element to 16 bits.
169    v8i16 tmp = reinterpret_cast<v8i16>(__builtin_msa_sat_s_w(
170        input.reg[0], 15));
171    output.reg[0] = __builtin_msa_copy_s_h(tmp, 0);
172    output.reg[1] = __builtin_msa_copy_s_h(tmp, 2);
173    output.reg[2] = __builtin_msa_copy_s_h(tmp, 4);
174    output.reg[3] = __builtin_msa_copy_s_h(tmp, 6);
175    return output;
176  }
177};
178
179#define GEMMLOWP_MIPS_SAT_I16_8(out, in0, in1)                         \
180  {                                                                    \
181    v4i32 tmp0 = __builtin_msa_sat_s_w(in0, 15);                       \
182    v4i32 tmp1 = __builtin_msa_sat_s_w(in1, 15);                       \
183    out = __builtin_msa_pckev_h(                                       \
184        reinterpret_cast<v8i16>(tmp1), reinterpret_cast<v8i16>(tmp0)); \
185  }
186
187template <>
188struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
189                                 RegBufferInt32<8>> {
190  typedef RegBufferInt32<8> InputType;
191  typedef RegBufferInt16<8> OutputType;
192
193  typedef OutputStageSaturatingCastToInt16 OutputStage;
194
195  OutputStageEvalBufferImpl(const OutputStage&) {}
196
197  OutputType Eval(InputType input) const {
198    OutputType output;
199    GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]);
200    return output;
201  }
202};
203
204template <>
205struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
206                                 RegBufferInt32<16>> {
207  typedef RegBufferInt32<16> InputType;
208  typedef RegBufferInt16<16> OutputType;
209
210  typedef OutputStageSaturatingCastToInt16 OutputStage;
211
212  OutputStageEvalBufferImpl(const OutputStage&) {}
213
214  OutputType Eval(InputType input) const {
215    OutputType output;
216    GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]);
217    GEMMLOWP_MIPS_SAT_I16_8(output.reg[1], input.reg[2], input.reg[3]);
218    return output;
219  }
220};
221
222template <>
223struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16,
224                                 RegBufferInt32<32>> {
225  typedef RegBufferInt32<32> InputType;
226  typedef RegBufferInt16<32> OutputType;
227
228  typedef OutputStageSaturatingCastToInt16 OutputStage;
229
230  OutputStageEvalBufferImpl(const OutputStage&) {}
231
232  OutputType Eval(InputType input) const {
233    OutputType output;
234    GEMMLOWP_MIPS_SAT_I16_8(output.reg[0], input.reg[0], input.reg[1]);
235    GEMMLOWP_MIPS_SAT_I16_8(output.reg[1], input.reg[2], input.reg[3]);
236    GEMMLOWP_MIPS_SAT_I16_8(output.reg[2], input.reg[4], input.reg[5]);
237    GEMMLOWP_MIPS_SAT_I16_8(output.reg[3], input.reg[6], input.reg[7]);
238    return output;
239  }
240};
241
242#undef GEMMLOWP_MIPS_SAT_I16_8
243
244template <typename DstType>
245struct StoreFinalOutputImpl<RegBlockInt32<4, 1>, DstType> {
246  static void Run(const RegBlockInt32<4, 1>& src, DstType* dst, int row,
247                  int col) {
248    if (DstType::kOrder == MapOrder::ColMajor) {
249      StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
250    } else {
251      *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
252      *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
253      *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
254      *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
255    }
256  }
257};
258
259template <typename DstType>
260struct StoreFinalOutputImpl<RegBlockInt32<8, 1>, DstType> {
261  static void Run(const RegBlockInt32<8, 1>& src, DstType* dst, int row,
262                  int col) {
263    if (DstType::kOrder == MapOrder::ColMajor) {
264      StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
265      StoreInt32x4(dst->data(row + 4, col), src.buf.reg[1]);
266    } else {
267      *dst->data(row + 0, col) = GetLane<0>(src.buf.reg[0]);
268      *dst->data(row + 1, col) = GetLane<1>(src.buf.reg[0]);
269      *dst->data(row + 2, col) = GetLane<2>(src.buf.reg[0]);
270      *dst->data(row + 3, col) = GetLane<3>(src.buf.reg[0]);
271      *dst->data(row + 4, col) = GetLane<0>(src.buf.reg[1]);
272      *dst->data(row + 5, col) = GetLane<1>(src.buf.reg[1]);
273      *dst->data(row + 6, col) = GetLane<2>(src.buf.reg[1]);
274      *dst->data(row + 7, col) = GetLane<3>(src.buf.reg[1]);
275    }
276  }
277};
278
279template <typename DstType>
280struct StoreFinalOutputImpl<RegBlockInt16<4, 1>, DstType> {
281  static void Run(const RegBlockInt16<4, 1>& src, DstType* dst, int row,
282                  int col) {
283    *dst->data(row + 0, col) = src.buf.reg[0];
284    *dst->data(row + 1, col) = src.buf.reg[1];
285    *dst->data(row + 2, col) = src.buf.reg[2];
286    *dst->data(row + 3, col) = src.buf.reg[3];
287  }
288};
289
290template <typename DstType>
291struct StoreFinalOutputImpl<RegBlockInt16<8, 1>, DstType> {
292  static void Run(const RegBlockInt16<8, 1>& src, DstType* dst, int row,
293                  int col) {
294    if (DstType::kOrder == MapOrder::ColMajor) {
295      StoreInt16x8(dst->data(row, col), src.buf.reg[0]);
296    } else {
297      *dst->data(row + 0, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 0);
298      *dst->data(row + 1, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 1);
299      *dst->data(row + 2, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 2);
300      *dst->data(row + 3, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 3);
301      *dst->data(row + 4, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 4);
302      *dst->data(row + 5, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 5);
303      *dst->data(row + 6, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 6);
304      *dst->data(row + 7, col) = __builtin_msa_copy_s_h(src.buf.reg[0], 7);
305    }
306  }
307};
308
309inline RegBlockInt32<4, 4> Transpose(const RegBlockInt32<4, 4>& src) {
310  RegBlockInt32<4, 4> result;
311  v4i32 tmp0, tmp1;
312  tmp0 = __builtin_msa_ilvr_w(src.buf.reg[1], src.buf.reg[0]);
313  tmp1 = __builtin_msa_ilvr_w(src.buf.reg[3], src.buf.reg[2]);
314  result.buf.reg[0] = reinterpret_cast<v4i32>(__builtin_msa_ilvr_d(
315      reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
316  result.buf.reg[1] = reinterpret_cast<v4i32>(__builtin_msa_ilvl_d(
317      reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
318  tmp0 = __builtin_msa_ilvl_w(src.buf.reg[1], src.buf.reg[0]);
319  tmp1 = __builtin_msa_ilvl_w(src.buf.reg[3], src.buf.reg[2]);
320  result.buf.reg[2] = reinterpret_cast<v4i32>(__builtin_msa_ilvr_d(
321      reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
322  result.buf.reg[3] = reinterpret_cast<v4i32>(__builtin_msa_ilvl_d(
323      reinterpret_cast<v2i64>(tmp1), reinterpret_cast<v2i64>(tmp0)));
324  return result;
325}
326
327template <typename DstType>
328struct StoreFinalOutputImpl<RegBlockInt32<4, 4>, DstType> {
329  static void Run(const RegBlockInt32<4, 4>& src, DstType* dst, int row,
330                  int col) {
331    if (DstType::kOrder == MapOrder::ColMajor) {
332      for (int i = 0; i < 4; i++) {
333        StoreInt32x4(dst->data(row, col + i), src.buf.reg[i]);
334      }
335    } else {
336      const auto transpose = Transpose(src);
337      for (int i = 0; i < 4; i++) {
338        StoreInt32x4(dst->data(row + i, col), transpose.buf.reg[i]);
339      }
340    }
341  }
342};
343
344template <typename DstType>
345struct StoreFinalOutputImpl<RegBlockInt16<4, 4>, DstType> {
346  static void Run(const RegBlockInt16<4, 4>& src, DstType* dst, int row,
347                  int col) {
348    std::int16_t buf[16];
349    StoreInt16x8(buf + 0, src.buf.reg[0]);
350    StoreInt16x8(buf + 8, src.buf.reg[1]);
351    for (int i = 0; i < 4; i++) {
352      for (int j = 0; j < 4; j++) {
353        *dst->data(row + i, col + j) = buf[i + 4 * j];
354      }
355    }
356  }
357};
358
359template <typename DstType>
360struct StoreFinalOutputImpl<RegBlockInt32<8, 4>, DstType> {
361  static void Run(const RegBlockInt32<8, 4>& src, DstType* dst, int row,
362                  int col) {
363    if (DstType::kOrder == MapOrder::ColMajor) {
364      for (int i = 0; i < 4; i++) {
365        StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
366        StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
367      }
368    } else {
369      RegBlockInt32<4, 4> top;
370      top.buf.reg[0] = src.buf.reg[0];
371      top.buf.reg[1] = src.buf.reg[2];
372      top.buf.reg[2] = src.buf.reg[4];
373      top.buf.reg[3] = src.buf.reg[6];
374      const auto transpose_top = Transpose(top);
375      for (int i = 0; i < 4; i++) {
376        StoreInt32x4(dst->data(row + i, col), transpose_top.buf.reg[i]);
377      }
378      RegBlockInt32<4, 4> bottom;
379      bottom.buf.reg[0] = src.buf.reg[1];
380      bottom.buf.reg[1] = src.buf.reg[3];
381      bottom.buf.reg[2] = src.buf.reg[5];
382      bottom.buf.reg[3] = src.buf.reg[7];
383      const auto transpose_bottom = Transpose(bottom);
384      for (int i = 0; i < 4; i++) {
385        StoreInt32x4(dst->data(row + 4 + i, col), transpose_bottom.buf.reg[i]);
386      }
387    }
388  }
389};
390
391template <typename DstType>
392struct StoreFinalOutputImpl<RegBlockInt16<8, 4>, DstType> {
393  static void Run(const RegBlockInt16<8, 4>& src, DstType* dst, int row,
394                  int col) {
395    if (DstType::kOrder == MapOrder::ColMajor) {
396      for (int i = 0; i < 4; i++) {
397        StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
398      }
399    } else {
400      std::int16_t buf[32];
401      StoreInt16x8(buf + 0, src.buf.reg[0]);
402      StoreInt16x8(buf + 8, src.buf.reg[1]);
403      StoreInt16x8(buf + 16, src.buf.reg[2]);
404      StoreInt16x8(buf + 24, src.buf.reg[3]);
405      for (int i = 0; i < 8; i++) {
406        for (int j = 0; j < 4; j++) {
407          *dst->data(row + i, col + j) = buf[i + 8 * j];
408        }
409      }
410    }
411  }
412};
413
414template <typename DstType>
415struct StoreFinalOutputImpl<RegBlockInt32<8, 8>, DstType> {
416  static void Run(const RegBlockInt32<8, 8>& src, DstType* dst, int row,
417                  int col) {
418    if (DstType::kOrder == MapOrder::ColMajor) {
419      for (int i = 0; i < 8; i++) {
420        StoreInt32x4(dst->data(row, col + i), src.buf.reg[2 * i]);
421        StoreInt32x4(dst->data(row + 4, col + i), src.buf.reg[2 * i + 1]);
422      }
423    } else {
424      RegBlockInt32<4, 4> top_left;
425      top_left.buf.reg[0] = src.buf.reg[0];
426      top_left.buf.reg[1] = src.buf.reg[2];
427      top_left.buf.reg[2] = src.buf.reg[4];
428      top_left.buf.reg[3] = src.buf.reg[6];
429      const auto transpose_top_left = Transpose(top_left);
430      for (int i = 0; i < 4; i++) {
431        StoreInt32x4(dst->data(row + i, col), transpose_top_left.buf.reg[i]);
432      }
433      RegBlockInt32<4, 4> bottom_left;
434      bottom_left.buf.reg[0] = src.buf.reg[1];
435      bottom_left.buf.reg[1] = src.buf.reg[3];
436      bottom_left.buf.reg[2] = src.buf.reg[5];
437      bottom_left.buf.reg[3] = src.buf.reg[7];
438      const auto transpose_bottom_left = Transpose(bottom_left);
439      for (int i = 0; i < 4; i++) {
440        StoreInt32x4(dst->data(row + 4 + i, col),
441                     transpose_bottom_left.buf.reg[i]);
442      }
443      RegBlockInt32<4, 4> top_right;
444      top_right.buf.reg[0] = src.buf.reg[8];
445      top_right.buf.reg[1] = src.buf.reg[10];
446      top_right.buf.reg[2] = src.buf.reg[12];
447      top_right.buf.reg[3] = src.buf.reg[14];
448      const auto transpose_top_right = Transpose(top_right);
449      for (int i = 0; i < 4; i++) {
450        StoreInt32x4(dst->data(row + i, col + 4),
451                     transpose_top_right.buf.reg[i]);
452      }
453      RegBlockInt32<4, 4> bottom_right;
454      bottom_right.buf.reg[0] = src.buf.reg[9];
455      bottom_right.buf.reg[1] = src.buf.reg[11];
456      bottom_right.buf.reg[2] = src.buf.reg[13];
457      bottom_right.buf.reg[3] = src.buf.reg[15];
458      const auto transpose_bottom_right = Transpose(bottom_right);
459      for (int i = 0; i < 4; i++) {
460        StoreInt32x4(dst->data(row + 4 + i, col + 4),
461                     transpose_bottom_right.buf.reg[i]);
462      }
463    }
464  }
465};
466
467template <typename DstType>
468struct StoreFinalOutputImpl<RegBlockInt16<8, 8>, DstType> {
469  static void Run(const RegBlockInt16<8, 8>& src, DstType* dst, int row,
470                  int col) {
471    if (DstType::kOrder == MapOrder::ColMajor) {
472      for (int i = 0; i < 8; i++) {
473        StoreInt16x8(dst->data(row, col + i), src.buf.reg[i]);
474      }
475    } else {
476      // top-left 4x4
477      v4i32 t0 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[1],
478          src.buf.reg[0]));
479      v4i32 t1 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[3],
480          src.buf.reg[2]));
481      v2i64 u0 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t1, t0));
482      v2i64 u1 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t1, t0));
483      // top-right 4x4
484      v4i32 t2 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[5],
485          src.buf.reg[4]));
486      v4i32 t3 = reinterpret_cast<v4i32>(__builtin_msa_ilvr_h(src.buf.reg[7],
487          src.buf.reg[6]));
488      v2i64 u2 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t3, t2));
489      v2i64 u3 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t3, t2));
490      // bottom-left 4x4
491      v4i32 t4 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[1],
492          src.buf.reg[0]));
493      v4i32 t5 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[3],
494          src.buf.reg[2]));
495      v2i64 u4 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t5, t4));
496      v2i64 u5 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t5, t4));
497      // bottom-right 4x4
498      v4i32 t6 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[5],
499          src.buf.reg[4]));
500      v4i32 t7 = reinterpret_cast<v4i32>(__builtin_msa_ilvl_h(src.buf.reg[7],
501          src.buf.reg[6]));
502      v2i64 u6 = reinterpret_cast<v2i64>(__builtin_msa_ilvr_w(t7, t6));
503      v2i64 u7 = reinterpret_cast<v2i64>(__builtin_msa_ilvl_w(t7, t6));
504
505      StoreInt16x8(dst->data(row + 0, col), reinterpret_cast<v8i16>(
506          __builtin_msa_ilvr_d(u2, u0)));
507      StoreInt16x8(dst->data(row + 1, col), reinterpret_cast<v8i16>(
508          __builtin_msa_ilvl_d(u2, u0)));
509      StoreInt16x8(dst->data(row + 2, col), reinterpret_cast<v8i16>(
510          __builtin_msa_ilvr_d(u3, u1)));
511      StoreInt16x8(dst->data(row + 3, col), reinterpret_cast<v8i16>(
512          __builtin_msa_ilvl_d(u3, u1)));
513      StoreInt16x8(dst->data(row + 4, col), reinterpret_cast<v8i16>(
514          __builtin_msa_ilvr_d(u6, u4)));
515      StoreInt16x8(dst->data(row + 5, col), reinterpret_cast<v8i16>(
516          __builtin_msa_ilvl_d(u6, u4)));
517      StoreInt16x8(dst->data(row + 6, col), reinterpret_cast<v8i16>(
518          __builtin_msa_ilvr_d(u7, u5)));
519      StoreInt16x8(dst->data(row + 7, col), reinterpret_cast<v8i16>(
520          __builtin_msa_ilvl_d(u7, u5)));
521    }
522  }
523};
524
525template <typename DstType>
526struct StoreFinalOutputImpl<RegBlockInt32<1, 4>, DstType> {
527  static void Run(const RegBlockInt32<1, 4>& src, DstType* dst, int row,
528                  int col) {
529    if (DstType::kOrder == MapOrder::ColMajor) {
530      *dst->data(row, col + 0) = GetLane<0>(src.buf.reg[0]);
531      *dst->data(row, col + 1) = GetLane<1>(src.buf.reg[0]);
532      *dst->data(row, col + 2) = GetLane<2>(src.buf.reg[0]);
533      *dst->data(row, col + 3) = GetLane<3>(src.buf.reg[0]);
534    } else {
535      StoreInt32x4(dst->data(row, col), src.buf.reg[0]);
536    }
537  }
538};
539
540template <typename DstType>
541struct StoreFinalOutputImpl<RegBlockUint8<4, 1>, DstType> {
542  static void Run(const RegBlockUint8<4, 1>& src, DstType* dst, int row,
543                  int col) {
544    const std::uint32_t src_reg = src.buf.reg[0];
545    for (int i = 0; i < 4; i++) {
546      *dst->data(row + i, col) = (src_reg >> (8 * i));
547    }
548  }
549};
550
551template <typename DstType>
552struct StoreFinalOutputImpl<RegBlockUint8<8, 1>, DstType> {
553  static void Run(const RegBlockUint8<8, 1>& src, DstType* dst, int row,
554                  int col) {
555    for (int i = 0; i < 4; i++) {
556      *dst->data(row + i, col) = (src.buf.reg[0] >> (8 * i));
557    }
558    for (int i = 0; i < 4; i++) {
559      *dst->data(row + 4 + i, col) = (src.buf.reg[1] >> (8 * i));
560    }
561  }
562};
563
564template <typename DstType>
565struct StoreFinalOutputImpl<RegBlockUint8<1, 4>, DstType> {
566  static void Run(const RegBlockUint8<1, 4>& src, DstType* dst, int row,
567                  int col) {
568    for (int i = 0; i < 4; i++) {
569      *dst->data(row, col + i) = (src.buf.reg[0] >> (8 * i));
570    }
571  }
572};
573
574template <typename DstType>
575struct StoreFinalOutputImpl<RegBlockUint8<4, 4>, DstType> {
576  static void Run(const RegBlockUint8<4, 4>& src, DstType* dst, int row,
577                  int col) {
578    std::uint8_t buf[16];
579    StoreUint8x16(buf, src.buf.reg[0]);
580    for (int c = 0; c < 4; c++) {
581      for (int r = 0; r < 4; r++) {
582        *dst->data(row + r, col + c) = buf[r + 4 * c];
583      }
584    }
585  }
586};
587
588template <typename DstType>
589struct StoreFinalOutputImpl<RegBlockUint8<8, 4>, DstType> {
590  static void Run(const RegBlockUint8<8, 4>& src, DstType* dst, int row,
591                  int col) {
592    std::uint8_t buf[32];
593    StoreUint8x16(buf, src.buf.reg[0]);
594    StoreUint8x16(buf + 16, src.buf.reg[1]);
595    for (int c = 0; c < 4; c++) {
596      for (int r = 0; r < 8; r++) {
597        *dst->data(row + r, col + c) = buf[r + 8 * c];
598      }
599    }
600  }
601};
602
603template <typename DstType>
604struct StoreFinalOutputImpl<RegBlockUint8<8, 8>, DstType> {
605  static void Run(const RegBlockUint8<8, 8>& src, DstType* dst, int row,
606                  int col) {
607    std::uint8_t buf[64];
608    StoreUint8x16(buf, src.buf.reg[0]);
609    StoreUint8x16(buf + 16, src.buf.reg[1]);
610    StoreUint8x16(buf + 32, src.buf.reg[2]);
611    StoreUint8x16(buf + 48, src.buf.reg[3]);
612    for (int c = 0; c < 8; c++) {
613      for (int r = 0; r < 8; r++) {
614        *dst->data(row + r, col + c) = buf[r + 8 * c];
615      }
616    }
617  }
618};
619
620}  // namespace gemmlowp
621
622#endif  // GEMMLOWP_INTERNAL_OUTPUT_MSA_H_
623