1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/kernels/meta_support.h"
19
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/framework/resource_mgr.h"
22#include "tensorflow/core/kernels/quantization_utils.h"
23#include "tensorflow/core/platform/logging.h"
24#include "tensorflow/core/platform/mutex.h"
25
26#if (defined(GEMMLOWP_NEON_32) || defined(GEMMLOWP_NEON_64)) && \
27    !defined(TENSORFLOW_DISABLE_META) && !defined(__APPLE__)
28#define TENSORFLOW_USE_META (1)
29#endif
30
31namespace tensorflow {
32namespace meta {
33
34namespace {
35
36int g_num_threads = 0;
37bool g_enabled = true;
38bool g_use_local_context = false;
39
40#ifdef TENSORFLOW_USE_META
41
42const int kAlignment = 32;
43const int kScratchSize = 2048 * 1024 + kAlignment;
44
45class Scratch : public ResourceBase {
46 public:
47  Scratch() : scratch_(new uint8_t[kScratchSize]) {
48    // Make sure scratch is aligned to 32 bytes. Scratch object owns the
49    // scratch buffer.
50    scratch_32_aligned_ =
51        scratch_.get() + kAlignment -
52        (reinterpret_cast<uintptr_t>(scratch_.get()) % kAlignment);
53  }
54
55  uint8_t* buffer() { return scratch_32_aligned_; }
56
57  string DebugString() { return "MetaGemmScratchResource"; }
58
59 private:
60  std::unique_ptr<uint8_t> scratch_;
61  uint8_t* scratch_32_aligned_;
62};
63
64uint8_t* GetScratch(OpKernelContext* context) {
65  Scratch* scratch = nullptr;
66  std::function<Status(Scratch**)> creator = [](Scratch** resource) {
67    *resource = new Scratch();
68    return Status::OK();
69  };
70  Status s = context->resource_manager()->LookupOrCreate(
71      "MetaGemm", "ScratchBuffer", &scratch, creator);
72  if (!s.ok()) {
73    context->CtxFailureWithWarning(s);
74    return nullptr;
75  }
76  return scratch->buffer();
77}
78
79gemmlowp::WorkersPool* GetWorkersPool() {
80  static gemmlowp::WorkersPool* pool = new gemmlowp::WorkersPool();
81  return pool;
82}
83
84mutex& GetMutex() {
85  static mutex mu(LINKER_INITIALIZED);
86  return mu;
87}
88
89int GetWorkersCount(OpKernelContext* tf_context) {
90  if (g_num_threads == 0) {
91    return tf_context->device()->tensorflow_cpu_worker_threads()->num_threads;
92  }
93  return g_num_threads;
94}
95
96typedef gemmlowp::meta::SimpleContext<gemmlowp::WorkersPool> LocalContext;
97
98template <typename Context, typename Params>
99void MultiThreadGemm(Context* context, const Params& params) {
100  if (params.m <= 4) {
101    gemmlowp::meta::MultiThreadGemm<
102        Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params, 1,
103        8, 8>(context, params);
104  } else {
105    if (params.m >= params.n) {
106      gemmlowp::meta::MultiThreadGemm<
107          Context, gemmlowp::meta::GemmExecutorPackRHSCacheFriendly<>, Params,
108          2, 4, 8>(context, params);
109    } else {
110      gemmlowp::meta::MultiThreadGemm<
111          Context, gemmlowp::meta::GemmExecutorPackLHSCacheFriendly<>, Params,
112          2, 4, 8>(context, params);
113    }
114  }
115}
116
117template <typename LeftStream, typename RightStream>
118void QuantizedGemmImpl(OpKernelContext* tf_context, const quint8* a_data,
119                       const quint8* b_data, qint32* c_data, int m, int n,
120                       int k, int offset_a, int offset_b, int lda, int ldb,
121                       int ldc) {
122  typedef gemmlowp::meta::GemmParams<
123      uint8_t, int32_t, LeftStream, RightStream,
124      gemmlowp::meta::QuantizedStaticPreprocessedAsInt32,
125      gemmlowp::meta::RowMajor>
126      Params;
127  Params params;
128
129  params.m = m;
130  params.n = n;
131  params.k = k;
132
133  params.lhs = reinterpret_cast<const uint8_t*>(&(a_data->value));
134  params.rhs = reinterpret_cast<const uint8_t*>(&(b_data->value));
135  params.result = reinterpret_cast<int32_t*>(&(c_data->value));
136  params.scratch = CHECK_NOTNULL(GetScratch(tf_context));
137
138  params.left_stream.count = k;
139  params.left_stream.stride = lda;
140  params.left_stream.multiplicative_sum_offset = offset_b;
141  params.left_stream.additive_sum_offset = k * offset_a * offset_b;
142
143  params.right_stream.count = k;
144  params.right_stream.stride = ldb;
145  params.right_stream.multiplicative_sum_offset = offset_a;
146  params.right_stream.additive_sum_offset = 0;
147
148  params.fused_kernel.kernel.count = k;
149  params.fused_kernel.output_stream.stride = ldc * sizeof(int32_t);
150
151  if (g_use_local_context) {
152    LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
153    MultiThreadGemm<LocalContext, Params>(&local_context, params);
154  } else {
155    auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
156    TensorflowGemmContext context(workers.num_threads, workers.workers);
157    MultiThreadGemm<TensorflowGemmContext, Params>(&context, params);
158  }
159}
160
161template <typename Params, int kernel_size>
162void MultiThreadTransform1D(OpKernelContext* tf_context, const Params& params) {
163  if (g_use_local_context) {
164    LocalContext local_context(GetWorkersCount(tf_context), GetWorkersPool());
165    gemmlowp::meta::MultiThreadTransform1D<LocalContext, Params, kernel_size>(
166        &local_context, params);
167  } else {
168    auto& workers = *(tf_context->device()->tensorflow_cpu_worker_threads());
169    TensorflowGemmContext context(workers.num_threads, workers.workers);
170    gemmlowp::meta::MultiThreadTransform1D<TensorflowGemmContext, Params,
171                                           kernel_size>(&context, params);
172  }
173}
174
175template <typename QuantizedType>
176double CalculateRangeScale(float min, float max) {
177  const int bits = sizeof(QuantizedType) * 8;
178  return static_cast<double>(max - min) /
179         ((static_cast<int64_t>(1) << bits) - 1);
180}
181
182template <typename QuantizedType>
183double CalculateOneOverRangeScale(float min, float max) {
184  if (min == max) {
185    return 0.0;
186  }
187  const int bits = sizeof(QuantizedType) * 8;
188  return static_cast<double>((static_cast<int64_t>(1) << bits) - 1) /
189         (max - min);
190}
191
192#endif  // TENSORFLOW_USE_META
193
194}  // namespace
195
196void SetNumThreads(int num_threads) { g_num_threads = num_threads; }
197
198int GetNumThreads() { return g_num_threads; }
199
200void SetUseLocalContext(bool use_local_context) {
201  g_use_local_context = use_local_context;
202}
203
204bool GetUseLocalContext() { return g_use_local_context; }
205
206bool IsSupported() {
207#if defined(TENSORFLOW_USE_META)
208  return true;
209#else
210  return false;
211#endif
212}
213
214bool IsEnabled() { return g_enabled; }
215
216void SetEnabled(bool enabled) { g_enabled = enabled; }
217
218bool IsSupportedAndEnabled() { return IsSupported() && IsEnabled(); }
219
220void QuantizedGemm(OpKernelContext* tf_context, bool transpose_a,
221                   bool transpose_b, const quint8* a_data, const quint8* b_data,
222                   qint32* c_data, int m, int n, int k, int offset_a,
223                   int offset_b, int lda, int ldb, int ldc) {
224#ifdef TENSORFLOW_USE_META
225  mutex_lock library_lock(GetMutex());
226  if (transpose_a) {
227    if (transpose_b) {
228      QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
229                        gemmlowp::meta::RowMajorWithSum>(
230          tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
231          ldb, ldc);
232    } else {
233      QuantizedGemmImpl<gemmlowp::meta::ColumnMajorWithSum,
234                        gemmlowp::meta::ColumnMajorWithSum>(
235          tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
236          ldb, ldc);
237    }
238  } else {
239    if (transpose_b) {
240      QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
241                        gemmlowp::meta::RowMajorWithSum>(
242          tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
243          ldb, ldc);
244    } else {
245      QuantizedGemmImpl<gemmlowp::meta::RowMajorWithSum,
246                        gemmlowp::meta::ColumnMajorWithSum>(
247          tf_context, a_data, b_data, c_data, m, n, k, offset_a, offset_b, lda,
248          ldb, ldc);
249    }
250  }
251#else
252  LOG(FATAL) << "QuantizedGemm: Meta fastpath not supported.";
253#endif
254}
255
256void Requantize(OpKernelContext* tf_context, const qint32* input, int count,
257                float input_min, float input_max, float output_min,
258                float output_max, quint8* output) {
259#ifdef TENSORFLOW_USE_META
260  mutex_lock library_lock(GetMutex());
261  typedef gemmlowp::meta::Transform1DParams<int32_t, uint8_t,
262                                            gemmlowp::meta::Requantize>
263      Params;
264
265  Params params;
266  params.input = reinterpret_cast<const int32_t*>(input);
267  params.output = reinterpret_cast<uint8_t*>(output);
268  params.kernel.count = count;
269  params.kernel.input_range_min = input_min;
270  params.kernel.output_range_min = output_min;
271  params.kernel.input_range_scale =
272      CalculateRangeScale<int32_t>(input_min, input_max);
273  params.kernel.one_over_output_range_scale =
274      CalculateOneOverRangeScale<uint8_t>(output_min, output_max);
275  params.kernel.input_range_offset =
276      static_cast<float>(std::numeric_limits<int32_t>::lowest());
277
278  // After adding the output_range_offset the value is cast from float to uint.
279  // The float to int/uint cast in NEON uses round toward 0. To keep the
280  // rounding consistent with Eigen, which uses round toward closest, we can
281  // add 0.5f and exploit the fact that we only operate on non negative values.
282  // TODO(maciekc): fix the actual kernel in gemmlowp/meta
283  params.kernel.output_range_offset =
284      static_cast<float>(std::numeric_limits<uint8_t>::lowest()) + 0.5f;
285
286  MultiThreadTransform1D<Params, 16>(tf_context, params);
287#else
288  LOG(FATAL) << "Requantize: Meta fastpath not supported.";
289#endif
290}
291
292void Dequantize(OpKernelContext* tf_context, const quint8* input, int count,
293                float range_min, float range_max, float* output) {
294#ifdef TENSORFLOW_USE_META
295  mutex_lock library_lock(GetMutex());
296  typedef gemmlowp::meta::Transform1DParams<uint8_t, float,
297                                            gemmlowp::meta::Dequantize>
298      Params;
299
300  Params params;
301  params.input = reinterpret_cast<const uint8_t*>(input);
302  params.output = reinterpret_cast<float*>(output);
303  params.kernel.count = count;
304  params.kernel.range_min = range_min;
305  params.kernel.range_scale =
306      CalculateRangeScale<uint8_t>(range_min, range_max);
307  params.kernel.range_offset =
308      static_cast<float>(std::numeric_limits<uint8_t>::lowest());
309
310  MultiThreadTransform1D<Params, 16>(tf_context, params);
311#else
312  LOG(FATAL) << "Dequantize: Meta fastpath not supported.";
313#endif
314}
315
316void Quantize(OpKernelContext* tf_context, const float* input, int count,
317              float range_min, float range_max, quint8* output) {
318#ifdef TENSORFLOW_USE_META
319  mutex_lock library_lock(GetMutex());
320  typedef gemmlowp::meta::Transform1DParams<float, uint8_t,
321                                            gemmlowp::meta::Quantize>
322      Params;
323
324  Params params;
325  params.input = reinterpret_cast<const float*>(input);
326  params.output = reinterpret_cast<uint8_t*>(output);
327  params.kernel.count = count;
328  params.kernel.range_min = range_min;
329  params.kernel.range_scale =
330      CalculateOneOverRangeScale<uint8_t>(range_min, range_max);
331
332  // After adding the range_offset the value is cast from float to uint.
333  // The float to int/uint cast in NEON uses round toward 0. To keep the
334  // rounding consistent with Eigen, which uses round toward closest, we can
335  // add 0.5f and exploit the fact that we only operate on non negative values.
336  // TODO(maciekc): fix the actual kernel in gemmlowp/meta
337  params.kernel.range_offset =
338      static_cast<float>(std::numeric_limits<uint8_t>::lowest()) + 0.5f;
339
340  MultiThreadTransform1D<Params, 16>(tf_context, params);
341#else
342  LOG(FATAL) << "Quantize: Meta fastpath not supported.";
343#endif
344}
345
346void QuantizedBiasAdd(OpKernelContext* tf_context, const quint8* input,
347                      int input_count, const quint8* bias, int bias_count,
348                      float input_min, float input_max, float bias_min,
349                      float bias_max, float output_min, float output_max,
350                      qint32* output) {
351#ifdef TENSORFLOW_USE_META
352  mutex_lock library_lock(GetMutex());
353  typedef gemmlowp::meta::Transform1DParams<uint8_t, int32_t,
354                                            gemmlowp::meta::BiasAdd<uint8_t>>
355      Params;
356
357  Params params;
358  params.input = reinterpret_cast<const uint8_t*>(input);
359  params.output = reinterpret_cast<int32_t*>(output);
360  params.kernel.bias = reinterpret_cast<const uint8_t*>(bias);
361  params.kernel.count = bias_count;
362  params.kernel.rows = input_count / bias_count;
363  params.kernel.input_range_min = input_min;
364  params.kernel.bias_range_min = bias_min;
365  params.kernel.input_range_scale =
366      CalculateRangeScale<uint8_t>(input_min, input_max);
367  params.kernel.bias_range_scale =
368      CalculateRangeScale<uint8_t>(bias_min, bias_max);
369  params.kernel.input_range_offset = 0;
370  params.kernel.bias_range_offset = 0;
371  params.kernel.output_range_min = output_min;
372  params.kernel.one_over_output_range_scale =
373      CalculateOneOverRangeScale<int32_t>(output_min, output_max);
374  params.kernel.output_range_offset =
375      static_cast<float>(std::numeric_limits<int32_t>::lowest());
376
377  // TODO(maciekc): add multithreading to bias add.
378  // Right now this kernel does not support multi threaded execution.
379  gemmlowp::meta::Transform1D<Params, 16>(params);
380#else
381  LOG(FATAL) << "QuantizedBiasAdd: Meta fastpath not supported.";
382#endif
383}
384
385void Clamp(OpKernelContext* tf_context, const quint8* input, int count,
386           quint8 clamp_min, quint8 clamp_max, quint8* output) {
387#ifdef TENSORFLOW_USE_META
388  mutex_lock library_lock(GetMutex());
389  typedef gemmlowp::meta::Transform1DParams<uint8_t, uint8_t,
390                                            gemmlowp::meta::MinMax<uint8_t>>
391      Params;
392
393  Params params;
394  params.input = reinterpret_cast<const uint8_t*>(input);
395  params.output = reinterpret_cast<uint8_t*>(output);
396  params.kernel.count = count;
397  params.kernel.min = clamp_min;
398  params.kernel.max = clamp_max;
399
400  MultiThreadTransform1D<Params, 16>(tf_context, params);
401#else
402  LOG(FATAL) << "Clamp: Meta fastpath not supported.";
403#endif
404}
405
406}  // namespace meta
407}  // namespace tensorflow
408