1321f69487c9244350b5e5b7d8fd68e56aa9eb6c8Benoit Jacob// Copyright 2015 Google Inc. All Rights Reserved.
275c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob//
375c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob// Licensed under the Apache License, Version 2.0 (the "License");
475c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob// you may not use this file except in compliance with the License.
575c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob// You may obtain a copy of the License at
675c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob//
775c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob//     http://www.apache.org/licenses/LICENSE-2.0
875c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob//
975c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob// Unless required by applicable law or agreed to in writing, software
1075c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob// distributed under the License is distributed on an "AS IS" BASIS,
1175c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1275c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob// See the License for the specific language governing permissions and
1375c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob// limitations under the License.
1475c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob
1575c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob// gemmlowp.h: the main public interface header of gemmlowp.
1675c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob
1775c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob#ifndef GEMMLOWP_PUBLIC_GEMMLOWP_H_
1875c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob#define GEMMLOWP_PUBLIC_GEMMLOWP_H_
197b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang#include "../internal/kernel_default.h"
207b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang#include "../internal/multi_thread_gemm.h"
217b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang#include "../internal/unpack.h"
220a70f98b4be89f51cdd54bf739c953e82ec7fb55Miao Wang#include "bit_depth.h"
23544690cac8f06f1b2f5fa3799e1e8f13c75d95e9Miao Wang#include "map.h"
247b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang#include "output_stages.h"
2575c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob
2675c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacobnamespace gemmlowp {
2775c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob
287b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wanginline bool IsRequantizationWorthIt(int rows, int cols) {
297b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  // We pack depth*(rows+cols) and compute depth*rows*cols.
307b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  // Thus the ratio of compute/packing cost is rows*cols/(rows+cols)
317b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  // In the square case rows==cols==N, it becomes N/2.
327b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  return 2 * rows * cols >= (rows + cols) * kMinimumWidthForRequantization;
337b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang}
347b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang
3575c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacobclass GemmContext : public MultiThreadGemmContext {};
3675c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob
3775c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob// Computes a general matrix product ("GEMM").
387b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang// This is a version that supports per channel quantization.
397b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wangtemplate <typename InputScalar, typename OutputScalar, typename BitDepthParams,
407b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang          MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
417b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang          typename LhsOffset, typename RhsOffset, typename OutputPipelineType>
427b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wangvoid GemmWithOutputPipelinePC(GemmContext* context,
437b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                              const MatrixMap<const InputScalar, LhsOrder>& lhs,
447b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                              const MatrixMap<const InputScalar, RhsOrder>& rhs,
457b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                              MatrixMap<OutputScalar, ResultOrder>* result,
467b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                              const LhsOffset& lhs_offset,
477b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                              const RhsOffset& rhs_offset,
487b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                              const OutputPipelineType& output_pipeline) {
497b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  assert(lhs.cols() == rhs.rows());
507b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang
517b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  int rows = result->rows();
527b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  int cols = result->cols();
537b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  int depth = lhs.cols();
547b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang
557b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  if (rows == 0 || cols == 0 || depth == 0) {
567b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang    // Vacuous GEMM, return early to avoid having to deal with
577b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang    // zero sizes below.
587b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang    return;
597b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  }
607b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang
617b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  if (cols == 1) {
627b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang    if (IsRequantizationWorthIt(rows, cols)) {
637b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang      typedef DefaultKernel<KernelFamily::Gemv, BitDepthParams> Kernel;
647b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang      MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
657b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                      BitDepthParams>(context, Kernel(), lhs, rhs, result,
667b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                                      lhs_offset, rhs_offset, output_pipeline);
677b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang    } else {
687b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang      typedef DefaultKernel<KernelFamily::Gemv, DefaultL8R8BitDepthParams>
697b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang          Kernel;
707b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang      MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
717b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                      DefaultL8R8BitDepthParams>(context, Kernel(), lhs, rhs,
727b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                                                 result, lhs_offset, rhs_offset,
737b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                                                 output_pipeline);
747b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang    }
757b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  } else {
767b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang    if (IsRequantizationWorthIt(rows, cols)) {
777b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang      typedef DefaultKernel<KernelFamily::Gemm, BitDepthParams> Kernel;
787b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang      MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
797b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                      BitDepthParams>(context, Kernel(), lhs, rhs, result,
807b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                                      lhs_offset, rhs_offset, output_pipeline);
817b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang    } else {
827b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang      typedef DefaultKernel<KernelFamily::Gemm, DefaultL8R8BitDepthParams>
837b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang          Kernel;
847b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang      MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
857b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                      DefaultL8R8BitDepthParams>(context, Kernel(), lhs, rhs,
867b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                                                 result, lhs_offset, rhs_offset,
877b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                                                 output_pipeline);
887b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang    }
897b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  }
907b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang}
917b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang
927b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang// Computes a general matrix product ("GEMM").
937b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang// This is the legacy version that does not support per channel quantization.
947b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang// The meaning of the offsets, result_mult_int and result_shift
957b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang// parameters is the same as in the standard EightBitIntGemm interface
967b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang// (which is also implemented in the eight_bit_int_gemm directory).
977b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wangtemplate <typename InputScalar, typename OutputScalar, typename BitDepthParams,
987b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang          MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
997b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang          typename OutputPipelineType>
1007b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wangvoid GemmWithOutputPipeline(GemmContext* context,
1017b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                            const MatrixMap<const InputScalar, LhsOrder>& lhs,
1027b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                            const MatrixMap<const InputScalar, RhsOrder>& rhs,
1037b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                            MatrixMap<OutputScalar, ResultOrder>* result,
1047b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                            int lhs_offset, int rhs_offset,
1057b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang                            const OutputPipelineType& output_pipeline) {
1067b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  const OffsetColDup lhs_offset_vector(lhs_offset, lhs.rows());
1077b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  const OffsetRowDup rhs_offset_vector(rhs_offset, rhs.cols());
1087b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  GemmWithOutputPipelinePC<InputScalar, OutputScalar, BitDepthParams>(
1097b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang      context, lhs, rhs, result, lhs_offset_vector, rhs_offset_vector,
1107b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang      output_pipeline);
1117b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang}
1127b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang
1137b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang// Computes a general matrix product ("GEMM").
11475c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob// The meaning of the offsets, result_mult_int and result_shift
11575c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob// parameters is the same as in the standard EightBitIntGemm interface
11675c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob// (which is also implemented in the eight_bit_int_gemm directory).
1177b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wangtemplate <typename Scalar, typename BitDepthParams, MapOrder LhsOrder,
1180a70f98b4be89f51cdd54bf739c953e82ec7fb55Miao Wang          MapOrder RhsOrder, MapOrder ResultOrder>
11975c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacobvoid Gemm(GemmContext* context, const MatrixMap<const Scalar, LhsOrder>& lhs,
12075c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob          const MatrixMap<const Scalar, RhsOrder>& rhs,
12175c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob          MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
12275c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob          int rhs_offset, int result_offset, int result_mult_int,
12375c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob          int result_shift) {
1247b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang  GemmWithOutputPipeline<Scalar, Scalar, BitDepthParams>(
1257b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang      context, lhs, rhs, result, lhs_offset, rhs_offset,
1267b05d573cf2e0fd3a58e98cdbfc65153a83fd6f1Miao Wang      MakeStandardOutputPipeline(result_offset, result_mult_int, result_shift));
12775c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob}
12875c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob
12975c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob}  // namespace gemmlowp
13075c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob
13175c4ec0ba4dd86e4f763a54e01002ff29f1d57aBenoit Jacob#endif  // GEMMLOWP_PUBLIC_GEMMLOWP_H_
132