1// Copyright 2017 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// dispatch_gemm_shape.h: dispatch GEMM calls according to their shape
16
17#ifndef GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
18#define GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
19
20#include "../internal/kernel_default.h"
21#include "../public/map.h"
22#include "../public/output_stages.h"
23#include "multi_thread_gemm.h"
24
25namespace gemmlowp {
26
27template <typename T>
28struct TransposeImpl {
29  typedef T DstType;
30  static T Run(const T& t) { return t; }
31};
32
33template <typename T>
34using TransposeType = typename TransposeImpl<T>::DstType;
35
36template <typename T>
37TransposeType<T> Transpose(const T& t) {
38  return TransposeImpl<T>::Run(t);
39}
40
41template <MapOrder Order>
42struct TransposeMapOrder {
43  static constexpr MapOrder Value =
44      Order == MapOrder::RowMajor ? MapOrder::ColMajor : MapOrder::RowMajor;
45};
46
47template <VectorShape Shape>
48struct TransposeVectorShape {
49  static constexpr VectorShape Value =
50      Shape == VectorShape::Row ? VectorShape::Col : VectorShape::Row;
51};
52
53template <typename Scalar, VectorShape Shape>
54struct TransposeImpl<VectorMap<Scalar, Shape>> {
55  typedef VectorMap<Scalar, Shape> SrcType;
56  static constexpr VectorShape TransposedShape =
57      TransposeVectorShape<Shape>::Value;
58  typedef VectorMap<Scalar, TransposedShape> DstType;
59  static DstType Run(const SrcType& src) {
60    return DstType(src.data(), src.size());
61  }
62};
63
64template <typename Scalar, MapOrder Order>
65struct TransposeImpl<MatrixMap<Scalar, Order>> {
66  typedef MatrixMap<Scalar, Order> SrcType;
67  static constexpr MapOrder TransposedOrder = TransposeMapOrder<Order>::Value;
68  typedef MatrixMap<Scalar, TransposedOrder> DstType;
69  static DstType Run(const SrcType& src) {
70    return DstType(src.data(), src.cols(), src.rows(), src.stride());
71  }
72};
73
74template <VectorShape Shape>
75struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> {
76  typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType;
77  static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value;
78  typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType;
79  static DstType Run(const SrcType& src) {
80    DstType dst;
81    dst.result_shift = src.result_shift;
82    dst.result_offset = Transpose(src.result_offset);
83    dst.result_mult_int = Transpose(src.result_mult_int);
84    return dst;
85  }
86};
87
88template <typename VectorMapType>
89struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> {
90  typedef OutputStageBiasAddition<VectorMapType> SrcType;
91  typedef TransposeType<VectorMapType> TransposedVectorMapType;
92  typedef OutputStageBiasAddition<TransposedVectorMapType> DstType;
93  static DstType Run(const SrcType& src) {
94    DstType dst;
95    dst.bias_vector = Transpose(src.bias_vector);
96    return dst;
97  }
98};
99
100// TODO(benoitjacob) - does anyone understand C++ variadic templates?
101// How to use them to implement TransposeTuple? Note: there are lots
102// of answers on StackOverflow but they seem to all involve either
103// C++14/C++17 (we can only use C++11) or lots of abstract nonsense.
104inline std::tuple<> TransposeTuple(const std::tuple<>& t) { return t; }
105
106template <typename T0>
107std::tuple<TransposeType<T0>> TransposeTuple(const std::tuple<T0>& t) {
108  return std::make_tuple(Transpose(std::get<0>(t)));
109}
110
111template <typename T0, typename T1>
112std::tuple<TransposeType<T0>, TransposeType<T1>> TransposeTuple(
113    const std::tuple<T0, T1>& t) {
114  return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)));
115}
116
117template <typename T0, typename T1, typename T2>
118std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>>
119TransposeTuple(const std::tuple<T0, T1, T2>& t) {
120  return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
121                         Transpose(std::get<2>(t)));
122}
123
124template <typename T0, typename T1, typename T2, typename T3>
125std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
126           TransposeType<T3>>
127TransposeTuple(const std::tuple<T0, T1, T2, T3>& t) {
128  return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
129                         Transpose(std::get<2>(t)), Transpose(std::get<3>(t)));
130}
131
132template <typename T0, typename T1, typename T2, typename T3, typename T4>
133std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
134           TransposeType<T3>, TransposeType<T4>>
135TransposeTuple(const std::tuple<T0, T1, T2, T3, T4>& t) {
136  return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
137                         Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
138                         Transpose(std::get<4>(t)));
139}
140
141template <typename T0, typename T1, typename T2, typename T3, typename T4,
142          typename T5>
143std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>,
144           TransposeType<T3>, TransposeType<T4>, TransposeType<T5>>
145TransposeTuple(const std::tuple<T0, T1, T2, T3, T4, T5>& t) {
146  return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)),
147                         Transpose(std::get<2>(t)), Transpose(std::get<3>(t)),
148                         Transpose(std::get<4>(t)), Transpose(std::get<5>(t)));
149}
150
151template <typename InputScalar, typename OutputScalar, typename BitDepthParams,
152          MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder,
153          typename LhsOffset, typename RhsOffset, typename OutputPipelineType,
154          typename GemmContextType>
155void DispatchGemmShape(GemmContextType* context,
156                       const MatrixMap<const InputScalar, LhsOrder>& lhs,
157                       const MatrixMap<const InputScalar, RhsOrder>& rhs,
158                       MatrixMap<OutputScalar, ResultOrder>* result,
159                       const LhsOffset& lhs_offset, const RhsOffset& rhs_offset,
160                       const OutputPipelineType& output_pipeline) {
161  assert(lhs.cols() == rhs.rows());
162
163  int rows = result->rows();
164  int cols = result->cols();
165  int depth = lhs.cols();
166
167  if (rows == 0 || cols == 0 || depth == 0) {
168    // Vacuous GEMM, return early to avoid having to deal with
169    // zero sizes below.
170    return;
171  }
172
173  if (rows < cols) {
174    auto transposed_result_map = Transpose(*result);
175    return DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>(
176        context, Transpose(rhs), Transpose(lhs), &transposed_result_map,
177        Transpose(rhs_offset), Transpose(lhs_offset),
178        TransposeTuple(output_pipeline));
179  }
180
181  typedef DefaultKernel<BitDepthParams> Kernel;
182  MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar,
183                  BitDepthParams>(context, Kernel(), lhs, rhs, result,
184                                  lhs_offset, rhs_offset, output_pipeline);
185}
186
187}  // end namespace gemmlowp
188
189#endif  // GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_
190