float-weight.h revision 4a68b3365c8c50aa93505e99ead2565ab73dcdb0
1// float-weight.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// Float weight set and associated semiring operation definitions.
18//
19
20#ifndef FST_LIB_FLOAT_WEIGHT_H__
21#define FST_LIB_FLOAT_WEIGHT_H__
22
23#include <limits>
24
25#include "fst/lib/weight.h"
26
27namespace fst {
28
29static const float kPosInfinity = numeric_limits<float>::infinity();
30static const float kNegInfinity = -kPosInfinity;
31
32// Single precision floating point weight base class
33class FloatWeight {
34 public:
35  FloatWeight() {}
36
37  FloatWeight(float f) : value_(f) {}
38
39  FloatWeight(const FloatWeight &w) : value_(w.value_) {}
40
41  FloatWeight &operator=(const FloatWeight &w) {
42    value_ = w.value_;
43    return *this;
44  }
45
46  istream &Read(istream &strm) {
47    return ReadType(strm, &value_);
48  }
49
50  ostream &Write(ostream &strm) const {
51    return WriteType(strm, value_);
52  }
53
54  ssize_t Hash() const {
55    union {
56      float f;
57      ssize_t s;
58    } u = { value_ };
59    return u.s;
60  }
61
62  const float &Value() const { return value_; }
63
64 protected:
65  float value_;
66};
67
68inline bool operator==(const FloatWeight &w1, const FloatWeight &w2) {
69  // Volatile qualifier thwarts over-aggressive compiler optimizations
70  // that lead to problems esp. with NaturalLess().
71  volatile float v1 = w1.Value();
72  volatile float v2 = w2.Value();
73  return v1 == v2;
74}
75
76inline bool operator!=(const FloatWeight &w1, const FloatWeight &w2) {
77  return !(w1 == w2);
78}
79
80inline bool ApproxEqual(const FloatWeight &w1, const FloatWeight &w2,
81                        float delta = kDelta) {
82  return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
83}
84
85inline ostream &operator<<(ostream &strm, const FloatWeight &w) {
86  if (w.Value() == kPosInfinity)
87    return strm << "Infinity";
88  else if (w.Value() == kNegInfinity)
89    return strm << "-Infinity";
90  else if (w.Value() != w.Value())   // Fails for NaN
91    return strm << "BadFloat";
92  else
93    return strm << w.Value();
94}
95
96inline istream &operator>>(istream &strm, FloatWeight &w) {
97  string s;
98  strm >> s;
99  if (s == "Infinity") {
100    w = FloatWeight(kPosInfinity);
101  } else if (s == "-Infinity") {
102    w = FloatWeight(kNegInfinity);
103  } else {
104    char *p;
105    float f = strtod(s.c_str(), &p);
106    if (p < s.c_str() + s.size())
107      strm.clear(std::ios::badbit);
108    else
109      w = FloatWeight(f);
110  }
111  return strm;
112}
113
114
115// Tropical semiring: (min, +, inf, 0)
116class TropicalWeight : public FloatWeight {
117 public:
118  typedef TropicalWeight ReverseWeight;
119
120  TropicalWeight() : FloatWeight() {}
121
122  TropicalWeight(float f) : FloatWeight(f) {}
123
124  TropicalWeight(const TropicalWeight &w) : FloatWeight(w) {}
125
126  static const TropicalWeight Zero() { return TropicalWeight(kPosInfinity); }
127
128  static const TropicalWeight One() { return TropicalWeight(0.0F); }
129
130  static const string &Type() {
131    static const string type = "tropical";
132    return type;
133  }
134
135  bool Member() const {
136    // First part fails for IEEE NaN
137    return Value() == Value() && Value() != kNegInfinity;
138  }
139
140  TropicalWeight Quantize(float delta = kDelta) const {
141    return TropicalWeight(floor(Value()/delta + 0.5F) * delta);
142  }
143
144  TropicalWeight Reverse() const { return *this; }
145
146  static uint64 Properties() {
147    return kLeftSemiring | kRightSemiring | kCommutative |
148      kPath | kIdempotent;
149  }
150};
151
152inline TropicalWeight Plus(const TropicalWeight &w1,
153                           const TropicalWeight &w2) {
154  return w1.Value() < w2.Value() ? w1 : w2;
155}
156
157inline TropicalWeight Times(const TropicalWeight &w1,
158                            const TropicalWeight &w2) {
159  float f1 = w1.Value(), f2 = w2.Value();
160  if (f1 == kPosInfinity)
161    return w1;
162  else if (f2 == kPosInfinity)
163    return w2;
164  else
165    return TropicalWeight(f1 + f2);
166}
167
168inline TropicalWeight Divide(const TropicalWeight &w1,
169                             const TropicalWeight &w2,
170                             DivideType typ = DIVIDE_ANY) {
171  float f1 = w1.Value(), f2 = w2.Value();
172  if (f2 == kPosInfinity)
173    return kNegInfinity;
174  else if (f1 == kPosInfinity)
175    return kPosInfinity;
176  else
177    return TropicalWeight(f1 - f2);
178}
179
180
181// Log semiring: (log(e^-x + e^y), +, inf, 0)
182class LogWeight : public FloatWeight {
183 public:
184  typedef LogWeight ReverseWeight;
185
186  LogWeight() : FloatWeight() {}
187
188  LogWeight(float f) : FloatWeight(f) {}
189
190  LogWeight(const LogWeight &w) : FloatWeight(w) {}
191
192  static const LogWeight Zero() {   return LogWeight(kPosInfinity); }
193
194  static const LogWeight One() { return LogWeight(0.0F); }
195
196  static const string &Type() {
197    static const string type = "log";
198    return type;
199  }
200
201  bool Member() const {
202    // First part fails for IEEE NaN
203    return Value() == Value() && Value() != kNegInfinity;
204  }
205
206  LogWeight Quantize(float delta = kDelta) const {
207    return LogWeight(floor(Value()/delta + 0.5F) * delta);
208  }
209
210  LogWeight Reverse() const { return *this; }
211
212  static uint64 Properties() {
213    return kLeftSemiring | kRightSemiring | kCommutative;
214  }
215};
216
217inline double LogExp(double x) { return log(1.0F + exp(-x)); }
218
219inline LogWeight Plus(const LogWeight &w1, const LogWeight &w2) {
220  float f1 = w1.Value(), f2 = w2.Value();
221  if (f1 == kPosInfinity)
222    return w2;
223  else if (f2 == kPosInfinity)
224    return w1;
225  else if (f1 > f2)
226    return LogWeight(f2 - LogExp(f1 - f2));
227  else
228    return LogWeight(f1 - LogExp(f2 - f1));
229}
230
231inline LogWeight Times(const LogWeight &w1, const LogWeight &w2) {
232  float f1 = w1.Value(), f2 = w2.Value();
233  if (f1 == kPosInfinity)
234    return w1;
235  else if (f2 == kPosInfinity)
236    return w2;
237  else
238    return LogWeight(f1 + f2);
239}
240
241inline LogWeight Divide(const LogWeight &w1,
242                             const LogWeight &w2,
243                             DivideType typ = DIVIDE_ANY) {
244  float f1 = w1.Value(), f2 = w2.Value();
245  if (f2 == kPosInfinity)
246    return kNegInfinity;
247  else if (f1 == kPosInfinity)
248    return kPosInfinity;
249  else
250    return LogWeight(f1 - f2);
251}
252
253}  // namespace fst;
254
255#endif  // FST_LIB_FLOAT_WEIGHT_H__
256