1// Copyright 2015 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <unistd.h>
16#ifdef __APPLE__
17#include <sys/time.h>
18#endif
19
20#include <cstdint>
21#include <cstdlib>
22#include <ctime>
23#include <iomanip>
24#include <iostream>
25#include <map>
26#include <vector>
27
28#include "../eight_bit_int_gemm/eight_bit_int_gemm.h"
29#include "test.h"
30
31#if defined(__arm__) && !defined(GEMMLOWP_NEON)
32#warning "Building without NEON support on ARM, check your compiler setup!"
33#endif
34
35double time() {
36#ifdef __APPLE__
37  timeval t;
38  gettimeofday(&t, nullptr);
39  return t.tv_sec + 1e-6 * t.tv_usec;
40#else
41  timespec t;
42  clock_gettime(CLOCK_REALTIME, &t);
43  return t.tv_sec + 1e-9 * t.tv_nsec;
44#endif
45}
46
47const std::int32_t MIN_WORKING_SET_SIZE = 2 * 1024 * 1024;
48const double MIN_OPS = 1000.0 * 1000000.0;
49
50struct WorkingSet {
51  WorkingSet() : lhs(nullptr), rhs(nullptr), result(nullptr) {}
52
53  void init(std::int32_t n, std::int32_t m, std::int32_t k) {
54    lhs = new std::uint8_t[n * k];
55    rhs = new std::uint8_t[k * m];
56    result = new std::uint8_t[m * n];
57  }
58
59  std::uint8_t* lhs;
60  std::uint8_t* rhs;
61  std::uint8_t* result;
62};
63
64struct Shape {
65  std::int32_t n;
66  std::int32_t m;
67  std::int32_t k;
68
69  std::int32_t repetitions;
70  std::int32_t current_set;
71  std::vector<WorkingSet> working_sets;
72
73  Shape(std::int32_t n, std::int32_t m, std::int32_t k)
74      : n(n), m(m), k(k), repetitions(1), current_set(0), working_sets() {}
75
76  void init() {
77    const std::int32_t size = n * k + k * m + n * m;
78    const std::int32_t count = MIN_WORKING_SET_SIZE / size + 1;
79    const double ops = static_cast<double>(n) * static_cast<double>(m) *
80                       static_cast<double>(k);
81    for (int i = 0; i < count; ++i) {
82      working_sets.push_back(WorkingSet());
83      working_sets.back().init(n, m, k);
84    }
85    current_set = 0;
86    repetitions = MIN_OPS / ops + 20;
87  }
88
89  WorkingSet& working_set() { return working_sets[current_set]; }
90
91  void next_working_set() {
92    current_set = (current_set + 1) % working_sets.size();
93  }
94};
95
96double run_gemm(std::int32_t n, std::int32_t m, std::int32_t k,
97                std::uint8_t* lhs, std::uint8_t* rhs, std::uint8_t* result) {
98  gemmlowp::eight_bit_int_gemm::EightBitIntGemm(
99      true, false, false, m, n, k, rhs, -100, k, lhs, -100, k, result, 10000,
100      10, 3, m, gemmlowp::eight_bit_int_gemm::BitDepthSetting::A8B8);
101  return static_cast<double>(n * m * k * 2);
102}
103
104double run_gemms(std::vector<Shape>* shapes) {
105  double ops = 0.0;
106  for (auto& shape : *shapes) {
107    ops += run_gemm(shape.n, shape.m, shape.k, shape.working_set().lhs,
108                    shape.working_set().rhs, shape.working_set().result);
109  }
110  return ops;
111}
112
113void print_summary(std::vector<double>* times, bool full) {
114  std::sort(times->begin(), times->end());
115
116  double sum_times = 0;
117  double sum_times_trimmed = 0;
118  int count_times_trimmed = 0;
119  const float trim_ratio = 0.25;
120  const size_t count_trimmed = times->size() * trim_ratio;
121  double sum_times_best = 0;
122  int count_times_best = 0;
123  const float best_ratio = 0.1;
124  const size_t count_best = times->size() * best_ratio;
125
126  for (size_t i = 0; i < times->size(); i++) {
127    sum_times += (*times)[i];
128    if (i >= count_trimmed && i < times->size() - count_trimmed) {
129      sum_times_trimmed += (*times)[i];
130      count_times_trimmed++;
131    }
132    if (i < count_best) {
133      sum_times_best += (*times)[i];
134      count_times_best++;
135    }
136  }
137
138  const double min_latency = times->front();
139  const double max_latency = times->back();
140  const double mean_latency = sum_times / times->size();
141  const double trimmed_mean_latency = sum_times_trimmed / count_times_trimmed;
142  const double best_mean_latency = sum_times_best / count_times_best;
143
144  if (full) {
145    std::cout << "Graph latency (over " << times->size()
146              << " iterations):" << std::endl;
147    std::cout << "  Best:             " << min_latency << "s" << std::endl;
148    std::cout << "  Worst:            " << max_latency << "s" << std::endl;
149    std::cout << "  Mean:             " << mean_latency << "s" << std::endl;
150    std::cout << "  " << 100 * trim_ratio
151              << "% trimmed mean: " << trimmed_mean_latency << "s" << std::endl;
152    std::cout << "  Mean of " << 100 * best_ratio
153              << "% best: " << best_mean_latency << "s" << std::endl;
154  } else {
155    std::cout << (mean_latency * 1000.0) << std::endl;
156  }
157}
158
159void time_all(std::vector<Shape>* shapes, std::int32_t repetitions,
160              double max_time) {
161  std::vector<double> times;
162  double ops = 0.0;
163  double sum_time = 0.0;
164
165  while (sum_time < max_time) {
166    double start = time();
167
168    for (int i = 0; i < repetitions; ++i) {
169      ops += run_gemms(shapes);
170    }
171    double delta_time = (time() - start);
172    times.push_back(delta_time / repetitions);
173    sum_time += delta_time;
174  }
175
176  print_summary(&times, true);
177}
178
179void time_one(Shape* shape, double max_time) {
180  std::vector<double> times;
181  double ops = 0.0;
182  double sum_time = 0.0;
183
184  std::cout << std::setprecision(6) << std::fixed << shape->n << ", "
185            << shape->m << ", " << shape->k << ", " << std::flush;
186
187  while (sum_time < max_time) {
188    double start = time();
189
190    for (int i = 0; i < shape->repetitions; ++i) {
191      ops += run_gemm(shape->n, shape->m, shape->k, shape->working_set().lhs,
192                      shape->working_set().rhs, shape->working_set().result);
193      shape->next_working_set();
194    }
195    double delta_time = (time() - start);
196    times.push_back(delta_time / shape->repetitions);
197    sum_time += delta_time;
198  }
199
200  print_summary(&times, false);
201}
202
203int main() {
204  std::vector<Shape> googlenet_gemms;
205  googlenet_gemms.push_back(Shape(12544, 64, 147));
206  googlenet_gemms.push_back(Shape(3136, 64, 64));
207  googlenet_gemms.push_back(Shape(3136, 192, 576));
208  googlenet_gemms.push_back(Shape(784, 64, 192));
209  googlenet_gemms.push_back(Shape(784, 96, 192));
210  googlenet_gemms.push_back(Shape(784, 128, 864));
211  googlenet_gemms.push_back(Shape(784, 16, 192));
212  googlenet_gemms.push_back(Shape(784, 32, 400));
213  googlenet_gemms.push_back(Shape(784, 32, 192));
214  googlenet_gemms.push_back(Shape(784, 128, 256));
215  googlenet_gemms.push_back(Shape(784, 128, 256));
216  googlenet_gemms.push_back(Shape(784, 192, 1152));
217  googlenet_gemms.push_back(Shape(784, 32, 256));
218  googlenet_gemms.push_back(Shape(784, 96, 800));
219  googlenet_gemms.push_back(Shape(784, 64, 256));
220  googlenet_gemms.push_back(Shape(196, 192, 480));
221  googlenet_gemms.push_back(Shape(196, 96, 480));
222  googlenet_gemms.push_back(Shape(196, 204, 864));
223  googlenet_gemms.push_back(Shape(196, 16, 480));
224  googlenet_gemms.push_back(Shape(196, 48, 400));
225  googlenet_gemms.push_back(Shape(196, 64, 480));
226  googlenet_gemms.push_back(Shape(196, 160, 508));
227  googlenet_gemms.push_back(Shape(196, 112, 508));
228  googlenet_gemms.push_back(Shape(196, 224, 1008));
229  googlenet_gemms.push_back(Shape(196, 24, 508));
230  googlenet_gemms.push_back(Shape(196, 64, 600));
231  googlenet_gemms.push_back(Shape(196, 64, 508));
232  googlenet_gemms.push_back(Shape(196, 128, 512));
233  googlenet_gemms.push_back(Shape(196, 128, 512));
234  googlenet_gemms.push_back(Shape(196, 256, 1152));
235  googlenet_gemms.push_back(Shape(196, 24, 512));
236  googlenet_gemms.push_back(Shape(196, 64, 600));
237  googlenet_gemms.push_back(Shape(196, 64, 512));
238  googlenet_gemms.push_back(Shape(196, 112, 512));
239  googlenet_gemms.push_back(Shape(196, 144, 512));
240  googlenet_gemms.push_back(Shape(196, 288, 1296));
241  googlenet_gemms.push_back(Shape(196, 32, 512));
242  googlenet_gemms.push_back(Shape(196, 64, 800));
243  googlenet_gemms.push_back(Shape(196, 64, 512));
244  googlenet_gemms.push_back(Shape(196, 256, 528));
245  googlenet_gemms.push_back(Shape(196, 160, 528));
246  googlenet_gemms.push_back(Shape(196, 320, 1440));
247  googlenet_gemms.push_back(Shape(196, 32, 528));
248  googlenet_gemms.push_back(Shape(196, 128, 800));
249  googlenet_gemms.push_back(Shape(196, 128, 528));
250  googlenet_gemms.push_back(Shape(49, 256, 832));
251  googlenet_gemms.push_back(Shape(49, 160, 832));
252  googlenet_gemms.push_back(Shape(49, 320, 1440));
253  googlenet_gemms.push_back(Shape(49, 48, 832));
254  googlenet_gemms.push_back(Shape(49, 128, 1200));
255  googlenet_gemms.push_back(Shape(49, 128, 832));
256  googlenet_gemms.push_back(Shape(49, 384, 832));
257  googlenet_gemms.push_back(Shape(49, 192, 832));
258  googlenet_gemms.push_back(Shape(49, 384, 1728));
259  googlenet_gemms.push_back(Shape(49, 48, 832));
260  googlenet_gemms.push_back(Shape(49, 128, 1200));
261  googlenet_gemms.push_back(Shape(49, 128, 832));
262  googlenet_gemms.push_back(Shape(16, 128, 508));
263  googlenet_gemms.push_back(Shape(1, 1024, 2048));
264  googlenet_gemms.push_back(Shape(1, 1008, 1024));
265  googlenet_gemms.push_back(Shape(16, 128, 528));
266  googlenet_gemms.push_back(Shape(1, 1024, 2048));
267  googlenet_gemms.push_back(Shape(1, 1008, 1024));
268  googlenet_gemms.push_back(Shape(1, 1008, 1024));
269
270  for (auto& shape : googlenet_gemms) {
271    shape.init();
272  }
273
274  std::vector<Shape> small_gemms;
275  small_gemms.push_back(Shape(29232, 16, 25));
276  small_gemms.push_back(Shape(7308, 6, 400));
277  small_gemms.push_back(Shape(203, 3002, 216));
278
279  for (auto& shape : small_gemms) {
280    shape.init();
281  }
282
283  std::vector<Shape> others;
284  others.push_back(Shape(100, 100, 100));
285  others.push_back(Shape(1000, 1000, 1000));
286  others.push_back(Shape(2000, 1000, 1000));
287
288  for (auto& shape : others) {
289    shape.init();
290  }
291
292  gemmlowp::eight_bit_int_gemm::SetMaxNumThreads(4);
293
294  std::cout << "Warmup run." << std::endl;
295  time_all(&googlenet_gemms, 10, 1.0);
296  time_all(&small_gemms, 50, 1.0);
297
298  std::cout << "Timing all." << std::endl;
299  time_all(&googlenet_gemms, 10, 20.0);
300  time_all(&small_gemms, 50, 10.0);
301
302  std::cout << "Timing separate." << std::endl;
303
304  for (auto& shape : googlenet_gemms) {
305    time_one(&shape, 0.10);
306  }
307
308  for (auto& shape : small_gemms) {
309    time_one(&shape, 0.10);
310  }
311
312  for (auto& shape : others) {
313    time_one(&shape, 0.10);
314  }
315
316  return 0;
317}
318