1// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// fixedpoint.h: fixed-point arithmetic, with basic operations and
16// a few math functions such as tanh.
17
18#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
19#define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
20
21#include <cassert>
22#include <limits>
23
24#include "../internal/common.h"
25
26namespace gemmlowp {
27
28// Part 1: Low-level integer-arithmetic primitives.
29// The implementations here are generic implementations valid for
30// scalar types (e.g. std::int32_t). Architecture-specific SIMD types
31// (e.g. NEON int32x4_t) may be supported by providing
32// specializations for them in separate files.
33//
34// The purpose of these primitives is two-fold:
35//  - They will be used to implement higher-level fixed-point
36//    abstractions, namely the FixedPoint class and its arithmetic
37//    operators.
38//  - They will be directly used to implement some more involved
39//    fixed-point computations, e.g. the fixed-point implementation
40//    of math functions such as tanh.
41
42// Some compile-time traits around raw types to handle SIMD aspects:
43// number of lanes, underlying scalar type.
44template <typename tIntegerType>
45struct FixedPointRawTypeTraits {};
46
47template <>
48struct FixedPointRawTypeTraits<std::int32_t> {
49  typedef std::int32_t ScalarRawType;
50  static const int kLanes = 1;
51};
52
53template <>
54struct FixedPointRawTypeTraits<std::int16_t> {
55  typedef std::int16_t ScalarRawType;
56  static const int kLanes = 1;
57};
58
59// Returns a SIMD value duplicating a scalar value across all lanes.
60template <typename tRawType>
61tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
62  return x;
63}
64
65// Plain bit-wise AND
66template <typename tIntegerType>
67tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
68  return a & b;
69}
70
71// Plain bit-wise OR
72template <typename tIntegerType>
73tIntegerType BitOr(tIntegerType a, tIntegerType b) {
74  return a | b;
75}
76
77// Plain bit-wise XOR
78template <typename tIntegerType>
79tIntegerType BitXor(tIntegerType a, tIntegerType b) {
80  return a ^ b;
81}
82
83// Plain bit-wise NOT
84template <typename tIntegerType>
85tIntegerType BitNot(tIntegerType a) {
86  return ~a;
87}
88
89// Integer addition. Not saturating. Overflow is undefined behavior.
90template <typename tIntegerType>
91tIntegerType Add(tIntegerType a, tIntegerType b) {
92  return a + b;
93}
94
95// Integer subtraction. Not saturating. Overflow is undefined behavior.
96template <typename tIntegerType>
97tIntegerType Mul(tIntegerType a, tIntegerType b) {
98  return a * b;
99}
100
101template <typename tIntegerType>
102tIntegerType Sub(tIntegerType a, tIntegerType b) {
103  return a - b;
104}
105
106// Integer unary negative. Not saturating. Overflow is undefined behavior.
107template <typename tIntegerType>
108tIntegerType Neg(tIntegerType a) {
109  return -a;
110}
111
112// Integer arithmetic left-shift, equivalent to multiplying with a
113// power of two. Not saturating. Overflow is undefined behavior.
114template <typename tIntegerType>
115tIntegerType ShiftLeft(tIntegerType a, int offset) {
116  return a << offset;
117}
118
119// Integer arithmetic right-shift. Not rounding.
120// Relying on implementation-defined, but in-practice-consistent,
121// C++ compiler behavior.
122template <typename tIntegerType>
123tIntegerType ShiftRight(tIntegerType a, int offset) {
124  return a >> offset;
125}
126
127// Each bit of the result is set to the corresponding bit of either then_val or
128// else_val depending on whether the corresponding bit of if_mask is set.
129// Equivalent to the VBSL instruction in ARM NEON.
130template <typename tIntegerType>
131tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
132                             tIntegerType else_val) {
133  return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
134}
135
136// For each input scalar, the corresponding bits of the result are set if the
137// input scalar is non-zero.
138template <typename tIntegerType>
139tIntegerType MaskIfNonZero(tIntegerType a) {
140  static const tIntegerType zero = 0;
141  return a ? BitNot(zero) : zero;
142}
143
144// For each input scalar, the corresponding bits of the result are set if the
145// input scalar is zero.
146template <typename tIntegerType>
147tIntegerType MaskIfZero(tIntegerType a) {
148  return MaskIfNonZero<tIntegerType>(!a);
149}
150
151// For each pair of input scalars, the corresponding bits of the result are
152// set if the input scalars are equal.
153template <typename tIntegerType>
154tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
155  return MaskIfNonZero<tIntegerType>(a == b);
156}
157
158// For each pair of input scalars, the corresponding bits of the result are
159// set if the input scalars are not equal.
160template <typename tIntegerType>
161tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
162  return MaskIfNonZero<tIntegerType>(a != b);
163}
164
165// For each pair of input scalars, the corresponding bits of the result are
166// set if the input scalars a, b satisfy a > b.
167template <typename tIntegerType>
168tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
169  return MaskIfNonZero<tIntegerType>(a > b);
170}
171
172// For each pair of input scalars, the corresponding bits of the result are
173// set if the input scalars a, b satisfy a >= b.
174template <typename tIntegerType>
175tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
176  return MaskIfNonZero<tIntegerType>(a >= b);
177}
178
179// For each pair of input scalars, the corresponding bits of the result are
180// set if the input scalars a, b satisfy a < b.
181template <typename tIntegerType>
182tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
183  return MaskIfNonZero<tIntegerType>(a < b);
184}
185
186// For each pair of input scalars, the corresponding bits of the result are
187// set if the input scalars a, b satisfy a <= b.
188template <typename tIntegerType>
189tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
190  return MaskIfNonZero<tIntegerType>(a <= b);
191}
192
193// Returns true if all of the input scalars are nonzero.
194// This function may currently assume that each of the input scalars has either
195// all or none of its bits set. Otherwise, its behavior is currently undefined.
196template <typename tIntegerType>
197bool All(tIntegerType a) {
198  return a;
199}
200
201// Returns true if any of the input scalars are nonzero.
202// This function may currently assume that each of the input scalars has either
203// all or none of its bits set. Otherwise, its behavior is currently undefined.
204template <typename tIntegerType>
205bool Any(tIntegerType a) {
206  return a;
207}
208
209// Returns (a+b)/2, rounded to the nearest integer.
210// Equivalent to VRHADD in the ARM NEON instruction set.
211template <typename IntegerType>
212IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
213  static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
214  return a;
215}
216
217template <>
218inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
219  std::int64_t a64 = a;
220  std::int64_t b64 = b;
221  std::int64_t sum = a64 + b64;
222  std::int64_t sign = sum >= 0 ? 1 : -1;
223  return static_cast<std::int32_t>((sum + sign) / 2);
224}
225
226template <>
227inline std::int16_t RoundingHalfSum(std::int16_t a, std::int16_t b) {
228  std::int32_t a32 = a;
229  std::int32_t b32 = b;
230  std::int32_t sum = a32 + b32;
231  std::int32_t sign = sum >= 0 ? 1 : -1;
232  return static_cast<std::int16_t>((sum + sign) / 2);
233}
234
235template <typename IntegerType>
236IntegerType SaturatingAdd(IntegerType a, IntegerType b) {
237  static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
238  return a;
239}
240
241// So far this is only needed for int16.
242template <>
243inline std::int16_t SaturatingAdd(std::int16_t a, std::int16_t b) {
244  std::int32_t a32 = a;
245  std::int32_t b32 = b;
246  std::int32_t sum = a32 + b32;
247  return static_cast<std::int16_t>(std::min(32767, std::max(-32768, sum)));
248}
249
250// Returns a+b, saturating if the integers are 16bit or narrower,
251// otherwise just a plain addition.
252template <typename IntegerType, bool Is16Bit>
253struct AddSaturatingIf16BitImpl {
254  static IntegerType Run(IntegerType a, IntegerType b) { return Add(a, b); }
255};
256template <typename IntegerType>
257struct AddSaturatingIf16BitImpl<IntegerType, true> {
258  static IntegerType Run(IntegerType a, IntegerType b) {
259    return SaturatingAdd(a, b);
260  }
261};
262template <typename IntegerType>
263IntegerType AddSaturatingIf16Bit(IntegerType a, IntegerType b) {
264  using ScalarType =
265      typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
266  return AddSaturatingIf16BitImpl<IntegerType, sizeof(ScalarType) == 2>::Run(a,
267                                                                             b);
268}
269
270// Returns the integer that represents the product of two fixed-point
271// numbers, interpreting all integers as fixed-point values in the
272// interval [-1, 1), rounding to the nearest value, and saturating
273// -1 * -1 to the maximum value (since 1 is not in the half-open
274// interval [-1, 1)).
275//
276// [The explanation below specializes to std::int32_t for example purpose.]
277//
278// The mapping between IntegerType and the interval [-1, 1) is unique and
279// implied by IntegerType, which is assumed to be signed. For example,
280// for IntegerType==std::int32_t, the mapping is
281//   real_value = integer_value / 2^31.
282// So in this case, and leaving aside rounding and saturating, this
283// function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to
284//   (a * b) / 2^31.
285//
286// The 'doubling' part in the name of this function comes from the fact that
287// this operation is very close to a "multiply-high" operation, keeping only
288// the top half bits, except that that would be effectively computing
289//   (a * b) / 2^32,
290// so here we are computing 2x that, since
291//   1/2^31 = 2 * 1/2^32.
292// The idea is to use all of the available 32 bits in the destination int32
293// value.
294//
295// [End of the explanation specializing to int32.]
296//
297// This is equivalent to the VQRDMULH instruction in ARM NEON.
298template <typename IntegerType>
299IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
300  static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
301  return a;
302}
303
304// This function implements the same computation as the ARMv7 NEON VQRDMULH
305// instruction.
306template <>
307inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
308                                                      std::int32_t b) {
309  bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
310  std::int64_t a_64(a);
311  std::int64_t b_64(b);
312  std::int64_t ab_64 = a_64 * b_64;
313  std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
314  std::int32_t ab_x2_high32 =
315      static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
316  return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
317}
318
319template <>
320inline std::int16_t SaturatingRoundingDoublingHighMul(std::int16_t a,
321                                                      std::int16_t b) {
322  bool overflow = a == b && a == std::numeric_limits<std::int16_t>::min();
323  std::int32_t a_32(a);
324  std::int32_t b_32(b);
325  std::int32_t ab_32 = a_32 * b_32;
326  std::int16_t nudge = ab_32 >= 0 ? (1 << 14) : (1 - (1 << 14));
327  std::int16_t ab_x2_high16 =
328      static_cast<std::int16_t>((ab_32 + nudge) / (1 << 15));
329  return overflow ? std::numeric_limits<std::int16_t>::max() : ab_x2_high16;
330}
331
332// Correctly-rounded-to-nearest division by a power-of-two.
333// Also known as a rounding arithmetic right shift.
334template <typename IntegerType>
335inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) {
336  assert(exponent >= 0);
337  assert(exponent <= 31);
338  const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
339  const IntegerType zero = Dup<IntegerType>(0);
340  const IntegerType one = Dup<IntegerType>(1);
341  const IntegerType remainder = BitAnd(x, mask);
342  const IntegerType threshold =
343      Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one));
344  return Add(ShiftRight(x, exponent),
345             BitAnd(MaskIfGreaterThan(remainder, threshold), one));
346}
347
348// Returns the product of a run-time integer value by a compile-time power
349// of two, with either a positive exponent (equivalent to an arithmetic
350// left shift, saturating) or a negative exponent (equivalent to an arithmetic
351// right shift, rounding to nearest).
352template <int Exponent, typename IntegerType,
353          int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
354struct ImplSaturatingRoundingMultiplyByPOT {};
355
356template <int Exponent, typename IntegerType>
357struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
358  static IntegerType eval(IntegerType x) { return x; }
359};
360
361template <int Exponent, typename IntegerType>
362struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
363  static IntegerType eval(IntegerType x) {
364    using ScalarIntegerType =
365        typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
366    const IntegerType min =
367        Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::min());
368    const IntegerType max =
369        Dup<IntegerType>(std::numeric_limits<ScalarIntegerType>::max());
370    const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType);
371
372    const std::int32_t threshold =
373        ((1 << (ScalarIntegerTypeBits - 1 - Exponent)) - 1);
374    const IntegerType positive_mask =
375        MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
376    const IntegerType negative_mask =
377        MaskIfLessThan(x, Dup<IntegerType>(-threshold));
378
379    IntegerType result = ShiftLeft(x, Exponent);
380    result = SelectUsingMask(positive_mask, max, result);
381    result = SelectUsingMask(negative_mask, min, result);
382    return result;
383  }
384};
385
386template <int Exponent, typename IntegerType>
387struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> {
388  static IntegerType eval(IntegerType x) {
389    return RoundingDivideByPOT<IntegerType>(x, -Exponent);
390  }
391};
392
393template <int Exponent, typename IntegerType>
394IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
395  return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
396}
397
398// Part 2: the FixedPoint class.
399
400// A FixedPoint object represents a fixed-point value stored in the underlying
401// integer type tRawType, if tRawType is a plain scalar integer type.
402// Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which
403// case a FixedPoint object represents a corresponding SIMD vector of fixed
404// point values.
405//
406// tIntegerBits describes the range of the fixed-point format: if
407// tIntegerBits == m then the range of representable values is the half-open
408// interval [-2^m; 2^m) where the open boundary on the right side means that
409// 2^m is not representable (how close the maximum representable value is to
410// it, depends on bit-depth of tRawType).
411//
412// In "Q format notation",
413//   https://en.wikipedia.org/wiki/Q_(number_format)
414// we are describing the format
415//   Qm.n
416// where
417//   m = tIntegerBits
418// and
419//   n = NumberOfBits(tRawType) - (m + 1)
420// Note that the (m + 1) in the above line is because we adopt the convention
421// that we count the integer bits exclusively of the sign bit; so (m + 1) is
422// the total number of integer bits inclusive of the sign bit.
423//
424// Accordingly, the number of integral representable values in our range
425//   [-2^m ; 2^m)
426// is equal to 2^(m+1).
427template <typename tRawType, int tIntegerBits>
428class FixedPoint {
429 public:
430  typedef tRawType RawType;
431
432  typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
433  typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
434
435  static const int kTotalBits = 8 * sizeof(ScalarRawType);
436  static const int kIntegerBits = tIntegerBits;
437  static const int kFractionalBits = kTotalBits - 1 - kIntegerBits;
438  static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
439                "bad IntegerBits");
440
441  typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
442
443  static const ScalarRawType ScalarRawMin() {
444    return std::numeric_limits<ScalarRawType>::min();
445  }
446
447  static const ScalarRawType ScalarRawMax() {
448    return std::numeric_limits<ScalarRawType>::max();
449  }
450
451  static const ScalarRawType RawMin() {
452    return VectorFromScalar(ScalarRawMin());
453  }
454
455  static const ScalarRawType RawMax() {
456    return VectorFromScalar(ScalarRawMax());
457  }
458
459  static FixedPoint FromRaw(RawType x) {
460    FixedPoint retval;
461    retval.raw() = x;
462    return retval;
463  }
464
465  static FixedPoint FromScalarRaw(ScalarRawType x) {
466    FixedPoint retval;
467    retval.raw() = Dup<RawType>(x);
468    return retval;
469  }
470
471  static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) {
472    return FromScalarRaw(x.raw());
473  }
474
475  template <int Exponent>
476  static FixedPoint ConstantPOT() {
477    static const int kOffset = kFractionalBits + Exponent;
478    static_assert(
479        kOffset < 31,
480        "Constant not exactly representable in this fixed-point format");
481    return FromScalarRaw(ScalarRawType(1) << kOffset);
482  }
483
484  static FixedPoint Zero() { return FromScalarRaw(0); }
485
486  static FixedPoint One() {
487    return FromScalarRaw(
488        kIntegerBits == 0
489            ? ScalarRawMax()
490            : (ScalarRawType(1) << (kIntegerBits == 0 ? 0 : kFractionalBits)));
491  }
492
493  static FixedPoint FromDouble(double x) {
494    const double min_bound = static_cast<double>(ScalarRawMin());
495    const double max_bound = static_cast<double>(ScalarRawMax());
496    return FromScalarRaw(static_cast<ScalarRawType>(std::min(
497        std::max(round(x * static_cast<double>(1ll << kFractionalBits)),
498                 min_bound),
499        max_bound)));
500  }
501
502  RawType raw() const { return i_; }
503  RawType& raw() { return i_; }
504
505 private:
506  RawType i_;
507};
508
509// Part 3: implementation of arithmetic operators for the
510// FixedPoint class, and a few related functions.
511
512// A FixedPoint multiplication is just a
513// SaturatingRoundingDoublingHighMul operation on the underlying
514// raw integer values. The IntegerBits simply add up, as is obvious
515// from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
516template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
517FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(
518    FixedPoint<tRawType, tIntegerBits_a> a,
519    FixedPoint<tRawType, tIntegerBits_b> b) {
520  FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
521  c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
522  return c;
523}
524
525// Tweaking IntegerBits gives exact multiplication by a power of two.
526template <int tExponent, typename tRawType, int tIntegerBits>
527FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(
528    FixedPoint<tRawType, tIntegerBits> a) {
529  FixedPoint<tRawType, tExponent + tIntegerBits> c;
530  c.raw() = a.raw();
531  return c;
532}
533
534// If we want to leave IntegerBits fixed, then multiplication
535// by a power of two has to be saturating/rounding, not exact anymore.
536template <int tExponent, typename tRawType, int tIntegerBits>
537FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(
538    FixedPoint<tRawType, tIntegerBits> a) {
539  return FixedPoint<tRawType, tIntegerBits>::FromRaw(
540      SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
541}
542
543// Generic arithmetic operators.
544
545#define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName)                     \
546  template <typename tRawType, int tIntegerBits>                               \
547  FixedPoint<tRawType, tIntegerBits> FuncName(                                 \
548      FixedPoint<tRawType, tIntegerBits> a) {                                  \
549    return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
550  }
551
552#define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
553  template <typename tRawType, int tIntegerBits>            \
554  FixedPoint<tRawType, tIntegerBits> FuncName(              \
555      FixedPoint<tRawType, tIntegerBits> a,                 \
556      FixedPoint<tRawType, tIntegerBits> b) {               \
557    return FixedPoint<tRawType, tIntegerBits>::FromRaw(     \
558        ImplFuncName(a.raw(), b.raw()));                    \
559  }
560
561MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
562MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
563MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
564MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
565MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
566MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
567MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
568MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
569
570#undef MAKE_FIXEDPOINT_UNARY_FUNC
571#undef MAKE_FIXEDPOINT_BINARY_FUNC
572
573#define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName)  \
574  template <typename tRawType, int tIntegerBits>            \
575  tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
576    return FuncName(a.raw());                               \
577  }
578
579#define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
580  template <typename tRawType, int tIntegerBits>            \
581  tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a,   \
582                    FixedPoint<tRawType, tIntegerBits> b) { \
583    return FuncName(a.raw(), b.raw());                      \
584  }
585
586MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
587MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
588MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
589MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
590MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
591MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
592MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
593MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
594
595#undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
596#undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
597
598template <typename tRawType, int tIntegerBits>
599FixedPoint<tRawType, tIntegerBits> SelectUsingMask(
600    tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
601    FixedPoint<tRawType, tIntegerBits> else_val) {
602  return FixedPoint<tRawType, tIntegerBits>::FromRaw(
603      SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
604}
605
606template <typename tRawType, int tIntegerBits>
607bool operator==(FixedPoint<tRawType, tIntegerBits> a,
608                FixedPoint<tRawType, tIntegerBits> b) {
609  return All(MaskIfEqual(a.raw(), b.raw()));
610}
611
612template <typename tRawType, int tIntegerBits>
613bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
614                FixedPoint<tRawType, tIntegerBits> b) {
615  return !(a == b);
616}
617
618template <typename tRawType, int tIntegerBits>
619FixedPoint<tRawType, tIntegerBits> SaturatingAdd(
620    FixedPoint<tRawType, tIntegerBits> a,
621    FixedPoint<tRawType, tIntegerBits> b) {
622  return FixedPoint<tRawType, tIntegerBits>::FromRaw(
623      SaturatingAdd(a.raw(), b.raw()));
624}
625
626template <typename tRawType, int tIntegerBits>
627FixedPoint<tRawType, tIntegerBits> AddSaturatingIf16Bit(
628    FixedPoint<tRawType, tIntegerBits> a,
629    FixedPoint<tRawType, tIntegerBits> b) {
630  return FixedPoint<tRawType, tIntegerBits>::FromRaw(
631      AddSaturatingIf16Bit(a.raw(), b.raw()));
632}
633
634// Conversion to floating-point.
635template <typename tRawType, int tIntegerBits>
636double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
637  static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1,
638                "not applicable to SIMD types");
639  typedef FixedPoint<tRawType, tIntegerBits> F;
640  return x.raw() / static_cast<double>(1ll << F::kFractionalBits);
641}
642
643// Rescale changes the number of IntegerBits and updates the underlying
644// raw integer value accordingly.
645template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
646FixedPoint<tRawType, tIntegerBitsDst> Rescale(
647    FixedPoint<tRawType, tIntegerBitsSrc> x) {
648  static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
649  FixedPoint<tRawType, tIntegerBitsDst> result;
650  result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
651  return result;
652}
653
654// CheckedFixedPointConstant allows to specify fixed-point constants
655// initialized as real numbers, in a way that does not compile floating-point
656// arithmetic in production code, yet still checks agreement with the
657// floating-point expressions when asserts are enabled.
658//
659// The raw integer value provided is always a int32, encoding a 32-bit
660// fixed-point value, regardless of the actual Scalar type. This allows
661// writing generic code that applies just as well to the 32-bit and 16-bit
662// cases. In the 16-bit case, the raw integer value is internally
663// rounding-shifted by 16 bits to the right.
664template <typename FixedPointType>
665inline typename FixedPointType::ScalarRawType RescaleConstantInitializer(
666    std::int32_t int32_value) {
667  typedef typename FixedPointType::ScalarRawType ScalarRawType;
668  static constexpr int ScalarTypeBits = 8 * sizeof(ScalarRawType);
669  return static_cast<ScalarRawType>(
670      RoundingDivideByPOT<std::int32_t>(int32_value, 32 - ScalarTypeBits));
671}
672#ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
673template <typename FixedPointType>
674FixedPointType CheckedFixedPointConstant(std::int32_t raw_value,
675                                         double double_value) {
676  const FixedPointType result = FixedPointType::FromScalarRaw(raw_value);
677  assert(result == FixedPointType::FromDouble(double_value));
678  return result;
679}
680#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType,                   \
681                                             ScalarRawInt32Value, DoubleValue) \
682  (gemmlowp::CheckedFixedPointConstant<FixedPointType>(                        \
683      gemmlowp::RescaleConstantInitializer<FixedPointType>(                    \
684          ScalarRawInt32Value),                                                \
685      DoubleValue))
686
687#else
688#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType,                   \
689                                             ScalarRawInt32Value, DoubleValue) \
690  (FixedPointType::FromScalarRaw(                                              \
691      gemmlowp::RescaleConstantInitializer<FixedPointType>(                    \
692          ScalarRawInt32Value)))
693#endif
694
695// Implementation of exponential function.
696
697// Returns exp(x) for x in [-1/4, 0).
698template <typename tRawType>
699FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
700    FixedPoint<tRawType, 0> a) {
701  typedef FixedPoint<tRawType, 0> F;
702  const F constant_term =
703      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0));
704  const F constant_1_over_3 =
705      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0);
706  // We're evaluating a Taylor expansion around -1/8, so we do the change of
707  // variable: x = a + 1/8.
708  // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
709  F x = a + F::template ConstantPOT<-3>();
710  F x2 = x * x;
711  F x3 = x2 * x;
712  F x4 = x2 * x2;
713  F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4);
714  F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
715      SaturatingRoundingMultiplyByPOT<-1>(
716          ((x4_over_4 + x3) * constant_1_over_3) + x2);
717  return AddSaturatingIf16Bit(
718      constant_term,
719      constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2));
720}
721
722// Returns exp(x) for x < 0.
723template <typename tRawType, int tIntegerBits>
724FixedPoint<tRawType, 0> exp_on_negative_values(
725    FixedPoint<tRawType, tIntegerBits> a) {
726  typedef FixedPoint<tRawType, tIntegerBits> InputF;
727  typedef FixedPoint<tRawType, 0> ResultF;
728  static const int kFractionalBits = InputF::kFractionalBits;
729  static const int kIntegerBits = InputF::kIntegerBits;
730  static const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
731  InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
732  InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
733  ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
734      Rescale<0>(a_mod_quarter_minus_one_quarter));
735  tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw();
736
737#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier)         \
738  if (kIntegerBits > Exponent) {                                            \
739    const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(       \
740        ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \
741    static constexpr int kShiftAmount =                                     \
742        kIntegerBits > Exponent ? kFractionalBits + Exponent : 0;           \
743    result = SelectUsingMask(                                               \
744        MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \
745        result * kMultiplier, result);                                      \
746  }
747
748  GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);
749  GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);
750  GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);
751  GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);
752  GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);
753  GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);
754  GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);
755
756#undef GEMMLOWP_EXP_BARREL_SHIFTER
757
758  if (kIntegerBits > 5) {
759    static const int b = kIntegerBits > 5 ? 36 - kIntegerBits : 0;
760    const InputF clamp =
761        GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0);
762    result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
763  }
764
765  result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result);
766  return result;
767}
768
769// Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)).
770
771// Returns (1 - x) / (1 + x) for x in (0, 1).
772template <typename tRawType>
773FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(
774    FixedPoint<tRawType, 0> a) {
775  typedef FixedPoint<tRawType, 0> F0;
776  typedef FixedPoint<tRawType, 2> F2;
777  F0 half_denominator = RoundingHalfSum(a, F0::One());
778  // Newton-Raphson division
779  // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
780  // Refer to that page for the logic behind the 48/17 and 32/17 constants.
781  const F2 constant_48_over_17 =
782      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
783  const F2 constant_neg_32_over_17 =
784      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
785  F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
786  for (int i = 0; i < 3; i++) {
787    F2 half_denominator_times_x = half_denominator * x;
788    F2 one_minus_half_denominator_times_x =
789        F2::One() - half_denominator_times_x;
790    x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
791  }
792  return Rescale<0>(x - F2::One());
793}
794
795// Returns -tanh(x) for x < 0.
796template <typename tRawType, int tIntegerBits>
797FixedPoint<tRawType, 0> neg_tanh_on_negative_values(
798    FixedPoint<tRawType, tIntegerBits> a) {
799  return one_minus_x_over_one_plus_x_for_x_in_0_1(
800      exp_on_negative_values(ExactMulByPot<1>(a)));
801}
802
803// Returns tanh(x) for any x.
804template <typename tRawType, int tIntegerBits>
805FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
806  typedef FixedPoint<tRawType, tIntegerBits> InputF;
807  typedef FixedPoint<tRawType, 0> ResultF;
808  tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
809  tRawType mask_if_zero = MaskIfZero(a);
810  InputF n = SelectUsingMask(mask_if_negative, a, -a);
811  ResultF t = neg_tanh_on_negative_values(n);
812  return SelectUsingMask(mask_if_zero, ResultF::Zero(),
813                         SelectUsingMask(mask_if_negative, -t, t));
814}
815
816// Implementation of logistic function.
817
818// Returns 1 / (1 + x) for x in (0, 1).
819template <typename tRawType>
820FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1(
821    FixedPoint<tRawType, 0> a) {
822  typedef FixedPoint<tRawType, 0> F0;
823  typedef FixedPoint<tRawType, 2> F2;
824  F0 half_denominator = RoundingHalfSum(a, F0::One());
825  // Newton-Raphson division
826  // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
827  // Refer to that page for the logic behind the 48/17 and 32/17 constants.
828  const F2 constant_48_over_17 =
829      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
830  const F2 constant_neg_32_over_17 =
831      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
832  F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
833  for (int i = 0; i < 3; i++) {
834    F2 half_denominator_times_x = half_denominator * x;
835    F2 one_minus_half_denominator_times_x =
836        F2::One() - half_denominator_times_x;
837    x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
838  }
839  return Rescale<0>(ExactMulByPot<-1>(x));
840}
841
842// Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0.
843template <typename tRawType, int tIntegerBits>
844FixedPoint<tRawType, 0> logistic_on_positive_values(
845    FixedPoint<tRawType, tIntegerBits> a) {
846  return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a));
847}
848
849// Returns logistic(x) = 1 / (1 + exp(-x)) for any x.
850template <typename tRawType, int tIntegerBits>
851FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
852  typedef FixedPoint<tRawType, tIntegerBits> InputF;
853  typedef FixedPoint<tRawType, 0> ResultF;
854  tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero());
855  tRawType mask_if_zero = MaskIfZero(a);
856  InputF abs_input = SelectUsingMask(mask_if_positive, a, -a);
857  ResultF result_if_positive = logistic_on_positive_values(abs_input);
858  ResultF result_if_negative = ResultF::One() - result_if_positive;
859  const ResultF one_half =
860      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5);
861  return SelectUsingMask(mask_if_zero, one_half,
862                         SelectUsingMask(mask_if_positive, result_if_positive,
863                                         result_if_negative));
864}
865
866}  // end namespace gemmlowp
867
868#ifdef GEMMLOWP_NEON
869#include "./fixedpoint_neon.h"
870#elif defined(GEMMLOWP_SSE4)
871#include "./fixedpoint_sse.h"
872#elif defined(GEMMLOWP_MSA)
873#include "./fixedpoint_msa.h"
874#endif
875
876#endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_H_
877