1// float-weight.h 2 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// Copyright 2005-2010 Google, Inc. 16// Author: riley@google.com (Michael Riley) 17// 18// \file 19// Float weight set and associated semiring operation definitions. 20// 21 22#ifndef FST_LIB_FLOAT_WEIGHT_H__ 23#define FST_LIB_FLOAT_WEIGHT_H__ 24 25#include <limits> 26#include <climits> 27#include <sstream> 28#include <string> 29 30#include <fst/util.h> 31#include <fst/weight.h> 32 33 34namespace fst { 35 36// numeric limits class 37template <class T> 38class FloatLimits { 39 public: 40 static const T kPosInfinity; 41 static const T kNegInfinity; 42 static const T kNumberBad; 43}; 44 45template <class T> 46const T FloatLimits<T>::kPosInfinity = numeric_limits<T>::infinity(); 47 48template <class T> 49const T FloatLimits<T>::kNegInfinity = -FloatLimits<T>::kPosInfinity; 50 51template <class T> 52const T FloatLimits<T>::kNumberBad = numeric_limits<T>::quiet_NaN(); 53 54// weight class to be templated on floating-points types 55template <class T = float> 56class FloatWeightTpl { 57 public: 58 FloatWeightTpl() {} 59 60 FloatWeightTpl(T f) : value_(f) {} 61 62 FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {} 63 64 FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) { 65 value_ = w.value_; 66 return *this; 67 } 68 69 istream &Read(istream &strm) { 70 return ReadType(strm, &value_); 71 } 72 73 ostream &Write(ostream &strm) const { 74 return WriteType(strm, value_); 75 } 76 77 size_t Hash() const { 78 union { 79 T f; 80 size_t s; 81 } u; 82 u.s = 0; 83 u.f = value_; 84 return u.s; 85 } 86 87 const T &Value() const { return value_; } 88 89 protected: 90 void SetValue(const T &f) { value_ = f; } 91 92 inline static string GetPrecisionString() { 93 int64 size = sizeof(T); 94 if (size == sizeof(float)) return ""; 95 size *= CHAR_BIT; 96 97 string result; 98 Int64ToStr(size, &result); 99 return result; 100 } 101 102 private: 103 T value_; 104}; 105 106// Single-precision float weight 107typedef FloatWeightTpl<float> FloatWeight; 108 109template <class T> 110inline bool operator==(const FloatWeightTpl<T> &w1, 111 const FloatWeightTpl<T> &w2) { 112 // Volatile qualifier thwarts over-aggressive compiler optimizations 113 // that lead to problems esp. with NaturalLess(). 114 volatile T v1 = w1.Value(); 115 volatile T v2 = w2.Value(); 116 return v1 == v2; 117} 118 119inline bool operator==(const FloatWeightTpl<double> &w1, 120 const FloatWeightTpl<double> &w2) { 121 return operator==<double>(w1, w2); 122} 123 124inline bool operator==(const FloatWeightTpl<float> &w1, 125 const FloatWeightTpl<float> &w2) { 126 return operator==<float>(w1, w2); 127} 128 129template <class T> 130inline bool operator!=(const FloatWeightTpl<T> &w1, 131 const FloatWeightTpl<T> &w2) { 132 return !(w1 == w2); 133} 134 135inline bool operator!=(const FloatWeightTpl<double> &w1, 136 const FloatWeightTpl<double> &w2) { 137 return operator!=<double>(w1, w2); 138} 139 140inline bool operator!=(const FloatWeightTpl<float> &w1, 141 const FloatWeightTpl<float> &w2) { 142 return operator!=<float>(w1, w2); 143} 144 145template <class T> 146inline bool ApproxEqual(const FloatWeightTpl<T> &w1, 147 const FloatWeightTpl<T> &w2, 148 float delta = kDelta) { 149 return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta; 150} 151 152template <class T> 153inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) { 154 if (w.Value() == FloatLimits<T>::kPosInfinity) 155 return strm << "Infinity"; 156 else if (w.Value() == FloatLimits<T>::kNegInfinity) 157 return strm << "-Infinity"; 158 else if (w.Value() != w.Value()) // Fails for NaN 159 return strm << "BadNumber"; 160 else 161 return strm << w.Value(); 162} 163 164template <class T> 165inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) { 166 string s; 167 strm >> s; 168 if (s == "Infinity") { 169 w = FloatWeightTpl<T>(FloatLimits<T>::kPosInfinity); 170 } else if (s == "-Infinity") { 171 w = FloatWeightTpl<T>(FloatLimits<T>::kNegInfinity); 172 } else { 173 char *p; 174 T f = strtod(s.c_str(), &p); 175 if (p < s.c_str() + s.size()) 176 strm.clear(std::ios::badbit); 177 else 178 w = FloatWeightTpl<T>(f); 179 } 180 return strm; 181} 182 183 184// Tropical semiring: (min, +, inf, 0) 185template <class T> 186class TropicalWeightTpl : public FloatWeightTpl<T> { 187 public: 188 using FloatWeightTpl<T>::Value; 189 190 typedef TropicalWeightTpl<T> ReverseWeight; 191 192 TropicalWeightTpl() : FloatWeightTpl<T>() {} 193 194 TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {} 195 196 TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} 197 198 static const TropicalWeightTpl<T> Zero() { 199 return TropicalWeightTpl<T>(FloatLimits<T>::kPosInfinity); } 200 201 static const TropicalWeightTpl<T> One() { 202 return TropicalWeightTpl<T>(0.0F); } 203 204 static const TropicalWeightTpl<T> NoWeight() { 205 return TropicalWeightTpl<T>(FloatLimits<T>::kNumberBad); } 206 207 static const string &Type() { 208 static const string type = "tropical" + 209 FloatWeightTpl<T>::GetPrecisionString(); 210 return type; 211 } 212 213 bool Member() const { 214 // First part fails for IEEE NaN 215 return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity; 216 } 217 218 TropicalWeightTpl<T> Quantize(float delta = kDelta) const { 219 if (Value() == FloatLimits<T>::kNegInfinity || 220 Value() == FloatLimits<T>::kPosInfinity || 221 Value() != Value()) 222 return *this; 223 else 224 return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); 225 } 226 227 TropicalWeightTpl<T> Reverse() const { return *this; } 228 229 static uint64 Properties() { 230 return kLeftSemiring | kRightSemiring | kCommutative | 231 kPath | kIdempotent; 232 } 233}; 234 235// Single precision tropical weight 236typedef TropicalWeightTpl<float> TropicalWeight; 237 238template <class T> 239inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1, 240 const TropicalWeightTpl<T> &w2) { 241 if (!w1.Member() || !w2.Member()) 242 return TropicalWeightTpl<T>::NoWeight(); 243 return w1.Value() < w2.Value() ? w1 : w2; 244} 245 246inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1, 247 const TropicalWeightTpl<float> &w2) { 248 return Plus<float>(w1, w2); 249} 250 251inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1, 252 const TropicalWeightTpl<double> &w2) { 253 return Plus<double>(w1, w2); 254} 255 256template <class T> 257inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1, 258 const TropicalWeightTpl<T> &w2) { 259 if (!w1.Member() || !w2.Member()) 260 return TropicalWeightTpl<T>::NoWeight(); 261 T f1 = w1.Value(), f2 = w2.Value(); 262 if (f1 == FloatLimits<T>::kPosInfinity) 263 return w1; 264 else if (f2 == FloatLimits<T>::kPosInfinity) 265 return w2; 266 else 267 return TropicalWeightTpl<T>(f1 + f2); 268} 269 270inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1, 271 const TropicalWeightTpl<float> &w2) { 272 return Times<float>(w1, w2); 273} 274 275inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1, 276 const TropicalWeightTpl<double> &w2) { 277 return Times<double>(w1, w2); 278} 279 280template <class T> 281inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1, 282 const TropicalWeightTpl<T> &w2, 283 DivideType typ = DIVIDE_ANY) { 284 if (!w1.Member() || !w2.Member()) 285 return TropicalWeightTpl<T>::NoWeight(); 286 T f1 = w1.Value(), f2 = w2.Value(); 287 if (f2 == FloatLimits<T>::kPosInfinity) 288 return FloatLimits<T>::kNumberBad; 289 else if (f1 == FloatLimits<T>::kPosInfinity) 290 return FloatLimits<T>::kPosInfinity; 291 else 292 return TropicalWeightTpl<T>(f1 - f2); 293} 294 295inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1, 296 const TropicalWeightTpl<float> &w2, 297 DivideType typ = DIVIDE_ANY) { 298 return Divide<float>(w1, w2, typ); 299} 300 301inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1, 302 const TropicalWeightTpl<double> &w2, 303 DivideType typ = DIVIDE_ANY) { 304 return Divide<double>(w1, w2, typ); 305} 306 307 308// Log semiring: (log(e^-x + e^y), +, inf, 0) 309template <class T> 310class LogWeightTpl : public FloatWeightTpl<T> { 311 public: 312 using FloatWeightTpl<T>::Value; 313 314 typedef LogWeightTpl ReverseWeight; 315 316 LogWeightTpl() : FloatWeightTpl<T>() {} 317 318 LogWeightTpl(T f) : FloatWeightTpl<T>(f) {} 319 320 LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} 321 322 static const LogWeightTpl<T> Zero() { 323 return LogWeightTpl<T>(FloatLimits<T>::kPosInfinity); 324 } 325 326 static const LogWeightTpl<T> One() { 327 return LogWeightTpl<T>(0.0F); 328 } 329 330 static const LogWeightTpl<T> NoWeight() { 331 return LogWeightTpl<T>(FloatLimits<T>::kNumberBad); } 332 333 static const string &Type() { 334 static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString(); 335 return type; 336 } 337 338 bool Member() const { 339 // First part fails for IEEE NaN 340 return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity; 341 } 342 343 LogWeightTpl<T> Quantize(float delta = kDelta) const { 344 if (Value() == FloatLimits<T>::kNegInfinity || 345 Value() == FloatLimits<T>::kPosInfinity || 346 Value() != Value()) 347 return *this; 348 else 349 return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); 350 } 351 352 LogWeightTpl<T> Reverse() const { return *this; } 353 354 static uint64 Properties() { 355 return kLeftSemiring | kRightSemiring | kCommutative; 356 } 357}; 358 359// Single-precision log weight 360typedef LogWeightTpl<float> LogWeight; 361// Double-precision log weight 362typedef LogWeightTpl<double> Log64Weight; 363 364template <class T> 365inline T LogExp(T x) { return log(1.0F + exp(-x)); } 366 367template <class T> 368inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1, 369 const LogWeightTpl<T> &w2) { 370 T f1 = w1.Value(), f2 = w2.Value(); 371 if (f1 == FloatLimits<T>::kPosInfinity) 372 return w2; 373 else if (f2 == FloatLimits<T>::kPosInfinity) 374 return w1; 375 else if (f1 > f2) 376 return LogWeightTpl<T>(f2 - LogExp(f1 - f2)); 377 else 378 return LogWeightTpl<T>(f1 - LogExp(f2 - f1)); 379} 380 381inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1, 382 const LogWeightTpl<float> &w2) { 383 return Plus<float>(w1, w2); 384} 385 386inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1, 387 const LogWeightTpl<double> &w2) { 388 return Plus<double>(w1, w2); 389} 390 391template <class T> 392inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1, 393 const LogWeightTpl<T> &w2) { 394 if (!w1.Member() || !w2.Member()) 395 return LogWeightTpl<T>::NoWeight(); 396 T f1 = w1.Value(), f2 = w2.Value(); 397 if (f1 == FloatLimits<T>::kPosInfinity) 398 return w1; 399 else if (f2 == FloatLimits<T>::kPosInfinity) 400 return w2; 401 else 402 return LogWeightTpl<T>(f1 + f2); 403} 404 405inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1, 406 const LogWeightTpl<float> &w2) { 407 return Times<float>(w1, w2); 408} 409 410inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1, 411 const LogWeightTpl<double> &w2) { 412 return Times<double>(w1, w2); 413} 414 415template <class T> 416inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1, 417 const LogWeightTpl<T> &w2, 418 DivideType typ = DIVIDE_ANY) { 419 if (!w1.Member() || !w2.Member()) 420 return LogWeightTpl<T>::NoWeight(); 421 T f1 = w1.Value(), f2 = w2.Value(); 422 if (f2 == FloatLimits<T>::kPosInfinity) 423 return FloatLimits<T>::kNumberBad; 424 else if (f1 == FloatLimits<T>::kPosInfinity) 425 return FloatLimits<T>::kPosInfinity; 426 else 427 return LogWeightTpl<T>(f1 - f2); 428} 429 430inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1, 431 const LogWeightTpl<float> &w2, 432 DivideType typ = DIVIDE_ANY) { 433 return Divide<float>(w1, w2, typ); 434} 435 436inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1, 437 const LogWeightTpl<double> &w2, 438 DivideType typ = DIVIDE_ANY) { 439 return Divide<double>(w1, w2, typ); 440} 441 442// MinMax semiring: (min, max, inf, -inf) 443template <class T> 444class MinMaxWeightTpl : public FloatWeightTpl<T> { 445 public: 446 using FloatWeightTpl<T>::Value; 447 448 typedef MinMaxWeightTpl<T> ReverseWeight; 449 450 MinMaxWeightTpl() : FloatWeightTpl<T>() {} 451 452 MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {} 453 454 MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {} 455 456 static const MinMaxWeightTpl<T> Zero() { 457 return MinMaxWeightTpl<T>(FloatLimits<T>::kPosInfinity); 458 } 459 460 static const MinMaxWeightTpl<T> One() { 461 return MinMaxWeightTpl<T>(FloatLimits<T>::kNegInfinity); 462 } 463 464 static const MinMaxWeightTpl<T> NoWeight() { 465 return MinMaxWeightTpl<T>(FloatLimits<T>::kNumberBad); } 466 467 static const string &Type() { 468 static const string type = "minmax" + 469 FloatWeightTpl<T>::GetPrecisionString(); 470 return type; 471 } 472 473 bool Member() const { 474 // Fails for IEEE NaN 475 return Value() == Value(); 476 } 477 478 MinMaxWeightTpl<T> Quantize(float delta = kDelta) const { 479 // If one of infinities, or a NaN 480 if (Value() == FloatLimits<T>::kNegInfinity || 481 Value() == FloatLimits<T>::kPosInfinity || 482 Value() != Value()) 483 return *this; 484 else 485 return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta); 486 } 487 488 MinMaxWeightTpl<T> Reverse() const { return *this; } 489 490 static uint64 Properties() { 491 return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath; 492 } 493}; 494 495// Single-precision min-max weight 496typedef MinMaxWeightTpl<float> MinMaxWeight; 497 498// Min 499template <class T> 500inline MinMaxWeightTpl<T> Plus( 501 const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) { 502 if (!w1.Member() || !w2.Member()) 503 return MinMaxWeightTpl<T>::NoWeight(); 504 return w1.Value() < w2.Value() ? w1 : w2; 505} 506 507inline MinMaxWeightTpl<float> Plus( 508 const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) { 509 return Plus<float>(w1, w2); 510} 511 512inline MinMaxWeightTpl<double> Plus( 513 const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) { 514 return Plus<double>(w1, w2); 515} 516 517// Max 518template <class T> 519inline MinMaxWeightTpl<T> Times( 520 const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) { 521 if (!w1.Member() || !w2.Member()) 522 return MinMaxWeightTpl<T>::NoWeight(); 523 return w1.Value() >= w2.Value() ? w1 : w2; 524} 525 526inline MinMaxWeightTpl<float> Times( 527 const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) { 528 return Times<float>(w1, w2); 529} 530 531inline MinMaxWeightTpl<double> Times( 532 const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) { 533 return Times<double>(w1, w2); 534} 535 536// Defined only for special cases 537template <class T> 538inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1, 539 const MinMaxWeightTpl<T> &w2, 540 DivideType typ = DIVIDE_ANY) { 541 if (!w1.Member() || !w2.Member()) 542 return MinMaxWeightTpl<T>::NoWeight(); 543 // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2 544 return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::kNumberBad; 545} 546 547inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1, 548 const MinMaxWeightTpl<float> &w2, 549 DivideType typ = DIVIDE_ANY) { 550 return Divide<float>(w1, w2, typ); 551} 552 553inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1, 554 const MinMaxWeightTpl<double> &w2, 555 DivideType typ = DIVIDE_ANY) { 556 return Divide<double>(w1, w2, typ); 557} 558 559// 560// WEIGHT CONVERTER SPECIALIZATIONS. 561// 562 563// Convert to tropical 564template <> 565struct WeightConvert<LogWeight, TropicalWeight> { 566 TropicalWeight operator()(LogWeight w) const { return w.Value(); } 567}; 568 569template <> 570struct WeightConvert<Log64Weight, TropicalWeight> { 571 TropicalWeight operator()(Log64Weight w) const { return w.Value(); } 572}; 573 574// Convert to log 575template <> 576struct WeightConvert<TropicalWeight, LogWeight> { 577 LogWeight operator()(TropicalWeight w) const { return w.Value(); } 578}; 579 580template <> 581struct WeightConvert<Log64Weight, LogWeight> { 582 LogWeight operator()(Log64Weight w) const { return w.Value(); } 583}; 584 585// Convert to log64 586template <> 587struct WeightConvert<TropicalWeight, Log64Weight> { 588 Log64Weight operator()(TropicalWeight w) const { return w.Value(); } 589}; 590 591template <> 592struct WeightConvert<LogWeight, Log64Weight> { 593 Log64Weight operator()(LogWeight w) const { return w.Value(); } 594}; 595 596} // namespace fst 597 598#endif // FST_LIB_FLOAT_WEIGHT_H__ 599