fixedpoint.h revision a9fd919a0080e2c3c7ed1ce451c85a4d86f2f8c1
1// Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// fixedpoint.h: fixed-point arithmetic, with basic operations and
16// a few math functions such as tanh.
17
18#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_
19#define GEMMLOWP_INTERNAL_FIXEDPOINT_H_
20
21#include <cassert>
22#include <limits>
23
24#include "../internal/common.h"
25
26namespace gemmlowp {
27
28// Part 1: Low-level integer-arithmetic primitives.
29// The implementations here are generic implementations valid for
30// scalar types (e.g. std::int32_t). Architecture-specific SIMD types
31// (e.g. NEON int32x4_t) may be supported by providing
32// specializations for them in separate files.
33//
34// The purpose of these primitives is two-fold:
35//  - They will be used to implement higher-level fixed-point
36//    abstractions, namely the FixedPoint class and its arithmetic
37//    operators.
38//  - They will be directly used to implement some more involved
39//    fixed-point computations, e.g. the fixed-point implementation
40//    of math functions such as tanh.
41
42// Some compile-time traits around raw types to handle SIMD aspects:
43// number of lanes, underlying scalar type.
44template <typename tIntegerType>
45struct FixedPointRawTypeTraits {};
46
47template <>
48struct FixedPointRawTypeTraits<std::int32_t> {
49  typedef std::int32_t ScalarRawType;
50  static const int kLanes = 1;
51};
52
53// Returns a SIMD value duplicating a scalar value across all lanes.
54template <typename tRawType>
55tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) {
56  return x;
57}
58
59// Plain bit-wise AND
60template <typename tIntegerType>
61tIntegerType BitAnd(tIntegerType a, tIntegerType b) {
62  return a & b;
63}
64
65// Plain bit-wise OR
66template <typename tIntegerType>
67tIntegerType BitOr(tIntegerType a, tIntegerType b) {
68  return a | b;
69}
70
71// Plain bit-wise XOR
72template <typename tIntegerType>
73tIntegerType BitXor(tIntegerType a, tIntegerType b) {
74  return a ^ b;
75}
76
77// Plain bit-wise NOT
78template <typename tIntegerType>
79tIntegerType BitNot(tIntegerType a) {
80  return ~a;
81}
82
83// Integer addition. Not saturating. Overflow is undefined behavior.
84template <typename tIntegerType>
85tIntegerType Add(tIntegerType a, tIntegerType b) {
86  return a + b;
87}
88
89// Integer subtraction. Not saturating. Overflow is undefined behavior.
90template <typename tIntegerType>
91tIntegerType Mul(tIntegerType a, tIntegerType b) {
92  return a * b;
93}
94
95template <typename tIntegerType>
96tIntegerType Sub(tIntegerType a, tIntegerType b) {
97  return a - b;
98}
99
100// Integer unary negative. Not saturating. Overflow is undefined behavior.
101template <typename tIntegerType>
102tIntegerType Neg(tIntegerType a) {
103  return -a;
104}
105
106// Integer arithmetic left-shift, equivalent to multiplying with a
107// power of two. Not saturating. Overflow is undefined behavior.
108template <typename tIntegerType>
109tIntegerType ShiftLeft(tIntegerType a, int offset) {
110  return a << offset;
111}
112
113// Integer arithmetic right-shift. Not rounding.
114// Relying on implementation-defined, but in-practice-consistent,
115// C++ compiler behavior.
116template <typename tIntegerType>
117tIntegerType ShiftRight(tIntegerType a, int offset) {
118  return a >> offset;
119}
120
121// Each bit of the result is set to the corresponding bit of either then_val or
122// else_val depending on whether the corresponding bit of if_mask is set.
123// Equivalent to the VBSL instruction in ARM NEON.
124template <typename tIntegerType>
125tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val,
126                             tIntegerType else_val) {
127  return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val));
128}
129
130// For each input scalar, the corresponding bits of the result are set if the
131// input scalar is non-zero.
132template <typename tIntegerType>
133tIntegerType MaskIfNonZero(tIntegerType a) {
134  static const tIntegerType zero = 0;
135  return a ? BitNot(zero) : zero;
136}
137
138// For each input scalar, the corresponding bits of the result are set if the
139// input scalar is zero.
140template <typename tIntegerType>
141tIntegerType MaskIfZero(tIntegerType a) {
142  return MaskIfNonZero<tIntegerType>(!a);
143}
144
145// For each pair of input scalars, the corresponding bits of the result are
146// set if the input scalars are equal.
147template <typename tIntegerType>
148tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) {
149  return MaskIfNonZero<tIntegerType>(a == b);
150}
151
152// For each pair of input scalars, the corresponding bits of the result are
153// set if the input scalars are not equal.
154template <typename tIntegerType>
155tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) {
156  return MaskIfNonZero<tIntegerType>(a != b);
157}
158
159// For each pair of input scalars, the corresponding bits of the result are
160// set if the input scalars a, b satisfy a > b.
161template <typename tIntegerType>
162tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) {
163  return MaskIfNonZero<tIntegerType>(a > b);
164}
165
166// For each pair of input scalars, the corresponding bits of the result are
167// set if the input scalars a, b satisfy a >= b.
168template <typename tIntegerType>
169tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) {
170  return MaskIfNonZero<tIntegerType>(a >= b);
171}
172
173// For each pair of input scalars, the corresponding bits of the result are
174// set if the input scalars a, b satisfy a < b.
175template <typename tIntegerType>
176tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) {
177  return MaskIfNonZero<tIntegerType>(a < b);
178}
179
180// For each pair of input scalars, the corresponding bits of the result are
181// set if the input scalars a, b satisfy a <= b.
182template <typename tIntegerType>
183tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) {
184  return MaskIfNonZero<tIntegerType>(a <= b);
185}
186
187// Returns true if all of the input scalars are nonzero.
188// This function may currently assume that each of the input scalars has either
189// all or none of its bits set. Otherwise, its behavior is currently undefined.
190template <typename tIntegerType>
191bool All(tIntegerType a) {
192  return a;
193}
194
195// Returns true if any of the input scalars are nonzero.
196// This function may currently assume that each of the input scalars has either
197// all or none of its bits set. Otherwise, its behavior is currently undefined.
198template <typename tIntegerType>
199bool Any(tIntegerType a) {
200  return a;
201}
202
203// Returns (a+b)/2, rounded to the nearest integer.
204// Equivalent to VRHADD in the ARM NEON instruction set.
205template <typename IntegerType>
206IntegerType RoundingHalfSum(IntegerType a, IntegerType b) {
207  static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
208  return a;
209}
210
211template <>
212inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) {
213  std::int64_t a64 = a;
214  std::int64_t b64 = b;
215  std::int64_t sum = a64 + b64;
216  std::int64_t sign = sum >= 0 ? 1 : -1;
217  return static_cast<std::int32_t>((sum + sign) / 2);
218}
219
220// Returns the integer that represents the product of two fixed-point
221// numbers, interpreting all integers as fixed-point values in the
222// interval [-1, 1), rounding to the nearest value, and saturating
223// -1 * -1 to the maximum value (since 1 is not in the half-open
224// interval [-1, 1)).
225//
226// [The explanation below specializes to std::int32_t for example purpose.]
227//
228// The mapping between IntegerType and the interval [-1, 1) is unique and
229// implied by IntegerType, which is assumed to be signed. For example,
230// for IntegerType==std::int32_t, the mapping is
231//   real_value = integer_value / 2^31.
232// So in this case, and leaving aside rounding and saturating, this
233// function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to
234//   (a * b) / 2^31.
235//
236// The 'doubling' part in the name of this function comes from the fact that
237// this operation is very close to a "multiply-high" operation, keeping only
238// the top half bits, except that that would be effectively computing
239//   (a * b) / 2^32,
240// so here we are computing 2x that, since
241//   1/2^31 = 2 * 1/2^32.
242// The idea is to use all of the available 32 bits in the destination int32
243// value.
244//
245// [End of the explanation specializing to int32.]
246//
247// This is equivalent to the VQRDMULH instruction in ARM NEON.
248template <typename IntegerType>
249IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) {
250  static_assert(std::is_same<IntegerType, void>::value, "unimplemented");
251  return a;
252}
253
254// This function implements the same computation as the ARMv7 NEON VQRDMULH
255// instruction.
256template <>
257inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a,
258                                                      std::int32_t b) {
259  bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
260  std::int64_t a_64(a);
261  std::int64_t b_64(b);
262  std::int64_t ab_64 = a_64 * b_64;
263  std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
264  std::int32_t ab_x2_high32 =
265      static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
266  return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
267}
268
269// Correctly-rounded-to-nearest division by a power-of-two.
270// Also known as a rounding arithmetic right shift.
271template <typename IntegerType>
272inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) {
273  using ScalarIntegerType =
274      typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
275  static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value,
276                "Currently only supporting int32 scalar and SIMD types");
277  assert(exponent >= 0);
278  assert(exponent <= 31);
279  const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1);
280  const IntegerType zero = Dup<IntegerType>(0);
281  const IntegerType one = Dup<IntegerType>(1);
282  const IntegerType remainder = BitAnd(x, mask);
283  const IntegerType threshold =
284      Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one));
285  return Add(ShiftRight(x, exponent),
286             BitAnd(MaskIfGreaterThan(remainder, threshold), one));
287}
288
289// Returns the product of a run-time integer value by a compile-time power
290// of two, with either a positive exponent (equivalent to an arithmetic
291// left shift, saturating) or a negative exponent (equivalent to an arithmetic
292// right shift, rounding to nearest).
293template <int Exponent, typename IntegerType,
294          int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)>
295struct ImplSaturatingRoundingMultiplyByPOT {};
296
297template <int Exponent, typename IntegerType>
298struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> {
299  static IntegerType eval(IntegerType x) { return x; }
300};
301
302template <int Exponent, typename IntegerType>
303struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> {
304  static IntegerType eval(IntegerType x) {
305    using ScalarIntegerType =
306        typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType;
307    static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value,
308                  "Currently only supporting int32 scalar and SIMD types");
309    const IntegerType min =
310        Dup<IntegerType>(std::numeric_limits<std::int32_t>::min());
311    const IntegerType max =
312        Dup<IntegerType>(std::numeric_limits<std::int32_t>::max());
313
314    const std::int32_t threshold = ((1 << (31 - Exponent)) - 1);
315    const IntegerType positive_mask =
316        MaskIfGreaterThan(x, Dup<IntegerType>(threshold));
317    const IntegerType negative_mask =
318        MaskIfLessThan(x, Dup<IntegerType>(-threshold));
319
320    IntegerType result = ShiftLeft(x, Exponent);
321    result = SelectUsingMask(positive_mask, max, result);
322    result = SelectUsingMask(negative_mask, min, result);
323    return result;
324  }
325};
326
327template <int Exponent, typename IntegerType>
328struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> {
329  static IntegerType eval(IntegerType x) {
330    return RoundingDivideByPOT<IntegerType>(x, -Exponent);
331  }
332};
333
334template <int Exponent, typename IntegerType>
335IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) {
336  return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x);
337}
338
339// Part 2: the FixedPoint class.
340
341// A FixedPoint object represents a fixed-point value stored in the underlying
342// integer type tRawType, if tRawType is a plain scalar integer type.
343// Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which
344// case a FixedPoint object represents a corresponding SIMD vector of fixed
345// point values.
346//
347// tIntegerBits describes the range of the fixed-point format: if
348// tIntegerBits == m then the range of representable values is the half-open
349// interval [-2^m; 2^m) where the open boundary on the right side means that
350// 2^m is not representable (how close the maximum representable value is to
351// it, depends on bit-depth of tRawType).
352//
353// In "Q format notation",
354//   https://en.wikipedia.org/wiki/Q_(number_format)
355// we are describing the format
356//   Qm.n
357// where
358//   m = tIntegerBits
359// and
360//   n = NumberOfBits(tRawType) - (m + 1)
361// Note that the (m + 1) in the above line is because we adopt the convention
362// that we count the integer bits exclusively of the sign bit; so (m + 1) is
363// the total number of integer bits inclusive of the sign bit.
364//
365// Accordingly, the number of integral representable values in our range
366//   [-2^m ; 2^m)
367// is equal to 2^(m+1).
368template <typename tRawType, int tIntegerBits>
369class FixedPoint {
370 public:
371  typedef tRawType RawType;
372
373  typedef FixedPointRawTypeTraits<RawType> RawTypeTraits;
374  typedef typename RawTypeTraits::ScalarRawType ScalarRawType;
375
376  static const int kTotalBits = 8 * sizeof(ScalarRawType);
377  static const int kIntegerBits = tIntegerBits;
378  static const int kFractionalBits = kTotalBits - 1 - kIntegerBits;
379  static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits,
380                "bad IntegerBits");
381
382  typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType;
383
384  static const ScalarRawType ScalarRawMin() {
385    return std::numeric_limits<ScalarRawType>::min();
386  }
387
388  static const ScalarRawType ScalarRawMax() {
389    return std::numeric_limits<ScalarRawType>::max();
390  }
391
392  static const ScalarRawType RawMin() {
393    return VectorFromScalar(ScalarRawMin());
394  }
395
396  static const ScalarRawType RawMax() {
397    return VectorFromScalar(ScalarRawMax());
398  }
399
400  static FixedPoint FromRaw(RawType x) {
401    FixedPoint retval;
402    retval.raw() = x;
403    return retval;
404  }
405
406  static FixedPoint FromScalarRaw(ScalarRawType x) {
407    FixedPoint retval;
408    retval.raw() = Dup<RawType>(x);
409    return retval;
410  }
411
412  static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) {
413    return FromScalarRaw(x.raw());
414  }
415
416  template <int Exponent>
417  static FixedPoint ConstantPOT() {
418    static const int kOffset = kFractionalBits + Exponent;
419    static_assert(
420        kOffset < 31,
421        "Constant not exactly representable in this fixed-point format");
422    return FromScalarRaw(ScalarRawType(1) << kOffset);
423  }
424
425  static FixedPoint Zero() { return FromScalarRaw(0); }
426
427  static FixedPoint One() {
428    return FromScalarRaw(kIntegerBits == 0
429                             ? ScalarRawMax()
430                             : (ScalarRawType(1) << kFractionalBits));
431  }
432
433  static FixedPoint FromDouble(double x) {
434    const double min_bound = static_cast<double>(ScalarRawMin());
435    const double max_bound = static_cast<double>(ScalarRawMax());
436    return FromScalarRaw(static_cast<std::int32_t>(std::min(
437        std::max(round(x * static_cast<double>(1ll << kFractionalBits)),
438                 min_bound),
439        max_bound)));
440  }
441
442  RawType raw() const { return i_; }
443  RawType& raw() { return i_; }
444
445 private:
446  RawType i_;
447};
448
449// Part 3: implementation of arithmetic operators for the
450// FixedPoint class, and a few related functions.
451
452// A FixedPoint multiplication is just a
453// SaturatingRoundingDoublingHighMul operation on the underlying
454// raw integer values. The IntegerBits simply add up, as is obvious
455// from the fact that the range is [-2^IntegerBits, 2^IntegerBits).
456template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b>
457FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*(
458    FixedPoint<tRawType, tIntegerBits_a> a,
459    FixedPoint<tRawType, tIntegerBits_b> b) {
460  FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c;
461  c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw());
462  return c;
463}
464
465// Tweaking IntegerBits gives exact multiplication by a power of two.
466template <int tExponent, typename tRawType, int tIntegerBits>
467FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot(
468    FixedPoint<tRawType, tIntegerBits> a) {
469  FixedPoint<tRawType, tExponent + tIntegerBits> c;
470  c.raw() = a.raw();
471  return c;
472}
473
474// If we want to leave IntegerBits fixed, then multiplication
475// by a power of two has to be saturating/rounding, not exact anymore.
476template <int tExponent, typename tRawType, int tIntegerBits>
477FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT(
478    FixedPoint<tRawType, tIntegerBits> a) {
479  return FixedPoint<tRawType, tIntegerBits>::FromRaw(
480      SaturatingRoundingMultiplyByPOT<tExponent>(a.raw()));
481}
482
483// Generic arithmetic operators.
484
485#define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName)                     \
486  template <typename tRawType, int tIntegerBits>                               \
487  FixedPoint<tRawType, tIntegerBits> FuncName(                                 \
488      FixedPoint<tRawType, tIntegerBits> a) {                                  \
489    return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \
490  }
491
492#define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \
493  template <typename tRawType, int tIntegerBits>            \
494  FixedPoint<tRawType, tIntegerBits> FuncName(              \
495      FixedPoint<tRawType, tIntegerBits> a,                 \
496      FixedPoint<tRawType, tIntegerBits> b) {               \
497    return FixedPoint<tRawType, tIntegerBits>::FromRaw(     \
498        ImplFuncName(a.raw(), b.raw()));                    \
499  }
500
501MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg)
502MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot)
503MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add)
504MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub)
505MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd)
506MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor)
507MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr)
508MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum)
509
510#undef MAKE_FIXEDPOINT_UNARY_FUNC
511#undef MAKE_FIXEDPOINT_BINARY_FUNC
512
513#define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName)  \
514  template <typename tRawType, int tIntegerBits>            \
515  tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \
516    return FuncName(a.raw());                               \
517  }
518
519#define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \
520  template <typename tRawType, int tIntegerBits>            \
521  tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a,   \
522                    FixedPoint<tRawType, tIntegerBits> b) { \
523    return FuncName(a.raw(), b.raw());                      \
524  }
525
526MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero)
527MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero)
528MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual)
529MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual)
530MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan)
531MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual)
532MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan)
533MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual)
534
535#undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW
536#undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW
537
538template <typename tRawType, int tIntegerBits>
539FixedPoint<tRawType, tIntegerBits> SelectUsingMask(
540    tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val,
541    FixedPoint<tRawType, tIntegerBits> else_val) {
542  return FixedPoint<tRawType, tIntegerBits>::FromRaw(
543      SelectUsingMask(if_mask, then_val.raw(), else_val.raw()));
544}
545
546template <typename tRawType, int tIntegerBits>
547bool operator==(FixedPoint<tRawType, tIntegerBits> a,
548                FixedPoint<tRawType, tIntegerBits> b) {
549  return All(MaskIfEqual(a.raw(), b.raw()));
550}
551
552template <typename tRawType, int tIntegerBits>
553bool operator!=(FixedPoint<tRawType, tIntegerBits> a,
554                FixedPoint<tRawType, tIntegerBits> b) {
555  return !(a == b);
556}
557
558// Conversion to floating-point.
559template <typename tRawType, int tIntegerBits>
560double ToDouble(FixedPoint<tRawType, tIntegerBits> x) {
561  static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1,
562                "not applicable to SIMD types");
563  typedef FixedPoint<tRawType, tIntegerBits> F;
564  return x.raw() / static_cast<double>(1ll << F::kFractionalBits);
565}
566
567// Rescale changes the number of IntegerBits and updates the underlying
568// raw integer value accordingly.
569template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc>
570FixedPoint<tRawType, tIntegerBitsDst> Rescale(
571    FixedPoint<tRawType, tIntegerBitsSrc> x) {
572  static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst;
573  FixedPoint<tRawType, tIntegerBitsDst> result;
574  result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw());
575  return result;
576}
577
578// CheckedFixedPointConstant allows to specify fixed-point constants
579// initialized as real numbers, in a way that does not compile floating-point
580// arithmetic in production code, yet still checks agreement with the
581// floating-point expressions when asserts are enabled.
582#ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS
583template <typename FixedPointType>
584FixedPointType CheckedFixedPointConstant(
585    typename FixedPointType::ScalarRawType raw_value, double double_value) {
586  typedef typename FixedPointType::RawType RawType;
587  const FixedPointType result = FixedPointType::FromScalarRaw(raw_value);
588  assert(result == FixedPointType::FromDouble(double_value));
589  return result;
590}
591#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \
592                                             DoubleValue)                    \
593  (CheckedFixedPointConstant<FixedPointType>(ScalarRawValue, DoubleValue))
594
595#else
596#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \
597                                             DoubleValue)                    \
598  (FixedPointType::FromScalarRaw(ScalarRawValue))
599#endif
600
601// Implementation of exponential function.
602
603// Returns exp(x) for x in [-1/4, 0).
604template <typename tRawType>
605FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl(
606    FixedPoint<tRawType, 0> a) {
607  typedef FixedPoint<tRawType, 0> F;
608  const F constant_term =
609      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0));
610  const F constant_1_over_3 =
611      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0);
612  // We're evaluating a Taylor expansion around -1/8, so we do the change of
613  // variable: x = a + 1/8.
614  // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28.
615  F x = a + F::template ConstantPOT<-3>();
616  F x2 = x * x;
617  F x3 = x2 * x;
618  F x4 = x2 * x2;
619  F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4);
620  F x4_over_24_plus_x3_over_6_plus_x2_over_2 =
621      SaturatingRoundingMultiplyByPOT<-1>(
622          ((x4_over_4 + x3) * constant_1_over_3) + x2);
623  return constant_term +
624         constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2);
625}
626
627// Returns exp(x) for x < 0.
628template <typename tRawType, int tIntegerBits>
629FixedPoint<tRawType, 0> exp_on_negative_values(
630    FixedPoint<tRawType, tIntegerBits> a) {
631  typedef FixedPoint<tRawType, tIntegerBits> InputF;
632  typedef FixedPoint<tRawType, 0> ResultF;
633  static const int kFractionalBits = InputF::kFractionalBits;
634  static const int kIntegerBits = InputF::kIntegerBits;
635  static const InputF kOneQuarter = InputF::template ConstantPOT<-2>();
636  InputF mask = kOneQuarter - InputF::FromScalarRaw(1);
637  InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter;
638  ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl(
639      Rescale<0>(a_mod_quarter_minus_one_quarter));
640  tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw();
641
642#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier)         \
643  if (kIntegerBits > Exponent) {                                            \
644    const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(       \
645        ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \
646    static constexpr int kShiftAmount =                                     \
647        kIntegerBits > Exponent ? kFractionalBits + Exponent : 0;           \
648    result = SelectUsingMask(                                               \
649        MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \
650        result * kMultiplier, result);                                      \
651  }
652
653  GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947);
654  GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674);
655  GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084);
656  GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308);
657  GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535);
658  GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401);
659  GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242);
660
661#undef GEMMLOWP_EXP_BARREL_SHIFTER
662
663  if (kIntegerBits > 5) {
664    static const int b = kIntegerBits > 5 ? kFractionalBits + 5 : 0;
665    const InputF clamp =
666        GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0);
667    result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result);
668  }
669
670  result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result);
671  return result;
672}
673
674// Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)).
675
676// Returns (1 - x) / (1 + x) for x in (0, 1).
677template <typename tRawType>
678FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1(
679    FixedPoint<tRawType, 0> a) {
680  typedef FixedPoint<tRawType, 0> F0;
681  typedef FixedPoint<tRawType, 2> F2;
682  F0 half_denominator = RoundingHalfSum(a, F0::One());
683  // Newton-Raphson division
684  // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
685  // Refer to that page for the logic behind the 48/17 and 32/17 constants.
686  const F2 constant_48_over_17 =
687      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
688  const F2 constant_neg_32_over_17 =
689      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
690  F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
691  for (int i = 0; i < 3; i++) {
692    F2 half_denominator_times_x = half_denominator * x;
693    F2 one_minus_half_denominator_times_x =
694        F2::One() - half_denominator_times_x;
695    x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
696  }
697  return Rescale<0>(x - F2::One());
698}
699
700// Returns -tanh(x) for x < 0.
701template <typename tRawType, int tIntegerBits>
702FixedPoint<tRawType, 0> neg_tanh_on_negative_values(
703    FixedPoint<tRawType, tIntegerBits> a) {
704  return one_minus_x_over_one_plus_x_for_x_in_0_1(
705      exp_on_negative_values(ExactMulByPot<1>(a)));
706}
707
708// Returns tanh(x) for any x.
709template <typename tRawType, int tIntegerBits>
710FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) {
711  typedef FixedPoint<tRawType, tIntegerBits> InputF;
712  typedef FixedPoint<tRawType, 0> ResultF;
713  tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero());
714  tRawType mask_if_zero = MaskIfZero(a);
715  InputF n = SelectUsingMask(mask_if_negative, a, -a);
716  ResultF t = neg_tanh_on_negative_values(n);
717  return SelectUsingMask(mask_if_zero, ResultF::Zero(),
718                         SelectUsingMask(mask_if_negative, -t, t));
719}
720
721// Implementation of logistic function.
722
723// Returns 1 / (1 + x) for x in (0, 1).
724template <typename tRawType>
725FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1(
726    FixedPoint<tRawType, 0> a) {
727  typedef FixedPoint<tRawType, 0> F0;
728  typedef FixedPoint<tRawType, 2> F2;
729  F0 half_denominator = RoundingHalfSum(a, F0::One());
730  // Newton-Raphson division
731  // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division
732  // Refer to that page for the logic behind the 48/17 and 32/17 constants.
733  const F2 constant_48_over_17 =
734      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0);
735  const F2 constant_neg_32_over_17 =
736      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0);
737  F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17;
738  for (int i = 0; i < 3; i++) {
739    F2 half_denominator_times_x = half_denominator * x;
740    F2 one_minus_half_denominator_times_x =
741        F2::One() - half_denominator_times_x;
742    x = x + Rescale<2>(x * one_minus_half_denominator_times_x);
743  }
744  return Rescale<0>(ExactMulByPot<-1>(x));
745}
746
747// Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0.
748template <typename tRawType, int tIntegerBits>
749FixedPoint<tRawType, 0> logistic_on_positive_values(
750    FixedPoint<tRawType, tIntegerBits> a) {
751  return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a));
752}
753
754// Returns logistic(x) = 1 / (1 + exp(-x)) for any x.
755template <typename tRawType, int tIntegerBits>
756FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) {
757  typedef FixedPoint<tRawType, tIntegerBits> InputF;
758  typedef FixedPoint<tRawType, 0> ResultF;
759  tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero());
760  tRawType mask_if_zero = MaskIfZero(a);
761  InputF abs_input = SelectUsingMask(mask_if_positive, a, -a);
762  ResultF result_if_positive = logistic_on_positive_values(abs_input);
763  ResultF result_if_negative = ResultF::One() - result_if_positive;
764  const ResultF one_half =
765      GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5);
766  return SelectUsingMask(mask_if_zero, one_half,
767                         SelectUsingMask(mask_if_positive, result_if_positive,
768                                         result_if_negative));
769}
770
771}  // end namespace gemmlowp
772
773#ifdef GEMMLOWP_NEON
774#include "./fixedpoint_neon.h"
775#elif defined(GEMMLOWP_SSE4)
776#include "./fixedpoint_sse.h"
777#endif
778
779#endif  // GEMMLOWP_INTERNAL_FIXEDPOINT_H_
780