1a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// Copyright 2015 Google Inc. All Rights Reserved.
2a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang//
3a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// Licensed under the Apache License, Version 2.0 (the "License");
4a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// you may not use this file except in compliance with the License.
5a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// You may obtain a copy of the License at
6a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang//
7a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang//     http://www.apache.org/licenses/LICENSE-2.0
8a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang//
9a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// Unless required by applicable law or agreed to in writing, software
10a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// distributed under the License is distributed on an "AS IS" BASIS,
11a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// See the License for the specific language governing permissions and
13a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// limitations under the License.
14a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
15a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// multi_thread_gemv.h: Entry point to the multithreaded version of the
16a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang// generated (meta) gemv library.
17a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
18a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#ifndef GEMMLOWP_META_MULTI_THREAD_GEMV_H_
19a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#define GEMMLOWP_META_MULTI_THREAD_GEMV_H_
20a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
21a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#ifdef GEMMLOWP_NEON_32
22a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
23a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#include "multi_thread_common.h"
24a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#include "operations_common.h"
25a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#include "single_thread_gemm.h"
26a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
27a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangnamespace gemmlowp {
28a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangnamespace meta {
29a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangnamespace internal {
30a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
31a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangclass GemvQuantized8BitOperation : public Quantized8BitOperation {
32a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang public:
33a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  GemvQuantized8BitOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
34a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                             std::int32_t sum_offset, std::int32_t multiplier,
35a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                             std::int32_t shift)
36a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang      : Quantized8BitOperation(lhs_offset, rhs_offset, sum_offset, multiplier,
37a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                               shift) {}
38a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
39a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
40a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                           const std::uint8_t* rhs, std::int32_t m,
41a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                           std::int32_t n, std::int32_t k, std::uint8_t* result,
42a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                           std::int32_t result_stride) const {
43a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    gemv_q8(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, sum_offset,
44a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang            multiplier, shift, result);
45a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  }
46a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
47a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
48a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                                       std::int32_t k) {
49a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    return 128 * 1024;
50a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  }
51a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang};
52a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
53a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangclass GemvFloatOperation : public FloatOperation {
54a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang public:
55a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  GemvFloatOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
56a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                     float result_offset)
57a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang      : FloatOperation(lhs_offset, rhs_offset, result_offset) {}
58a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
59a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
60a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                           const std::uint8_t* rhs, std::int32_t m,
61a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                           std::int32_t n, std::int32_t k, float* result,
62a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                           std::int32_t result_stride) const {
63a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    gemv_f(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result_offset,
64a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang           result);
65a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  }
66a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
67a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
68a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                                       std::int32_t k) {
69a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    return 128 * 1024;
70a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  }
71a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang};
72a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
73a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangclass GemvInt32Operation : public Int32Operation {
74a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang public:
75a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  GemvInt32Operation(std::int32_t lhs_offset, std::int32_t rhs_offset)
76a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang      : Int32Operation(lhs_offset, rhs_offset) {}
77a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
78a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
79a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                           const std::uint8_t* rhs, std::int32_t m,
80a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                           std::int32_t n, std::int32_t k, std::int32_t* result,
81a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                           std::int32_t result_stride) const {
82a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    gemv_i32(scratch, lhs, rhs, n, k, lhs_offset, rhs_offset, result);
83a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  }
84a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
85a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
86a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                                       std::int32_t k) {
87a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    return 128 * 1024;
88a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  }
89a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang};
90a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
91a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}  // namespace internal
92a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
93a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangstd::int32_t gemv_q8_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
94a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                             std::int32_t max_threads) {
95a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  return internal::ResolveMaxThreads(max_threads) *
96a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang         internal::GemvQuantized8BitOperation::ScratchPerThread(m, n, k);
97a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}
98a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
99a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangvoid multi_thread_gemv_q8(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
100a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                          std::uint8_t* scratch, const std::uint8_t* lhs,
101a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                          const std::uint8_t* rhs, std::int32_t n,
102a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                          std::int32_t k, std::int32_t lhs_offset,
103a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                          std::int32_t rhs_offset, std::int32_t sum_offset,
104a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                          std::int32_t multiplier, std::int32_t shift,
105a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                          std::uint8_t* result) {
106a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  max_threads = internal::ResolveMaxThreads(max_threads);
107a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  internal::GemvQuantized8BitOperation operation(lhs_offset, rhs_offset,
108a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                                                 sum_offset, multiplier, shift);
109a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  if (max_threads == 1) {
110a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
111a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  } else {
112a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
113a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                                        n, k, result, n, operation);
114a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  }
115a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}
116a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
117a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangstd::int32_t gemv_f_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
118a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                            std::int32_t max_threads) {
119a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  return internal::ResolveMaxThreads(max_threads) *
120a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang         internal::GemvFloatOperation::ScratchPerThread(m, n, k);
121a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}
122a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
123a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangvoid multi_thread_gemv_f(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
124a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                         std::uint8_t* scratch, const std::uint8_t* lhs,
125a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                         const std::uint8_t* rhs, std::int32_t n,
126a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                         std::int32_t k, std::int32_t lhs_offset,
127a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                         std::int32_t rhs_offset, float result_offset,
128a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                         float* result) {
129a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  max_threads = internal::ResolveMaxThreads(max_threads);
130a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  internal::GemvFloatOperation operation(lhs_offset, rhs_offset, result_offset);
131a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  if (max_threads == 1) {
132a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
133a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  } else {
134a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
135a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                                        n, k, result, n, operation);
136a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  }
137a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}
138a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
139a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangstd::int32_t gemv_i32_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
140a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                              std::int32_t max_threads) {
141a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  return internal::ResolveMaxThreads(max_threads) *
142a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang         internal::GemvInt32Operation::ScratchPerThread(m, n, k);
143a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}
144a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
145a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wangvoid multi_thread_gemv_i32(gemmlowp::WorkersPool* pool,
146a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                           std::int32_t max_threads, std::uint8_t* scratch,
147a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                           const std::uint8_t* lhs, const std::uint8_t* rhs,
148a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                           std::int32_t n, std::int32_t k,
149a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                           std::int32_t lhs_offset, std::int32_t rhs_offset,
150a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                           std::int32_t* result) {
151a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  max_threads = internal::ResolveMaxThreads(max_threads);
152a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  internal::GemvInt32Operation operation(lhs_offset, rhs_offset);
153a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  if (max_threads == 1) {
154a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    operation.ExecuteMatrixMatrix(scratch, lhs, rhs, 1, n, k, result, n);
155a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  } else {
156a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang    internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, 1,
157a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang                                        n, k, result, n, operation);
158a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang  }
159a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}
160a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
161a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}  // namespace meta
162a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang}  // namespace gemmlowp
163a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
164a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#else
165a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#warning "Meta gemm fast-path requires GEMMLOWP_NEON_32!"
166a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#endif
167a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang
168a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1Miao Wang#endif  // GEMMLOWP_META_MULTI_THREAD_GEMV_H_
169