1// Copyright 2017 The Gemmlowp Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// dispatch_gemm_shape.h: dispatch GEMM calls according to their shape 16 17#ifndef GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_ 18#define GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_ 19 20#include "../internal/kernel_default.h" 21#include "../public/map.h" 22#include "../public/output_stages.h" 23#include "multi_thread_gemm.h" 24 25namespace gemmlowp { 26 27template <typename T> 28struct TransposeImpl { 29 typedef T DstType; 30 static T Run(const T& t) { return t; } 31}; 32 33template <typename T> 34using TransposeType = typename TransposeImpl<T>::DstType; 35 36template <typename T> 37TransposeType<T> Transpose(const T& t) { 38 return TransposeImpl<T>::Run(t); 39} 40 41template <MapOrder Order> 42struct TransposeMapOrder { 43 static constexpr MapOrder Value = 44 Order == MapOrder::RowMajor ? MapOrder::ColMajor : MapOrder::RowMajor; 45}; 46 47template <VectorShape Shape> 48struct TransposeVectorShape { 49 static constexpr VectorShape Value = 50 Shape == VectorShape::Row ? VectorShape::Col : VectorShape::Row; 51}; 52 53template <typename Scalar, VectorShape Shape> 54struct TransposeImpl<VectorMap<Scalar, Shape>> { 55 typedef VectorMap<Scalar, Shape> SrcType; 56 static constexpr VectorShape TransposedShape = 57 TransposeVectorShape<Shape>::Value; 58 typedef VectorMap<Scalar, TransposedShape> DstType; 59 static DstType Run(const SrcType& src) { 60 return DstType(src.data(), src.size()); 61 } 62}; 63 64template <typename Scalar, MapOrder Order> 65struct TransposeImpl<MatrixMap<Scalar, Order>> { 66 typedef MatrixMap<Scalar, Order> SrcType; 67 static constexpr MapOrder TransposedOrder = TransposeMapOrder<Order>::Value; 68 typedef MatrixMap<Scalar, TransposedOrder> DstType; 69 static DstType Run(const SrcType& src) { 70 return DstType(src.data(), src.cols(), src.rows(), src.stride()); 71 } 72}; 73 74template <VectorShape Shape> 75struct TransposeImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>> { 76 typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> SrcType; 77 static const VectorShape TransposedShape = TransposeVectorShape<Shape>::Value; 78 typedef OutputStageQuantizeDownInt32ToUint8ScalePC<TransposedShape> DstType; 79 static DstType Run(const SrcType& src) { 80 DstType dst; 81 dst.result_shift = src.result_shift; 82 dst.result_offset = Transpose(src.result_offset); 83 dst.result_mult_int = Transpose(src.result_mult_int); 84 return dst; 85 } 86}; 87 88template <typename VectorMapType> 89struct TransposeImpl<OutputStageBiasAddition<VectorMapType>> { 90 typedef OutputStageBiasAddition<VectorMapType> SrcType; 91 typedef TransposeType<VectorMapType> TransposedVectorMapType; 92 typedef OutputStageBiasAddition<TransposedVectorMapType> DstType; 93 static DstType Run(const SrcType& src) { 94 DstType dst; 95 dst.bias_vector = Transpose(src.bias_vector); 96 return dst; 97 } 98}; 99 100// TODO(benoitjacob) - does anyone understand C++ variadic templates? 101// How to use them to implement TransposeTuple? Note: there are lots 102// of answers on StackOverflow but they seem to all involve either 103// C++14/C++17 (we can only use C++11) or lots of abstract nonsense. 104inline std::tuple<> TransposeTuple(const std::tuple<>& t) { return t; } 105 106template <typename T0> 107std::tuple<TransposeType<T0>> TransposeTuple(const std::tuple<T0>& t) { 108 return std::make_tuple(Transpose(std::get<0>(t))); 109} 110 111template <typename T0, typename T1> 112std::tuple<TransposeType<T0>, TransposeType<T1>> TransposeTuple( 113 const std::tuple<T0, T1>& t) { 114 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t))); 115} 116 117template <typename T0, typename T1, typename T2> 118std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>> 119TransposeTuple(const std::tuple<T0, T1, T2>& t) { 120 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), 121 Transpose(std::get<2>(t))); 122} 123 124template <typename T0, typename T1, typename T2, typename T3> 125std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>, 126 TransposeType<T3>> 127TransposeTuple(const std::tuple<T0, T1, T2, T3>& t) { 128 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), 129 Transpose(std::get<2>(t)), Transpose(std::get<3>(t))); 130} 131 132template <typename T0, typename T1, typename T2, typename T3, typename T4> 133std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>, 134 TransposeType<T3>, TransposeType<T4>> 135TransposeTuple(const std::tuple<T0, T1, T2, T3, T4>& t) { 136 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), 137 Transpose(std::get<2>(t)), Transpose(std::get<3>(t)), 138 Transpose(std::get<4>(t))); 139} 140 141template <typename T0, typename T1, typename T2, typename T3, typename T4, 142 typename T5> 143std::tuple<TransposeType<T0>, TransposeType<T1>, TransposeType<T2>, 144 TransposeType<T3>, TransposeType<T4>, TransposeType<T5>> 145TransposeTuple(const std::tuple<T0, T1, T2, T3, T4, T5>& t) { 146 return std::make_tuple(Transpose(std::get<0>(t)), Transpose(std::get<1>(t)), 147 Transpose(std::get<2>(t)), Transpose(std::get<3>(t)), 148 Transpose(std::get<4>(t)), Transpose(std::get<5>(t))); 149} 150 151template <typename InputScalar, typename OutputScalar, typename BitDepthParams, 152 MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder, 153 typename LhsOffset, typename RhsOffset, typename OutputPipelineType, 154 typename GemmContextType> 155void DispatchGemmShape(GemmContextType* context, 156 const MatrixMap<const InputScalar, LhsOrder>& lhs, 157 const MatrixMap<const InputScalar, RhsOrder>& rhs, 158 MatrixMap<OutputScalar, ResultOrder>* result, 159 const LhsOffset& lhs_offset, const RhsOffset& rhs_offset, 160 const OutputPipelineType& output_pipeline) { 161 assert(lhs.cols() == rhs.rows()); 162 163 int rows = result->rows(); 164 int cols = result->cols(); 165 int depth = lhs.cols(); 166 167 if (rows == 0 || cols == 0 || depth == 0) { 168 // Vacuous GEMM, return early to avoid having to deal with 169 // zero sizes below. 170 return; 171 } 172 173 if (rows < cols) { 174 auto transposed_result_map = Transpose(*result); 175 return DispatchGemmShape<InputScalar, OutputScalar, BitDepthParams>( 176 context, Transpose(rhs), Transpose(lhs), &transposed_result_map, 177 Transpose(rhs_offset), Transpose(lhs_offset), 178 TransposeTuple(output_pipeline)); 179 } 180 181 typedef DefaultKernel<BitDepthParams> Kernel; 182 MultiThreadGemm<typename Kernel::Format, InputScalar, OutputScalar, 183 BitDepthParams>(context, Kernel(), lhs, rhs, result, 184 lhs_offset, rhs_offset, output_pipeline); 185} 186 187} // end namespace gemmlowp 188 189#endif // GEMMLOWP_INTERNAL_DISPATCH_GEMM_SHAPE_H_ 190