kernel.h revision 0a70f98b4be89f51cdd54bf739c953e82ec7fb55
1// Copyright 2015 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// kernel.h: general definitions for kernels.
16
17#ifndef GEMMLOWP_INTERNAL_KERNEL_H_
18#define GEMMLOWP_INTERNAL_KERNEL_H_
19
20#include "common.h"
21#include "../public/bit_depth.h"
22
23namespace gemmlowp {
24
25// Explanation of general gemmlowp terminology
26// ===========================================
27//
28// We use the following abbreviations:
29// LHS = "left-hand side"
30// RHS = "right-hand side"
31// Sometimes when referring to either LHS or RHS, we just say a "Side".
32//
33// In a matrix product of a MxK matrix times a KxN matrix,
34// we call K the 'depth'. Note that M is the number of rows
35// of the result (and of the LHS), and N is the number of columns
36// of the result (and of the RHS).
37//
38// In each of the LHS and RHS matrices, we call 'width' the
39// other dimension, besides the depth. So in the LHS, 'width'
40// is the number of rows, while in the RHS, 'width' is the number
41// of columns.
42//
43//  So in the LHS MxK matrix, the depth is K and the width in M.
44// And in the RHS KxN matrix, the depth is K and the width in N.
45//
46// This is illustrated in this picture:
47//
48//                             RHS width
49//                        <----------------->
50//                        +-----------------+ ^
51//                        |       RHS       | | Depth
52//                        +-----------------+ v
53//                 ^ +--+ +-----------------+
54//                 | |L | |                 |
55//       LHS width | |H | |      Result     |
56//                 | |S | |                 |
57//                 v +--+ +-----------------+
58//                   <-->
59//                   Depth
60
61// Explanation of gemmlowp kernel formats and "cells"
62// ==================================================
63//
64// Kernels operate on small LHS and RHS blocks that fit in registers.
65// These blocks are stored contiguously in memory, but not always
66// in a traditional column-major or row-major order; instead,
67// they consist of a number of sub-blocks, which we call "cells",
68// that are stored in column-major or row-major order. However,
69// what really matters to us is not so much rows vs columns, but
70// rather width vs depth. So we refer to "width-major" and "depth-major"
71// storage orders. In the LHS, width-major means row-major,
72// while in the RHS, width-major means column-major.
73// There is also a third possibility, "diagonal order",
74// which is unused at the moment.
75//
76// We aim to treat both sides, LHS and RHS, on an equal footing,
77// so we call them both 'sides'. A KernelFormat thus is just a pair
78// of KernelSideFormat's, one for LHS and one for RHS; each KernelSideFormat
79// contains a CellFormat and a number of cells; cells are only ever
80// stacked in the width dimension, which means stacked vertically in the
81// LHS and stacked horizondally in the RHS.
82//
83// Example
84// =======
85//
86// Let's work out the data layout expected by a kernel having the
87// following format (the struct names here are defined below in this file):
88//
89// KernelFormat<
90//   KernelSideFormat<CellFormat<3, 4>, 3>,
91//   KernelSideFormat<CellFormat<5, 4>, 2>
92// >
93//
94// The LHS format, KernelSideFormat<CellFormat<3, 4>, 3>, means:
95// 3 cells, each cell having dimensions (width=3, depth=4), laid out in
96// DepthMajor order (the default value, see CellFormat). In the LHS,
97// DepthMajor means column-major, so the LHS cells are of size 3x4 in
98// column-major order, so the LHS layout is:
99//
100// 0  3  6  9
101// 1  4  7  10
102// 2  5  8  11
103// 12 15 18 21
104// 13 16 19 22
105// 14 17 20 23
106// 24 27 30 33
107// 25 28 31 34
108// 26 29 32 35
109//
110// The RHS format, KernelSideFormat<CellFormat<5, 4>, 2>, means:
111// 2 cells each having dimensions (width=5, depth=4), laid out in
112// DepthMajor order (the default value, see CellFormat). In the RHS,
113// DepthMajor means row-major, so the RHS cells are of size 4x5 in
114// row-major order, so the RHS layout is:
115//
116// 0  1  2  3  4  20 21 22 23 24
117// 5  6  7  8  9  25 26 27 28 29
118// 10 11 12 13 14 30 31 32 33 34
119// 15 16 17 18 19 35 36 37 38 39
120
121// CellOrder enumerates the possible storage orders (=layouts) for
122// a cell (see explanation above).
123enum class CellOrder { DepthMajor, WidthMajor, Diagonal };
124
125// CellFormat describes how data is laid
126// out in a cell. That is, a CellOrder together with actual dimensions.
127template <int tWidth, int tDepth, CellOrder tOrder = CellOrder::DepthMajor>
128struct CellFormat {
129  static const int kWidth = tWidth;
130  static const int kDepth = tDepth;
131  static const CellOrder kOrder = tOrder;
132
133  static const int kSize = kWidth * kDepth;
134};
135
136// KernelSideFormat describes how data is laid out in a kernel side
137// (i.e. LHS or RHS). That is, a CellFormat together with a number of
138// cells. These cells are always stacked in the Width dimension.
139// For example, in the LHS case, the Width dimension is the rows dimension,
140// se we're saying that in the LHS, cells are stacked vertically.
141// We never stack cells in the Depth dimension.
142template <typename tCellFormat, int tCells>
143struct KernelSideFormat {
144  typedef tCellFormat Cell;
145  static const int kCells = tCells;
146  static const int kWidth = kCells * Cell::kWidth;
147  static const int kDepth = Cell::kDepth;
148};
149
150// KernelFormat describes fully the input data layout that a kernel expects.
151// It consists of two KernelSideFormat's, one for LHS and one for RHS.
152template <typename tLhs, typename tRhs>
153struct KernelFormat {
154  typedef tLhs Lhs;
155  typedef tRhs Rhs;
156
157  static_assert(Lhs::Cell::kDepth == Rhs::Cell::kDepth, "");
158  static const int kDepth = Lhs::Cell::kDepth;
159  static const int kRows = Lhs::Cell::kWidth * Lhs::kCells;
160  static const int kCols = Rhs::Cell::kWidth * Rhs::kCells;
161};
162
163inline const char* CellOrderName(CellOrder o) {
164  switch (o) {
165    case CellOrder::DepthMajor:
166      return "DepthMajor";
167    case CellOrder::WidthMajor:
168      return "WidthMajor";
169    case CellOrder::Diagonal:
170      return "Diagonal";
171    default:
172      assert(false);
173      return nullptr;
174  }
175}
176
177// Returns the offset into a cell, at which a given coefficient is stored.
178template <typename CellFormat>
179inline int OffsetIntoCell(int w, int d) {
180  switch (CellFormat::kOrder) {
181    case CellOrder::DepthMajor:
182      return w + d * CellFormat::kWidth;
183    case CellOrder::WidthMajor:
184      return d + w * CellFormat::kDepth;
185    case CellOrder::Diagonal:
186      assert(CellFormat::kWidth == CellFormat::kDepth);
187      static const int size = CellFormat::kWidth;
188      return ((size + w - d) * size + d) % (size * size);
189    default:
190      assert(false);
191      return 0;
192  }
193}
194
195// KernelBase is the virtual base class below all kernels.
196// The idea is that we don't need to templatize all our code on the exact
197// kernel type; we only need to templatize on kernel format. Kernels
198// sharing the same format can thus share the same packing/unpacking code.
199struct KernelBase {
200  virtual const char* Name() const = 0;
201
202  // This is the kernel implementation. We use the word 'run' consistently
203  // throughout gemmlowp to mean an inner loop, the implementation of which
204  // is to be provided by a separate optimized function.
205  virtual void Run(std::int32_t* dst_ptr, int dst_row_stride,
206                   int dst_col_stride, const std::uint8_t* lhs_ptr,
207                   const std::uint8_t* rhs_ptr, int start_depth,
208                   int run_depth) const = 0;
209
210  virtual ~KernelBase() {}
211};
212
213}  // namespace gemmlowp
214
215#endif  // GEMMLOWP_INTERNAL_KERNEL_H_
216