1// float-weight.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Float weight set and associated semiring operation definitions.
20//
21
22#ifndef FST_LIB_FLOAT_WEIGHT_H__
23#define FST_LIB_FLOAT_WEIGHT_H__
24
25#include <limits>
26#include <climits>
27#include <sstream>
28#include <string>
29
30#include <fst/util.h>
31#include <fst/weight.h>
32
33
34namespace fst {
35
36// numeric limits class
37template <class T>
38class FloatLimits {
39 public:
40  static const T PosInfinity() {
41    static const T pos_infinity = numeric_limits<T>::infinity();
42    return pos_infinity;
43  }
44
45  static const T NegInfinity() {
46    static const T neg_infinity = -PosInfinity();
47    return neg_infinity;
48  }
49
50  static const T NumberBad() {
51    static const T number_bad = numeric_limits<T>::quiet_NaN();
52    return number_bad;
53  }
54
55};
56
57// weight class to be templated on floating-points types
58template <class T = float>
59class FloatWeightTpl {
60 public:
61  FloatWeightTpl() {}
62
63  FloatWeightTpl(T f) : value_(f) {}
64
65  FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {}
66
67  FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) {
68    value_ = w.value_;
69    return *this;
70  }
71
72  istream &Read(istream &strm) {
73    return ReadType(strm, &value_);
74  }
75
76  ostream &Write(ostream &strm) const {
77    return WriteType(strm, value_);
78  }
79
80  size_t Hash() const {
81    union {
82      T f;
83      size_t s;
84    } u;
85    u.s = 0;
86    u.f = value_;
87    return u.s;
88  }
89
90  const T &Value() const { return value_; }
91
92 protected:
93  void SetValue(const T &f) { value_ = f; }
94
95  inline static string GetPrecisionString() {
96    int64 size = sizeof(T);
97    if (size == sizeof(float)) return "";
98    size *= CHAR_BIT;
99
100    string result;
101    Int64ToStr(size, &result);
102    return result;
103  }
104
105 private:
106  T value_;
107};
108
109// Single-precision float weight
110typedef FloatWeightTpl<float> FloatWeight;
111
112template <class T>
113inline bool operator==(const FloatWeightTpl<T> &w1,
114                       const FloatWeightTpl<T> &w2) {
115  // Volatile qualifier thwarts over-aggressive compiler optimizations
116  // that lead to problems esp. with NaturalLess().
117  volatile T v1 = w1.Value();
118  volatile T v2 = w2.Value();
119  return v1 == v2;
120}
121
122inline bool operator==(const FloatWeightTpl<double> &w1,
123                       const FloatWeightTpl<double> &w2) {
124  return operator==<double>(w1, w2);
125}
126
127inline bool operator==(const FloatWeightTpl<float> &w1,
128                       const FloatWeightTpl<float> &w2) {
129  return operator==<float>(w1, w2);
130}
131
132template <class T>
133inline bool operator!=(const FloatWeightTpl<T> &w1,
134                       const FloatWeightTpl<T> &w2) {
135  return !(w1 == w2);
136}
137
138inline bool operator!=(const FloatWeightTpl<double> &w1,
139                       const FloatWeightTpl<double> &w2) {
140  return operator!=<double>(w1, w2);
141}
142
143inline bool operator!=(const FloatWeightTpl<float> &w1,
144                       const FloatWeightTpl<float> &w2) {
145  return operator!=<float>(w1, w2);
146}
147
148template <class T>
149inline bool ApproxEqual(const FloatWeightTpl<T> &w1,
150                        const FloatWeightTpl<T> &w2,
151                        float delta = kDelta) {
152  return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
153}
154
155template <class T>
156inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) {
157  if (w.Value() == FloatLimits<T>::PosInfinity())
158    return strm << "Infinity";
159  else if (w.Value() == FloatLimits<T>::NegInfinity())
160    return strm << "-Infinity";
161  else if (w.Value() != w.Value())   // Fails for NaN
162    return strm << "BadNumber";
163  else
164    return strm << w.Value();
165}
166
167template <class T>
168inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) {
169  string s;
170  strm >> s;
171  if (s == "Infinity") {
172    w = FloatWeightTpl<T>(FloatLimits<T>::PosInfinity());
173  } else if (s == "-Infinity") {
174    w = FloatWeightTpl<T>(FloatLimits<T>::NegInfinity());
175  } else {
176    char *p;
177    T f = strtod(s.c_str(), &p);
178    if (p < s.c_str() + s.size())
179      strm.clear(std::ios::badbit);
180    else
181      w = FloatWeightTpl<T>(f);
182  }
183  return strm;
184}
185
186
187// Tropical semiring: (min, +, inf, 0)
188template <class T>
189class TropicalWeightTpl : public FloatWeightTpl<T> {
190 public:
191  using FloatWeightTpl<T>::Value;
192
193  typedef TropicalWeightTpl<T> ReverseWeight;
194
195  TropicalWeightTpl() : FloatWeightTpl<T>() {}
196
197  TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
198
199  TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
200
201  static const TropicalWeightTpl<T> Zero() {
202    return TropicalWeightTpl<T>(FloatLimits<T>::PosInfinity()); }
203
204  static const TropicalWeightTpl<T> One() {
205    return TropicalWeightTpl<T>(0.0F); }
206
207  static const TropicalWeightTpl<T> NoWeight() {
208    return TropicalWeightTpl<T>(FloatLimits<T>::NumberBad()); }
209
210  static const string &Type() {
211    static const string type = "tropical" +
212        FloatWeightTpl<T>::GetPrecisionString();
213    return type;
214  }
215
216  bool Member() const {
217    // First part fails for IEEE NaN
218    return Value() == Value() && Value() != FloatLimits<T>::NegInfinity();
219  }
220
221  TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
222    if (Value() == FloatLimits<T>::NegInfinity() ||
223        Value() == FloatLimits<T>::PosInfinity() ||
224        Value() != Value())
225      return *this;
226    else
227      return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
228  }
229
230  TropicalWeightTpl<T> Reverse() const { return *this; }
231
232  static uint64 Properties() {
233    return kLeftSemiring | kRightSemiring | kCommutative |
234        kPath | kIdempotent;
235  }
236};
237
238// Single precision tropical weight
239typedef TropicalWeightTpl<float> TropicalWeight;
240
241template <class T>
242inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1,
243                                 const TropicalWeightTpl<T> &w2) {
244  if (!w1.Member() || !w2.Member())
245    return TropicalWeightTpl<T>::NoWeight();
246  return w1.Value() < w2.Value() ? w1 : w2;
247}
248
249inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1,
250                                     const TropicalWeightTpl<float> &w2) {
251  return Plus<float>(w1, w2);
252}
253
254inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1,
255                                      const TropicalWeightTpl<double> &w2) {
256  return Plus<double>(w1, w2);
257}
258
259template <class T>
260inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1,
261                                  const TropicalWeightTpl<T> &w2) {
262  if (!w1.Member() || !w2.Member())
263    return TropicalWeightTpl<T>::NoWeight();
264  T f1 = w1.Value(), f2 = w2.Value();
265  if (f1 == FloatLimits<T>::PosInfinity())
266    return w1;
267  else if (f2 == FloatLimits<T>::PosInfinity())
268    return w2;
269  else
270    return TropicalWeightTpl<T>(f1 + f2);
271}
272
273inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1,
274                                      const TropicalWeightTpl<float> &w2) {
275  return Times<float>(w1, w2);
276}
277
278inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1,
279                                       const TropicalWeightTpl<double> &w2) {
280  return Times<double>(w1, w2);
281}
282
283template <class T>
284inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1,
285                                   const TropicalWeightTpl<T> &w2,
286                                   DivideType typ = DIVIDE_ANY) {
287  if (!w1.Member() || !w2.Member())
288    return TropicalWeightTpl<T>::NoWeight();
289  T f1 = w1.Value(), f2 = w2.Value();
290  if (f2 == FloatLimits<T>::PosInfinity())
291    return FloatLimits<T>::NumberBad();
292  else if (f1 == FloatLimits<T>::PosInfinity())
293    return FloatLimits<T>::PosInfinity();
294  else
295    return TropicalWeightTpl<T>(f1 - f2);
296}
297
298inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1,
299                                       const TropicalWeightTpl<float> &w2,
300                                       DivideType typ = DIVIDE_ANY) {
301  return Divide<float>(w1, w2, typ);
302}
303
304inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1,
305                                        const TropicalWeightTpl<double> &w2,
306                                        DivideType typ = DIVIDE_ANY) {
307  return Divide<double>(w1, w2, typ);
308}
309
310
311// Log semiring: (log(e^-x + e^y), +, inf, 0)
312template <class T>
313class LogWeightTpl : public FloatWeightTpl<T> {
314 public:
315  using FloatWeightTpl<T>::Value;
316
317  typedef LogWeightTpl ReverseWeight;
318
319  LogWeightTpl() : FloatWeightTpl<T>() {}
320
321  LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
322
323  LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
324
325  static const LogWeightTpl<T> Zero() {
326    return LogWeightTpl<T>(FloatLimits<T>::PosInfinity());
327  }
328
329  static const LogWeightTpl<T> One() {
330    return LogWeightTpl<T>(0.0F);
331  }
332
333  static const LogWeightTpl<T> NoWeight() {
334    return LogWeightTpl<T>(FloatLimits<T>::NumberBad()); }
335
336  static const string &Type() {
337    static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString();
338    return type;
339  }
340
341  bool Member() const {
342    // First part fails for IEEE NaN
343    return Value() == Value() && Value() != FloatLimits<T>::NegInfinity();
344  }
345
346  LogWeightTpl<T> Quantize(float delta = kDelta) const {
347    if (Value() == FloatLimits<T>::NegInfinity() ||
348        Value() == FloatLimits<T>::PosInfinity() ||
349        Value() != Value())
350      return *this;
351    else
352      return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
353  }
354
355  LogWeightTpl<T> Reverse() const { return *this; }
356
357  static uint64 Properties() {
358    return kLeftSemiring | kRightSemiring | kCommutative;
359  }
360};
361
362// Single-precision log weight
363typedef LogWeightTpl<float> LogWeight;
364// Double-precision log weight
365typedef LogWeightTpl<double> Log64Weight;
366
367template <class T>
368inline T LogExp(T x) { return log(1.0F + exp(-x)); }
369
370template <class T>
371inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1,
372                            const LogWeightTpl<T> &w2) {
373  T f1 = w1.Value(), f2 = w2.Value();
374  if (f1 == FloatLimits<T>::PosInfinity())
375    return w2;
376  else if (f2 == FloatLimits<T>::PosInfinity())
377    return w1;
378  else if (f1 > f2)
379    return LogWeightTpl<T>(f2 - LogExp(f1 - f2));
380  else
381    return LogWeightTpl<T>(f1 - LogExp(f2 - f1));
382}
383
384inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1,
385                                const LogWeightTpl<float> &w2) {
386  return Plus<float>(w1, w2);
387}
388
389inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1,
390                                 const LogWeightTpl<double> &w2) {
391  return Plus<double>(w1, w2);
392}
393
394template <class T>
395inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1,
396                             const LogWeightTpl<T> &w2) {
397  if (!w1.Member() || !w2.Member())
398    return LogWeightTpl<T>::NoWeight();
399  T f1 = w1.Value(), f2 = w2.Value();
400  if (f1 == FloatLimits<T>::PosInfinity())
401    return w1;
402  else if (f2 == FloatLimits<T>::PosInfinity())
403    return w2;
404  else
405    return LogWeightTpl<T>(f1 + f2);
406}
407
408inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1,
409                                 const LogWeightTpl<float> &w2) {
410  return Times<float>(w1, w2);
411}
412
413inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1,
414                                  const LogWeightTpl<double> &w2) {
415  return Times<double>(w1, w2);
416}
417
418template <class T>
419inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1,
420                              const LogWeightTpl<T> &w2,
421                              DivideType typ = DIVIDE_ANY) {
422  if (!w1.Member() || !w2.Member())
423    return LogWeightTpl<T>::NoWeight();
424  T f1 = w1.Value(), f2 = w2.Value();
425  if (f2 == FloatLimits<T>::PosInfinity())
426    return FloatLimits<T>::NumberBad();
427  else if (f1 == FloatLimits<T>::PosInfinity())
428    return FloatLimits<T>::PosInfinity();
429  else
430    return LogWeightTpl<T>(f1 - f2);
431}
432
433inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1,
434                                  const LogWeightTpl<float> &w2,
435                                  DivideType typ = DIVIDE_ANY) {
436  return Divide<float>(w1, w2, typ);
437}
438
439inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1,
440                                   const LogWeightTpl<double> &w2,
441                                   DivideType typ = DIVIDE_ANY) {
442  return Divide<double>(w1, w2, typ);
443}
444
445// MinMax semiring: (min, max, inf, -inf)
446template <class T>
447class MinMaxWeightTpl : public FloatWeightTpl<T> {
448 public:
449  using FloatWeightTpl<T>::Value;
450
451  typedef MinMaxWeightTpl<T> ReverseWeight;
452
453  MinMaxWeightTpl() : FloatWeightTpl<T>() {}
454
455  MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {}
456
457  MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
458
459  static const MinMaxWeightTpl<T> Zero() {
460    return MinMaxWeightTpl<T>(FloatLimits<T>::PosInfinity());
461  }
462
463  static const MinMaxWeightTpl<T> One() {
464    return MinMaxWeightTpl<T>(FloatLimits<T>::NegInfinity());
465  }
466
467  static const MinMaxWeightTpl<T> NoWeight() {
468    return MinMaxWeightTpl<T>(FloatLimits<T>::NumberBad()); }
469
470  static const string &Type() {
471    static const string type = "minmax" +
472        FloatWeightTpl<T>::GetPrecisionString();
473    return type;
474  }
475
476  bool Member() const {
477    // Fails for IEEE NaN
478    return Value() == Value();
479  }
480
481  MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
482    // If one of infinities, or a NaN
483    if (Value() == FloatLimits<T>::NegInfinity() ||
484        Value() == FloatLimits<T>::PosInfinity() ||
485        Value() != Value())
486      return *this;
487    else
488      return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
489  }
490
491  MinMaxWeightTpl<T> Reverse() const { return *this; }
492
493  static uint64 Properties() {
494    return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath;
495  }
496};
497
498// Single-precision min-max weight
499typedef MinMaxWeightTpl<float> MinMaxWeight;
500
501// Min
502template <class T>
503inline MinMaxWeightTpl<T> Plus(
504    const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
505  if (!w1.Member() || !w2.Member())
506    return MinMaxWeightTpl<T>::NoWeight();
507  return w1.Value() < w2.Value() ? w1 : w2;
508}
509
510inline MinMaxWeightTpl<float> Plus(
511    const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
512  return Plus<float>(w1, w2);
513}
514
515inline MinMaxWeightTpl<double> Plus(
516    const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
517  return Plus<double>(w1, w2);
518}
519
520// Max
521template <class T>
522inline MinMaxWeightTpl<T> Times(
523    const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
524  if (!w1.Member() || !w2.Member())
525    return MinMaxWeightTpl<T>::NoWeight();
526  return w1.Value() >= w2.Value() ? w1 : w2;
527}
528
529inline MinMaxWeightTpl<float> Times(
530    const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
531  return Times<float>(w1, w2);
532}
533
534inline MinMaxWeightTpl<double> Times(
535    const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
536  return Times<double>(w1, w2);
537}
538
539// Defined only for special cases
540template <class T>
541inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1,
542                                 const MinMaxWeightTpl<T> &w2,
543                                 DivideType typ = DIVIDE_ANY) {
544  if (!w1.Member() || !w2.Member())
545    return MinMaxWeightTpl<T>::NoWeight();
546  // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2
547  return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::NumberBad();
548}
549
550inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1,
551                                     const MinMaxWeightTpl<float> &w2,
552                                     DivideType typ = DIVIDE_ANY) {
553  return Divide<float>(w1, w2, typ);
554}
555
556inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1,
557                                      const MinMaxWeightTpl<double> &w2,
558                                      DivideType typ = DIVIDE_ANY) {
559  return Divide<double>(w1, w2, typ);
560}
561
562//
563// WEIGHT CONVERTER SPECIALIZATIONS.
564//
565
566// Convert to tropical
567template <>
568struct WeightConvert<LogWeight, TropicalWeight> {
569  TropicalWeight operator()(LogWeight w) const { return w.Value(); }
570};
571
572template <>
573struct WeightConvert<Log64Weight, TropicalWeight> {
574  TropicalWeight operator()(Log64Weight w) const { return w.Value(); }
575};
576
577// Convert to log
578template <>
579struct WeightConvert<TropicalWeight, LogWeight> {
580  LogWeight operator()(TropicalWeight w) const { return w.Value(); }
581};
582
583template <>
584struct WeightConvert<Log64Weight, LogWeight> {
585  LogWeight operator()(Log64Weight w) const { return w.Value(); }
586};
587
588// Convert to log64
589template <>
590struct WeightConvert<TropicalWeight, Log64Weight> {
591  Log64Weight operator()(TropicalWeight w) const { return w.Value(); }
592};
593
594template <>
595struct WeightConvert<LogWeight, Log64Weight> {
596  Log64Weight operator()(LogWeight w) const { return w.Value(); }
597};
598
599}  // namespace fst
600
601#endif  // FST_LIB_FLOAT_WEIGHT_H__
602