1// float-weight.h 2 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// Copyright 2005-2010 Google, Inc. 16// Author: riley@google.com (Michael Riley) 17// 18// \file 19// Float weight set and associated semiring operation definitions. 20// 21 22#ifndef FST_LIB_FLOAT_WEIGHT_H__ 23#define FST_LIB_FLOAT_WEIGHT_H__ 24 25#include <limits> 26#include <climits> 27#include <sstream> 28#include <string> 29 30#include <fst/util.h> 31#include <fst/weight.h> 32 33 34namespace fst { 35 36// numeric limits class 37template <class T> 38class FloatLimits { 39 public: 40 static const T PosInfinity() { 41 static const T pos_infinity = numeric_limits<T>::infinity(); 42 return pos_infinity; 43 } 44 45 static const T NegInfinity() { 46 static const T neg_infinity = -PosInfinity(); 47 return neg_infinity; 48 } 49 50 static const T NumberBad() { 51 static const T number_bad = numeric_limits<T>::quiet_NaN(); 52 return number_bad; 53 } 54 55}; 56 57// weight class to be templated on floating-points types 58template <class T = float> 59class FloatWeightTpl { 60 public: 61 FloatWeightTpl() {} 62 63 FloatWeightTpl(T f) : value_(f) {} 64 65 FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {} 66 67 FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) { 68 value_ = w.value_; 69 return *this; 70 } 71 72 istream &Read(istream &strm) { 73 return ReadType(strm, &value_); 74 } 75 76 ostream &Write(ostream &strm) const { 77 return WriteType(strm, value_); 78 } 79 80 size_t Hash() const { 81 union { 82 T f; 83 size_t s; 84 } u; 85 u.s = 0; 86 u.f = value_; 87 return u.s; 88 } 89 90 const T &Value() const { return value_; } 91 92 protected: 93 void SetValue(const T &f) { value_ = f; } 94 95 inline static string GetPrecisionString() { 96 int64 size = sizeof(T); 97 if (size == sizeof(float)) return ""; 98 size *= CHAR_BIT; 99 100 string result; 101 Int64ToStr(size, &result); 102 return result; 103 } 104 105 private: 106 T value_; 107}; 108 109// Single-precision float weight 110typedef FloatWeightTpl<float> FloatWeight; 111 112template <class T> 113inline bool operator==(const FloatWeightTpl<T> &w1, 114 const FloatWeightTpl<T> &w2) { 115 // Volatile qualifier thwarts over-aggressive compiler optimizations 116 // that lead to problems esp. with NaturalLess(). 117 volatile T v1 = w1.Value(); 118 volatile T v2 = w2.Value(); 119 return v1 == v2; 120} 121 122inline bool operator==(const FloatWeightTpl<double> &w1, 123 const FloatWeightTpl<double> &w2) { 124 return operator==<double>(w1, w2); 125} 126 127inline bool operator==(const FloatWeightTpl<float> &w1, 128 const FloatWeightTpl<float> &w2) { 129 return operator==<float>(w1, w2); 130} 131 132template <class T> 133inline bool operator!=(const FloatWeightTpl<T> &w1, 134 const FloatWeightTpl<T> &w2) { 135 return !(w1 == w2); 136} 137 138inline bool operator!=(const FloatWeightTpl<double> &w1, 139 const FloatWeightTpl<double> &w2) { 140 return operator!=<double>(w1, w2); 141} 142 143inline bool operator!=(const FloatWeightTpl<float> &w1, 144 const FloatWeightTpl<float> &w2) { 145 return operator!=<float>(w1, w2); 146} 147 148template <class T> 149inline bool ApproxEqual(const FloatWeightTpl<T> &w1, 150 const FloatWeightTpl<T> &w2, 151 float delta = kDelta) { 152 return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta; 153} 154 155template <class T> 156inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) { 157 if (w.Value() == FloatLimits<T>::PosInfinity()) 158 return strm << "Infinity"; 159 else if (w.Value() == FloatLimits<T>::NegInfinity()) 160 return strm << "-Infinity"; 161 else if (w.Value() != w.Value()) // Fails for NaN 162 return strm << "BadNumber"; 163 else 164 return strm << w.Value(); 165} 166 167template <class T> 168inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) { 169 string s; 170 strm >> s; 171 if (s == "Infinity") { 172 w = FloatWeightTpl<T>(FloatLimits<T>::PosInfinity()); 173 } else if (s == "-Infinity") { 174 w = FloatWeightTpl<T>(FloatLimits<T>::NegInfinity()); 175 } else { 176 char *p; 177 T f = strtod(s.c_str(), &p); 178 if (p < s.c_str() + s.size()) 179 strm.clear(std::ios::badbit); 180 else 181 w = FloatWeightTpl<T>(f); 182 } 183 return strm; 184} 185 186 187// Tropical semiring: (min, +, inf, 0) 188template <class T> 189class TropicalWeightTpl : public FloatWeightTpl<T> { 190 public: 191 using FloatWeightTpl<T>::Value; 192 193 typedef TropicalWeightTpl<T> ReverseWeight; 194 195 TropicalWeightTpl() : FloatWeightTpl<T>() {} 196 197 TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {} 198 199 TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} 200 201 static const TropicalWeightTpl<T> Zero() { 202 return TropicalWeightTpl<T>(FloatLimits<T>::PosInfinity()); } 203 204 static const TropicalWeightTpl<T> One() { 205 return TropicalWeightTpl<T>(0.0F); } 206 207 static const TropicalWeightTpl<T> NoWeight() { 208 return TropicalWeightTpl<T>(FloatLimits<T>::NumberBad()); } 209 210 static const string &Type() { 211 static const string type = "tropical" + 212 FloatWeightTpl<T>::GetPrecisionString(); 213 return type; 214 } 215 216 bool Member() const { 217 // First part fails for IEEE NaN 218 return Value() == Value() && Value() != FloatLimits<T>::NegInfinity(); 219 } 220 221 TropicalWeightTpl<T> Quantize(float delta = kDelta) const { 222 if (Value() == FloatLimits<T>::NegInfinity() || 223 Value() == FloatLimits<T>::PosInfinity() || 224 Value() != Value()) 225 return *this; 226 else 227 return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); 228 } 229 230 TropicalWeightTpl<T> Reverse() const { return *this; } 231 232 static uint64 Properties() { 233 return kLeftSemiring | kRightSemiring | kCommutative | 234 kPath | kIdempotent; 235 } 236}; 237 238// Single precision tropical weight 239typedef TropicalWeightTpl<float> TropicalWeight; 240 241template <class T> 242inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1, 243 const TropicalWeightTpl<T> &w2) { 244 if (!w1.Member() || !w2.Member()) 245 return TropicalWeightTpl<T>::NoWeight(); 246 return w1.Value() < w2.Value() ? w1 : w2; 247} 248 249inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1, 250 const TropicalWeightTpl<float> &w2) { 251 return Plus<float>(w1, w2); 252} 253 254inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1, 255 const TropicalWeightTpl<double> &w2) { 256 return Plus<double>(w1, w2); 257} 258 259template <class T> 260inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1, 261 const TropicalWeightTpl<T> &w2) { 262 if (!w1.Member() || !w2.Member()) 263 return TropicalWeightTpl<T>::NoWeight(); 264 T f1 = w1.Value(), f2 = w2.Value(); 265 if (f1 == FloatLimits<T>::PosInfinity()) 266 return w1; 267 else if (f2 == FloatLimits<T>::PosInfinity()) 268 return w2; 269 else 270 return TropicalWeightTpl<T>(f1 + f2); 271} 272 273inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1, 274 const TropicalWeightTpl<float> &w2) { 275 return Times<float>(w1, w2); 276} 277 278inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1, 279 const TropicalWeightTpl<double> &w2) { 280 return Times<double>(w1, w2); 281} 282 283template <class T> 284inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1, 285 const TropicalWeightTpl<T> &w2, 286 DivideType typ = DIVIDE_ANY) { 287 if (!w1.Member() || !w2.Member()) 288 return TropicalWeightTpl<T>::NoWeight(); 289 T f1 = w1.Value(), f2 = w2.Value(); 290 if (f2 == FloatLimits<T>::PosInfinity()) 291 return FloatLimits<T>::NumberBad(); 292 else if (f1 == FloatLimits<T>::PosInfinity()) 293 return FloatLimits<T>::PosInfinity(); 294 else 295 return TropicalWeightTpl<T>(f1 - f2); 296} 297 298inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1, 299 const TropicalWeightTpl<float> &w2, 300 DivideType typ = DIVIDE_ANY) { 301 return Divide<float>(w1, w2, typ); 302} 303 304inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1, 305 const TropicalWeightTpl<double> &w2, 306 DivideType typ = DIVIDE_ANY) { 307 return Divide<double>(w1, w2, typ); 308} 309 310 311// Log semiring: (log(e^-x + e^y), +, inf, 0) 312template <class T> 313class LogWeightTpl : public FloatWeightTpl<T> { 314 public: 315 using FloatWeightTpl<T>::Value; 316 317 typedef LogWeightTpl ReverseWeight; 318 319 LogWeightTpl() : FloatWeightTpl<T>() {} 320 321 LogWeightTpl(T f) : FloatWeightTpl<T>(f) {} 322 323 LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} 324 325 static const LogWeightTpl<T> Zero() { 326 return LogWeightTpl<T>(FloatLimits<T>::PosInfinity()); 327 } 328 329 static const LogWeightTpl<T> One() { 330 return LogWeightTpl<T>(0.0F); 331 } 332 333 static const LogWeightTpl<T> NoWeight() { 334 return LogWeightTpl<T>(FloatLimits<T>::NumberBad()); } 335 336 static const string &Type() { 337 static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString(); 338 return type; 339 } 340 341 bool Member() const { 342 // First part fails for IEEE NaN 343 return Value() == Value() && Value() != FloatLimits<T>::NegInfinity(); 344 } 345 346 LogWeightTpl<T> Quantize(float delta = kDelta) const { 347 if (Value() == FloatLimits<T>::NegInfinity() || 348 Value() == FloatLimits<T>::PosInfinity() || 349 Value() != Value()) 350 return *this; 351 else 352 return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); 353 } 354 355 LogWeightTpl<T> Reverse() const { return *this; } 356 357 static uint64 Properties() { 358 return kLeftSemiring | kRightSemiring | kCommutative; 359 } 360}; 361 362// Single-precision log weight 363typedef LogWeightTpl<float> LogWeight; 364// Double-precision log weight 365typedef LogWeightTpl<double> Log64Weight; 366 367template <class T> 368inline T LogExp(T x) { return log(1.0F + exp(-x)); } 369 370template <class T> 371inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1, 372 const LogWeightTpl<T> &w2) { 373 T f1 = w1.Value(), f2 = w2.Value(); 374 if (f1 == FloatLimits<T>::PosInfinity()) 375 return w2; 376 else if (f2 == FloatLimits<T>::PosInfinity()) 377 return w1; 378 else if (f1 > f2) 379 return LogWeightTpl<T>(f2 - LogExp(f1 - f2)); 380 else 381 return LogWeightTpl<T>(f1 - LogExp(f2 - f1)); 382} 383 384inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1, 385 const LogWeightTpl<float> &w2) { 386 return Plus<float>(w1, w2); 387} 388 389inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1, 390 const LogWeightTpl<double> &w2) { 391 return Plus<double>(w1, w2); 392} 393 394template <class T> 395inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1, 396 const LogWeightTpl<T> &w2) { 397 if (!w1.Member() || !w2.Member()) 398 return LogWeightTpl<T>::NoWeight(); 399 T f1 = w1.Value(), f2 = w2.Value(); 400 if (f1 == FloatLimits<T>::PosInfinity()) 401 return w1; 402 else if (f2 == FloatLimits<T>::PosInfinity()) 403 return w2; 404 else 405 return LogWeightTpl<T>(f1 + f2); 406} 407 408inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1, 409 const LogWeightTpl<float> &w2) { 410 return Times<float>(w1, w2); 411} 412 413inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1, 414 const LogWeightTpl<double> &w2) { 415 return Times<double>(w1, w2); 416} 417 418template <class T> 419inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1, 420 const LogWeightTpl<T> &w2, 421 DivideType typ = DIVIDE_ANY) { 422 if (!w1.Member() || !w2.Member()) 423 return LogWeightTpl<T>::NoWeight(); 424 T f1 = w1.Value(), f2 = w2.Value(); 425 if (f2 == FloatLimits<T>::PosInfinity()) 426 return FloatLimits<T>::NumberBad(); 427 else if (f1 == FloatLimits<T>::PosInfinity()) 428 return FloatLimits<T>::PosInfinity(); 429 else 430 return LogWeightTpl<T>(f1 - f2); 431} 432 433inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1, 434 const LogWeightTpl<float> &w2, 435 DivideType typ = DIVIDE_ANY) { 436 return Divide<float>(w1, w2, typ); 437} 438 439inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1, 440 const LogWeightTpl<double> &w2, 441 DivideType typ = DIVIDE_ANY) { 442 return Divide<double>(w1, w2, typ); 443} 444 445// MinMax semiring: (min, max, inf, -inf) 446template <class T> 447class MinMaxWeightTpl : public FloatWeightTpl<T> { 448 public: 449 using FloatWeightTpl<T>::Value; 450 451 typedef MinMaxWeightTpl<T> ReverseWeight; 452 453 MinMaxWeightTpl() : FloatWeightTpl<T>() {} 454 455 MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {} 456 457 MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} 458 459 static const MinMaxWeightTpl<T> Zero() { 460 return MinMaxWeightTpl<T>(FloatLimits<T>::PosInfinity()); 461 } 462 463 static const MinMaxWeightTpl<T> One() { 464 return MinMaxWeightTpl<T>(FloatLimits<T>::NegInfinity()); 465 } 466 467 static const MinMaxWeightTpl<T> NoWeight() { 468 return MinMaxWeightTpl<T>(FloatLimits<T>::NumberBad()); } 469 470 static const string &Type() { 471 static const string type = "minmax" + 472 FloatWeightTpl<T>::GetPrecisionString(); 473 return type; 474 } 475 476 bool Member() const { 477 // Fails for IEEE NaN 478 return Value() == Value(); 479 } 480 481 MinMaxWeightTpl<T> Quantize(float delta = kDelta) const { 482 // If one of infinities, or a NaN 483 if (Value() == FloatLimits<T>::NegInfinity() || 484 Value() == FloatLimits<T>::PosInfinity() || 485 Value() != Value()) 486 return *this; 487 else 488 return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); 489 } 490 491 MinMaxWeightTpl<T> Reverse() const { return *this; } 492 493 static uint64 Properties() { 494 return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath; 495 } 496}; 497 498// Single-precision min-max weight 499typedef MinMaxWeightTpl<float> MinMaxWeight; 500 501// Min 502template <class T> 503inline MinMaxWeightTpl<T> Plus( 504 const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) { 505 if (!w1.Member() || !w2.Member()) 506 return MinMaxWeightTpl<T>::NoWeight(); 507 return w1.Value() < w2.Value() ? w1 : w2; 508} 509 510inline MinMaxWeightTpl<float> Plus( 511 const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) { 512 return Plus<float>(w1, w2); 513} 514 515inline MinMaxWeightTpl<double> Plus( 516 const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) { 517 return Plus<double>(w1, w2); 518} 519 520// Max 521template <class T> 522inline MinMaxWeightTpl<T> Times( 523 const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) { 524 if (!w1.Member() || !w2.Member()) 525 return MinMaxWeightTpl<T>::NoWeight(); 526 return w1.Value() >= w2.Value() ? w1 : w2; 527} 528 529inline MinMaxWeightTpl<float> Times( 530 const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) { 531 return Times<float>(w1, w2); 532} 533 534inline MinMaxWeightTpl<double> Times( 535 const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) { 536 return Times<double>(w1, w2); 537} 538 539// Defined only for special cases 540template <class T> 541inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1, 542 const MinMaxWeightTpl<T> &w2, 543 DivideType typ = DIVIDE_ANY) { 544 if (!w1.Member() || !w2.Member()) 545 return MinMaxWeightTpl<T>::NoWeight(); 546 // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2 547 return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::NumberBad(); 548} 549 550inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1, 551 const MinMaxWeightTpl<float> &w2, 552 DivideType typ = DIVIDE_ANY) { 553 return Divide<float>(w1, w2, typ); 554} 555 556inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1, 557 const MinMaxWeightTpl<double> &w2, 558 DivideType typ = DIVIDE_ANY) { 559 return Divide<double>(w1, w2, typ); 560} 561 562// 563// WEIGHT CONVERTER SPECIALIZATIONS. 564// 565 566// Convert to tropical 567template <> 568struct WeightConvert<LogWeight, TropicalWeight> { 569 TropicalWeight operator()(LogWeight w) const { return w.Value(); } 570}; 571 572template <> 573struct WeightConvert<Log64Weight, TropicalWeight> { 574 TropicalWeight operator()(Log64Weight w) const { return w.Value(); } 575}; 576 577// Convert to log 578template <> 579struct WeightConvert<TropicalWeight, LogWeight> { 580 LogWeight operator()(TropicalWeight w) const { return w.Value(); } 581}; 582 583template <> 584struct WeightConvert<Log64Weight, LogWeight> { 585 LogWeight operator()(Log64Weight w) const { return w.Value(); } 586}; 587 588// Convert to log64 589template <> 590struct WeightConvert<TropicalWeight, Log64Weight> { 591 Log64Weight operator()(TropicalWeight w) const { return w.Value(); } 592}; 593 594template <> 595struct WeightConvert<LogWeight, Log64Weight> { 596 Log64Weight operator()(LogWeight w) const { return w.Value(); } 597}; 598 599} // namespace fst 600 601#endif // FST_LIB_FLOAT_WEIGHT_H__ 602