1// Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <unistd.h>
16#ifdef __APPLE__
17#include <sys/time.h>
18#endif
19
20#include <cstdint>
21#include <cstdlib>
22#include <ctime>
23#include <iomanip>
24#include <iostream>
25#include <map>
26#include <memory>
27#include <vector>
28
29#include "streams.h"
30
31#define MUL_OFFSET (3)
32#define ADD_OFFSET (100)
33
34using namespace gemmlowp::meta;
35
36void prepare_row_major_data(int rows, int elements, int stride, std::uint8_t* data) {
37  for (int i = 0; i < rows * stride; ++i) {
38    data[i] = 255;
39  }
40  for (int i = 0; i < rows; ++i) {
41    for (int j = 0; j < elements; ++j) {
42      data[i * stride + j] = j % 256;
43    }
44  }
45}
46
47void prepare_column_major_data(int columns, int elements, int stride,
48                               std::uint8_t* data) {
49  for (int i = 0; i < elements * stride; ++i) {
50    data[i] = 255;
51  }
52  for (int i = 0; i < elements; ++i) {
53    for (int j = 0; j < columns; ++j) {
54      data[i * stride + j] = i % 256;
55    }
56  }
57}
58
59void print_out(std::uint8_t* result, int rows, int elements) {
60  int size = rows * ((elements + 7) / 8) * 8;
61  for (int i = 0; i < size; ++i) {
62    std::cout << static_cast<int>(result[i]) << " ";
63  }
64  std::cout << std::endl << std::flush;
65}
66
67bool check(std::uint8_t* result, int rows, int elements) {
68  int chunks = elements / 8;
69  int leftover = elements % 8;
70  for (int i = 0; i < chunks; ++i) {
71    int chunk_index = i * rows * 8;
72    int chunk_start_value = i * 8;
73    for (int j = 0; j < rows; ++j) {
74      for (int k = 0; k < 8; ++k) {
75        if (result[chunk_index + j * 8 + k] != chunk_start_value + k) {
76          return false;
77        }
78      }
79    }
80  }
81
82  int leftover_index = chunks * rows * 8;
83  int leftover_start_value = chunks * 8;
84  for (int i = 0; i < rows; ++i) {
85    for (int j = 0; j < leftover; ++j) {
86      if (result[leftover_index + i * 8 + j] != leftover_start_value + j) {
87        return false;
88      }
89    }
90  }
91
92  int expected_sum =
93      ((elements * (elements - 1)) / 2) * MUL_OFFSET + ADD_OFFSET;
94  int sums_offset = rows * ((elements + 7) / 8) * 8;
95  std::int32_t* sums = reinterpret_cast<std::int32_t*>(result + sums_offset);
96  for (int i = 0; i < rows; ++i) {
97    if (sums[i] != expected_sum) {
98      return false;
99    }
100  }
101
102  return true;
103}
104
105template <int lanes, int leftover>
106void test_2(std::uint8_t* in, std::uint8_t* out) {
107  for (int elements = 8; elements < 64; elements += 8) {
108    int all_elements = elements + leftover;
109    for (int stride = all_elements; stride < all_elements + 4; ++stride) {
110      RowMajorWithSum params;
111      params.count = all_elements;
112      params.stride = stride;
113      params.multiplicative_sum_offset = MUL_OFFSET;
114      params.additive_sum_offset = ADD_OFFSET;
115
116      prepare_row_major_data(lanes, all_elements, stride, in);
117      Stream<std::uint8_t, lanes, 8, leftover, RowMajorWithSum>::Pack(in, params,
118                                                                 out);
119      if (check(out, lanes, all_elements)) {
120        //        std::cout << "Row: " << lanes << "x8x" << leftover << " : "
121        //                  << all_elements << "@" << stride << " -- OK" <<
122        //                  std::endl;
123      } else {
124        std::cout << "Row: " << lanes << "x8x" << leftover << " : "
125                  << all_elements << "@" << stride << " -- ERROR" << std::endl;
126        std::cout << "Exiting." << std::endl;
127        std::exit(1);
128      }
129    }
130
131    for (int stride = lanes; stride < lanes + 4; ++stride) {
132      ColumnMajorWithSum params;
133      params.count = all_elements;
134      params.stride = stride;
135      params.multiplicative_sum_offset = MUL_OFFSET;
136      params.additive_sum_offset = ADD_OFFSET;
137
138      prepare_column_major_data(lanes, all_elements, stride, in);
139      Stream<std::uint8_t, lanes, 8, leftover, ColumnMajorWithSum>::Pack(in, params,
140                                                                    out);
141      if (check(out, lanes, all_elements)) {
142        //        std::cout << "Column: " << lanes << "x8x" << leftover << " : "
143        //                  << all_elements << "@" << stride << " -- OK" <<
144        //                  std::endl;
145      } else {
146        std::cout << "Column: " << lanes << "x8x" << leftover << " : "
147                  << all_elements << "@" << stride << " -- ERROR" << std::endl;
148        std::cout << "Exiting." << std::endl;
149        std::exit(1);
150      }
151    }
152  }
153}
154
155template <int lanes>
156void test(std::uint8_t* in, std::uint8_t* out) {
157  test_2<lanes, 0>(in, out);
158  test_2<lanes, 1>(in, out);
159  test_2<lanes, 2>(in, out);
160  test_2<lanes, 3>(in, out);
161  test_2<lanes, 4>(in, out);
162  test_2<lanes, 5>(in, out);
163  test_2<lanes, 6>(in, out);
164  test_2<lanes, 7>(in, out);
165}
166
167int main() {
168  std::unique_ptr<std::uint8_t> in(new std::uint8_t[128 * 1024]);
169  std::unique_ptr<std::uint8_t> out(new std::uint8_t[128 * 1024]);
170
171  test<1>(in.get(), out.get());
172  test<2>(in.get(), out.get());
173  test<3>(in.get(), out.get());
174  test<4>(in.get(), out.get());
175  test<5>(in.get(), out.get());
176  test<6>(in.get(), out.get());
177  test<7>(in.get(), out.get());
178  test<8>(in.get(), out.get());
179
180  std::cout << "Ok." << std::endl;
181  return 0;
182}
183