1// Copyright 2015 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// multi_thread_gemm.h: Entry point to the multithreaded version of the
16// generated (meta) gemm library.
17
18#ifndef GEMMLOWP_META_MULTI_THREAD_GEMM_H_
19#define GEMMLOWP_META_MULTI_THREAD_GEMM_H_
20
21#ifdef GEMMLOWP_NEON_32
22
23#include "multi_thread_common.h"
24#include "single_thread_gemm.h"
25
26namespace gemmlowp {
27namespace meta {
28namespace internal {
29
30const std::int32_t kMaxCacheFriendlySize = 24 * 1024;
31
32template <typename IN_TYPE, typename OUT_TYPE, typename F>
33void CacheFriendlyMatrixMatrix(std::uint8_t* scratch, const IN_TYPE* lhs,
34                               const IN_TYPE* rhs, std::int32_t m,
35                               std::int32_t n, std::int32_t k, OUT_TYPE* result,
36                               std::int32_t result_stride, const F& operation) {
37  const std::int32_t rhs_size = n * k * sizeof(IN_TYPE);
38  if (rhs_size > kMaxCacheFriendlySize) {
39    const std::int32_t optimal_n =
40        std::max(1, 3 * (kMaxCacheFriendlySize / (k * 3)));
41    const std::int32_t chunks_count_less_one = n / optimal_n - 1;
42    const std::int32_t chunk_size = optimal_n * k;
43    for (int i = 0; i < chunks_count_less_one; ++i) {
44      operation.ExecuteCacheFriendlyMatrixMatrix(
45          scratch, lhs, rhs + i * chunk_size, m, optimal_n, k,
46          result + i * optimal_n, result_stride);
47    }
48    const std::int32_t n_left = n - chunks_count_less_one * optimal_n;
49    operation.ExecuteCacheFriendlyMatrixMatrix(
50        scratch, lhs, rhs + chunks_count_less_one * chunk_size, m, n_left, k,
51        result + chunks_count_less_one * optimal_n, result_stride);
52  } else {
53    operation.ExecuteCacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k,
54                                               result, result_stride);
55  }
56}
57
58class GemmQuantized8BitOperation {
59 public:
60  GemmQuantized8BitOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
61                             std::int32_t sum_offset, std::int32_t multiplier,
62                             std::int32_t shift)
63      : lhs_offset(lhs_offset),
64        rhs_offset(rhs_offset),
65        sum_offset(sum_offset),
66        multiplier(multiplier),
67        shift(shift) {}
68
69  void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
70                           const std::uint8_t* rhs, std::int32_t m,
71                           std::int32_t n, std::int32_t k, std::uint8_t* result,
72                           std::int32_t result_stride) const {
73    CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride,
74                              *this);
75  }
76
77  void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch,
78                                        const std::uint8_t* lhs,
79                                        const std::uint8_t* rhs, std::int32_t m,
80                                        std::int32_t n, std::int32_t k,
81                                        std::uint8_t* result,
82                                        std::int32_t result_stride) const {
83    gemm_q8_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
84                    sum_offset, multiplier, shift, result, result_stride);
85  }
86
87  static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
88                                       std::int32_t k) {
89    return 128 * 1024;
90  }
91
92 private:
93  std::int32_t lhs_offset;
94  std::int32_t rhs_offset;
95  std::int32_t sum_offset;
96  std::int32_t multiplier;
97  std::int32_t shift;
98};
99
100class GemmFloatOperation {
101 public:
102  GemmFloatOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
103                     float result_offset)
104      : lhs_offset(lhs_offset),
105        rhs_offset(rhs_offset),
106        result_offset(result_offset) {}
107
108  void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
109                           const std::uint8_t* rhs, std::int32_t m,
110                           std::int32_t n, std::int32_t k, float* result,
111                           std::int32_t result_stride) const {
112    CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride,
113                              *this);
114  }
115
116  void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch,
117                                        const std::uint8_t* lhs,
118                                        const std::uint8_t* rhs, std::int32_t m,
119                                        std::int32_t n, std::int32_t k,
120                                        float* result,
121                                        std::int32_t result_stride) const {
122    gemm_f_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
123                   result_offset, result, result_stride);
124  }
125
126  static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
127                                       std::int32_t k) {
128    return 128 * 1024;
129  }
130
131 private:
132  std::int32_t lhs_offset;
133  std::int32_t rhs_offset;
134  float result_offset;
135};
136
137class GemmInt32Operation {
138 public:
139  GemmInt32Operation(std::int32_t lhs_offset, std::int32_t rhs_offset)
140      : lhs_offset(lhs_offset), rhs_offset(rhs_offset) {}
141
142  void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
143                           const std::uint8_t* rhs, std::int32_t m,
144                           std::int32_t n, std::int32_t k, std::int32_t* result,
145                           std::int32_t result_stride) const {
146    CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride,
147                              *this);
148  }
149
150  void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch,
151                                        const std::uint8_t* lhs,
152                                        const std::uint8_t* rhs, std::int32_t m,
153                                        std::int32_t n, std::int32_t k,
154                                        std::int32_t* result,
155                                        std::int32_t result_stride) const {
156    gemm_i32_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset, result,
157                     result_stride);
158  }
159
160  static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
161                                       std::int32_t k) {
162    return 128 * 1024;
163  }
164
165 private:
166  std::int32_t lhs_offset;
167  std::int32_t rhs_offset;
168};
169
170}  // namespace internal
171
172std::int32_t gemm_q8_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
173                             std::int32_t max_threads) {
174  return internal::ResolveMaxThreads(max_threads) *
175         internal::GemmQuantized8BitOperation::ScratchPerThread(m, n, k);
176}
177
178void multi_thread_gemm_q8(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
179                          std::uint8_t* scratch, const std::uint8_t* lhs,
180                          const std::uint8_t* rhs, std::int32_t m,
181                          std::int32_t n, std::int32_t k,
182                          std::int32_t lhs_offset, std::int32_t rhs_offset,
183                          std::int32_t sum_offset, std::int32_t multiplier,
184                          std::int32_t shift, std::uint8_t* result) {
185  internal::GemmQuantized8BitOperation operation(lhs_offset, rhs_offset,
186                                                 sum_offset, multiplier, shift);
187  internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m,
188                                      n, k, result, n, operation);
189}
190
191std::int32_t gemm_f_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
192                            std::int32_t max_threads) {
193  return internal::ResolveMaxThreads(max_threads) *
194         internal::GemmFloatOperation::ScratchPerThread(m, n, k);
195}
196
197void multi_thread_gemm_f(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
198                         std::uint8_t* scratch, const std::uint8_t* lhs,
199                         const std::uint8_t* rhs, std::int32_t m,
200                         std::int32_t n, std::int32_t k,
201                         std::int32_t lhs_offset, std::int32_t rhs_offset,
202                         float result_offset, float* result) {
203  internal::GemmFloatOperation operation(lhs_offset, rhs_offset, result_offset);
204  internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m,
205                                      n, k, result, n, operation);
206}
207
208std::int32_t gemm_i32_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
209                              std::int32_t max_threads) {
210  return internal::ResolveMaxThreads(max_threads) *
211         internal::GemmInt32Operation::ScratchPerThread(m, n, k);
212}
213
214void multi_thread_gemm_i32(gemmlowp::WorkersPool* pool,
215                           std::int32_t max_threads, std::uint8_t* scratch,
216                           const std::uint8_t* lhs, const std::uint8_t* rhs,
217                           std::int32_t m, std::int32_t n, std::int32_t k,
218                           std::int32_t lhs_offset, std::int32_t rhs_offset,
219                           std::int32_t* result) {
220  internal::GemmInt32Operation operation(lhs_offset, rhs_offset);
221  internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m,
222                                      n, k, result, n, operation);
223}
224
225}  // namespace meta
226}  // namespace gemmlowp
227
228#else
229#warning "Meta gemm fast-path requires GEMMLOWP_NEON_32!"
230#endif
231
232#endif  // GEMMLOWP_META_MULTI_THREAD_GEMM_H_
233