1// Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// simd_wrappers_neon.h: NEON specialization of simd_wrappers.h
16
17#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_
18#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_
19
20#include <arm_neon.h>
21
22namespace gemmlowp {
23
24using Int32x4 = int32x4_t;
25using Int16x4 = int16x4_t;
26using Int16x8 = int16x8_t;
27using Uint8x8 = uint8x8_t;
28
29template <int ScalarCount>
30struct RegisterType<std::int32_t, ScalarCount> {
31  using Type =
32      typename std::conditional<ScalarCount >= 4, Int32x4, std::int32_t>::type;
33};
34
35template <int ScalarCount>
36struct RegisterType<std::int16_t, ScalarCount> {
37  using Type = typename std::conditional<
38      ScalarCount >= 8, Int16x8,
39      typename std::conditional<ScalarCount >= 4, Int16x4,
40                                std::int16_t>::type>::type;
41};
42
43template <int ScalarCount>
44struct RegisterType<std::uint8_t, ScalarCount> {
45  using Type = typename std::conditional<
46      ScalarCount >= 8, Uint8x8,
47      typename std::conditional<ScalarCount >= 4, std::uint32_t,
48                                std::uint8_t>::type>::type;
49};
50
51inline Int32x4 LoadInt32x4(const std::int32_t* src) { return vld1q_s32(src); }
52inline Int16x4 LoadInt16x4(const std::int16_t* src) { return vld1_s16(src); }
53inline Int16x8 LoadInt16x8(const std::int16_t* src) { return vld1q_s16(src); }
54
55inline void StoreInt32x4(std::int32_t* dst, Int32x4 value) {
56  vst1q_s32(dst, value);
57}
58
59inline void StoreInt16x4(std::int16_t* dst, Int16x4 value) {
60  vst1_s16(dst, value);
61}
62
63inline void StoreInt16x8(std::int16_t* dst, Int16x8 value) {
64  vst1q_s16(dst, value);
65}
66
67template <int Lane>
68std::int32_t GetLane(Int32x4 value) {
69  return vgetq_lane_s32(value, Lane);
70}
71
72template <int Lane>
73Int32x4 DupLane(Int32x4 value) {
74  switch (Lane) {
75    case 0:
76      return vdupq_lane_s32(vget_low_s32(value), 0);
77    case 1:
78      return vdupq_lane_s32(vget_low_s32(value), 1);
79    case 2:
80      return vdupq_lane_s32(vget_high_s32(value), 0);
81    case 3:
82      return vdupq_lane_s32(vget_high_s32(value), 1);
83    default:
84      static_assert(Lane >= 0 && Lane <= 3, "");
85      return vdupq_n_s32(0);
86  }
87}
88
89inline Int32x4 Mul(Int32x4 a, std::int32_t b) { return vmulq_n_s32(a, b); }
90
91inline Int32x4 Min(Int32x4 a, Int32x4 b) { return vminq_s32(a, b); }
92
93inline Int32x4 Max(Int32x4 a, Int32x4 b) { return vmaxq_s32(a, b); }
94
95inline Int32x4 SaturatingRoundingDoublingHighMul(Int32x4 a, std::int32_t b) {
96  return vqrdmulhq_n_s32(a, b);
97}
98
99template <int Lane>
100Int32x4 MulByRhsLane(Int32x4 a, Int32x4 b) {
101  switch (Lane) {
102    case 0:
103      return vmulq_lane_s32(a, vget_low_s32(b), 0);
104    case 1:
105      return vmulq_lane_s32(a, vget_low_s32(b), 1);
106    case 2:
107      return vmulq_lane_s32(a, vget_high_s32(b), 0);
108    case 3:
109      return vmulq_lane_s32(a, vget_high_s32(b), 1);
110    default:
111      static_assert(Lane >= 0 && Lane <= 3, "");
112      return vdupq_n_s32(0);
113  }
114}
115
116inline void MulAdd(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
117  *acc = vmlaq_s32(*acc, lhs, rhs);
118}
119
120inline void MulAdd(Int32x4 lhs, std::int32_t rhs, Int32x4* acc) {
121  *acc = vmlaq_n_s32(*acc, lhs, rhs);
122}
123
124template <int Lane>
125inline void MulAddByRhsLane(Int32x4 lhs, Int32x4 rhs, Int32x4* acc) {
126  switch (Lane) {
127    case 0:
128      *acc = vmlaq_lane_s32(*acc, lhs, vget_low_s32(rhs), 0);
129      break;
130    case 1:
131      *acc = vmlaq_lane_s32(*acc, lhs, vget_low_s32(rhs), 1);
132      break;
133    case 2:
134      *acc = vmlaq_lane_s32(*acc, lhs, vget_high_s32(rhs), 0);
135      break;
136    case 3:
137      *acc = vmlaq_lane_s32(*acc, lhs, vget_high_s32(rhs), 1);
138      break;
139    default:
140      static_assert(Lane >= 0 && Lane <= 3, "");
141  }
142}
143
144template <>
145struct LoadContiguousImpl<RegBlockInt16<8, 8>> {
146  static RegBlockInt16<8, 8> Run(const std::int16_t* src) {
147    RegBlockInt16<8, 8> result;
148    for (int i = 0; i < 8; i++) {
149      result.buf.reg[i] = vld1q_s16(src + 8 * i);
150    }
151    return result;
152  }
153};
154
155template <>
156struct LoadContiguousImpl<RegBlockUint8<8, 8>> {
157  static RegBlockUint8<8, 8> Run(const std::uint8_t* src) {
158    RegBlockUint8<8, 8> result;
159    for (int i = 0; i < 8; i++) {
160      result.buf.reg[i] = vld1_u8(src + 8 * i);
161    }
162    return result;
163  }
164};
165
166template <>
167struct LoadContiguousImpl<RegBlockInt32<8, 8>> {
168  static RegBlockInt32<8, 8> Run(const std::int32_t* src) {
169    RegBlockInt32<8, 8> result;
170    for (int i = 0; i < 16; i++) {
171      result.buf.reg[i] = vld1q_s32(src + 4 * i);
172    }
173    return result;
174  }
175};
176
177}  // end namespace gemmlowp
178
179#include "simd_wrappers_common_neon_sse.h"
180
181#endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_NEON_H_
182