1// Copyright 2015 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// kernel_reference.h: a reference kernel for CPU architectures where we don't
16// have optimized kernels yet. Also useful for testing, as it's templatized
17// to have any arbitrary format, allowing tests to cover all sorts of corner
18// cases.
19
20#ifndef GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_
21#define GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_
22
23#include "kernel.h"
24
25#include <cstdio>
26#include <cstring>
27
28namespace gemmlowp {
29
30// This kernel is templatized in an arbitrary Format template parameter,
31// allowing it to have any arbitrary format.
32template <typename tFormat>
33struct ReferenceKernel : KernelBase {
34  typedef tFormat Format;
35
36  const char* Name() const override {
37    static char buf[256];
38    snprintf(buf, sizeof(buf),
39             "reference(Lhs: %d cells %dx%d %s, Rhs: %d cells %dx%d %s)",
40             Format::Lhs::kCells, Format::Lhs::Cell::kWidth,
41             Format::Lhs::Cell::kDepth,
42             CellOrderName(Format::Lhs::Cell::kOrder), Format::Rhs::kCells,
43             Format::Rhs::Cell::kDepth, Format::Rhs::Cell::kWidth,
44             CellOrderName(Format::Rhs::Cell::kOrder));
45    return buf;
46  }
47
48  void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride,
49           std::size_t dst_col_stride, const std::uint8_t* lhs_ptr,
50           const std::uint8_t* rhs_ptr, std::size_t start_depth,
51           std::size_t run_depth) const override {
52    std::int32_t accumulator[Format::kRows * Format::kCols];
53    memset(accumulator, 0, sizeof(accumulator));
54
55    const int run_depth_cells = static_cast<int>(run_depth / Format::kDepth);
56
57    // The outer loop is over the depth dimension.
58    for (int dc = 0; dc < run_depth_cells; dc++) {
59      // The next two loops are over cells of the Lhs (stacked vertically),
60      // and over cells of the Rhs (stacked horizontally).
61      for (int rc = 0; rc < Format::Lhs::kCells; rc++) {
62        const std::uint8_t* lhs_cell_ptr = lhs_ptr +
63                                           (dc * Format::Lhs::kCells + rc) *
64                                               Format::Lhs::Cell::kWidth *
65                                               Format::kDepth;
66        for (int cc = 0; cc < Format::Rhs::kCells; cc++) {
67          const std::uint8_t* rhs_cell_ptr = rhs_ptr +
68                                             (dc * Format::Rhs::kCells + cc) *
69                                                 Format::Rhs::Cell::kWidth *
70                                                 Format::kDepth;
71
72          // Now we are inside one cell of the Lhs and inside one cell
73          // of the Rhs, so the remaining inner loops are just
74          // traditional three loops of matrix multiplication.
75          for (int di = 0; di < Format::kDepth; di++) {
76            for (int ri = 0; ri < Format::Lhs::Cell::kWidth; ri++) {
77              for (int ci = 0; ci < Format::Rhs::Cell::kWidth; ci++) {
78                const std::uint8_t* lhs_coeff_ptr =
79                    lhs_cell_ptr +
80                    OffsetIntoCell<typename Format::Lhs::Cell>(ri, di);
81                const std::uint8_t* rhs_coeff_ptr =
82                    rhs_cell_ptr +
83                    OffsetIntoCell<typename Format::Rhs::Cell>(ci, di);
84                std::int32_t* accumulator_coeff_ptr =
85                    accumulator + (ri + rc * Format::Lhs::Cell::kWidth) +
86                    (ci + cc * Format::Rhs::Cell::kWidth) * Format::kRows;
87                *accumulator_coeff_ptr +=
88                    std::int32_t(*lhs_coeff_ptr) * std::int32_t(*rhs_coeff_ptr);
89              }
90            }
91          }
92        }
93      }
94    }
95
96    if (start_depth == 0) {
97      // start_depth == 0 means we haven't accumulated anything yet, so we need
98      // to overwrite the accumulator, as it hasn't been initialized to zero.
99      for (int r = 0; r < Format::kRows; r++) {
100        for (int c = 0; c < Format::kCols; c++) {
101          dst_ptr[r * dst_row_stride + c * dst_col_stride] =
102              accumulator[r + c * Format::kRows];
103        }
104      }
105    } else {
106      // We have already accumulated stuff, so we need to continue accumulating
107      // instead of just overwriting.
108      for (int r = 0; r < Format::kRows; r++) {
109        for (int c = 0; c < Format::kCols; c++) {
110          dst_ptr[r * dst_row_stride + c * dst_col_stride] +=
111              accumulator[r + c * Format::kRows];
112        }
113      }
114    }
115  }
116};
117
118}  // namespace gemmlowp
119
120#endif  // GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_
121