1// Copyright 2018 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// simd_wrappers_msa.h: MSA specialization of simd_wrappers.h
16
17#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_MSA_H_
18#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_MSA_H_
19
20#include <msa.h>
21
22namespace gemmlowp {
23
24using Int32x4 = v4i32;
25using Int16x8 = v8i16;
26using Uint8x16 = v16i8;
27
28template <int ScalarCount>
29struct RegisterType<std::int32_t, ScalarCount> {
30  using Type =
31      typename std::conditional<ScalarCount >= 4, Int32x4, std::int32_t>::type;
32};
33
34template <int ScalarCount>
35struct RegisterType<std::int16_t, ScalarCount> {
36  using Type =
37      typename std::conditional<ScalarCount >= 8, Int16x8, std::int16_t>::type;
38};
39
40template <int ScalarCount>
41struct RegisterType<std::uint8_t, ScalarCount> {
42  using Type = typename std::conditional<
43      ScalarCount >= 16, Uint8x16,
44      typename std::conditional<ScalarCount >= 4, std::uint32_t,
45                                std::uint8_t>::type>::type;
46};
47
48inline Int32x4 LoadInt32x4(const std::int32_t* src) {
49  return __builtin_msa_ld_w(const_cast<std::int32_t*>(src), 0);
50}
51
52inline Int32x4 LoadInt32x4(const Int32x4* src) {
53  return __builtin_msa_ld_w(const_cast<Int32x4*>(src), 0);
54}
55
56inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) {
57  __builtin_msa_st_w(value, dst, 0);
58}
59
60inline void StoreInt32x4(Int32x4* dst, Int32x4 value) {
61  __builtin_msa_st_w(value, dst, 0);
62}
63
64inline Int16x8 LoadInt16x8(const std::int16_t* src) {
65  return __builtin_msa_ld_h(const_cast<std::int16_t*>(src), 0);
66}
67
68inline Int16x8 LoadInt16x8(const Int16x8* src) {
69  return __builtin_msa_ld_h(const_cast<Int16x8*>(src), 0);
70}
71
72inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) {
73  __builtin_msa_st_h(value, dst, 0);
74}
75
76inline void StoreInt16x8(Int16x8* dst, Int16x8 value) {
77  __builtin_msa_st_h(value, dst, 0);
78}
79
80inline Uint8x16 LoadUint8x16(const std::uint8_t* src) {
81  return __builtin_msa_ld_b(const_cast<std::uint8_t*>(src), 0);
82}
83
84inline Uint8x16 LoadUint8x16(const Uint8x16* src) {
85  return __builtin_msa_ld_b(const_cast<Uint8x16*>(src), 0);
86}
87
88inline void StoreUint8x16(std::uint8_t* dst, Uint8x16 value) {
89  __builtin_msa_st_b(value, dst, 0);
90}
91
92inline void StoreUint8x16(Uint8x16* dst, Uint8x16 value) {
93  __builtin_msa_st_b(value, dst, 0);
94}
95
96template <int Lane>
97std::int32_t GetLane(Int32x4 value) {
98  return __builtin_msa_copy_s_w(value, Lane);
99}
100
101template <int Lane>
102Int32x4 DupLane(Int32x4 value) {
103  static_assert(Lane >= 0 && Lane <= 3, "");
104  return __builtin_msa_splati_w(value, Lane);
105}
106
107inline Int32x4 Mul(Int32x4 a, std::int32_t b) {
108  return __builtin_msa_mulv_w(a, __builtin_msa_fill_w(b));
109}
110
111inline Int32x4 Min(Int32x4 a, Int32x4 b) { return __builtin_msa_min_s_w(a, b); }
112
113inline Int32x4 Max(Int32x4 a, Int32x4 b) { return __builtin_msa_max_s_w(a, b); }
114
115inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) {
116  return __builtin_msa_mulr_q_w(a, __builtin_msa_fill_w(b));
117}
118
119template <int Lane>
120Int32x4 MulByRhsLane(Int32x4 a, Int32x4 b) {
121  static_assert(Lane >= 0 && Lane <= 3, "");
122  return __builtin_msa_mulv_w(a, __builtin_msa_splati_w(b, Lane));
123}
124
125static inline v4i32 workaround_msa_maddv_w(v4i32 a, v4i32 b, v4i32 c) {
126  // Workaround for incorrect encoding of maddv.df in gcc (a exchanged with c).
127#if 0
128  return __builtin_msa_maddv_w(a, b, c);
129#else
130  asm volatile("maddv.w %w[a], %w[b], %w[c]\n"
131               // Outputs
132               : [a] "+f"(a)
133               // Inputs
134               : [b] "f"(b), [c] "f"(c));
135  return a;
136#endif
137}
138
139inline void MulAdd(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
140  Int32x4 tmp = LoadInt32x4(acc);
141  tmp = workaround_msa_maddv_w(tmp, lhs, rhs);
142  StoreInt32x4(acc, tmp);
143}
144
145inline void MulAdd(Int32x4 lhs, std::int32_t rhs, Int32x4* acc) {
146  Int32x4 tmp = LoadInt32x4(acc);
147  tmp = workaround_msa_maddv_w(tmp, lhs, __builtin_msa_fill_w(rhs));
148  StoreInt32x4(acc, tmp);
149}
150
151template <int Lane>
152inline void MulAddByRhsLane(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
153  static_assert(Lane >= 0 && Lane <= 3, "");
154  Int32x4 tmp = LoadInt32x4(acc);
155  tmp = workaround_msa_maddv_w(tmp, lhs, __builtin_msa_splati_w(rhs, Lane));
156  StoreInt32x4(acc, tmp);
157}
158
159template <>
160struct LoadContiguousImpl<RegBlockUint8<8, 8>> {
161  static RegBlockUint8<8, 8> Run(const std::uint8_t* src) {
162    RegBlockUint8<8, 8> result;
163    for (int i = 0; i < 4; i++) {
164      result.buf.reg[i] = LoadUint8x16(src + 16 * i);
165    }
166    return result;
167  }
168};
169
170template <>
171struct LoadContiguousImpl<RegBlockInt32<8, 8>> {
172  static RegBlockInt32<8, 8> Run(const std::int32_t* src) {
173    RegBlockInt32<8, 8> result;
174    for (int i = 0; i < 16; i++) {
175      result.buf.reg[i] = LoadInt32x4(src + 4 * i);
176    }
177    return result;
178  }
179};
180
181template <>
182struct LoadContiguousImpl<RegBlockInt16<8, 8>> {
183  static RegBlockInt16<8, 8> Run(const std::int16_t* src) {
184    RegBlockInt16<8, 8> result;
185    for (int i = 0; i < 8; i++) {
186      result.buf.reg[i] = LoadInt16x8(src + 8 * i);
187    }
188    return result;
189  }
190};
191
192}  // end namespace gemmlowp
193
194#include "simd_wrappers_common_neon_sse.h"
195
196#endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_MSA_H_
197