1// Copyright 2016 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// test_fixedpoint.cc: unit tests covering the fixedpoint/ directory.
16
17#define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
18
19#include <algorithm>
20#include <cmath>
21#include <random>
22#include <vector>
23#include "test.h"
24
25#include "../fixedpoint/fixedpoint.h"
26
27namespace gemmlowp {
28
29namespace {
30
31// Explanation of SimdVector type and associated functions
32// (LoadSimdVector, StoreSimdVector):
33// The fixedpoint stuff being tested here is generic in an underlying
34// integer type which may be either scalar (int32_t) or SIMD (e.g.
35// NEON int32x4_t). We want to write uniform tests that can test
36// both the scalar and SIMD paths. We achieve this by having this
37// generic SimdVector abstraction, local to this test.
38
39#ifdef GEMMLOWP_NEON
40using SimdVector = int32x4_t;
41constexpr std::size_t SimdVectorSize = 4;
42SimdVector LoadSimdVector(const std::int32_t* src) { return vld1q_s32(src); }
43void StoreSimdVector(std::int32_t* dst, SimdVector v) { vst1q_s32(dst, v); }
44#elif defined(GEMMLOWP_SSE4)
45using SimdVector = __m128i;
46constexpr std::size_t SimdVectorSize = 4;
47SimdVector LoadSimdVector(const std::int32_t* src) {
48  return _mm_loadu_si128(reinterpret_cast<const __m128i*>(src));
49}
50void StoreSimdVector(std::int32_t* dst, SimdVector v) {
51  _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), v);
52}
53#else
54using SimdVector = std::int32_t;
55constexpr std::size_t SimdVectorSize = 1;
56SimdVector LoadSimdVector(const std::int32_t* src) { return *src; }
57void StoreSimdVector(std::int32_t* dst, SimdVector v) { *dst = v; }
58#endif
59
60// Explanation of UnaryOpBase, its *Op subclasses below, and TestUnaryOp:
61// Most (though not all) of the fixedpoint functionality being tested
62// consists of functions taking one fixedpoint value and returning one
63// fixedpoint value, e.g. "exp" or "tanh". We call them "unary operators".
64// We factor a lot of testing boilerplate into a common TestUnaryOp function
65// taking a "unary op" object that fully describes the function to be tested.
66// These objects inherit UnaryOpBase mostly as a means to share some default
67// values for some properties.
68//
69// An important design element here is that the fixed-point values are passed
70// around as raw integers (e.g. int32_t or SIMD types such as int32x4_t), not
71// as higher-level FixedPoint objects. The motivation for this design is 1) to
72// avoid having to templatize everything in the tIntegerBits parameter of
73// class FixedPoint, and 2) to allow directly testing low-level functions
74// operating on raw types (e.g. RoundingDivideByPOT) without needlessly
75// requiring
76// wrapping raw values in FixedPoint objects.
77class UnaryOpBase {
78 public:
79  // Min bound of the input range of this op. For example, an op only handling
80  // nonnegative values would return 0.
81  std::int32_t MinInput() const {
82    return std::numeric_limits<std::int32_t>::min();
83  }
84  // Max bound of the input range of this op. For example, an op only handling
85  // nonpositive values would return 0.
86  std::int32_t MaxInput() const {
87    return std::numeric_limits<std::int32_t>::max();
88  }
89  // Tolerated difference between actual and reference int32 values.
90  // Note that the corresponding real-numbers tolerance depends on the number
91  // of integer bits of the fixed-point representation of the results of this
92  // op.
93  // For example, for an op returning fixed-point values with 0 integer bits,
94  // the correspondence between real-number values and raw values is
95  // real_number = (2^31) * raw_value.
96  std::int32_t Tolerance() const { return 0; }
97};
98
99// Op wrapping RoundingDivideByPOT
100class RoundingDivideByPOTOp final : public UnaryOpBase {
101 public:
102  RoundingDivideByPOTOp(int exponent) : exponent_(exponent) {}
103  std::int32_t ReferenceOp(std::int32_t x) const {
104    const double d = static_cast<double>(x) / (1ll << exponent_);
105    return static_cast<std::int32_t>(std::round(d));
106  }
107  template <typename tRawType>
108  tRawType Op(tRawType x) const {
109    return RoundingDivideByPOT(x, exponent_);
110  }
111
112 private:
113  const int exponent_;
114};
115
116// Op wrapping SaturatingRoundingMultiplyByPOT
117template <int tExponent>
118class SaturatingRoundingMultiplyByPOTOp final : public UnaryOpBase {
119 public:
120  std::int32_t ReferenceOp(std::int32_t x) const {
121    const double d = static_cast<double>(x) * std::pow(2., tExponent);
122    const double clamp_min = std::numeric_limits<std::int32_t>::min();
123    const double clamp_max = std::numeric_limits<std::int32_t>::max();
124    const double clamped = std::min(clamp_max, std::max(clamp_min, d));
125    return static_cast<std::int32_t>(std::round(clamped));
126  }
127  template <typename tRawType>
128  tRawType Op(tRawType x) const {
129    return SaturatingRoundingMultiplyByPOT<tExponent>(x);
130  }
131};
132
133// Op wrapping exp_on_interval_between_negative_one_quarter_and_0_excl
134class ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp final
135    : public UnaryOpBase {
136 public:
137  std::int32_t MinInput() const { return -(1 << 29); }
138  std::int32_t MaxInput() const { return 0; }
139  std::int32_t Tolerance() const { return 500; }
140  std::int32_t ReferenceOp(std::int32_t x) const {
141    using F = FixedPoint<std::int32_t, 0>;
142    const double d = ToDouble(F::FromRaw(x));
143    const double e = std::exp(d);
144    return F::FromDouble(e).raw();
145  }
146  template <typename tRawType>
147  tRawType Op(tRawType x) const {
148    using F = FixedPoint<tRawType, 0>;
149    const F f = F::FromRaw(x);
150    const F e = exp_on_interval_between_negative_one_quarter_and_0_excl(f);
151    return e.raw();
152  }
153};
154
155// Op wrapping exp_on_negative_values
156template <int tIntegerBits>
157class ExpOnNegativeValuesOp final : public UnaryOpBase {
158 public:
159  std::int32_t MaxInput() const { return 0; }
160  std::int32_t Tolerance() const { return 500; }
161  std::int32_t ReferenceOp(std::int32_t x) const {
162    using F = FixedPoint<std::int32_t, tIntegerBits>;
163    using F0 = FixedPoint<std::int32_t, 0>;
164    const double d = ToDouble(F::FromRaw(x));
165    const double e = std::exp(d);
166    return F0::FromDouble(e).raw();
167  }
168  template <typename tRawType>
169  tRawType Op(tRawType x) const {
170    using F = FixedPoint<tRawType, tIntegerBits>;
171    const F f = F::FromRaw(x);
172    return exp_on_negative_values(f).raw();
173  }
174};
175
176// Op wrapping one_minus_x_over_one_plus_x_for_x_in_0_1
177class OneMinusXOverOnePlusXForXIn01Op final : public UnaryOpBase {
178 public:
179  std::int32_t MinInput() const { return 0; }
180  std::int32_t Tolerance() const { return 12; }
181  std::int32_t ReferenceOp(std::int32_t x) const {
182    using F = FixedPoint<std::int32_t, 0>;
183    const double d = ToDouble(F::FromRaw(x));
184    const double e = (1 - d) / (1 + d);
185    return F::FromDouble(e).raw();
186  }
187  template <typename tRawType>
188  tRawType Op(tRawType x) const {
189    using F = FixedPoint<tRawType, 0>;
190    const F f = F::FromRaw(x);
191    return one_minus_x_over_one_plus_x_for_x_in_0_1(f).raw();
192  }
193};
194
195// Op wrapping tanh
196template <int tIntegerBits>
197class TanhOp final : public UnaryOpBase {
198 public:
199  std::int32_t Tolerance() const { return 310; }
200  std::int32_t ReferenceOp(std::int32_t x) const {
201    using F = FixedPoint<std::int32_t, tIntegerBits>;
202    using F0 = FixedPoint<std::int32_t, 0>;
203    const double d = ToDouble(F::FromRaw(x));
204    const double e = std::tanh(d);
205    return F0::FromDouble(e).raw();
206  }
207  template <typename tRawType>
208  tRawType Op(tRawType x) const {
209    using F = FixedPoint<tRawType, tIntegerBits>;
210    const F f = F::FromRaw(x);
211    return tanh(f).raw();
212  }
213};
214
215// Op wrapping one_over_one_plus_x_for_x_in_0_1
216class OneOverOnePlusXForXIn01Op final : public UnaryOpBase {
217 public:
218  std::int32_t MinInput() const { return 0; }
219  std::int32_t Tolerance() const { return 6; }
220  std::int32_t ReferenceOp(std::int32_t x) const {
221    using F = FixedPoint<std::int32_t, 0>;
222    const double d = ToDouble(F::FromRaw(x));
223    const double e = 1 / (1 + d);
224    return F::FromDouble(e).raw();
225  }
226  template <typename tRawType>
227  tRawType Op(tRawType x) const {
228    using F = FixedPoint<tRawType, 0>;
229    const F f = F::FromRaw(x);
230    return one_over_one_plus_x_for_x_in_0_1(f).raw();
231  }
232};
233
234// Op wrapping logistic
235template <int tIntegerBits>
236class LogisticOp final : public UnaryOpBase {
237 public:
238  std::int32_t Tolerance() const { return 155; }
239  std::int32_t ReferenceOp(std::int32_t x) const {
240    using F = FixedPoint<std::int32_t, tIntegerBits>;
241    using F0 = FixedPoint<std::int32_t, 0>;
242    const double d = ToDouble(F::FromRaw(x));
243    const double e = 1 / (1 + std::exp(-d));
244    return F0::FromDouble(e).raw();
245  }
246  template <typename tRawType>
247  tRawType Op(tRawType x) const {
248    using F = FixedPoint<tRawType, tIntegerBits>;
249    const F f = F::FromRaw(x);
250    return logistic(f).raw();
251  }
252};
253
254// Tests a given op, on a given list of int32 input values.
255template <typename tUnaryOpType>
256void TestUnaryOp(const tUnaryOpType& unary_op,
257                 const std::vector<std::int32_t>& testvals_int32) {
258  Check(0 == (testvals_int32.size() % SimdVectorSize));
259  for (std::size_t i = 0; i < testvals_int32.size(); i += SimdVectorSize) {
260    // First, clamp input int32 values accoding to the MinInput() and MaxInput()
261    // bounds returned by the op.
262    std::int32_t input[SimdVectorSize] = {0};
263    for (std::size_t j = 0; j < SimdVectorSize; j++) {
264      const std::int32_t raw_input = testvals_int32[i + j];
265      input[j] = std::min(unary_op.MaxInput(),
266                          std::max(unary_op.MinInput(), raw_input));
267    }
268    // Compute reference results and check that the actual results on
269    // scalar inputs agree with them, to the Tolerance() returned by the op.
270    std::int32_t reference[SimdVectorSize] = {0};
271    std::int32_t actual_scalar[SimdVectorSize] = {0};
272    for (std::size_t j = 0; j < SimdVectorSize; j++) {
273      reference[j] = unary_op.ReferenceOp(input[j]);
274      actual_scalar[j] = unary_op.Op(input[j]);
275      const std::int64_t diff = static_cast<std::int64_t>(actual_scalar[j]) -
276                                static_cast<std::int64_t>(reference[j]);
277      Check(std::abs(diff) <= unary_op.Tolerance());
278    }
279    // Check that the actual results on SIMD inputs agree *exactly* with the
280    // actual results on scalar inputs. I.e. SIMD must make absolutely no
281    // difference
282    // to the results, regardless of the fact that both scalar and SIMD results
283    // may differ from the reference results.
284    std::int32_t actual_simd[SimdVectorSize] = {0};
285    StoreSimdVector(actual_simd, unary_op.Op(LoadSimdVector(input)));
286    for (std::size_t j = 0; j < SimdVectorSize; j++) {
287      Check(actual_simd[j] == actual_scalar[j]);
288    }
289  }
290}
291
292template <int tIntegerBits>
293void test_convert(FixedPoint<std::int32_t, tIntegerBits> x) {
294  typedef FixedPoint<std::int32_t, tIntegerBits> F;
295  F y = F::FromDouble(ToDouble(x));
296  Check(y == x);
297}
298
299template <int tIntegerBits_a, int tIntegerBits_b>
300void test_Rescale(FixedPoint<std::int32_t, tIntegerBits_a> a) {
301  FixedPoint<std::int32_t, tIntegerBits_b> actual = Rescale<tIntegerBits_b>(a);
302  FixedPoint<std::int32_t, tIntegerBits_b> expected =
303      FixedPoint<std::int32_t, tIntegerBits_b>::FromDouble(ToDouble(a));
304  Check(actual == expected);
305}
306
307template <int tIntegerBits_a, int tIntegerBits_b>
308void test_Rescale(const std::vector<std::int32_t>& testvals_int32) {
309  for (auto a : testvals_int32) {
310    FixedPoint<std::int32_t, tIntegerBits_a> aq;
311    aq.raw() = a;
312    test_Rescale<tIntegerBits_a, tIntegerBits_b>(aq);
313  }
314}
315
316template <int tIntegerBits_a, int tIntegerBits_b>
317void test_mul(FixedPoint<std::int32_t, tIntegerBits_a> a,
318              FixedPoint<std::int32_t, tIntegerBits_b> b) {
319  static const int ProductIntegerBits = tIntegerBits_a + tIntegerBits_b;
320  using ProductFixedPoint = FixedPoint<std::int32_t, ProductIntegerBits>;
321  ProductFixedPoint ab;
322  ab = a * b;
323  double a_double = ToDouble(a);
324  double b_double = ToDouble(b);
325  double ab_double = a_double * b_double;
326  ProductFixedPoint expected = ProductFixedPoint::FromDouble(ab_double);
327  std::int64_t diff = std::int64_t(ab.raw()) - std::int64_t(expected.raw());
328  Check(std::abs(diff) <= 1);
329}
330
331template <int tIntegerBits_a, int tIntegerBits_b>
332void test_mul(const std::vector<std::int32_t>& testvals_int32) {
333  for (auto a : testvals_int32) {
334    for (auto b : testvals_int32) {
335      FixedPoint<std::int32_t, tIntegerBits_a> aq;
336      FixedPoint<std::int32_t, tIntegerBits_b> bq;
337      aq.raw() = a;
338      bq.raw() = b;
339      test_mul(aq, bq);
340    }
341  }
342}
343
344template <int tExponent, int tIntegerBits_a>
345void test_ExactMulByPot(FixedPoint<std::int32_t, tIntegerBits_a> a) {
346  double x = ToDouble(a) * std::pow(2.0, tExponent);
347  double y = ToDouble(ExactMulByPot<tExponent>(a));
348  Check(x == y);
349}
350
351template <int tExponent, int tIntegerBits_a>
352void test_ExactMulByPot(const std::vector<std::int32_t>& testvals_int32) {
353  for (auto a : testvals_int32) {
354    FixedPoint<std::int32_t, tIntegerBits_a> aq;
355    aq.raw() = a;
356    test_ExactMulByPot<tExponent, tIntegerBits_a>(aq);
357  }
358}
359
360// Make the list of test values to test each op against.
361std::vector<std::int32_t> MakeTestValsInt32() {
362  std::vector<std::int32_t> testvals_int32;
363
364  for (int i = 0; i < 31; i++) {
365    testvals_int32.push_back((1 << i) - 2);
366    testvals_int32.push_back((1 << i) - 1);
367    testvals_int32.push_back((1 << i));
368    testvals_int32.push_back((1 << i) + 1);
369    testvals_int32.push_back((1 << i) + 2);
370    testvals_int32.push_back(-(1 << i) - 2);
371    testvals_int32.push_back(-(1 << i) - 1);
372    testvals_int32.push_back(-(1 << i));
373    testvals_int32.push_back(-(1 << i) + 1);
374    testvals_int32.push_back(-(1 << i) + 2);
375  }
376  testvals_int32.push_back(std::numeric_limits<std::int32_t>::min());
377  testvals_int32.push_back(std::numeric_limits<std::int32_t>::min() + 1);
378  testvals_int32.push_back(std::numeric_limits<std::int32_t>::min() + 2);
379  testvals_int32.push_back(std::numeric_limits<std::int32_t>::max() - 2);
380  testvals_int32.push_back(std::numeric_limits<std::int32_t>::max() - 1);
381  testvals_int32.push_back(std::numeric_limits<std::int32_t>::max());
382
383  std::mt19937 random_engine;
384  std::uniform_int_distribution<std::int32_t> uniform_distribution(
385      std::numeric_limits<std::int32_t>::min(),
386      std::numeric_limits<std::int32_t>::max());
387  for (int i = 0; i < 1000; i++) {
388    testvals_int32.push_back(uniform_distribution(random_engine));
389  }
390
391  // SIMD tests will require the length of testvals_int32 to be a multiple
392  // of SIMD vector size.
393  while (testvals_int32.size() % SimdVectorSize) {
394    testvals_int32.push_back(0);
395  }
396
397  std::sort(testvals_int32.begin(), testvals_int32.end());
398  return testvals_int32;
399}
400
401}  // end anonymous namespace
402
403}  // end namespace gemmlowp
404
405int main() {
406  using namespace gemmlowp;
407
408  const std::vector<std::int32_t> testvals_int32 = MakeTestValsInt32();
409
410  for (int s = 0; s < 32; s++) {
411    TestUnaryOp(RoundingDivideByPOTOp(s), testvals_int32);
412  }
413
414  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-31>(), testvals_int32);
415  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-30>(), testvals_int32);
416  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-29>(), testvals_int32);
417  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-17>(), testvals_int32);
418  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-16>(), testvals_int32);
419  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-15>(), testvals_int32);
420  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-4>(), testvals_int32);
421  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-3>(), testvals_int32);
422  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-2>(), testvals_int32);
423  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<-1>(), testvals_int32);
424  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<0>(), testvals_int32);
425  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<1>(), testvals_int32);
426  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<2>(), testvals_int32);
427  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<3>(), testvals_int32);
428  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<4>(), testvals_int32);
429  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<15>(), testvals_int32);
430  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<16>(), testvals_int32);
431  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<17>(), testvals_int32);
432  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<29>(), testvals_int32);
433  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<30>(), testvals_int32);
434  TestUnaryOp(SaturatingRoundingMultiplyByPOTOp<31>(), testvals_int32);
435
436  TestUnaryOp(ExpOnIntervalBetweenNegativeOneQuarterAnd0ExclOp(),
437              testvals_int32);
438  TestUnaryOp(ExpOnNegativeValuesOp<0>(), testvals_int32);
439  TestUnaryOp(ExpOnNegativeValuesOp<1>(), testvals_int32);
440  TestUnaryOp(ExpOnNegativeValuesOp<2>(), testvals_int32);
441  TestUnaryOp(ExpOnNegativeValuesOp<3>(), testvals_int32);
442  TestUnaryOp(ExpOnNegativeValuesOp<4>(), testvals_int32);
443  TestUnaryOp(ExpOnNegativeValuesOp<5>(), testvals_int32);
444  TestUnaryOp(ExpOnNegativeValuesOp<6>(), testvals_int32);
445
446  TestUnaryOp(OneMinusXOverOnePlusXForXIn01Op(), testvals_int32);
447  TestUnaryOp(TanhOp<0>(), testvals_int32);
448  TestUnaryOp(TanhOp<1>(), testvals_int32);
449  TestUnaryOp(TanhOp<2>(), testvals_int32);
450  TestUnaryOp(TanhOp<3>(), testvals_int32);
451  TestUnaryOp(TanhOp<4>(), testvals_int32);
452  TestUnaryOp(TanhOp<5>(), testvals_int32);
453  TestUnaryOp(TanhOp<6>(), testvals_int32);
454
455  TestUnaryOp(OneOverOnePlusXForXIn01Op(), testvals_int32);
456  TestUnaryOp(LogisticOp<0>(), testvals_int32);
457  TestUnaryOp(LogisticOp<1>(), testvals_int32);
458  TestUnaryOp(LogisticOp<2>(), testvals_int32);
459  TestUnaryOp(LogisticOp<3>(), testvals_int32);
460  TestUnaryOp(LogisticOp<4>(), testvals_int32);
461  TestUnaryOp(LogisticOp<5>(), testvals_int32);
462  TestUnaryOp(LogisticOp<6>(), testvals_int32);
463
464  for (auto a : testvals_int32) {
465    FixedPoint<std::int32_t, 4> x;
466    x.raw() = a;
467    test_convert(x);
468  }
469
470  test_mul<0, 0>(testvals_int32);
471  test_mul<0, 1>(testvals_int32);
472  test_mul<2, 0>(testvals_int32);
473  test_mul<1, 1>(testvals_int32);
474  test_mul<4, 4>(testvals_int32);
475  test_mul<3, 5>(testvals_int32);
476  test_mul<7, 2>(testvals_int32);
477  test_mul<14, 15>(testvals_int32);
478
479  test_Rescale<0, 0>(testvals_int32);
480  test_Rescale<0, 1>(testvals_int32);
481  test_Rescale<2, 0>(testvals_int32);
482  test_Rescale<4, 4>(testvals_int32);
483  test_Rescale<4, 5>(testvals_int32);
484  test_Rescale<6, 3>(testvals_int32);
485  test_Rescale<13, 9>(testvals_int32);
486
487  test_ExactMulByPot<0, 0>(testvals_int32);
488  test_ExactMulByPot<0, 4>(testvals_int32);
489  test_ExactMulByPot<1, 4>(testvals_int32);
490  test_ExactMulByPot<3, 2>(testvals_int32);
491  test_ExactMulByPot<-4, 5>(testvals_int32);
492  test_ExactMulByPot<-2, 6>(testvals_int32);
493
494  std::cerr << "All tests passed." << std::endl;
495}
496