1// Copyright 2015 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "eight_bit_int_gemm.h"
16
17#include <memory>
18
19// gemmlowp symbols should have hidden visibility.
20// currently this is ensured in the build system by
21// passing -finlines-visibility-hidden. TODO: it would be
22// safer to hardcode it here with some #pragma's.
23#include "../public/gemmlowp.h"
24
25// Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON
26// code. This code path consists of a number of meta-programmed, automatically
27// generated GEMM kernels that are suitable for some sizes of input matrices.
28// Due to the fact that the generated code relies heavily on loop unrolling,
29// inling and currying of runtime parameters the size of the generated binary
30// is quite significant (approx. 200kb) which might be prohibitive in
31// low-memory situations.
32
33#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
34#include "../meta/multi_thread_gemm.h"
35#endif
36
37namespace gemmlowp {
38namespace eight_bit_int_gemm {
39namespace {
40
41// To be used as template parameter for GlobalLock.
42// GlobalLock<EightBitIntGemmLockId> is the global lock
43// on EightBitIntGemm entry points, protecting
44// EightBitIntGemm's global state.
45struct EightBitIntGemmLockId;
46
47// Global state: consists of one global GemmContext instance.
48GemmContext* global_context;
49
50GemmContext* GetOrCreateGlobalContext() {
51  if (!global_context) {
52    global_context = new GemmContext;
53  }
54  return global_context;
55}
56
57void DestroyGlobalContext() {
58  delete global_context;
59  global_context = nullptr;
60}
61
62template <bool transpose_a, bool transpose_b, bool transpose_c>
63void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k,
64                         const std::uint8_t* a, std::int32_t a_offset, int lda,
65                         const std::uint8_t* b, std::int32_t b_offset, int ldb,
66                         std::uint8_t* c, std::int32_t c_offset,
67                         std::int32_t c_mult_int, std::int32_t c_shift, int ldc,
68                         BitDepthSetting bit_depth) {
69  const int lhs_offset = a_offset;
70  const int rhs_offset = b_offset;
71  const int result_offset = c_offset;
72  const int result_mult_int = c_mult_int;
73  const int result_shift = c_shift;
74
75  static const MapOrder ResultOrder =
76      transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
77  static const MapOrder LhsOrder =
78      transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
79  static const MapOrder RhsOrder =
80      transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
81
82  MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
83  MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
84  MatrixMap<std::uint8_t, ResultOrder> result(c, m, n, ldc);
85
86  switch (bit_depth) {
87#define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS)     \
88  case BitDepthSetting::BIT_DEPTH_SETTING:                                 \
89    Gemm<std::uint8_t, BIT_DEPTH_PARAMS>(                                  \
90        context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \
91        result_mult_int, result_shift);                                    \
92    return;
93    GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams)
94    GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams)
95    default:
96      abort();
97#undef GEMMLOWP_HANDLE_BIT_DEPTH
98  }
99}
100
101template <bool transpose_a, bool transpose_b, bool transpose_c>
102void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k,
103                              const std::uint8_t* a, std::int32_t a_offset,
104                              int lda, const std::uint8_t* b,
105                              std::int32_t b_offset, int ldb, std::int32_t* c,
106                              int ldc, BitDepthSetting bit_depth) {
107  const int lhs_offset = a_offset;
108  const int rhs_offset = b_offset;
109
110  static const MapOrder ResultOrder =
111      transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
112  static const MapOrder LhsOrder =
113      transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
114  static const MapOrder RhsOrder =
115      transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
116
117  MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
118  MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
119  MatrixMap<std::int32_t, ResultOrder> result(c, m, n, ldc);
120
121  auto empty_pipeline = std::make_tuple();
122
123  switch (bit_depth) {
124#define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \
125  case BitDepthSetting::BIT_DEPTH_SETTING:                                   \
126    GemmWithOutputPipeline<std::uint8_t, std::int32_t, BIT_DEPTH_PARAMS>(    \
127        context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \
128    return;
129    GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams)
130    GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams)
131    default:
132      abort();
133#undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32
134  }
135}
136
137class Scratch {
138 public:
139  Scratch() : buffer_(), size_(0) {}
140
141  void AssureSize(std::int32_t required_size) {
142    if (size_ >= required_size) {
143      return;
144    }
145    buffer_.reset(new std::uint8_t[required_size]);
146    size_ = required_size;
147  }
148
149  void Clear() {
150    buffer_.reset(nullptr);
151    size_ = 0;
152  }
153
154  std::uint8_t* buffer() { return buffer_.get(); }
155
156 private:
157  std::unique_ptr<std::uint8_t[]> buffer_;
158  std::int32_t size_;
159};
160
161Scratch* global_scratch = nullptr;
162
163Scratch* GetOrCreateGlobalScratch() {
164  if (global_scratch == nullptr) {
165    global_scratch = new Scratch();
166  }
167  return global_scratch;
168}
169
170void DestroyGlobalScratch() {
171  delete global_scratch;
172  global_scratch = nullptr;
173}
174
175#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
176
177bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) {
178  // Is it row major and nicely packed?
179  if (transpose && stride == cols) {
180    return true;
181  }
182
183  // Is it a one row vector? (a vector is both row and column major)
184  if (rows == 1) {
185    return true;
186  }
187
188  return false;
189}
190
191bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) {
192  // Is it column major and nicely packed?
193  if (!transpose && stride == rows) {
194    return true;
195  }
196
197  // Is it a one column vector? (a vector is both row and column major)
198  if (cols == 1) {
199    return true;
200  }
201
202  return false;
203}
204
205bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c,
206                           int m, int n, int k, int lda, int ldb, int ldc,
207                           BitDepthSetting depth_setting) {
208  // Meta fastpath only supports 8bit x 8bit and k up to 2048.
209  if (depth_setting != BitDepthSetting::A8B8 || k > 2048) {
210    return false;
211  }
212
213  // The first operand needs to be a row major matrix or a vector.
214  if (!IsRowMajorOrVector(transpose_a, lda, m, k)) {
215    return false;
216  }
217
218  // The second operand needs to be a column major matrix or a vector.
219  if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) {
220    return false;
221  }
222
223  // The result can either be a row major matrix, a column major matrix or
224  // a vector.
225  if (IsRowMajorOrVector(transpose_c, ldc, m, n)) {
226    return true;
227  }
228
229  if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) {
230    return true;
231  }
232
233  return false;
234}
235
236// Assure enough scratch memory is allocated and run the fast path gemm.
237void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs,
238                           const std::uint8_t* rhs, int m, int n, int k,
239                           std::int32_t lhs_offset, std::int32_t rhs_offset,
240                           std::int32_t sum_offset,
241                           std::int32_t multiplicative_offset,
242                           std::int32_t shift, bool result_transpose,
243                           std::int32_t result_stride, std::uint8_t* result) {
244  Scratch* scratch = GetOrCreateGlobalScratch();
245  if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
246    scratch->AssureSize(
247        meta::gemm_q8_scratch(m, n, k, context->max_num_threads()));
248    meta::multi_thread_gemm_q8(
249        context->workers_pool(), context->max_num_threads(), scratch->buffer(),
250        lhs, rhs, m, n, k, lhs_offset, rhs_offset, sum_offset,
251        multiplicative_offset, shift, result);
252  } else {
253    scratch->AssureSize(
254        meta::gemm_q8_scratch(n, m, k, context->max_num_threads()));
255    meta::multi_thread_gemm_q8(
256        context->workers_pool(), context->max_num_threads(), scratch->buffer(),
257        rhs, lhs, n, m, k, rhs_offset, lhs_offset, sum_offset,
258        multiplicative_offset, shift, result);
259  }
260}
261
262// Assure enough scratch memory is allocated and run the 8bit to float fast
263// path gemm.
264void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs,
265                   const std::uint8_t* rhs, int m, int n, int k,
266                   std::int32_t lhs_offset, std::int32_t rhs_offset,
267                   float result_offset, bool result_transpose,
268                   std::int32_t result_stride, float* result) {
269  Scratch* scratch = GetOrCreateGlobalScratch();
270  if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
271    scratch->AssureSize(
272        meta::gemm_f_scratch(m, n, k, context->max_num_threads()));
273    meta::multi_thread_gemm_f(
274        context->workers_pool(), context->max_num_threads(), scratch->buffer(),
275        lhs, rhs, m, n, k, lhs_offset, rhs_offset, result_offset, result);
276  } else {
277    scratch->AssureSize(
278        meta::gemm_f_scratch(n, m, k, context->max_num_threads()));
279    meta::multi_thread_gemm_f(
280        context->workers_pool(), context->max_num_threads(), scratch->buffer(),
281        rhs, lhs, n, m, k, rhs_offset, lhs_offset, result_offset, result);
282  }
283}
284
285#endif
286
287}  // end anonymous namespace
288
289// Public interface entry points
290
291void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
292                     int m, int n, int k, const std::uint8_t* a,
293                     std::int32_t a_offset, int lda, const std::uint8_t* b,
294                     std::int32_t b_offset, int ldb, std::uint8_t* c,
295                     std::int32_t c_offset, std::int32_t c_mult_int,
296                     std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) {
297  AutoGlobalLock<EightBitIntGemmLockId> lock;
298  GemmContext* context = GetOrCreateGlobalContext();
299
300#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
301  if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
302                            ldb, ldc, bit_depth)) {
303    MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset,
304                          c_mult_int, c_shift, transpose_c, ldc, c);
305    return;
306  }
307#endif
308
309#define GEMMLOWP_HANDLE_CASE(ta, tb, tc)                                    \
310  if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {        \
311    EightBitIntGemmImpl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, b,  \
312                                    b_offset, ldb, c, c_offset, c_mult_int, \
313                                    c_shift, ldc, bit_depth);               \
314  }
315
316  GEMMLOWP_HANDLE_CASE(false, false, false)
317  GEMMLOWP_HANDLE_CASE(false, false, true)
318  GEMMLOWP_HANDLE_CASE(false, true, false)
319  GEMMLOWP_HANDLE_CASE(false, true, true)
320  GEMMLOWP_HANDLE_CASE(true, false, false)
321  GEMMLOWP_HANDLE_CASE(true, false, true)
322  GEMMLOWP_HANDLE_CASE(true, true, false)
323  GEMMLOWP_HANDLE_CASE(true, true, true)
324
325#undef GEMMLOWP_HANDLE_CASE
326}
327
328void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
329                     int m, int n, int k, const std::uint8_t* a,
330                     std::int32_t a_offset, std::int32_t lda,
331                     const std::uint8_t* b, std::int32_t b_offset,
332                     std::int32_t ldb, float* c, float c_offset,
333                     std::int32_t ldc, BitDepthSetting bit_depth) {
334  AutoGlobalLock<EightBitIntGemmLockId> lock;
335  GemmContext* context = GetOrCreateGlobalContext();
336
337#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON_32)
338  if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
339                            ldb, ldc, bit_depth)) {
340    MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset,
341                  transpose_c, ldc, c);
342    return;
343  }
344#endif
345
346  // TODO(maciekc): implement a float output stage, get rid of scratch memory.
347  Scratch* scratch = GetOrCreateGlobalScratch();
348  if (transpose_c) {
349    scratch->AssureSize(m * ldc * sizeof(std::int32_t));
350  } else {
351    scratch->AssureSize(n * ldc * sizeof(std::int32_t));
352  }
353  std::int32_t* temp_c = reinterpret_cast<std::int32_t*>(scratch->buffer());
354
355#define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc)                               \
356  if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {         \
357    EightBitIntGemmInt32Impl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, \
358                                         b, b_offset, ldb, temp_c, ldc,      \
359                                         bit_depth);                         \
360  }
361
362  GEMMLOWP_HANDLE_INT32_CASE(false, false, false)
363  GEMMLOWP_HANDLE_INT32_CASE(false, false, true)
364  GEMMLOWP_HANDLE_INT32_CASE(false, true, false)
365  GEMMLOWP_HANDLE_INT32_CASE(false, true, true)
366  GEMMLOWP_HANDLE_INT32_CASE(true, false, false)
367  GEMMLOWP_HANDLE_INT32_CASE(true, false, true)
368  GEMMLOWP_HANDLE_INT32_CASE(true, true, false)
369  GEMMLOWP_HANDLE_INT32_CASE(true, true, true)
370
371#undef GEMMLOWP_HANDLE_INT32_CASE
372
373  if (transpose_c) {
374    // Row major.
375    for (int i = 0; i < m; ++i) {
376      float* dest_row = c + i * ldc;
377      std::int32_t* src_row = temp_c + i * ldc;
378      for (int j = 0; j < n; ++j) {
379        dest_row[j] = static_cast<float>(src_row[j]) * c_offset;
380      }
381    }
382  } else {
383    // Column major.
384    for (int i = 0; i < n; ++i) {
385      float* dest_column = c + i * ldc;
386      std::int32_t* src_column = temp_c + i * ldc;
387      for (int j = 0; j < m; ++j) {
388        dest_column[j] = static_cast<float>(src_column[j]) * c_offset;
389      }
390    }
391  }
392}
393
394void SetMaxNumThreads(int n) {
395  AutoGlobalLock<EightBitIntGemmLockId> lock;
396  GemmContext* context = GetOrCreateGlobalContext();
397  context->set_max_num_threads(n);
398}
399
400void FreePersistentResources() {
401  AutoGlobalLock<EightBitIntGemmLockId> lock;
402  DestroyGlobalContext();
403  DestroyGlobalScratch();
404}
405
406}  // namespace eight_bit_int_gemm
407}  // namespace gemmlowp
408