1// Copyright 2015 Google Inc. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// simd_wrappers_common_neon_sse.h: common SIMD (NEON and SSE) wrapper code 16 17#ifndef GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_ 18#define GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_ 19 20#include "simd_wrappers.h" 21 22namespace gemmlowp { 23 24template <typename SrcScalarType, int N> 25struct LoadImpl<RegBlockInt32<4, N>, 26 MatrixMap<SrcScalarType, MapOrder::ColMajor>> { 27 static RegBlockInt32<4, N> Run( 28 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, 29 int col) { 30 RegBlockInt32<4, N> result; 31 for (int i = 0; i < N; i++) { 32 result.buf.reg[i] = LoadInt32x4(src.data(row, col + i)); 33 } 34 return result; 35 } 36}; 37 38template <typename SrcScalarType, int N> 39struct LoadImpl<RegBlockInt32<8, N>, 40 MatrixMap<SrcScalarType, MapOrder::ColMajor>> { 41 static RegBlockInt32<8, N> Run( 42 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, 43 int col) { 44 RegBlockInt32<8, N> result; 45 for (int i = 0; i < N; i++) { 46 result.buf.reg[2 * i + 0] = LoadInt32x4(src.data(row + 0, col + i)); 47 result.buf.reg[2 * i + 1] = LoadInt32x4(src.data(row + 4, col + i)); 48 } 49 return result; 50 } 51}; 52 53template <typename SrcScalarType> 54struct LoadImpl<RegBlockInt32<1, 4>, 55 MatrixMap<SrcScalarType, MapOrder::ColMajor>> { 56 static RegBlockInt32<1, 4> Run( 57 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, 58 int col) { 59 RegBlockInt32<1, 4> result; 60 std::int32_t buf[4]; 61 for (int i = 0; i < 4; i++) { 62 buf[i] = src(row, col + i); 63 } 64 result.buf.reg[0] = LoadInt32x4(buf); 65 return result; 66 } 67}; 68 69template <typename SrcScalarType> 70struct LoadImpl<RegBlockInt32<1, 8>, 71 MatrixMap<SrcScalarType, MapOrder::ColMajor>> { 72 static RegBlockInt32<1, 8> Run( 73 const MatrixMap<SrcScalarType, MapOrder::ColMajor>& src, int row, 74 int col) { 75 RegBlockInt32<1, 8> result; 76 std::int32_t buf[8]; 77 for (int i = 0; i < 8; i++) { 78 buf[i] = src(row, col + i); 79 } 80 result.buf.reg[0] = LoadInt32x4(buf); 81 result.buf.reg[1] = LoadInt32x4(buf + 4); 82 return result; 83 } 84}; 85 86template <typename SrcScalarType> 87struct LoadImpl<RegBlockInt32<4, 1>, 88 VectorMap<SrcScalarType, VectorShape::Col>> { 89 static RegBlockInt32<4, 1> Run( 90 const VectorMap<SrcScalarType, VectorShape::Col>& src, int pos) { 91 RegBlockInt32<4, 1> result; 92 result.buf.reg[0] = LoadInt32x4(src.data(pos)); 93 return result; 94 } 95}; 96 97template <typename SrcScalarType> 98struct LoadImpl<RegBlockInt32<4, 1>, 99 VectorDup<SrcScalarType, VectorShape::Col>> { 100 static RegBlockInt32<4, 1> Run( 101 const VectorDup<SrcScalarType, VectorShape::Col>& src, int) { 102 RegBlockInt32<4, 1> result; 103 result.buf.reg[0] = LoadInt32x4(src(0)); 104 return result; 105 } 106}; 107 108template <typename SrcScalarType, int N> 109struct LoadForBroadcastingImpl<RegBlockInt32<4, N>, 110 VectorMap<SrcScalarType, VectorShape::Col>> { 111 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>; 112 using RegisterBlockType = RegBlockInt32<4, N>; 113 using ResultBlockType = 114 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 115 SrcObjectType>::Type; 116 117 static ResultBlockType Run(const SrcObjectType& src, int pos) { 118 ResultBlockType result; 119 static_assert(ResultBlockType::kRegisterCount == 1, ""); 120 result.buf.reg[0] = LoadInt32x4(src.data(pos)); 121 return result; 122 } 123}; 124 125template <typename SrcScalarType, int N> 126struct LoadForBroadcastingImpl<RegBlockInt32<8, N>, 127 VectorMap<SrcScalarType, VectorShape::Col>> { 128 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Col>; 129 using RegisterBlockType = RegBlockInt32<8, N>; 130 using ResultBlockType = 131 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 132 SrcObjectType>::Type; 133 134 static ResultBlockType Run(const SrcObjectType& src, int pos) { 135 ResultBlockType result; 136 static_assert(ResultBlockType::kRegisterCount == 2, ""); 137 result.buf.reg[0] = LoadInt32x4(src.data(pos)); 138 result.buf.reg[1] = LoadInt32x4(src.data(pos + 4)); 139 return result; 140 } 141}; 142 143template <typename SrcScalarType> 144struct LoadForBroadcastingImpl<RegBlockInt32<4, 1>, 145 VectorMap<SrcScalarType, VectorShape::Row>> { 146 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>; 147 using RegisterBlockType = RegBlockInt32<4, 1>; 148 using ResultBlockType = 149 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 150 SrcObjectType>::Type; 151 152 static ResultBlockType Run(const SrcObjectType& src, int pos) { 153 ResultBlockType result; 154 result.buf.reg[0] = src(pos); 155 return result; 156 } 157}; 158 159template <typename SrcScalarType, int N> 160struct LoadForBroadcastingImpl<RegBlockInt32<N, 4>, 161 VectorMap<SrcScalarType, VectorShape::Row>> { 162 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>; 163 using RegisterBlockType = RegBlockInt32<N, 4>; 164 using ResultBlockType = 165 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 166 SrcObjectType>::Type; 167 168 static ResultBlockType Run(const SrcObjectType& src, int pos) { 169 ResultBlockType result; 170 static_assert(ResultBlockType::kRegisterCount == 1, ""); 171 result.buf.reg[0] = LoadInt32x4(src.data(pos)); 172 return result; 173 } 174}; 175 176template <typename SrcScalarType, int N> 177struct LoadForBroadcastingImpl<RegBlockInt32<N, 8>, 178 VectorMap<SrcScalarType, VectorShape::Row>> { 179 using SrcObjectType = VectorMap<SrcScalarType, VectorShape::Row>; 180 using RegisterBlockType = RegBlockInt32<N, 8>; 181 using ResultBlockType = 182 typename LoadForBroadcastingRegisterBlock<RegisterBlockType, 183 SrcObjectType>::Type; 184 185 static ResultBlockType Run(const SrcObjectType& src, int pos) { 186 ResultBlockType result; 187 static_assert(ResultBlockType::kRegisterCount == 2, ""); 188 result.buf.reg[0] = LoadInt32x4(src.data(pos)); 189 result.buf.reg[1] = LoadInt32x4(src.data(pos + 4)); 190 return result; 191 } 192}; 193 194// 4x1 := 4x1 + 1x1 195template <> 196struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> { 197 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 198 const RegBlockInt32<1, 1>& rhs) { 199 RegBlockInt32<4, 1> result; 200 result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 201 return result; 202 } 203}; 204 205// 1x4 := 1x4 + 1x1 206template <> 207struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> { 208 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 209 const RegBlockInt32<1, 1>& rhs) { 210 RegBlockInt32<1, 4> result; 211 result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 212 return result; 213 } 214}; 215 216// 4x1 := 4x1 + 4x1 217template <> 218struct BroadcastAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> { 219 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 220 const RegBlockInt32<4, 1>& rhs) { 221 RegBlockInt32<4, 1> result; 222 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); 223 return result; 224 } 225}; 226 227// 1x4 := 1x4 + 1x4 228template <> 229struct BroadcastAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> { 230 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 231 const RegBlockInt32<1, 4>& rhs) { 232 RegBlockInt32<1, 4> result; 233 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); 234 return result; 235 } 236}; 237 238// 4x4 := 4x4 + 1x4 239template <> 240struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> { 241 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 242 const RegBlockInt32<1, 4>& rhs) { 243 RegBlockInt32<4, 4> result; 244 result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); 245 result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<1>(rhs.buf.reg[0])); 246 result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<2>(rhs.buf.reg[0])); 247 result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<3>(rhs.buf.reg[0])); 248 return result; 249 } 250}; 251 252// 4x4 := 4x4 + 4x1 253template <> 254struct BroadcastAddImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> { 255 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 256 const RegBlockInt32<4, 1>& rhs) { 257 RegBlockInt32<4, 4> result; 258 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); 259 result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[0]); 260 result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]); 261 result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[0]); 262 return result; 263 } 264}; 265 266// 8x1 := 8x1 + 1x1 267template <> 268struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> { 269 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 270 const RegBlockInt32<1, 1>& rhs) { 271 RegBlockInt32<8, 1> result; 272 const Int32x4 p = Dup<Int32x4>(rhs.buf.reg[0]); 273 for (int i = 0; i < 2; i++) { 274 result.buf.reg[i] = Add(lhs.buf.reg[i], p); 275 } 276 return result; 277 } 278}; 279 280// 8x1 := 8x1 + 8x1 281template <> 282struct BroadcastAddImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> { 283 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 284 const RegBlockInt32<8, 1>& rhs) { 285 RegBlockInt32<8, 1> result; 286 for (int i = 0; i < 2; i++) { 287 result.buf.reg[i] = Add(lhs.buf.reg[i], rhs.buf.reg[i]); 288 } 289 return result; 290 } 291}; 292 293// 8x4 := 8x4 + 1x4 294template <> 295struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> { 296 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 297 const RegBlockInt32<1, 4>& rhs) { 298 RegBlockInt32<8, 4> result; 299 result.buf.reg[0] = Add(lhs.buf.reg[0], DupLane<0>(rhs.buf.reg[0])); 300 result.buf.reg[1] = Add(lhs.buf.reg[1], DupLane<0>(rhs.buf.reg[0])); 301 result.buf.reg[2] = Add(lhs.buf.reg[2], DupLane<1>(rhs.buf.reg[0])); 302 result.buf.reg[3] = Add(lhs.buf.reg[3], DupLane<1>(rhs.buf.reg[0])); 303 result.buf.reg[4] = Add(lhs.buf.reg[4], DupLane<2>(rhs.buf.reg[0])); 304 result.buf.reg[5] = Add(lhs.buf.reg[5], DupLane<2>(rhs.buf.reg[0])); 305 result.buf.reg[6] = Add(lhs.buf.reg[6], DupLane<3>(rhs.buf.reg[0])); 306 result.buf.reg[7] = Add(lhs.buf.reg[7], DupLane<3>(rhs.buf.reg[0])); 307 return result; 308 } 309}; 310 311// 8x4 := 8x4 + 8x1 312template <> 313struct BroadcastAddImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> { 314 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 315 const RegBlockInt32<8, 1>& rhs) { 316 RegBlockInt32<8, 4> result; 317 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); 318 result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]); 319 result.buf.reg[2] = Add(lhs.buf.reg[2], rhs.buf.reg[0]); 320 result.buf.reg[3] = Add(lhs.buf.reg[3], rhs.buf.reg[1]); 321 result.buf.reg[4] = Add(lhs.buf.reg[4], rhs.buf.reg[0]); 322 result.buf.reg[5] = Add(lhs.buf.reg[5], rhs.buf.reg[1]); 323 result.buf.reg[6] = Add(lhs.buf.reg[6], rhs.buf.reg[0]); 324 result.buf.reg[7] = Add(lhs.buf.reg[7], rhs.buf.reg[1]); 325 return result; 326 } 327}; 328 329// 1x8 := 1x8 + 1x8 330template <> 331struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 8>> { 332 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, 333 const RegBlockInt32<1, 8>& rhs) { 334 RegBlockInt32<1, 8> result; 335 result.buf.reg[0] = Add(lhs.buf.reg[0], rhs.buf.reg[0]); 336 result.buf.reg[1] = Add(lhs.buf.reg[1], rhs.buf.reg[1]); 337 return result; 338 } 339}; 340 341// 1x8 := 1x8 + 1x1 342template <> 343struct BroadcastAddImpl<RegBlockInt32<1, 8>, RegBlockInt32<1, 1>> { 344 static RegBlockInt32<1, 8> Run(const RegBlockInt32<1, 8>& lhs, 345 const RegBlockInt32<1, 1>& rhs) { 346 RegBlockInt32<1, 8> result; 347 result.buf.reg[0] = Add(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 348 result.buf.reg[1] = Add(lhs.buf.reg[1], Dup<Int32x4>(rhs.buf.reg[0])); 349 return result; 350 } 351}; 352 353// 4x1 := 4x1 * 1x1 354template <> 355struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>> { 356 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 357 const RegBlockInt32<1, 1>& rhs) { 358 RegBlockInt32<4, 1> result; 359 result.buf.reg[0] = Mul(lhs.buf.reg[0], Dup<Int32x4>(rhs.buf.reg[0])); 360 return result; 361 } 362}; 363 364// 4x1 := 4x1 * 4x1 365template <> 366struct BroadcastMulImpl<RegBlockInt32<4, 1>, RegBlockInt32<4, 1>> { 367 static RegBlockInt32<4, 1> Run(const RegBlockInt32<4, 1>& lhs, 368 const RegBlockInt32<4, 1>& rhs) { 369 RegBlockInt32<4, 1> result; 370 result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); 371 return result; 372 } 373}; 374 375// 1x4 := 1x4 * 1x4 376template <> 377struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 4>> { 378 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 379 const RegBlockInt32<1, 4>& rhs) { 380 RegBlockInt32<1, 4> result; 381 result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); 382 return result; 383 } 384}; 385 386// 1x4 := 1x4 * 1x1 387template <> 388struct BroadcastMulImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>> { 389 static RegBlockInt32<1, 4> Run(const RegBlockInt32<1, 4>& lhs, 390 const RegBlockInt32<1, 1>& rhs) { 391 RegBlockInt32<1, 4> result; 392 result.buf.reg[0] = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); 393 return result; 394 } 395}; 396 397// 4x4 := 4x4 * 1x4 398template <> 399struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<1, 4>> { 400 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 401 const RegBlockInt32<1, 4>& rhs) { 402 RegBlockInt32<4, 4> result; 403 const Int32x4 p = rhs.buf.reg[0]; 404 result.buf.reg[0] = MulByRhsLane<0>(lhs.buf.reg[0], p); 405 result.buf.reg[1] = MulByRhsLane<1>(lhs.buf.reg[1], p); 406 result.buf.reg[2] = MulByRhsLane<2>(lhs.buf.reg[2], p); 407 result.buf.reg[3] = MulByRhsLane<3>(lhs.buf.reg[3], p); 408 return result; 409 } 410}; 411 412// 4x4 := 4x4 * 4x1 413template <> 414struct BroadcastMulImpl<RegBlockInt32<4, 4>, RegBlockInt32<4, 1>> { 415 static RegBlockInt32<4, 4> Run(const RegBlockInt32<4, 4>& lhs, 416 const RegBlockInt32<4, 1>& rhs) { 417 RegBlockInt32<4, 4> result; 418 const Int32x4 p = rhs.buf.reg[0]; 419 result.buf.reg[0] = Mul(lhs.buf.reg[0], p); 420 result.buf.reg[1] = Mul(lhs.buf.reg[1], p); 421 result.buf.reg[2] = Mul(lhs.buf.reg[2], p); 422 result.buf.reg[3] = Mul(lhs.buf.reg[3], p); 423 return result; 424 } 425}; 426 427// 8x1 := 8x1 * 1x1 428template <> 429struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<1, 1>> { 430 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 431 const RegBlockInt32<1, 1>& rhs) { 432 RegBlockInt32<8, 1> result; 433 const std::int32_t p = rhs.buf.reg[0]; 434 for (int i = 0; i < 2; i++) { 435 result.buf.reg[i] = Mul(lhs.buf.reg[i], p); 436 } 437 return result; 438 } 439}; 440 441// 8x1 := 8x1 * 8x1 442template <> 443struct BroadcastMulImpl<RegBlockInt32<8, 1>, RegBlockInt32<8, 1>> { 444 static RegBlockInt32<8, 1> Run(const RegBlockInt32<8, 1>& lhs, 445 const RegBlockInt32<8, 1>& rhs) { 446 RegBlockInt32<8, 1> result; 447 for (int i = 0; i < 2; i++) { 448 result.buf.reg[i] = Mul(lhs.buf.reg[i], rhs.buf.reg[i]); 449 } 450 return result; 451 } 452}; 453 454// 8x4 := 8x4 * 1x4 455template <> 456struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<1, 4>> { 457 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 458 const RegBlockInt32<1, 4>& rhs) { 459 RegBlockInt32<8, 4> result; 460 const Int32x4 p = rhs.buf.reg[0]; 461 for (int i = 0; i < 2; i++) { 462 result.buf.reg[i + 0] = MulByRhsLane<0>(lhs.buf.reg[i + 0], p); 463 result.buf.reg[i + 2] = MulByRhsLane<1>(lhs.buf.reg[i + 2], p); 464 result.buf.reg[i + 4] = MulByRhsLane<2>(lhs.buf.reg[i + 4], p); 465 result.buf.reg[i + 6] = MulByRhsLane<3>(lhs.buf.reg[i + 6], p); 466 } 467 return result; 468 } 469}; 470 471// 8x4 := 8x4 * 8x1 472template <> 473struct BroadcastMulImpl<RegBlockInt32<8, 4>, RegBlockInt32<8, 1>> { 474 static RegBlockInt32<8, 4> Run(const RegBlockInt32<8, 4>& lhs, 475 const RegBlockInt32<8, 1>& rhs) { 476 RegBlockInt32<8, 4> result; 477 const Int32x4 p[2]{rhs.buf.reg[0], rhs.buf.reg[1]}; 478 for (int i = 0; i < 4; i++) { 479 for (int j = 0; j < 2; j++) { 480 const int k = j + 2 * i; 481 result.buf.reg[k] = Mul(lhs.buf.reg[k], p[j]); 482 } 483 } 484 return result; 485 } 486}; 487 488// Rx1 += Rx1 * 1x1 489template <int Rows> 490struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>, 491 RegBlockInt32<Rows, 1>> { 492 static void Run(const RegBlockInt32<Rows, 1>& lhs, 493 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 1>* acc) { 494 const std::int32_t p = rhs.buf.reg[0]; 495 for (int i = 0; i < RegBlockInt32<Rows, 1>::kRegisterCount; i++) { 496 MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]); 497 } 498 } 499}; 500 501// RxC += Rx1 * 1x1 502template <int Rows, int Cols> 503struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 1>, 504 RegBlockInt32<Rows, Cols>> { 505 static void Run(const RegBlockInt32<Rows, 1>& lhs, 506 const RegBlockInt32<1, 1>& rhs, 507 RegBlockInt32<Rows, Cols>* acc) { 508 const std::int32_t p = rhs.buf.reg[0]; 509 static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount; 510 for (int i = 0; i < kRegsPerCol; i++) { 511 const Int32x4 q = Mul(lhs.buf.reg[i], p); 512 for (int j = 0; j < Cols; j++) { 513 acc->buf.reg[i + j * kRegsPerCol] = 514 Add(acc->buf.reg[i + j * kRegsPerCol], q); 515 } 516 } 517 } 518}; 519 520// 1xC += 1xC * 1x1 521template <int Cols> 522struct BroadcastMulAddImpl<RegBlockInt32<1, Cols>, RegBlockInt32<1, 1>, 523 RegBlockInt32<1, Cols>> { 524 static void Run(const RegBlockInt32<1, Cols>& lhs, 525 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) { 526 const std::int32_t p = rhs.buf.reg[0]; 527 for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) { 528 MulAdd(lhs.buf.reg[i], p, &acc->buf.reg[i]); 529 } 530 } 531}; 532 533// RxC += 1x1 * 1x1 534template <int Rows, int Cols> 535struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>, 536 RegBlockInt32<Rows, Cols>> { 537 static void Run(const RegBlockInt32<1, 1>& lhs, 538 const RegBlockInt32<1, 1>& rhs, 539 RegBlockInt32<Rows, Cols>* acc) { 540 const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0])); 541 for (int i = 0; i < RegBlockInt32<Rows, Cols>::kRegisterCount; i++) { 542 acc->buf.reg[i] = Add(acc->buf.reg[i], p); 543 } 544 } 545}; 546 547// 1x1 += 1x1 * 1x1 548template <> 549struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>, 550 RegBlockInt32<1, 1>> { 551 static void Run(const RegBlockInt32<1, 1>& lhs, 552 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 1>* acc) { 553 MulAdd(lhs.buf.reg[0], rhs.buf.reg[0], &acc->buf.reg[0]); 554 } 555}; 556 557// Rx4 += Rx1 * 1x4 558template <int Rows> 559struct BroadcastMulAddImpl<RegBlockInt32<Rows, 1>, RegBlockInt32<1, 4>, 560 RegBlockInt32<Rows, 4>> { 561 static void Run(const RegBlockInt32<Rows, 1>& lhs, 562 const RegBlockInt32<1, 4>& rhs, RegBlockInt32<Rows, 4>* acc) { 563 const Int32x4 p = rhs.buf.reg[0]; 564 static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount; 565 for (int i = 0; i < kRegsPerCol; i++) { 566 MulAddByRhsLane<0>(lhs.buf.reg[i], p, &acc->buf.reg[i + 0 * kRegsPerCol]); 567 MulAddByRhsLane<1>(lhs.buf.reg[i], p, &acc->buf.reg[i + 1 * kRegsPerCol]); 568 MulAddByRhsLane<2>(lhs.buf.reg[i], p, &acc->buf.reg[i + 2 * kRegsPerCol]); 569 MulAddByRhsLane<3>(lhs.buf.reg[i], p, &acc->buf.reg[i + 3 * kRegsPerCol]); 570 } 571 } 572}; 573 574// Rx4 += 1x4 * 1x1 575template <int Rows> 576struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>, 577 RegBlockInt32<Rows, 4>> { 578 static void Run(const RegBlockInt32<1, 4>& lhs, 579 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<Rows, 4>* acc) { 580 const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); 581 Int32x4 q[4]; 582 q[0] = DupLane<0>(p); 583 q[1] = DupLane<1>(p); 584 q[2] = DupLane<2>(p); 585 q[3] = DupLane<3>(p); 586 static constexpr int kRegsPerCol = RegBlockInt32<Rows, 1>::kRegisterCount; 587 for (int i = 0; i < kRegsPerCol; i++) { 588 for (int j = 0; j < 4; j++) { 589 acc->buf.reg[i + j * kRegsPerCol] = 590 Add(q[j], acc->buf.reg[i + j * kRegsPerCol]); 591 } 592 } 593 } 594}; 595 596// 1xC += 1x1 * 1x1 597template <int Cols> 598struct BroadcastMulAddImpl<RegBlockInt32<1, 1>, RegBlockInt32<1, 1>, 599 RegBlockInt32<1, Cols>> { 600 static void Run(const RegBlockInt32<1, 1>& lhs, 601 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, Cols>* acc) { 602 const Int32x4 p = Dup<Int32x4>(Mul(lhs.buf.reg[0], rhs.buf.reg[0])); 603 for (int i = 0; i < RegBlockInt32<1, Cols>::kRegisterCount; i++) { 604 acc->buf.reg[i] = Add(acc->buf.reg[i], p); 605 } 606 } 607}; 608 609// 1x4 += 1x4 * 1x1 610template <> 611struct BroadcastMulAddImpl<RegBlockInt32<1, 4>, RegBlockInt32<1, 1>, 612 RegBlockInt32<1, 4>> { 613 static void Run(const RegBlockInt32<1, 4>& lhs, 614 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<1, 4>* acc) { 615 const std::int32_t p = rhs.buf.reg[0]; 616 MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]); 617 } 618}; 619 620// 4xC += 4x1 * 1x1 621template <int Cols> 622struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>, 623 RegBlockInt32<4, Cols>> { 624 static void Run(const RegBlockInt32<4, 1>& lhs, 625 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, Cols>* acc) { 626 const Int32x4 p = Mul(lhs.buf.reg[0], rhs.buf.reg[0]); 627 for (int i = 0; i < Cols; i++) { 628 acc->buf.reg[i] = Add(p, acc->buf.reg[i]); 629 } 630 } 631}; 632 633// 4x1 += 4x1 * 1x1 634template <> 635struct BroadcastMulAddImpl<RegBlockInt32<4, 1>, RegBlockInt32<1, 1>, 636 RegBlockInt32<4, 1>> { 637 static void Run(const RegBlockInt32<4, 1>& lhs, 638 const RegBlockInt32<1, 1>& rhs, RegBlockInt32<4, 1>* acc) { 639 const std::int32_t p = rhs.buf.reg[0]; 640 MulAdd(lhs.buf.reg[0], p, &acc->buf.reg[0]); 641 } 642}; 643 644} // namespace gemmlowp 645 646#endif // GEMMLOWP_INTERNAL_SIMD_WRAPPERS_COMMON_NEON_SSE_H_ 647