1// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// kernel.h: general definitions for kernels. 16 17#ifndef GEMMLOWP_INTERNAL_KERNEL_H_ 18#define GEMMLOWP_INTERNAL_KERNEL_H_ 19 20#include "../public/bit_depth.h" 21#include "common.h" 22 23namespace gemmlowp { 24 25// Explanation of general gemmlowp terminology 26// =========================================== 27// 28// We use the following abbreviations: 29// LHS = "left-hand side" 30// RHS = "right-hand side" 31// Sometimes when referring to either LHS or RHS, we just say a "Side". 32// 33// In a matrix product of a MxK matrix times a KxN matrix, 34// we call K the 'depth'. Note that M is the number of rows 35// of the result (and of the LHS), and N is the number of columns 36// of the result (and of the RHS). 37// 38// In each of the LHS and RHS matrices, we call 'width' the 39// other dimension, besides the depth. So in the LHS, 'width' 40// is the number of rows, while in the RHS, 'width' is the number 41// of columns. 42// 43// So in the LHS MxK matrix, the depth is K and the width in M. 44// And in the RHS KxN matrix, the depth is K and the width in N. 45// 46// This is illustrated in this picture: 47// 48// RHS width 49// <-----------------> 50// +-----------------+ ^ 51// | RHS | | Depth 52// +-----------------+ v 53// ^ +--+ +-----------------+ 54// | |L | | | 55// LHS width | |H | | Result | 56// | |S | | | 57// v +--+ +-----------------+ 58// <--> 59// Depth 60 61// Explanation of gemmlowp kernel formats and "cells" 62// ================================================== 63// 64// Kernels operate on small LHS and RHS blocks that fit in registers. 65// These blocks are stored contiguously in memory, but not always 66// in a traditional column-major or row-major order; instead, 67// they consist of a number of sub-blocks, which we call "cells", 68// that are stored in column-major or row-major order. However, 69// what really matters to us is not so much rows vs columns, but 70// rather width vs depth. So we refer to "width-major" and "depth-major" 71// storage orders. In the LHS, width-major means row-major, 72// while in the RHS, width-major means column-major. 73// There is also a third possibility, "diagonal order", 74// which is unused at the moment. 75// 76// We aim to treat both sides, LHS and RHS, on an equal footing, 77// so we call them both 'sides'. A KernelFormat thus is just a pair 78// of KernelSideFormat's, one for LHS and one for RHS; each KernelSideFormat 79// contains a CellFormat and a number of cells; cells are only ever 80// stacked in the width dimension, which means stacked vertically in the 81// LHS and stacked horizondally in the RHS. 82// 83// Example 84// ======= 85// 86// Let's work out the data layout expected by a kernel having the 87// following format (the struct names here are defined below in this file): 88// 89// KernelFormat< 90// KernelSideFormat<CellFormat<3, 4>, 3>, 91// KernelSideFormat<CellFormat<5, 4>, 2> 92// > 93// 94// The LHS format, KernelSideFormat<CellFormat<3, 4>, 3>, means: 95// 3 cells, each cell having dimensions (width=3, depth=4), laid out in 96// DepthMajor order (the default value, see CellFormat). In the LHS, 97// DepthMajor means column-major, so the LHS cells are of size 3x4 in 98// column-major order, so the LHS layout is: 99// 100// 0 3 6 9 101// 1 4 7 10 102// 2 5 8 11 103// 12 15 18 21 104// 13 16 19 22 105// 14 17 20 23 106// 24 27 30 33 107// 25 28 31 34 108// 26 29 32 35 109// 110// The RHS format, KernelSideFormat<CellFormat<5, 4>, 2>, means: 111// 2 cells each having dimensions (width=5, depth=4), laid out in 112// DepthMajor order (the default value, see CellFormat). In the RHS, 113// DepthMajor means row-major, so the RHS cells are of size 4x5 in 114// row-major order, so the RHS layout is: 115// 116// 0 1 2 3 4 20 21 22 23 24 117// 5 6 7 8 9 25 26 27 28 29 118// 10 11 12 13 14 30 31 32 33 34 119// 15 16 17 18 19 35 36 37 38 39 120 121// CellOrder enumerates the possible storage orders (=layouts) for 122// a cell (see explanation above). 123enum class CellOrder { DepthMajor, WidthMajor, Diagonal }; 124 125// CellFormat describes how data is laid 126// out in a cell. That is, a CellOrder together with actual dimensions. 127template <int tWidth, int tDepth, CellOrder tOrder = CellOrder::DepthMajor> 128struct CellFormat { 129 static const int kWidth = tWidth; 130 static const int kDepth = tDepth; 131 static const CellOrder kOrder = tOrder; 132 133 static const int kSize = kWidth * kDepth; 134}; 135 136// KernelSideFormat describes how data is laid out in a kernel side 137// (i.e. LHS or RHS). That is, a CellFormat together with a number of 138// cells. These cells are always stacked in the Width dimension. 139// For example, in the LHS case, the Width dimension is the rows dimension, 140// se we're saying that in the LHS, cells are stacked vertically. 141// We never stack cells in the Depth dimension. 142template <typename tCellFormat, int tCells> 143struct KernelSideFormat { 144 typedef tCellFormat Cell; 145 static const int kCells = tCells; 146 static const int kWidth = kCells * Cell::kWidth; 147 static const int kDepth = Cell::kDepth; 148 typedef std::uint8_t Scalar; 149}; 150 151template <typename tCellFormat, int tCells> 152struct KernelSideFormatInt8 : KernelSideFormat<tCellFormat, tCells> { 153 typedef std::int8_t Scalar; 154}; 155 156// KernelFormat describes fully the input data layout that a kernel expects. 157// It consists of two KernelSideFormat's, one for LHS and one for RHS. 158template <typename tLhs, typename tRhs> 159struct KernelFormat { 160 typedef tLhs Lhs; 161 typedef tRhs Rhs; 162 163 static_assert(Lhs::Cell::kDepth == Rhs::Cell::kDepth, ""); 164 static const int kDepth = Lhs::Cell::kDepth; 165 static const int kRows = Lhs::Cell::kWidth * Lhs::kCells; 166 static const int kCols = Rhs::Cell::kWidth * Rhs::kCells; 167}; 168 169inline const char* CellOrderName(CellOrder o) { 170 switch (o) { 171 case CellOrder::DepthMajor: 172 return "DepthMajor"; 173 case CellOrder::WidthMajor: 174 return "WidthMajor"; 175 case CellOrder::Diagonal: 176 return "Diagonal"; 177 default: 178 assert(false); 179 return nullptr; 180 } 181} 182 183// Returns the offset into a cell, at which a given coefficient is stored. 184template <typename CellFormat> 185inline int OffsetIntoCell(int w, int d) { 186 const int size = CellFormat::kWidth; 187 switch (CellFormat::kOrder) { 188 case CellOrder::DepthMajor: 189 return w + d * CellFormat::kWidth; 190 case CellOrder::WidthMajor: 191 return d + w * CellFormat::kDepth; 192 case CellOrder::Diagonal: 193 assert(CellFormat::kWidth == CellFormat::kDepth); 194 return ((size + w - d) * size + d) % (size * size); 195 default: 196 assert(false); 197 return 0; 198 } 199} 200 201// KernelBase is the virtual base class below all kernels. 202// The idea is that we don't need to templatize all our code on the exact 203// kernel type; we only need to templatize on kernel format. Kernels 204// sharing the same format can thus share the same packing/unpacking code. 205struct KernelBase { 206 virtual const char* Name() const = 0; 207 208 // This is the kernel implementation. We use the word 'run' consistently 209 // throughout gemmlowp to mean an inner loop, the implementation of which 210 // is to be provided by a separate optimized function. 211 virtual void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, 212 std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, 213 const std::uint8_t* rhs_ptr, std::size_t start_depth, 214 std::size_t run_depth) const = 0; 215 216 virtual ~KernelBase() {} 217}; 218 219template <typename KernelScalarType> 220struct ZeroPointInputValue {}; 221 222template <> 223struct ZeroPointInputValue<std::uint8_t> { 224 static constexpr std::uint8_t kValue = 0; 225}; 226 227template <> 228struct ZeroPointInputValue<std::int8_t> { 229 static constexpr std::uint8_t kValue = 128; 230}; 231 232} // namespace gemmlowp 233 234#endif // GEMMLOWP_INTERNAL_KERNEL_H_ 235