1// Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15// fixedpoint.h: fixed-point arithmetic, with basic operations and 16// a few math functions such as tanh. 17 18#ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 19#define GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 20 21#include <cassert> 22#include <limits> 23 24#include "../internal/common.h" 25 26namespace gemmlowp { 27 28// Part 1: Low-level integer-arithmetic primitives. 29// The implementations here are generic implementations valid for 30// scalar types (e.g. std::int32_t). Architecture-specific SIMD types 31// (e.g. NEON int32x4_t) may be supported by providing 32// specializations for them in separate files. 33// 34// The purpose of these primitives is two-fold: 35// - They will be used to implement higher-level fixed-point 36// abstractions, namely the FixedPoint class and its arithmetic 37// operators. 38// - They will be directly used to implement some more involved 39// fixed-point computations, e.g. the fixed-point implementation 40// of math functions such as tanh. 41 42// Some compile-time traits around raw types to handle SIMD aspects: 43// number of lanes, underlying scalar type. 44template <typename tIntegerType> 45struct FixedPointRawTypeTraits {}; 46 47template <> 48struct FixedPointRawTypeTraits<std::int32_t> { 49 typedef std::int32_t ScalarRawType; 50 static const int kLanes = 1; 51}; 52 53// Returns a SIMD value duplicating a scalar value across all lanes. 54template <typename tRawType> 55tRawType Dup(typename FixedPointRawTypeTraits<tRawType>::ScalarRawType x) { 56 return x; 57} 58 59// Plain bit-wise AND 60template <typename tIntegerType> 61tIntegerType BitAnd(tIntegerType a, tIntegerType b) { 62 return a & b; 63} 64 65// Plain bit-wise OR 66template <typename tIntegerType> 67tIntegerType BitOr(tIntegerType a, tIntegerType b) { 68 return a | b; 69} 70 71// Plain bit-wise XOR 72template <typename tIntegerType> 73tIntegerType BitXor(tIntegerType a, tIntegerType b) { 74 return a ^ b; 75} 76 77// Plain bit-wise NOT 78template <typename tIntegerType> 79tIntegerType BitNot(tIntegerType a) { 80 return ~a; 81} 82 83// Integer addition. Not saturating. Overflow is undefined behavior. 84template <typename tIntegerType> 85tIntegerType Add(tIntegerType a, tIntegerType b) { 86 return a + b; 87} 88 89// Integer subtraction. Not saturating. Overflow is undefined behavior. 90template <typename tIntegerType> 91tIntegerType Mul(tIntegerType a, tIntegerType b) { 92 return a * b; 93} 94 95template <typename tIntegerType> 96tIntegerType Sub(tIntegerType a, tIntegerType b) { 97 return a - b; 98} 99 100// Integer unary negative. Not saturating. Overflow is undefined behavior. 101template <typename tIntegerType> 102tIntegerType Neg(tIntegerType a) { 103 return -a; 104} 105 106// Integer arithmetic left-shift, equivalent to multiplying with a 107// power of two. Not saturating. Overflow is undefined behavior. 108template <typename tIntegerType> 109tIntegerType ShiftLeft(tIntegerType a, int offset) { 110 return a << offset; 111} 112 113// Integer arithmetic right-shift. Not rounding. 114// Relying on implementation-defined, but in-practice-consistent, 115// C++ compiler behavior. 116template <typename tIntegerType> 117tIntegerType ShiftRight(tIntegerType a, int offset) { 118 return a >> offset; 119} 120 121// Each bit of the result is set to the corresponding bit of either then_val or 122// else_val depending on whether the corresponding bit of if_mask is set. 123// Equivalent to the VBSL instruction in ARM NEON. 124template <typename tIntegerType> 125tIntegerType SelectUsingMask(tIntegerType if_mask, tIntegerType then_val, 126 tIntegerType else_val) { 127 return BitXor(BitAnd(if_mask, then_val), BitAnd(BitNot(if_mask), else_val)); 128} 129 130// For each input scalar, the corresponding bits of the result are set if the 131// input scalar is non-zero. 132template <typename tIntegerType> 133tIntegerType MaskIfNonZero(tIntegerType a) { 134 static const tIntegerType zero = 0; 135 return a ? BitNot(zero) : zero; 136} 137 138// For each input scalar, the corresponding bits of the result are set if the 139// input scalar is zero. 140template <typename tIntegerType> 141tIntegerType MaskIfZero(tIntegerType a) { 142 return MaskIfNonZero<tIntegerType>(!a); 143} 144 145// For each pair of input scalars, the corresponding bits of the result are 146// set if the input scalars are equal. 147template <typename tIntegerType> 148tIntegerType MaskIfEqual(tIntegerType a, tIntegerType b) { 149 return MaskIfNonZero<tIntegerType>(a == b); 150} 151 152// For each pair of input scalars, the corresponding bits of the result are 153// set if the input scalars are not equal. 154template <typename tIntegerType> 155tIntegerType MaskIfNotEqual(tIntegerType a, tIntegerType b) { 156 return MaskIfNonZero<tIntegerType>(a != b); 157} 158 159// For each pair of input scalars, the corresponding bits of the result are 160// set if the input scalars a, b satisfy a > b. 161template <typename tIntegerType> 162tIntegerType MaskIfGreaterThan(tIntegerType a, tIntegerType b) { 163 return MaskIfNonZero<tIntegerType>(a > b); 164} 165 166// For each pair of input scalars, the corresponding bits of the result are 167// set if the input scalars a, b satisfy a >= b. 168template <typename tIntegerType> 169tIntegerType MaskIfGreaterThanOrEqual(tIntegerType a, tIntegerType b) { 170 return MaskIfNonZero<tIntegerType>(a >= b); 171} 172 173// For each pair of input scalars, the corresponding bits of the result are 174// set if the input scalars a, b satisfy a < b. 175template <typename tIntegerType> 176tIntegerType MaskIfLessThan(tIntegerType a, tIntegerType b) { 177 return MaskIfNonZero<tIntegerType>(a < b); 178} 179 180// For each pair of input scalars, the corresponding bits of the result are 181// set if the input scalars a, b satisfy a <= b. 182template <typename tIntegerType> 183tIntegerType MaskIfLessThanOrEqual(tIntegerType a, tIntegerType b) { 184 return MaskIfNonZero<tIntegerType>(a <= b); 185} 186 187// Returns true if all of the input scalars are nonzero. 188// This function may currently assume that each of the input scalars has either 189// all or none of its bits set. Otherwise, its behavior is currently undefined. 190template <typename tIntegerType> 191bool All(tIntegerType a) { 192 return a; 193} 194 195// Returns true if any of the input scalars are nonzero. 196// This function may currently assume that each of the input scalars has either 197// all or none of its bits set. Otherwise, its behavior is currently undefined. 198template <typename tIntegerType> 199bool Any(tIntegerType a) { 200 return a; 201} 202 203// Returns (a+b)/2, rounded to the nearest integer. 204// Equivalent to VRHADD in the ARM NEON instruction set. 205template <typename IntegerType> 206IntegerType RoundingHalfSum(IntegerType a, IntegerType b) { 207 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 208 return a; 209} 210 211template <> 212inline std::int32_t RoundingHalfSum(std::int32_t a, std::int32_t b) { 213 std::int64_t a64 = a; 214 std::int64_t b64 = b; 215 std::int64_t sum = a64 + b64; 216 std::int64_t sign = sum >= 0 ? 1 : -1; 217 return static_cast<std::int32_t>((sum + sign) / 2); 218} 219 220// Returns the integer that represents the product of two fixed-point 221// numbers, interpreting all integers as fixed-point values in the 222// interval [-1, 1), rounding to the nearest value, and saturating 223// -1 * -1 to the maximum value (since 1 is not in the half-open 224// interval [-1, 1)). 225// 226// [The explanation below specializes to std::int32_t for example purpose.] 227// 228// The mapping between IntegerType and the interval [-1, 1) is unique and 229// implied by IntegerType, which is assumed to be signed. For example, 230// for IntegerType==std::int32_t, the mapping is 231// real_value = integer_value / 2^31. 232// So in this case, and leaving aside rounding and saturating, this 233// function computes ((a / 2^31) * (b / 2^31)) * 2^31, which simplifies to 234// (a * b) / 2^31. 235// 236// The 'doubling' part in the name of this function comes from the fact that 237// this operation is very close to a "multiply-high" operation, keeping only 238// the top half bits, except that that would be effectively computing 239// (a * b) / 2^32, 240// so here we are computing 2x that, since 241// 1/2^31 = 2 * 1/2^32. 242// The idea is to use all of the available 32 bits in the destination int32 243// value. 244// 245// [End of the explanation specializing to int32.] 246// 247// This is equivalent to the VQRDMULH instruction in ARM NEON. 248template <typename IntegerType> 249IntegerType SaturatingRoundingDoublingHighMul(IntegerType a, IntegerType b) { 250 static_assert(std::is_same<IntegerType, void>::value, "unimplemented"); 251 return a; 252} 253 254// This function implements the same computation as the ARMv7 NEON VQRDMULH 255// instruction. 256template <> 257inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, 258 std::int32_t b) { 259 bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min(); 260 std::int64_t a_64(a); 261 std::int64_t b_64(b); 262 std::int64_t ab_64 = a_64 * b_64; 263 std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); 264 std::int32_t ab_x2_high32 = 265 static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31)); 266 return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32; 267} 268 269// Correctly-rounded-to-nearest division by a power-of-two. 270// Also known as a rounding arithmetic right shift. 271template <typename IntegerType> 272inline IntegerType RoundingDivideByPOT(IntegerType x, int exponent) { 273 using ScalarIntegerType = 274 typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; 275 static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value, 276 "Currently only supporting int32 scalar and SIMD types"); 277 assert(exponent >= 0); 278 assert(exponent <= 31); 279 const IntegerType mask = Dup<IntegerType>((1ll << exponent) - 1); 280 const IntegerType zero = Dup<IntegerType>(0); 281 const IntegerType one = Dup<IntegerType>(1); 282 const IntegerType remainder = BitAnd(x, mask); 283 const IntegerType threshold = 284 Add(ShiftRight(mask, 1), BitAnd(MaskIfLessThan(x, zero), one)); 285 return Add(ShiftRight(x, exponent), 286 BitAnd(MaskIfGreaterThan(remainder, threshold), one)); 287} 288 289// Returns the product of a run-time integer value by a compile-time power 290// of two, with either a positive exponent (equivalent to an arithmetic 291// left shift, saturating) or a negative exponent (equivalent to an arithmetic 292// right shift, rounding to nearest). 293template <int Exponent, typename IntegerType, 294 int ExponentSign = (Exponent > 0 ? 1 : Exponent < 0 ? -1 : 0)> 295struct ImplSaturatingRoundingMultiplyByPOT {}; 296 297template <int Exponent, typename IntegerType> 298struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 0> { 299 static IntegerType eval(IntegerType x) { return x; } 300}; 301 302template <int Exponent, typename IntegerType> 303struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, 1> { 304 static IntegerType eval(IntegerType x) { 305 using ScalarIntegerType = 306 typename FixedPointRawTypeTraits<IntegerType>::ScalarRawType; 307 static_assert(std::is_same<ScalarIntegerType, std::int32_t>::value, 308 "Currently only supporting int32 scalar and SIMD types"); 309 const IntegerType min = 310 Dup<IntegerType>(std::numeric_limits<std::int32_t>::min()); 311 const IntegerType max = 312 Dup<IntegerType>(std::numeric_limits<std::int32_t>::max()); 313 314 const std::int32_t threshold = ((1 << (31 - Exponent)) - 1); 315 const IntegerType positive_mask = 316 MaskIfGreaterThan(x, Dup<IntegerType>(threshold)); 317 const IntegerType negative_mask = 318 MaskIfLessThan(x, Dup<IntegerType>(-threshold)); 319 320 IntegerType result = ShiftLeft(x, Exponent); 321 result = SelectUsingMask(positive_mask, max, result); 322 result = SelectUsingMask(negative_mask, min, result); 323 return result; 324 } 325}; 326 327template <int Exponent, typename IntegerType> 328struct ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType, -1> { 329 static IntegerType eval(IntegerType x) { 330 return RoundingDivideByPOT<IntegerType>(x, -Exponent); 331 } 332}; 333 334template <int Exponent, typename IntegerType> 335IntegerType SaturatingRoundingMultiplyByPOT(IntegerType x) { 336 return ImplSaturatingRoundingMultiplyByPOT<Exponent, IntegerType>::eval(x); 337} 338 339// Part 2: the FixedPoint class. 340 341// A FixedPoint object represents a fixed-point value stored in the underlying 342// integer type tRawType, if tRawType is a plain scalar integer type. 343// Alternatively, tRawType may be a SIMD type (e.g. NEON int32x4_t) in which 344// case a FixedPoint object represents a corresponding SIMD vector of fixed 345// point values. 346// 347// tIntegerBits describes the range of the fixed-point format: if 348// tIntegerBits == m then the range of representable values is the half-open 349// interval [-2^m; 2^m) where the open boundary on the right side means that 350// 2^m is not representable (how close the maximum representable value is to 351// it, depends on bit-depth of tRawType). 352// 353// In "Q format notation", 354// https://en.wikipedia.org/wiki/Q_(number_format) 355// we are describing the format 356// Qm.n 357// where 358// m = tIntegerBits 359// and 360// n = NumberOfBits(tRawType) - (m + 1) 361// Note that the (m + 1) in the above line is because we adopt the convention 362// that we count the integer bits exclusively of the sign bit; so (m + 1) is 363// the total number of integer bits inclusive of the sign bit. 364// 365// Accordingly, the number of integral representable values in our range 366// [-2^m ; 2^m) 367// is equal to 2^(m+1). 368template <typename tRawType, int tIntegerBits> 369class FixedPoint { 370 public: 371 typedef tRawType RawType; 372 373 typedef FixedPointRawTypeTraits<RawType> RawTypeTraits; 374 typedef typename RawTypeTraits::ScalarRawType ScalarRawType; 375 376 static const int kTotalBits = 8 * sizeof(ScalarRawType); 377 static const int kIntegerBits = tIntegerBits; 378 static const int kFractionalBits = kTotalBits - 1 - kIntegerBits; 379 static_assert(kIntegerBits >= 0 && kIntegerBits < kTotalBits, 380 "bad IntegerBits"); 381 382 typedef FixedPoint<ScalarRawType, kIntegerBits> ScalarFixedPointType; 383 384 static const ScalarRawType ScalarRawMin() { 385 return std::numeric_limits<ScalarRawType>::min(); 386 } 387 388 static const ScalarRawType ScalarRawMax() { 389 return std::numeric_limits<ScalarRawType>::max(); 390 } 391 392 static const ScalarRawType RawMin() { 393 return VectorFromScalar(ScalarRawMin()); 394 } 395 396 static const ScalarRawType RawMax() { 397 return VectorFromScalar(ScalarRawMax()); 398 } 399 400 static FixedPoint FromRaw(RawType x) { 401 FixedPoint retval; 402 retval.raw() = x; 403 return retval; 404 } 405 406 static FixedPoint FromScalarRaw(ScalarRawType x) { 407 FixedPoint retval; 408 retval.raw() = Dup<RawType>(x); 409 return retval; 410 } 411 412 static FixedPoint FromScalarFixedPoint(ScalarFixedPointType x) { 413 return FromScalarRaw(x.raw()); 414 } 415 416 template <int Exponent> 417 static FixedPoint ConstantPOT() { 418 static const int kOffset = kFractionalBits + Exponent; 419 static_assert( 420 kOffset < 31, 421 "Constant not exactly representable in this fixed-point format"); 422 return FromScalarRaw(ScalarRawType(1) << kOffset); 423 } 424 425 static FixedPoint Zero() { return FromScalarRaw(0); } 426 427 static FixedPoint One() { 428 return FromScalarRaw(kIntegerBits == 0 429 ? ScalarRawMax() 430 : (ScalarRawType(1) << kFractionalBits)); 431 } 432 433 static FixedPoint FromDouble(double x) { 434 const double min_bound = static_cast<double>(ScalarRawMin()); 435 const double max_bound = static_cast<double>(ScalarRawMax()); 436 return FromScalarRaw(static_cast<std::int32_t>(std::min( 437 std::max(round(x * static_cast<double>(1ll << kFractionalBits)), 438 min_bound), 439 max_bound))); 440 } 441 442 RawType raw() const { return i_; } 443 RawType& raw() { return i_; } 444 445 private: 446 RawType i_; 447}; 448 449// Part 3: implementation of arithmetic operators for the 450// FixedPoint class, and a few related functions. 451 452// A FixedPoint multiplication is just a 453// SaturatingRoundingDoublingHighMul operation on the underlying 454// raw integer values. The IntegerBits simply add up, as is obvious 455// from the fact that the range is [-2^IntegerBits, 2^IntegerBits). 456template <typename tRawType, int tIntegerBits_a, int tIntegerBits_b> 457FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> operator*( 458 FixedPoint<tRawType, tIntegerBits_a> a, 459 FixedPoint<tRawType, tIntegerBits_b> b) { 460 FixedPoint<tRawType, tIntegerBits_a + tIntegerBits_b> c; 461 c.raw() = SaturatingRoundingDoublingHighMul(a.raw(), b.raw()); 462 return c; 463} 464 465// Tweaking IntegerBits gives exact multiplication by a power of two. 466template <int tExponent, typename tRawType, int tIntegerBits> 467FixedPoint<tRawType, tExponent + tIntegerBits> ExactMulByPot( 468 FixedPoint<tRawType, tIntegerBits> a) { 469 FixedPoint<tRawType, tExponent + tIntegerBits> c; 470 c.raw() = a.raw(); 471 return c; 472} 473 474// If we want to leave IntegerBits fixed, then multiplication 475// by a power of two has to be saturating/rounding, not exact anymore. 476template <int tExponent, typename tRawType, int tIntegerBits> 477FixedPoint<tRawType, tIntegerBits> SaturatingRoundingMultiplyByPOT( 478 FixedPoint<tRawType, tIntegerBits> a) { 479 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 480 SaturatingRoundingMultiplyByPOT<tExponent>(a.raw())); 481} 482 483// Generic arithmetic operators. 484 485#define MAKE_FIXEDPOINT_UNARY_FUNC(FuncName, ImplFuncName) \ 486 template <typename tRawType, int tIntegerBits> \ 487 FixedPoint<tRawType, tIntegerBits> FuncName( \ 488 FixedPoint<tRawType, tIntegerBits> a) { \ 489 return FixedPoint<tRawType, tIntegerBits>::FromRaw(ImplFuncName(a.raw())); \ 490 } 491 492#define MAKE_FIXEDPOINT_BINARY_FUNC(FuncName, ImplFuncName) \ 493 template <typename tRawType, int tIntegerBits> \ 494 FixedPoint<tRawType, tIntegerBits> FuncName( \ 495 FixedPoint<tRawType, tIntegerBits> a, \ 496 FixedPoint<tRawType, tIntegerBits> b) { \ 497 return FixedPoint<tRawType, tIntegerBits>::FromRaw( \ 498 ImplFuncName(a.raw(), b.raw())); \ 499 } 500 501MAKE_FIXEDPOINT_UNARY_FUNC(operator-, Neg) 502MAKE_FIXEDPOINT_UNARY_FUNC(operator~, BitNot) 503MAKE_FIXEDPOINT_BINARY_FUNC(operator+, Add) 504MAKE_FIXEDPOINT_BINARY_FUNC(operator-, Sub) 505MAKE_FIXEDPOINT_BINARY_FUNC(operator&, BitAnd) 506MAKE_FIXEDPOINT_BINARY_FUNC(operator^, BitXor) 507MAKE_FIXEDPOINT_BINARY_FUNC(operator|, BitOr) 508MAKE_FIXEDPOINT_BINARY_FUNC(RoundingHalfSum, RoundingHalfSum) 509 510#undef MAKE_FIXEDPOINT_UNARY_FUNC 511#undef MAKE_FIXEDPOINT_BINARY_FUNC 512 513#define MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(FuncName) \ 514 template <typename tRawType, int tIntegerBits> \ 515 tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a) { \ 516 return FuncName(a.raw()); \ 517 } 518 519#define MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(FuncName) \ 520 template <typename tRawType, int tIntegerBits> \ 521 tRawType FuncName(FixedPoint<tRawType, tIntegerBits> a, \ 522 FixedPoint<tRawType, tIntegerBits> b) { \ 523 return FuncName(a.raw(), b.raw()); \ 524 } 525 526MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfZero) 527MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW(MaskIfNonZero) 528MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfEqual) 529MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfNotEqual) 530MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThan) 531MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfGreaterThanOrEqual) 532MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThan) 533MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW(MaskIfLessThanOrEqual) 534 535#undef MAKE_FIXEDPOINT_UNARY_FUNC_RETURNING_RAW 536#undef MAKE_FIXEDPOINT_BINARY_FUNC_RETURNING_RAW 537 538template <typename tRawType, int tIntegerBits> 539FixedPoint<tRawType, tIntegerBits> SelectUsingMask( 540 tRawType if_mask, FixedPoint<tRawType, tIntegerBits> then_val, 541 FixedPoint<tRawType, tIntegerBits> else_val) { 542 return FixedPoint<tRawType, tIntegerBits>::FromRaw( 543 SelectUsingMask(if_mask, then_val.raw(), else_val.raw())); 544} 545 546template <typename tRawType, int tIntegerBits> 547bool operator==(FixedPoint<tRawType, tIntegerBits> a, 548 FixedPoint<tRawType, tIntegerBits> b) { 549 return All(MaskIfEqual(a.raw(), b.raw())); 550} 551 552template <typename tRawType, int tIntegerBits> 553bool operator!=(FixedPoint<tRawType, tIntegerBits> a, 554 FixedPoint<tRawType, tIntegerBits> b) { 555 return !(a == b); 556} 557 558// Conversion to floating-point. 559template <typename tRawType, int tIntegerBits> 560double ToDouble(FixedPoint<tRawType, tIntegerBits> x) { 561 static_assert(FixedPointRawTypeTraits<tRawType>::kLanes == 1, 562 "not applicable to SIMD types"); 563 typedef FixedPoint<tRawType, tIntegerBits> F; 564 return x.raw() / static_cast<double>(1ll << F::kFractionalBits); 565} 566 567// Rescale changes the number of IntegerBits and updates the underlying 568// raw integer value accordingly. 569template <int tIntegerBitsDst, typename tRawType, int tIntegerBitsSrc> 570FixedPoint<tRawType, tIntegerBitsDst> Rescale( 571 FixedPoint<tRawType, tIntegerBitsSrc> x) { 572 static const int kExponent = tIntegerBitsSrc - tIntegerBitsDst; 573 FixedPoint<tRawType, tIntegerBitsDst> result; 574 result.raw() = SaturatingRoundingMultiplyByPOT<kExponent>(x.raw()); 575 return result; 576} 577 578// CheckedFixedPointConstant allows to specify fixed-point constants 579// initialized as real numbers, in a way that does not compile floating-point 580// arithmetic in production code, yet still checks agreement with the 581// floating-point expressions when asserts are enabled. 582#ifdef GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS 583template <typename FixedPointType> 584FixedPointType CheckedFixedPointConstant( 585 typename FixedPointType::ScalarRawType raw_value, double double_value) { 586 typedef typename FixedPointType::RawType RawType; 587 const FixedPointType result = FixedPointType::FromScalarRaw(raw_value); 588 assert(result == FixedPointType::FromDouble(double_value)); 589 return result; 590} 591#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \ 592 DoubleValue) \ 593 (CheckedFixedPointConstant<FixedPointType>(ScalarRawValue, DoubleValue)) 594 595#else 596#define GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPointType, ScalarRawValue, \ 597 DoubleValue) \ 598 (FixedPointType::FromScalarRaw(ScalarRawValue)) 599#endif 600 601// Implementation of exponential function. 602 603// Returns exp(x) for x in [-1/4, 0). 604template <typename tRawType> 605FixedPoint<tRawType, 0> exp_on_interval_between_negative_one_quarter_and_0_excl( 606 FixedPoint<tRawType, 0> a) { 607 typedef FixedPoint<tRawType, 0> F; 608 const F constant_term = 609 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 1895147668, std::exp(-1.0 / 8.0)); 610 const F constant_1_over_3 = 611 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F, 715827883, 1.0 / 3.0); 612 // We're evaluating a Taylor expansion around -1/8, so we do the change of 613 // variable: x = a + 1/8. 614 // In fixed-point with 0 integer bits, 1/8 is represented by 1 << 28. 615 F x = a + F::template ConstantPOT<-3>(); 616 F x2 = x * x; 617 F x3 = x2 * x; 618 F x4 = x2 * x2; 619 F x4_over_4 = SaturatingRoundingMultiplyByPOT<-2>(x4); 620 F x4_over_24_plus_x3_over_6_plus_x2_over_2 = 621 SaturatingRoundingMultiplyByPOT<-1>( 622 ((x4_over_4 + x3) * constant_1_over_3) + x2); 623 return constant_term + 624 constant_term * (x + x4_over_24_plus_x3_over_6_plus_x2_over_2); 625} 626 627// Returns exp(x) for x < 0. 628template <typename tRawType, int tIntegerBits> 629FixedPoint<tRawType, 0> exp_on_negative_values( 630 FixedPoint<tRawType, tIntegerBits> a) { 631 typedef FixedPoint<tRawType, tIntegerBits> InputF; 632 typedef FixedPoint<tRawType, 0> ResultF; 633 static const int kFractionalBits = InputF::kFractionalBits; 634 static const int kIntegerBits = InputF::kIntegerBits; 635 static const InputF kOneQuarter = InputF::template ConstantPOT<-2>(); 636 InputF mask = kOneQuarter - InputF::FromScalarRaw(1); 637 InputF a_mod_quarter_minus_one_quarter = (a & mask) - kOneQuarter; 638 ResultF result = exp_on_interval_between_negative_one_quarter_and_0_excl( 639 Rescale<0>(a_mod_quarter_minus_one_quarter)); 640 tRawType remainder = (a_mod_quarter_minus_one_quarter - a).raw(); 641 642#define GEMMLOWP_EXP_BARREL_SHIFTER(Exponent, FixedPointMultiplier) \ 643 if (kIntegerBits > Exponent) { \ 644 const ResultF kMultiplier = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( \ 645 ResultF, FixedPointMultiplier, std::exp(-std::pow(2.0, Exponent))); \ 646 static constexpr int kShiftAmount = \ 647 kIntegerBits > Exponent ? kFractionalBits + Exponent : 0; \ 648 result = SelectUsingMask( \ 649 MaskIfNonZero(BitAnd(remainder, Dup<tRawType>(1 << kShiftAmount))), \ 650 result * kMultiplier, result); \ 651 } 652 653 GEMMLOWP_EXP_BARREL_SHIFTER(-2, 1672461947); 654 GEMMLOWP_EXP_BARREL_SHIFTER(-1, 1302514674); 655 GEMMLOWP_EXP_BARREL_SHIFTER(+0, 790015084); 656 GEMMLOWP_EXP_BARREL_SHIFTER(+1, 290630308); 657 GEMMLOWP_EXP_BARREL_SHIFTER(+2, 39332535); 658 GEMMLOWP_EXP_BARREL_SHIFTER(+3, 720401); 659 GEMMLOWP_EXP_BARREL_SHIFTER(+4, 242); 660 661#undef GEMMLOWP_EXP_BARREL_SHIFTER 662 663 if (kIntegerBits > 5) { 664 static const int b = kIntegerBits > 5 ? kFractionalBits + 5 : 0; 665 const InputF clamp = 666 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(InputF, -(1 << b), -32.0); 667 result = SelectUsingMask(MaskIfLessThan(a, clamp), ResultF::Zero(), result); 668 } 669 670 result = SelectUsingMask(MaskIfZero(a), ResultF::One(), result); 671 return result; 672} 673 674// Implementation of tanh: (1 - exp(-2x)) / (1 + exp(-2x)). 675 676// Returns (1 - x) / (1 + x) for x in (0, 1). 677template <typename tRawType> 678FixedPoint<tRawType, 0> one_minus_x_over_one_plus_x_for_x_in_0_1( 679 FixedPoint<tRawType, 0> a) { 680 typedef FixedPoint<tRawType, 0> F0; 681 typedef FixedPoint<tRawType, 2> F2; 682 F0 half_denominator = RoundingHalfSum(a, F0::One()); 683 // Newton-Raphson division 684 // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division 685 // Refer to that page for the logic behind the 48/17 and 32/17 constants. 686 const F2 constant_48_over_17 = 687 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); 688 const F2 constant_neg_32_over_17 = 689 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); 690 F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; 691 for (int i = 0; i < 3; i++) { 692 F2 half_denominator_times_x = half_denominator * x; 693 F2 one_minus_half_denominator_times_x = 694 F2::One() - half_denominator_times_x; 695 x = x + Rescale<2>(x * one_minus_half_denominator_times_x); 696 } 697 return Rescale<0>(x - F2::One()); 698} 699 700// Returns -tanh(x) for x < 0. 701template <typename tRawType, int tIntegerBits> 702FixedPoint<tRawType, 0> neg_tanh_on_negative_values( 703 FixedPoint<tRawType, tIntegerBits> a) { 704 return one_minus_x_over_one_plus_x_for_x_in_0_1( 705 exp_on_negative_values(ExactMulByPot<1>(a))); 706} 707 708// Returns tanh(x) for any x. 709template <typename tRawType, int tIntegerBits> 710FixedPoint<tRawType, 0> tanh(FixedPoint<tRawType, tIntegerBits> a) { 711 typedef FixedPoint<tRawType, tIntegerBits> InputF; 712 typedef FixedPoint<tRawType, 0> ResultF; 713 tRawType mask_if_negative = MaskIfLessThan(a, InputF::Zero()); 714 tRawType mask_if_zero = MaskIfZero(a); 715 InputF n = SelectUsingMask(mask_if_negative, a, -a); 716 ResultF t = neg_tanh_on_negative_values(n); 717 return SelectUsingMask(mask_if_zero, ResultF::Zero(), 718 SelectUsingMask(mask_if_negative, -t, t)); 719} 720 721// Implementation of logistic function. 722 723// Returns 1 / (1 + x) for x in (0, 1). 724template <typename tRawType> 725FixedPoint<tRawType, 0> one_over_one_plus_x_for_x_in_0_1( 726 FixedPoint<tRawType, 0> a) { 727 typedef FixedPoint<tRawType, 0> F0; 728 typedef FixedPoint<tRawType, 2> F2; 729 F0 half_denominator = RoundingHalfSum(a, F0::One()); 730 // Newton-Raphson division 731 // https://en.wikipedia.org/wiki/Division_algorithm#Newton.E2.80.93Raphson_division 732 // Refer to that page for the logic behind the 48/17 and 32/17 constants. 733 const F2 constant_48_over_17 = 734 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, 1515870810, 48.0 / 17.0); 735 const F2 constant_neg_32_over_17 = 736 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F2, -1010580540, -32.0 / 17.0); 737 F2 x = constant_48_over_17 + half_denominator * constant_neg_32_over_17; 738 for (int i = 0; i < 3; i++) { 739 F2 half_denominator_times_x = half_denominator * x; 740 F2 one_minus_half_denominator_times_x = 741 F2::One() - half_denominator_times_x; 742 x = x + Rescale<2>(x * one_minus_half_denominator_times_x); 743 } 744 return Rescale<0>(ExactMulByPot<-1>(x)); 745} 746 747// Returns logistic(x) = 1 / (1 + exp(-x)) for x > 0. 748template <typename tRawType, int tIntegerBits> 749FixedPoint<tRawType, 0> logistic_on_positive_values( 750 FixedPoint<tRawType, tIntegerBits> a) { 751 return one_over_one_plus_x_for_x_in_0_1(exp_on_negative_values(-a)); 752} 753 754// Returns logistic(x) = 1 / (1 + exp(-x)) for any x. 755template <typename tRawType, int tIntegerBits> 756FixedPoint<tRawType, 0> logistic(FixedPoint<tRawType, tIntegerBits> a) { 757 typedef FixedPoint<tRawType, tIntegerBits> InputF; 758 typedef FixedPoint<tRawType, 0> ResultF; 759 tRawType mask_if_positive = MaskIfGreaterThan(a, InputF::Zero()); 760 tRawType mask_if_zero = MaskIfZero(a); 761 InputF abs_input = SelectUsingMask(mask_if_positive, a, -a); 762 ResultF result_if_positive = logistic_on_positive_values(abs_input); 763 ResultF result_if_negative = ResultF::One() - result_if_positive; 764 const ResultF one_half = 765 GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(ResultF, 1 << 30, 0.5); 766 return SelectUsingMask(mask_if_zero, one_half, 767 SelectUsingMask(mask_if_positive, result_if_positive, 768 result_if_negative)); 769} 770 771} // end namespace gemmlowp 772 773#ifdef GEMMLOWP_NEON 774#include "./fixedpoint_neon.h" 775#elif defined(GEMMLOWP_SSE4) 776#include "./fixedpoint_sse.h" 777#endif 778 779#endif // GEMMLOWP_INTERNAL_FIXEDPOINT_H_ 780