1// Copyright 2015 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
16#define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
17#endif
18#include "eight_bit_int_gemm.h"
19
20#include <memory>
21
22// gemmlowp symbols should have hidden visibility.
23// currently this is ensured in the build system by
24// passing -finlines-visibility-hidden. TODO: it would be
25// safer to hardcode it here with some #pragma's.
26#include "../public/gemmlowp.h"
27
28// Define GEMMLOWP_USE_META_FASTPATH in order to use the fastpath ARM/NEON
29// code. This code path consists of a number of meta-programmed, automatically
30// generated GEMM kernels that are suitable for some sizes of input matrices.
31// Due to the fact that the generated code relies heavily on loop unrolling,
32// inling and currying of runtime parameters the size of the generated binary
33// is quite significant (approx. 200kb) which might be prohibitive in
34// low-memory situations.
35
36#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
37#include "../meta/legacy_multi_thread_gemm.h"
38#else
39
40#if defined(GEMMLOWP_USE_META_FASTPATH)
41#warning "META fast path turned on without NEON!"
42#endif
43
44#endif
45
46namespace gemmlowp {
47namespace eight_bit_int_gemm {
48namespace {
49
50// To be used as template parameter for GlobalLock.
51// GlobalLock<EightBitIntGemmLockId> is the global lock
52// on EightBitIntGemm entry points, protecting
53// EightBitIntGemm's global state.
54struct EightBitIntGemmLockId;
55
56// Global state: consists of one global GemmContext instance.
57GemmContext* global_context;
58
59GemmContext* GetOrCreateGlobalContext() {
60  if (!global_context) {
61    global_context = new GemmContext;
62  }
63  return global_context;
64}
65
66void DestroyGlobalContext() {
67  delete global_context;
68  global_context = nullptr;
69}
70
71template <bool transpose_a, bool transpose_b, bool transpose_c>
72void EightBitIntGemmImpl(GemmContext* context, int m, int n, int k,
73                         const std::uint8_t* a, std::int32_t a_offset, int lda,
74                         const std::uint8_t* b, std::int32_t b_offset, int ldb,
75                         std::uint8_t* c, std::int32_t c_offset,
76                         std::int32_t c_mult_int, std::int32_t c_shift, int ldc,
77                         BitDepthSetting bit_depth) {
78  const int lhs_offset = a_offset;
79  const int rhs_offset = b_offset;
80  const int result_offset = c_offset;
81  const int result_mult_int = c_mult_int;
82  const int result_shift = c_shift;
83
84  static const MapOrder ResultOrder =
85      transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
86  static const MapOrder LhsOrder =
87      transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
88  static const MapOrder RhsOrder =
89      transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
90
91  MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
92  MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
93  MatrixMap<std::uint8_t, ResultOrder> result(c, m, n, ldc);
94
95  switch (bit_depth) {
96#define GEMMLOWP_HANDLE_BIT_DEPTH(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS)     \
97  case BitDepthSetting::BIT_DEPTH_SETTING:                                 \
98    Gemm<std::uint8_t, BIT_DEPTH_PARAMS>(                                  \
99        context, lhs, rhs, &result, lhs_offset, rhs_offset, result_offset, \
100        result_mult_int, result_shift);                                    \
101    return;
102    GEMMLOWP_HANDLE_BIT_DEPTH(A8B8, DefaultL8R8BitDepthParams)
103    GEMMLOWP_HANDLE_BIT_DEPTH(A5B7, DefaultL7R5BitDepthParams)
104    default:
105      abort();
106#undef GEMMLOWP_HANDLE_BIT_DEPTH
107  }
108}
109
110template <bool transpose_a, bool transpose_b, bool transpose_c>
111void EightBitIntGemmInt32Impl(GemmContext* context, int m, int n, int k,
112                              const std::uint8_t* a, std::int32_t a_offset,
113                              int lda, const std::uint8_t* b,
114                              std::int32_t b_offset, int ldb, std::int32_t* c,
115                              int ldc, BitDepthSetting bit_depth) {
116  const int lhs_offset = a_offset;
117  const int rhs_offset = b_offset;
118
119  static const MapOrder ResultOrder =
120      transpose_c ? MapOrder::RowMajor : MapOrder::ColMajor;
121  static const MapOrder LhsOrder =
122      transpose_a ? MapOrder::RowMajor : MapOrder::ColMajor;
123  static const MapOrder RhsOrder =
124      transpose_b ? MapOrder::RowMajor : MapOrder::ColMajor;
125
126  MatrixMap<const std::uint8_t, LhsOrder> lhs(a, m, k, lda);
127  MatrixMap<const std::uint8_t, RhsOrder> rhs(b, k, n, ldb);
128  MatrixMap<std::int32_t, ResultOrder> result(c, m, n, ldc);
129
130  auto empty_pipeline = std::make_tuple();
131
132  switch (bit_depth) {
133#define GEMMLOWP_HANDLE_BIT_DEPTH_INT32(BIT_DEPTH_SETTING, BIT_DEPTH_PARAMS) \
134  case BitDepthSetting::BIT_DEPTH_SETTING:                                   \
135    GemmWithOutputPipeline<std::uint8_t, std::int32_t, BIT_DEPTH_PARAMS>(    \
136        context, lhs, rhs, &result, lhs_offset, rhs_offset, empty_pipeline); \
137    return;
138    GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A8B8, DefaultL8R8BitDepthParams)
139    GEMMLOWP_HANDLE_BIT_DEPTH_INT32(A5B7, DefaultL7R5BitDepthParams)
140    default:
141      abort();
142#undef GEMMLOWP_HANDLE_BIT_DEPTH_INT32
143  }
144}
145
146class Scratch {
147 public:
148  Scratch() : buffer_(), buffer_32_(nullptr), size_(0) {}
149
150  void AssureSize(std::int32_t required_size) {
151    if (size_ >= required_size) {
152      return;
153    }
154    buffer_.reset(new std::uint8_t[required_size + 32]);
155    buffer_32_ =
156        buffer_.get() +
157        ((32 - (reinterpret_cast<uintptr_t>(buffer_.get()) % 32)) % 32);
158    assert((reinterpret_cast<uintptr_t>(buffer_32_) % 32) == 0);
159    size_ = required_size;
160  }
161
162  void Clear() {
163    buffer_.reset(nullptr);
164    buffer_32_ = nullptr;
165    size_ = 0;
166  }
167
168  std::uint8_t* buffer() { return buffer_32_; }
169
170 private:
171  std::unique_ptr<std::uint8_t[]> buffer_;
172  std::uint8_t* buffer_32_;
173  std::int32_t size_;
174};
175
176Scratch* global_scratch = nullptr;
177
178Scratch* GetOrCreateGlobalScratch() {
179  if (global_scratch == nullptr) {
180    global_scratch = new Scratch();
181  }
182  return global_scratch;
183}
184
185void DestroyGlobalScratch() {
186  delete global_scratch;
187  global_scratch = nullptr;
188}
189
190#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
191
192bool IsRowMajorOrVector(bool transpose, int stride, int rows, int cols) {
193  // Is it row major and nicely packed?
194  if (transpose && stride == cols) {
195    return true;
196  }
197
198  // Is it a one row vector? (a vector is both row and column major)
199  if (rows == 1) {
200    return true;
201  }
202
203  return false;
204}
205
206bool IsColumnMajorOrVector(bool transpose, int stride, int rows, int cols) {
207  // Is it column major and nicely packed?
208  if (!transpose && stride == rows) {
209    return true;
210  }
211
212  // Is it a one column vector? (a vector is both row and column major)
213  if (cols == 1) {
214    return true;
215  }
216
217  return false;
218}
219
220bool CanHandleMetaFastpath(bool transpose_a, bool transpose_b, bool transpose_c,
221                           int m, int n, int k, int lda, int ldb, int ldc,
222                           BitDepthSetting depth_setting) {
223  // Meta fastpath only supports 8bit x 8bit and k between 8 and 2048.
224  if (depth_setting != BitDepthSetting::A8B8 || k < 8 || k > 2048) {
225    return false;
226  }
227
228  // The first operand needs to be a row major matrix or a vector.
229  if (!IsRowMajorOrVector(transpose_a, lda, m, k)) {
230    return false;
231  }
232
233  // The second operand needs to be a column major matrix or a vector.
234  if (!IsColumnMajorOrVector(transpose_b, ldb, k, n)) {
235    return false;
236  }
237
238  // The result can either be a row major matrix, a column major matrix or
239  // a vector.
240  if (IsRowMajorOrVector(transpose_c, ldc, m, n)) {
241    return true;
242  }
243
244  if (IsColumnMajorOrVector(transpose_c, ldc, m, n)) {
245    return true;
246  }
247
248  return false;
249}
250
251// Assure enough scratch memory is allocated and run the fast path gemm.
252void MetaGemmQuantized8Bit(GemmContext* context, const std::uint8_t* lhs,
253                           const std::uint8_t* rhs, int m, int n, int k,
254                           std::int32_t lhs_offset, std::int32_t rhs_offset,
255                           std::int32_t sum_offset,
256                           std::int32_t multiplicative_offset,
257                           std::int32_t shift, bool result_transpose,
258                           std::int32_t result_stride, std::uint8_t* result) {
259  Scratch* scratch = GetOrCreateGlobalScratch();
260  const std::int32_t max_num_threads = context->max_num_threads();
261  if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
262    scratch->AssureSize(meta::gemm_q8_scratch(m, n, k, max_num_threads));
263    meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads,
264                               scratch->buffer(), lhs, rhs, m, n, k, lhs_offset,
265                               rhs_offset, sum_offset, multiplicative_offset,
266                               shift, result);
267  } else {
268    scratch->AssureSize(meta::gemm_q8_scratch(n, m, k, max_num_threads));
269    meta::multi_thread_gemm_q8(context->workers_pool(), max_num_threads,
270                               scratch->buffer(), rhs, lhs, n, m, k, rhs_offset,
271                               lhs_offset, sum_offset, multiplicative_offset,
272                               shift, result);
273  }
274}
275
276// Assure enough scratch memory is allocated and run the 8bit to float fast
277// path gemm.
278void MetaGemmFloat(GemmContext* context, const std::uint8_t* lhs,
279                   const std::uint8_t* rhs, int m, int n, int k,
280                   std::int32_t lhs_offset, std::int32_t rhs_offset,
281                   float result_offset, bool result_transpose,
282                   std::int32_t result_stride, float* result) {
283  Scratch* scratch = GetOrCreateGlobalScratch();
284  const std::int32_t max_num_threads = context->max_num_threads();
285  if (IsRowMajorOrVector(result_transpose, result_stride, m, n)) {
286    scratch->AssureSize(meta::gemm_f_scratch(m, n, k, max_num_threads));
287    meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads,
288                              scratch->buffer(), lhs, rhs, m, n, k, lhs_offset,
289                              rhs_offset, result_offset, result);
290  } else {
291    scratch->AssureSize(meta::gemm_f_scratch(n, m, k, max_num_threads));
292    meta::multi_thread_gemm_f(context->workers_pool(), max_num_threads,
293                              scratch->buffer(), rhs, lhs, n, m, k, rhs_offset,
294                              lhs_offset, result_offset, result);
295  }
296}
297
298#endif
299
300}  // end anonymous namespace
301
302// Public interface entry points
303
304void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
305                     int m, int n, int k, const std::uint8_t* a,
306                     std::int32_t a_offset, int lda, const std::uint8_t* b,
307                     std::int32_t b_offset, int ldb, std::uint8_t* c,
308                     std::int32_t c_offset, std::int32_t c_mult_int,
309                     std::int32_t c_shift, int ldc, BitDepthSetting bit_depth) {
310  ScopedLock sl(GlobalMutexes::EightBitIntGemm());
311  GemmContext* context = GetOrCreateGlobalContext();
312
313#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
314  if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
315                            ldb, ldc, bit_depth)) {
316    MetaGemmQuantized8Bit(context, a, b, m, n, k, a_offset, b_offset, c_offset,
317                          c_mult_int, c_shift, transpose_c, ldc, c);
318    return;
319  }
320#endif
321
322#define GEMMLOWP_HANDLE_CASE(ta, tb, tc)                                    \
323  if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {        \
324    EightBitIntGemmImpl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, b,  \
325                                    b_offset, ldb, c, c_offset, c_mult_int, \
326                                    c_shift, ldc, bit_depth);               \
327  }
328
329  GEMMLOWP_HANDLE_CASE(false, false, false)
330  GEMMLOWP_HANDLE_CASE(false, false, true)
331  GEMMLOWP_HANDLE_CASE(false, true, false)
332  GEMMLOWP_HANDLE_CASE(false, true, true)
333  GEMMLOWP_HANDLE_CASE(true, false, false)
334  GEMMLOWP_HANDLE_CASE(true, false, true)
335  GEMMLOWP_HANDLE_CASE(true, true, false)
336  GEMMLOWP_HANDLE_CASE(true, true, true)
337
338#undef GEMMLOWP_HANDLE_CASE
339}
340
341void EightBitIntGemm(bool transpose_a, bool transpose_b, bool transpose_c,
342                     int m, int n, int k, const std::uint8_t* a,
343                     std::int32_t a_offset, std::int32_t lda,
344                     const std::uint8_t* b, std::int32_t b_offset,
345                     std::int32_t ldb, float* c, float c_offset,
346                     std::int32_t ldc, BitDepthSetting bit_depth) {
347  ScopedLock sl(GlobalMutexes::EightBitIntGemm());
348  GemmContext* context = GetOrCreateGlobalContext();
349
350#if defined(GEMMLOWP_USE_META_FASTPATH) && defined(GEMMLOWP_NEON)
351  if (CanHandleMetaFastpath(transpose_a, transpose_b, transpose_c, m, n, k, lda,
352                            ldb, ldc, bit_depth)) {
353    MetaGemmFloat(context, a, b, m, n, k, a_offset, b_offset, c_offset,
354                  transpose_c, ldc, c);
355    return;
356  }
357#endif
358
359  // TODO(maciekc): implement a float output stage, get rid of scratch memory.
360  Scratch* scratch = GetOrCreateGlobalScratch();
361  if (transpose_c) {
362    scratch->AssureSize(m * ldc * sizeof(std::int32_t));
363  } else {
364    scratch->AssureSize(n * ldc * sizeof(std::int32_t));
365  }
366  std::int32_t* temp_c = reinterpret_cast<std::int32_t*>(scratch->buffer());
367
368#define GEMMLOWP_HANDLE_INT32_CASE(ta, tb, tc)                               \
369  if (transpose_a == ta && transpose_b == tb && transpose_c == tc) {         \
370    EightBitIntGemmInt32Impl<ta, tb, tc>(context, m, n, k, a, a_offset, lda, \
371                                         b, b_offset, ldb, temp_c, ldc,      \
372                                         bit_depth);                         \
373  }
374
375  GEMMLOWP_HANDLE_INT32_CASE(false, false, false)
376  GEMMLOWP_HANDLE_INT32_CASE(false, false, true)
377  GEMMLOWP_HANDLE_INT32_CASE(false, true, false)
378  GEMMLOWP_HANDLE_INT32_CASE(false, true, true)
379  GEMMLOWP_HANDLE_INT32_CASE(true, false, false)
380  GEMMLOWP_HANDLE_INT32_CASE(true, false, true)
381  GEMMLOWP_HANDLE_INT32_CASE(true, true, false)
382  GEMMLOWP_HANDLE_INT32_CASE(true, true, true)
383
384#undef GEMMLOWP_HANDLE_INT32_CASE
385
386  if (transpose_c) {
387    // Row major.
388    for (int i = 0; i < m; ++i) {
389      float* dest_row = c + i * ldc;
390      std::int32_t* src_row = temp_c + i * ldc;
391      for (int j = 0; j < n; ++j) {
392        dest_row[j] = static_cast<float>(src_row[j]) * c_offset;
393      }
394    }
395  } else {
396    // Column major.
397    for (int i = 0; i < n; ++i) {
398      float* dest_column = c + i * ldc;
399      std::int32_t* src_column = temp_c + i * ldc;
400      for (int j = 0; j < m; ++j) {
401        dest_column[j] = static_cast<float>(src_column[j]) * c_offset;
402      }
403    }
404  }
405}
406
407void SetMaxNumThreads(int n) {
408  ScopedLock sl(GlobalMutexes::EightBitIntGemm());
409  GemmContext* context = GetOrCreateGlobalContext();
410  context->set_max_num_threads(n);
411}
412
413void FreePersistentResources() {
414  ScopedLock sl(GlobalMutexes::EightBitIntGemm());
415  DestroyGlobalContext();
416  DestroyGlobalScratch();
417}
418
419}  // namespace eight_bit_int_gemm
420}  // namespace gemmlowp
421