float-weight.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
1// float-weight.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Float weight set and associated semiring operation definitions.
20//
21
22#ifndef FST_LIB_FLOAT_WEIGHT_H__
23#define FST_LIB_FLOAT_WEIGHT_H__
24
25#include <limits>
26#include <climits>
27#include <sstream>
28#include <string>
29
30#include <fst/util.h>
31#include <fst/weight.h>
32
33
34namespace fst {
35
36// numeric limits class
37template <class T>
38class FloatLimits {
39 public:
40  static const T kPosInfinity;
41  static const T kNegInfinity;
42  static const T kNumberBad;
43};
44
45template <class T>
46const T FloatLimits<T>::kPosInfinity = numeric_limits<T>::infinity();
47
48template <class T>
49const T FloatLimits<T>::kNegInfinity = -FloatLimits<T>::kPosInfinity;
50
51template <class T>
52const T FloatLimits<T>::kNumberBad = numeric_limits<T>::quiet_NaN();
53
54// weight class to be templated on floating-points types
55template <class T = float>
56class FloatWeightTpl {
57 public:
58  FloatWeightTpl() {}
59
60  FloatWeightTpl(T f) : value_(f) {}
61
62  FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {}
63
64  FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) {
65    value_ = w.value_;
66    return *this;
67  }
68
69  istream &Read(istream &strm) {
70    return ReadType(strm, &value_);
71  }
72
73  ostream &Write(ostream &strm) const {
74    return WriteType(strm, value_);
75  }
76
77  size_t Hash() const {
78    union {
79      T f;
80      size_t s;
81    } u;
82    u.s = 0;
83    u.f = value_;
84    return u.s;
85  }
86
87  const T &Value() const { return value_; }
88
89 protected:
90  void SetValue(const T &f) { value_ = f; }
91
92  inline static string GetPrecisionString() {
93    int64 size = sizeof(T);
94    if (size == sizeof(float)) return "";
95    size *= CHAR_BIT;
96
97    string result;
98    Int64ToStr(size, &result);
99    return result;
100  }
101
102 private:
103  T value_;
104};
105
106// Single-precision float weight
107typedef FloatWeightTpl<float> FloatWeight;
108
109template <class T>
110inline bool operator==(const FloatWeightTpl<T> &w1,
111                       const FloatWeightTpl<T> &w2) {
112  // Volatile qualifier thwarts over-aggressive compiler optimizations
113  // that lead to problems esp. with NaturalLess().
114  volatile T v1 = w1.Value();
115  volatile T v2 = w2.Value();
116  return v1 == v2;
117}
118
119inline bool operator==(const FloatWeightTpl<double> &w1,
120                       const FloatWeightTpl<double> &w2) {
121  return operator==<double>(w1, w2);
122}
123
124inline bool operator==(const FloatWeightTpl<float> &w1,
125                       const FloatWeightTpl<float> &w2) {
126  return operator==<float>(w1, w2);
127}
128
129template <class T>
130inline bool operator!=(const FloatWeightTpl<T> &w1,
131                       const FloatWeightTpl<T> &w2) {
132  return !(w1 == w2);
133}
134
135inline bool operator!=(const FloatWeightTpl<double> &w1,
136                       const FloatWeightTpl<double> &w2) {
137  return operator!=<double>(w1, w2);
138}
139
140inline bool operator!=(const FloatWeightTpl<float> &w1,
141                       const FloatWeightTpl<float> &w2) {
142  return operator!=<float>(w1, w2);
143}
144
145template <class T>
146inline bool ApproxEqual(const FloatWeightTpl<T> &w1,
147                        const FloatWeightTpl<T> &w2,
148                        float delta = kDelta) {
149  return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
150}
151
152template <class T>
153inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) {
154  if (w.Value() == FloatLimits<T>::kPosInfinity)
155    return strm << "Infinity";
156  else if (w.Value() == FloatLimits<T>::kNegInfinity)
157    return strm << "-Infinity";
158  else if (w.Value() != w.Value())   // Fails for NaN
159    return strm << "BadNumber";
160  else
161    return strm << w.Value();
162}
163
164template <class T>
165inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) {
166  string s;
167  strm >> s;
168  if (s == "Infinity") {
169    w = FloatWeightTpl<T>(FloatLimits<T>::kPosInfinity);
170  } else if (s == "-Infinity") {
171    w = FloatWeightTpl<T>(FloatLimits<T>::kNegInfinity);
172  } else {
173    char *p;
174    T f = strtod(s.c_str(), &p);
175    if (p < s.c_str() + s.size())
176      strm.clear(std::ios::badbit);
177    else
178      w = FloatWeightTpl<T>(f);
179  }
180  return strm;
181}
182
183
184// Tropical semiring: (min, +, inf, 0)
185template <class T>
186class TropicalWeightTpl : public FloatWeightTpl<T> {
187 public:
188  using FloatWeightTpl<T>::Value;
189
190  typedef TropicalWeightTpl<T> ReverseWeight;
191
192  TropicalWeightTpl() : FloatWeightTpl<T>() {}
193
194  TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
195
196  TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
197
198  static const TropicalWeightTpl<T> Zero() {
199    return TropicalWeightTpl<T>(FloatLimits<T>::kPosInfinity); }
200
201  static const TropicalWeightTpl<T> One() {
202    return TropicalWeightTpl<T>(0.0F); }
203
204  static const TropicalWeightTpl<T> NoWeight() {
205    return TropicalWeightTpl<T>(FloatLimits<T>::kNumberBad); }
206
207  static const string &Type() {
208    static const string type = "tropical" +
209        FloatWeightTpl<T>::GetPrecisionString();
210    return type;
211  }
212
213  bool Member() const {
214    // First part fails for IEEE NaN
215    return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity;
216  }
217
218  TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
219    if (Value() == FloatLimits<T>::kNegInfinity ||
220        Value() == FloatLimits<T>::kPosInfinity ||
221        Value() != Value())
222      return *this;
223    else
224      return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
225  }
226
227  TropicalWeightTpl<T> Reverse() const { return *this; }
228
229  static uint64 Properties() {
230    return kLeftSemiring | kRightSemiring | kCommutative |
231        kPath | kIdempotent;
232  }
233};
234
235// Single precision tropical weight
236typedef TropicalWeightTpl<float> TropicalWeight;
237
238template <class T>
239inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1,
240                                 const TropicalWeightTpl<T> &w2) {
241  if (!w1.Member() || !w2.Member())
242    return TropicalWeightTpl<T>::NoWeight();
243  return w1.Value() < w2.Value() ? w1 : w2;
244}
245
246inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1,
247                                     const TropicalWeightTpl<float> &w2) {
248  return Plus<float>(w1, w2);
249}
250
251inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1,
252                                      const TropicalWeightTpl<double> &w2) {
253  return Plus<double>(w1, w2);
254}
255
256template <class T>
257inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1,
258                                  const TropicalWeightTpl<T> &w2) {
259  if (!w1.Member() || !w2.Member())
260    return TropicalWeightTpl<T>::NoWeight();
261  T f1 = w1.Value(), f2 = w2.Value();
262  if (f1 == FloatLimits<T>::kPosInfinity)
263    return w1;
264  else if (f2 == FloatLimits<T>::kPosInfinity)
265    return w2;
266  else
267    return TropicalWeightTpl<T>(f1 + f2);
268}
269
270inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1,
271                                      const TropicalWeightTpl<float> &w2) {
272  return Times<float>(w1, w2);
273}
274
275inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1,
276                                       const TropicalWeightTpl<double> &w2) {
277  return Times<double>(w1, w2);
278}
279
280template <class T>
281inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1,
282                                   const TropicalWeightTpl<T> &w2,
283                                   DivideType typ = DIVIDE_ANY) {
284  if (!w1.Member() || !w2.Member())
285    return TropicalWeightTpl<T>::NoWeight();
286  T f1 = w1.Value(), f2 = w2.Value();
287  if (f2 == FloatLimits<T>::kPosInfinity)
288    return FloatLimits<T>::kNumberBad;
289  else if (f1 == FloatLimits<T>::kPosInfinity)
290    return FloatLimits<T>::kPosInfinity;
291  else
292    return TropicalWeightTpl<T>(f1 - f2);
293}
294
295inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1,
296                                       const TropicalWeightTpl<float> &w2,
297                                       DivideType typ = DIVIDE_ANY) {
298  return Divide<float>(w1, w2, typ);
299}
300
301inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1,
302                                        const TropicalWeightTpl<double> &w2,
303                                        DivideType typ = DIVIDE_ANY) {
304  return Divide<double>(w1, w2, typ);
305}
306
307
308// Log semiring: (log(e^-x + e^y), +, inf, 0)
309template <class T>
310class LogWeightTpl : public FloatWeightTpl<T> {
311 public:
312  using FloatWeightTpl<T>::Value;
313
314  typedef LogWeightTpl ReverseWeight;
315
316  LogWeightTpl() : FloatWeightTpl<T>() {}
317
318  LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
319
320  LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
321
322  static const LogWeightTpl<T> Zero() {
323    return LogWeightTpl<T>(FloatLimits<T>::kPosInfinity);
324  }
325
326  static const LogWeightTpl<T> One() {
327    return LogWeightTpl<T>(0.0F);
328  }
329
330  static const LogWeightTpl<T> NoWeight() {
331    return LogWeightTpl<T>(FloatLimits<T>::kNumberBad); }
332
333  static const string &Type() {
334    static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString();
335    return type;
336  }
337
338  bool Member() const {
339    // First part fails for IEEE NaN
340    return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity;
341  }
342
343  LogWeightTpl<T> Quantize(float delta = kDelta) const {
344    if (Value() == FloatLimits<T>::kNegInfinity ||
345        Value() == FloatLimits<T>::kPosInfinity ||
346        Value() != Value())
347      return *this;
348    else
349      return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
350  }
351
352  LogWeightTpl<T> Reverse() const { return *this; }
353
354  static uint64 Properties() {
355    return kLeftSemiring | kRightSemiring | kCommutative;
356  }
357};
358
359// Single-precision log weight
360typedef LogWeightTpl<float> LogWeight;
361// Double-precision log weight
362typedef LogWeightTpl<double> Log64Weight;
363
364template <class T>
365inline T LogExp(T x) { return log(1.0F + exp(-x)); }
366
367template <class T>
368inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1,
369                            const LogWeightTpl<T> &w2) {
370  T f1 = w1.Value(), f2 = w2.Value();
371  if (f1 == FloatLimits<T>::kPosInfinity)
372    return w2;
373  else if (f2 == FloatLimits<T>::kPosInfinity)
374    return w1;
375  else if (f1 > f2)
376    return LogWeightTpl<T>(f2 - LogExp(f1 - f2));
377  else
378    return LogWeightTpl<T>(f1 - LogExp(f2 - f1));
379}
380
381inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1,
382                                const LogWeightTpl<float> &w2) {
383  return Plus<float>(w1, w2);
384}
385
386inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1,
387                                 const LogWeightTpl<double> &w2) {
388  return Plus<double>(w1, w2);
389}
390
391template <class T>
392inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1,
393                             const LogWeightTpl<T> &w2) {
394  if (!w1.Member() || !w2.Member())
395    return LogWeightTpl<T>::NoWeight();
396  T f1 = w1.Value(), f2 = w2.Value();
397  if (f1 == FloatLimits<T>::kPosInfinity)
398    return w1;
399  else if (f2 == FloatLimits<T>::kPosInfinity)
400    return w2;
401  else
402    return LogWeightTpl<T>(f1 + f2);
403}
404
405inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1,
406                                 const LogWeightTpl<float> &w2) {
407  return Times<float>(w1, w2);
408}
409
410inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1,
411                                  const LogWeightTpl<double> &w2) {
412  return Times<double>(w1, w2);
413}
414
415template <class T>
416inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1,
417                              const LogWeightTpl<T> &w2,
418                              DivideType typ = DIVIDE_ANY) {
419  if (!w1.Member() || !w2.Member())
420    return LogWeightTpl<T>::NoWeight();
421  T f1 = w1.Value(), f2 = w2.Value();
422  if (f2 == FloatLimits<T>::kPosInfinity)
423    return FloatLimits<T>::kNumberBad;
424  else if (f1 == FloatLimits<T>::kPosInfinity)
425    return FloatLimits<T>::kPosInfinity;
426  else
427    return LogWeightTpl<T>(f1 - f2);
428}
429
430inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1,
431                                  const LogWeightTpl<float> &w2,
432                                  DivideType typ = DIVIDE_ANY) {
433  return Divide<float>(w1, w2, typ);
434}
435
436inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1,
437                                   const LogWeightTpl<double> &w2,
438                                   DivideType typ = DIVIDE_ANY) {
439  return Divide<double>(w1, w2, typ);
440}
441
442// MinMax semiring: (min, max, inf, -inf)
443template <class T>
444class MinMaxWeightTpl : public FloatWeightTpl<T> {
445 public:
446  using FloatWeightTpl<T>::Value;
447
448  typedef MinMaxWeightTpl<T> ReverseWeight;
449
450  MinMaxWeightTpl() : FloatWeightTpl<T>() {}
451
452  MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {}
453
454  MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
455
456  static const MinMaxWeightTpl<T> Zero() {
457    return MinMaxWeightTpl<T>(FloatLimits<T>::kPosInfinity);
458  }
459
460  static const MinMaxWeightTpl<T> One() {
461    return MinMaxWeightTpl<T>(FloatLimits<T>::kNegInfinity);
462  }
463
464  static const MinMaxWeightTpl<T> NoWeight() {
465    return MinMaxWeightTpl<T>(FloatLimits<T>::kNumberBad); }
466
467  static const string &Type() {
468    static const string type = "minmax" +
469        FloatWeightTpl<T>::GetPrecisionString();
470    return type;
471  }
472
473  bool Member() const {
474    // Fails for IEEE NaN
475    return Value() == Value();
476  }
477
478  MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
479    // If one of infinities, or a NaN
480    if (Value() == FloatLimits<T>::kNegInfinity ||
481        Value() == FloatLimits<T>::kPosInfinity ||
482        Value() != Value())
483      return *this;
484    else
485      return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
486  }
487
488  MinMaxWeightTpl<T> Reverse() const { return *this; }
489
490  static uint64 Properties() {
491    return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath;
492  }
493};
494
495// Single-precision min-max weight
496typedef MinMaxWeightTpl<float> MinMaxWeight;
497
498// Min
499template <class T>
500inline MinMaxWeightTpl<T> Plus(
501    const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
502  if (!w1.Member() || !w2.Member())
503    return MinMaxWeightTpl<T>::NoWeight();
504  return w1.Value() < w2.Value() ? w1 : w2;
505}
506
507inline MinMaxWeightTpl<float> Plus(
508    const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
509  return Plus<float>(w1, w2);
510}
511
512inline MinMaxWeightTpl<double> Plus(
513    const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
514  return Plus<double>(w1, w2);
515}
516
517// Max
518template <class T>
519inline MinMaxWeightTpl<T> Times(
520    const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
521  if (!w1.Member() || !w2.Member())
522    return MinMaxWeightTpl<T>::NoWeight();
523  return w1.Value() >= w2.Value() ? w1 : w2;
524}
525
526inline MinMaxWeightTpl<float> Times(
527    const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
528  return Times<float>(w1, w2);
529}
530
531inline MinMaxWeightTpl<double> Times(
532    const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
533  return Times<double>(w1, w2);
534}
535
536// Defined only for special cases
537template <class T>
538inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1,
539                                 const MinMaxWeightTpl<T> &w2,
540                                 DivideType typ = DIVIDE_ANY) {
541  if (!w1.Member() || !w2.Member())
542    return MinMaxWeightTpl<T>::NoWeight();
543  // min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2
544  return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::kNumberBad;
545}
546
547inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1,
548                                     const MinMaxWeightTpl<float> &w2,
549                                     DivideType typ = DIVIDE_ANY) {
550  return Divide<float>(w1, w2, typ);
551}
552
553inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1,
554                                      const MinMaxWeightTpl<double> &w2,
555                                      DivideType typ = DIVIDE_ANY) {
556  return Divide<double>(w1, w2, typ);
557}
558
559//
560// WEIGHT CONVERTER SPECIALIZATIONS.
561//
562
563// Convert to tropical
564template <>
565struct WeightConvert<LogWeight, TropicalWeight> {
566  TropicalWeight operator()(LogWeight w) const { return w.Value(); }
567};
568
569template <>
570struct WeightConvert<Log64Weight, TropicalWeight> {
571  TropicalWeight operator()(Log64Weight w) const { return w.Value(); }
572};
573
574// Convert to log
575template <>
576struct WeightConvert<TropicalWeight, LogWeight> {
577  LogWeight operator()(TropicalWeight w) const { return w.Value(); }
578};
579
580template <>
581struct WeightConvert<Log64Weight, LogWeight> {
582  LogWeight operator()(Log64Weight w) const { return w.Value(); }
583};
584
585// Convert to log64
586template <>
587struct WeightConvert<TropicalWeight, Log64Weight> {
588  Log64Weight operator()(TropicalWeight w) const { return w.Value(); }
589};
590
591template <>
592struct WeightConvert<LogWeight, Log64Weight> {
593  Log64Weight operator()(LogWeight w) const { return w.Value(); }
594};
595
596}  // namespace fst
597
598#endif  // FST_LIB_FLOAT_WEIGHT_H__
599