1// Copyright 2015 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// simd_wrappers_common_neon_sse.h: common SIMD (NEON and SSE) wrapper code
16
17#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
18#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
19
20#include "simd_wrappers.h"
21
22namespace gemmlowp {
23
24template <typename SrcScalarType, int N>
25struct LoadImpl<RegBlockInt32<4, N>,
26                MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
27  static RegBlockInt32<4, N> Run(
28      const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
29      int col) {
30    RegBlockInt32<4, N> result;
31    for (int i = 0; i < N; i++) {
32      result.buf.reg[i] = LoadInt32x4(src.data(row, col + i));
33    }
34    return result;
35  }
36};
37
38template <typename SrcScalarType, int N>
39struct LoadImpl<RegBlockInt32<8, N>,
40                MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
41  static RegBlockInt32<8, N> Run(
42      const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
43      int col) {
44    RegBlockInt32<8, N> result;
45    for (int i = 0; i < N; i++) {
46      result.buf.reg[2 * i + 0] = LoadInt32x4(src.data(row + 0, col + i));
47      result.buf.reg[2 * i + 1] = LoadInt32x4(src.data(row + 4, col + i));
48    }
49    return result;
50  }
51};
52
53template <typename SrcScalarType>
54struct LoadImpl<RegBlockInt32<1, 4>,
55                MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
56  static RegBlockInt32<1, 4> Run(
57      const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
58      int col) {
59    RegBlockInt32<1, 4> result;
60    std::int32_t buf[4];
61    for (int i = 0; i < 4; i++) {
62      buf[i] = src(row, col + i);
63    }
64    result.buf.reg[0] = LoadInt32x4(buf);
65    return result;
66  }
67};
68
69template <typename SrcScalarType>
70struct LoadImpl<RegBlockInt32<1, 8>,
71                MatrixMap<SrcScalarType, MapOrder::ColMajor>> {
72  static RegBlockInt32<1, 8> Run(
73      const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row,
74      int col) {
75    RegBlockInt32<1, 8> result;
76    std::int32_t buf[8];
77    for (int i = 0; i < 8; i++) {
78      buf[i] = src(row, col + i);
79    }
80    result.buf.reg[0] = LoadInt32x4(buf);
81    result.buf.reg[1] = LoadInt32x4(buf + 4);
82    return result;
83  }
84};
85
86template <typename SrcScalarType>
87struct LoadImpl<RegBlockInt32<4, 1>,
88                VectorMap<SrcScalarType, VectorShape::Col>> {
89  static RegBlockInt32<4, 1> Run(
90      const VectorMap<SrcScalarType, VectorShape::Col>& src, int pos) {
91    RegBlockInt32<4, 1> result;
92    result.buf.reg[0] = LoadInt32x4(src.data(pos));
93    return result;
94  }
95};
96
97template <typename SrcScalarType>
98struct LoadImpl<RegBlockInt32<4, 1>,
99                VectorDup<SrcScalarType, VectorShape::Col>> {
100  static RegBlockInt32<4, 1> Run(
101      const VectorDup<SrcScalarType, VectorShape::Col>& src, int) {
102    RegBlockInt32<4, 1> result;
103    result.buf.reg[0] = LoadInt32x4(src(0));
104    return result;
105  }
106};
107
108template <typename SrcScalarType, int N>
109struct LoadForBroadcastingImpl<RegBlockInt32<4, N>,
110                               VectorMap<SrcScalarType, VectorShape::Col>> {
111  using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>;
112  using RegisterBlockType = RegBlockInt32<4, N>;
113  using ResultBlockType =
114      typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
115                                                SrcObjectType>::Type;
116
117  static ResultBlockType Run(const SrcObjectType& src, int pos) {
118    ResultBlockType result;
119    static_assert(ResultBlockType::kRegisterCount == 1, "");
120    result.buf.reg[0] = LoadInt32x4(src.data(pos));
121    return result;
122  }
123};
124
125template <typename SrcScalarType, int N>
126struct LoadForBroadcastingImpl<RegBlockInt32<8, N>,
127                               VectorMap<SrcScalarType, VectorShape::Col>> {
128  using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>;
129  using RegisterBlockType = RegBlockInt32<8, N>;
130  using ResultBlockType =
131      typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
132                                                SrcObjectType>::Type;
133
134  static ResultBlockType Run(const SrcObjectType& src, int pos) {
135    ResultBlockType result;
136    static_assert(ResultBlockType::kRegisterCount == 2, "");
137    result.buf.reg[0] = LoadInt32x4(src.data(pos));
138    result.buf.reg[1] = LoadInt32x4(src.data(pos + 4));
139    return result;
140  }
141};
142
143template <typename SrcScalarType>
144struct LoadForBroadcastingImpl<RegBlockInt32<4, 1>,
145                               VectorMap<SrcScalarType, VectorShape::Row>> {
146  using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
147  using RegisterBlockType = RegBlockInt32<4, 1>;
148  using ResultBlockType =
149      typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
150                                                SrcObjectType>::Type;
151
152  static ResultBlockType Run(const SrcObjectType& src, int pos) {
153    ResultBlockType result;
154    result.buf.reg[0] = src(pos);
155    return result;
156  }
157};
158
159template <typename SrcScalarType, int N>
160struct LoadForBroadcastingImpl<RegBlockInt32<N, 4>,
161                               VectorMap<SrcScalarType, VectorShape::Row>> {
162  using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
163  using RegisterBlockType = RegBlockInt32<N, 4>;
164  using ResultBlockType =
165      typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
166                                                SrcObjectType>::Type;
167
168  static ResultBlockType Run(const SrcObjectType& src, int pos) {
169    ResultBlockType result;
170    static_assert(ResultBlockType::kRegisterCount == 1, "");
171    result.buf.reg[0] = LoadInt32x4(src.data(pos));
172    return result;
173  }
174};
175
176template <typename SrcScalarType, int N>
177struct LoadForBroadcastingImpl<RegBlockInt32<N, 8>,
178                               VectorMap<SrcScalarType, VectorShape::Row>> {
179  using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>;
180  using RegisterBlockType = RegBlockInt32<N, 8>;
181  using ResultBlockType =
182      typename LoadForBroadcastingRegisterBlock<RegisterBlockType,
183                                                SrcObjectType>::Type;
184
185  static ResultBlockType Run(const SrcObjectType& src, int pos) {
186    ResultBlockType result;
187    static_assert(ResultBlockType::kRegisterCount == 2, "");
188    result.buf.reg[0] = LoadInt32x4(src.data(pos));
189    result.buf.reg[1] = LoadInt32x4(src.data(pos + 4));
190    return result;
191  }
192};
193
194// 4x1 := 4x1 + 1x1
195template <>
196struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
197  static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
198                                 const RegBlockInt32<1, 1>& rhs) {
199    RegBlockInt32<4, 1> result;
200    result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
201    return result;
202  }
203};
204
205// 1x4 := 1x4 + 1x1
206template <>
207struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
208  static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
209                                 const RegBlockInt32<1, 1>& rhs) {
210    RegBlockInt32<1, 4> result;
211    result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
212    return result;
213  }
214};
215
216// 4x1 := 4x1 + 4x1
217template <>
218struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
219  static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
220                                 const RegBlockInt32<4, 1>& rhs) {
221    RegBlockInt32<4, 1> result;
222    result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
223    return result;
224  }
225};
226
227// 1x4 := 1x4 + 1x4
228template <>
229struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
230  static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
231                                 const RegBlockInt32<1, 4>& rhs) {
232    RegBlockInt32<1, 4> result;
233    result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
234    return result;
235  }
236};
237
238// 4x4 := 4x4 + 1x4
239template <>
240struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
241  static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
242                                 const RegBlockInt32<1, 4>& rhs) {
243    RegBlockInt32<4, 4> result;
244    result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
245    result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0]));
246    result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0]));
247    result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0]));
248    return result;
249  }
250};
251
252// 4x4 := 4x4 + 4x1
253template <>
254struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
255  static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
256                                 const RegBlockInt32<4, 1>& rhs) {
257    RegBlockInt32<4, 4> result;
258    result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
259    result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[0]);
260    result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]);
261    result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[0]);
262    return result;
263  }
264};
265
266// 8x1 := 8x1 + 1x1
267template <>
268struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
269  static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
270                                 const RegBlockInt32<1, 1>& rhs) {
271    RegBlockInt32<8, 1> result;
272    const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]);
273    for (int i = 0; i < 2; i++) {
274      result.buf.reg[i] = Add(lhs.buf.reg[i], p);
275    }
276    return result;
277  }
278};
279
280// 8x1 := 8x1 + 8x1
281template <>
282struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
283  static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
284                                 const RegBlockInt32<8, 1>& rhs) {
285    RegBlockInt32<8, 1> result;
286    for (int i = 0; i < 2; i++) {
287      result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]);
288    }
289    return result;
290  }
291};
292
293// 8x4 := 8x4 + 1x4
294template <>
295struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
296  static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
297                                 const RegBlockInt32<1, 4>& rhs) {
298    RegBlockInt32<8, 4> result;
299    result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0]));
300    result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0]));
301    result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0]));
302    result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0]));
303    result.buf.reg[4] = Add(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0]));
304    result.buf.reg[5] = Add(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0]));
305    result.buf.reg[6] = Add(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0]));
306    result.buf.reg[7] = Add(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0]));
307    return result;
308  }
309};
310
311// 8x4 := 8x4 + 8x1
312template <>
313struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
314  static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
315                                 const RegBlockInt32<8, 1>& rhs) {
316    RegBlockInt32<8, 4> result;
317    result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
318    result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]);
319    result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]);
320    result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[1]);
321    result.buf.reg[4] = Add(lhs.buf.reg[4], rhs.buf.reg[0]);
322    result.buf.reg[5] = Add(lhs.buf.reg[5], rhs.buf.reg[1]);
323    result.buf.reg[6] = Add(lhs.buf.reg[6], rhs.buf.reg[0]);
324    result.buf.reg[7] = Add(lhs.buf.reg[7], rhs.buf.reg[1]);
325    return result;
326  }
327};
328
329// 1x8 := 1x8 + 1x8
330template <>
331struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> {
332  static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
333                                 const RegBlockInt32<1, 8>& rhs) {
334    RegBlockInt32<1, 8> result;
335    result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]);
336    result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]);
337    return result;
338  }
339};
340
341// 1x8 := 1x8 + 1x1
342template <>
343struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> {
344  static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs,
345                                 const RegBlockInt32<1, 1>& rhs) {
346    RegBlockInt32<1, 8> result;
347    result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
348    result.buf.reg[1] = Add(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0]));
349    return result;
350  }
351};
352
353// 4x1 := 4x1 * 1x1
354template <>
355struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> {
356  static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
357                                 const RegBlockInt32<1, 1>& rhs) {
358    RegBlockInt32<4, 1> result;
359    result.buf.reg[0] = Mul(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0]));
360    return result;
361  }
362};
363
364// 4x1 := 4x1 * 4x1
365template <>
366struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> {
367  static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs,
368                                 const RegBlockInt32<4, 1>& rhs) {
369    RegBlockInt32<4, 1> result;
370    result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
371    return result;
372  }
373};
374
375// 1x4 := 1x4 * 1x4
376template <>
377struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> {
378  static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
379                                 const RegBlockInt32<1, 4>& rhs) {
380    RegBlockInt32<1, 4> result;
381    result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
382    return result;
383  }
384};
385
386// 1x4 := 1x4 * 1x1
387template <>
388struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> {
389  static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs,
390                                 const RegBlockInt32<1, 1>& rhs) {
391    RegBlockInt32<1, 4> result;
392    result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
393    return result;
394  }
395};
396
397// 4x4 := 4x4 * 1x4
398template <>
399struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> {
400  static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
401                                 const RegBlockInt32<1, 4>& rhs) {
402    RegBlockInt32<4, 4> result;
403    const Int32x4 p = rhs.buf.reg[0];
404    result.buf.reg[0] = MulByRhsLane<0>(lhs.buf.reg[0], p);
405    result.buf.reg[1] = MulByRhsLane<1>(lhs.buf.reg[1], p);
406    result.buf.reg[2] = MulByRhsLane<2>(lhs.buf.reg[2], p);
407    result.buf.reg[3] = MulByRhsLane<3>(lhs.buf.reg[3], p);
408    return result;
409  }
410};
411
412// 4x4 := 4x4 * 4x1
413template <>
414struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> {
415  static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs,
416                                 const RegBlockInt32<4, 1>& rhs) {
417    RegBlockInt32<4, 4> result;
418    const Int32x4 p = rhs.buf.reg[0];
419    result.buf.reg[0] = Mul(lhs.buf.reg[0], p);
420    result.buf.reg[1] = Mul(lhs.buf.reg[1], p);
421    result.buf.reg[2] = Mul(lhs.buf.reg[2], p);
422    result.buf.reg[3] = Mul(lhs.buf.reg[3], p);
423    return result;
424  }
425};
426
427// 8x1 := 8x1 * 1x1
428template <>
429struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> {
430  static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
431                                 const RegBlockInt32<1, 1>& rhs) {
432    RegBlockInt32<8, 1> result;
433    const std::int32_t p = rhs.buf.reg[0];
434    for (int i = 0; i < 2; i++) {
435      result.buf.reg[i] = Mul(lhs.buf.reg[i], p);
436    }
437    return result;
438  }
439};
440
441// 8x1 := 8x1 * 8x1
442template <>
443struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> {
444  static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs,
445                                 const RegBlockInt32<8, 1>& rhs) {
446    RegBlockInt32<8, 1> result;
447    for (int i = 0; i < 2; i++) {
448      result.buf.reg[i] = Mul(lhs.buf.reg[i], rhs.buf.reg[i]);
449    }
450    return result;
451  }
452};
453
454// 8x4 := 8x4 * 1x4
455template <>
456struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> {
457  static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
458                                 const RegBlockInt32<1, 4>& rhs) {
459    RegBlockInt32<8, 4> result;
460    const Int32x4 p = rhs.buf.reg[0];
461    for (int i = 0; i < 2; i++) {
462      result.buf.reg[i + 0] = MulByRhsLane<0>(lhs.buf.reg[i + 0], p);
463      result.buf.reg[i + 2] = MulByRhsLane<1>(lhs.buf.reg[i + 2], p);
464      result.buf.reg[i + 4] = MulByRhsLane<2>(lhs.buf.reg[i + 4], p);
465      result.buf.reg[i + 6] = MulByRhsLane<3>(lhs.buf.reg[i + 6], p);
466    }
467    return result;
468  }
469};
470
471// 8x4 := 8x4 * 8x1
472template <>
473struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> {
474  static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs,
475                                 const RegBlockInt32<8, 1>& rhs) {
476    RegBlockInt32<8, 4> result;
477    const Int32x4 p[2]{rhs.buf.reg[0], rhs.buf.reg[1]};
478    for (int i = 0; i < 4; i++) {
479      for (int j = 0; j < 2; j++) {
480        const int k = j + 2 * i;
481        result.buf.reg[k] = Mul(lhs.buf.reg[k], p[j]);
482      }
483    }
484    return result;
485  }
486};
487
488// Rx1 += Rx1 * 1x1
489template <int Rows>
490struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>,
491                           RegBlockInt32<Rows, 1>> {
492  static void Run(const RegBlockInt32<Rows, 1>& lhs,
493                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 1>* acc) {
494    const std::int32_t p = rhs.buf.reg[0];
495    for (int i = 0; i < RegBlockInt32<Rows, 1>::kRegisterCount; i++) {
496      MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]);
497    }
498  }
499};
500
501// RxC += Rx1 * 1x1
502template <int Rows, int Cols>
503struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>,
504                           RegBlockInt32<Rows, Cols>> {
505  static void Run(const RegBlockInt32<Rows, 1>& lhs,
506                  const RegBlockInt32<1, 1>& rhs,
507                  RegBlockInt32<Rows, Cols>* acc) {
508    const std::int32_t p = rhs.buf.reg[0];
509    static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
510    for (int i = 0; i < kRegsPerCol; i++) {
511      const Int32x4 q = Mul(lhs.buf.reg[i], p);
512      for (int j = 0; j < Cols; j++) {
513        acc->buf.reg[i + j * kRegsPerCol] =
514            Add(acc->buf.reg[i + j * kRegsPerCol], q);
515      }
516    }
517  }
518};
519
520// 1xC += 1xC * 1x1
521template <int Cols>
522struct BroadcastMulAddImpl<RegBlockInt32<1, Cols>, RegBlockInt32<1, 1>,
523                           RegBlockInt32<1, Cols>> {
524  static void Run(const RegBlockInt32<1, Cols>& lhs,
525                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) {
526    const std::int32_t p = rhs.buf.reg[0];
527    for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) {
528      MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]);
529    }
530  }
531};
532
533// RxC += 1x1 * 1x1
534template <int Rows, int Cols>
535struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
536                           RegBlockInt32<Rows, Cols>> {
537  static void Run(const RegBlockInt32<1, 1>& lhs,
538                  const RegBlockInt32<1, 1>& rhs,
539                  RegBlockInt32<Rows, Cols>* acc) {
540    const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0]));
541    for (int i = 0; i < RegBlockInt32<Rows, Cols>::kRegisterCount; i++) {
542      acc->buf.reg[i] = Add(acc->buf.reg[i], p);
543    }
544  }
545};
546
547// 1x1 += 1x1 * 1x1
548template <>
549struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
550                           RegBlockInt32<1, 1>> {
551  static void Run(const RegBlockInt32<1, 1>& lhs,
552                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 1>* acc) {
553    MulAdd(lhs.buf.reg[0], rhs.buf.reg[0], &acc->buf.reg[0]);
554  }
555};
556
557// Rx4 += Rx1 * 1x4
558template <int Rows>
559struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 4>,
560                           RegBlockInt32<Rows, 4>> {
561  static void Run(const RegBlockInt32<Rows, 1>& lhs,
562                  const RegBlockInt32<1, 4>& rhs, RegBlockInt32<Rows, 4>* acc) {
563    const Int32x4 p = rhs.buf.reg[0];
564    static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
565    for (int i = 0; i < kRegsPerCol; i++) {
566      MulAddByRhsLane<0>(lhs.buf.reg[i], p, &acc->buf.reg[i + 0 * kRegsPerCol]);
567      MulAddByRhsLane<1>(lhs.buf.reg[i], p, &acc->buf.reg[i + 1 * kRegsPerCol]);
568      MulAddByRhsLane<2>(lhs.buf.reg[i], p, &acc->buf.reg[i + 2 * kRegsPerCol]);
569      MulAddByRhsLane<3>(lhs.buf.reg[i], p, &acc->buf.reg[i + 3 * kRegsPerCol]);
570    }
571  }
572};
573
574// Rx4 += 1x4 * 1x1
575template <int Rows>
576struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>,
577                           RegBlockInt32<Rows, 4>> {
578  static void Run(const RegBlockInt32<1, 4>& lhs,
579                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 4>* acc) {
580    const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
581    Int32x4 q[4];
582    q[0] = DupLane<0>(p);
583    q[1] = DupLane<1>(p);
584    q[2] = DupLane<2>(p);
585    q[3] = DupLane<3>(p);
586    static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount;
587    for (int i = 0; i < kRegsPerCol; i++) {
588      for (int j = 0; j < 4; j++) {
589        acc->buf.reg[i + j * kRegsPerCol] =
590            Add(q[j], acc->buf.reg[i + j * kRegsPerCol]);
591      }
592    }
593  }
594};
595
596// 1xC += 1x1 * 1x1
597template <int Cols>
598struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>,
599                           RegBlockInt32<1, Cols>> {
600  static void Run(const RegBlockInt32<1, 1>& lhs,
601                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) {
602    const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0]));
603    for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) {
604      acc->buf.reg[i] = Add(acc->buf.reg[i], p);
605    }
606  }
607};
608
609// 1x4 += 1x4 * 1x1
610template <>
611struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>,
612                           RegBlockInt32<1, 4>> {
613  static void Run(const RegBlockInt32<1, 4>& lhs,
614                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 4>* acc) {
615    const std::int32_t p = rhs.buf.reg[0];
616    MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]);
617  }
618};
619
620// 4xC += 4x1 * 1x1
621template <int Cols>
622struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>,
623                           RegBlockInt32<4, Cols>> {
624  static void Run(const RegBlockInt32<4, 1>& lhs,
625                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, Cols>* acc) {
626    const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]);
627    for (int i = 0; i < Cols; i++) {
628      acc->buf.reg[i] = Add(p, acc->buf.reg[i]);
629    }
630  }
631};
632
633// 4x1 += 4x1 * 1x1
634template <>
635struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>,
636                           RegBlockInt32<4, 1>> {
637  static void Run(const RegBlockInt32<4, 1>& lhs,
638                  const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, 1>* acc) {
639    const std::int32_t p = rhs.buf.reg[0];
640    MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]);
641  }
642};
643
644}  // namespace gemmlowp
645
646#endif  // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_
647