signed-log-weight.h revision dfd8b8327b93660601d016cdc6f29f433b45a8d8
1
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13//
14// Copyright 2005-2010 Google, Inc.
15// Author: krr@google.com (Kasturi Rangan Raghavan)
16// \file
17// LogWeight along with sign information that represents the value X in the
18// linear domain as <sign(X), -ln(|X|)>
19// The sign is a TropicalWeight:
20//  positive, TropicalWeight.Value() > 0.0, recommended value 1.0
21//  negative, TropicalWeight.Value() <= 0.0, recommended value -1.0
22
23#ifndef FST_LIB_SIGNED_LOG_WEIGHT_H_
24#define FST_LIB_SIGNED_LOG_WEIGHT_H_
25
26#include <fst/float-weight.h>
27#include <fst/pair-weight.h>
28
29
30namespace fst {
31template <class T>
32class SignedLogWeightTpl
33    : public PairWeight<TropicalWeight, LogWeightTpl<T> > {
34 public:
35  typedef TropicalWeight X1;
36  typedef LogWeightTpl<T> X2;
37  using PairWeight<X1, X2>::Value1;
38  using PairWeight<X1, X2>::Value2;
39
40  using PairWeight<X1, X2>::Reverse;
41  using PairWeight<X1, X2>::Quantize;
42  using PairWeight<X1, X2>::Member;
43
44  typedef SignedLogWeightTpl<T> ReverseWeight;
45
46  SignedLogWeightTpl() : PairWeight<X1, X2>() {}
47
48  SignedLogWeightTpl(const SignedLogWeightTpl<T>& w)
49      : PairWeight<X1, X2> (w) { }
50
51  SignedLogWeightTpl(const PairWeight<X1, X2>& w)
52      : PairWeight<X1, X2> (w) { }
53
54  SignedLogWeightTpl(const X1& x1, const X2& x2)
55      : PairWeight<X1, X2>(x1, x2) { }
56
57  static const SignedLogWeightTpl<T> &Zero() {
58    static const SignedLogWeightTpl<T> zero(X1(1.0), X2::Zero());
59    return zero;
60  }
61
62  static const SignedLogWeightTpl<T> &One() {
63    static const SignedLogWeightTpl<T> one(X1(1.0), X2::One());
64    return one;
65  }
66
67  static const SignedLogWeightTpl<T> &NoWeight() {
68    static const SignedLogWeightTpl<T> no_weight(X1(1.0), X2::NoWeight());
69    return no_weight;
70  }
71
72  static const string &Type() {
73    static const string type = "signed_log_" + X1::Type() + "_" + X2::Type();
74    return type;
75  }
76
77  ProductWeight<X1, X2> Quantize(float delta = kDelta) const {
78    return PairWeight<X1, X2>::Quantize();
79  }
80
81  ReverseWeight Reverse() const {
82    return PairWeight<X1, X2>::Reverse();
83  }
84
85  bool Member() const {
86    return PairWeight<X1, X2>::Member();
87  }
88
89  static uint64 Properties() {
90    // not idempotent nor path
91    return kLeftSemiring | kRightSemiring | kCommutative;
92  }
93
94  size_t Hash() const {
95    size_t h1;
96    if (Value2() == X2::Zero() || Value1().Value() > 0.0)
97      h1 = TropicalWeight(1.0).Hash();
98    else
99      h1 = TropicalWeight(-1.0).Hash();
100    size_t h2 = Value2().Hash();
101    const int lshift = 5;
102    const int rshift = CHAR_BIT * sizeof(size_t) - 5;
103    return h1 << lshift ^ h1 >> rshift ^ h2;
104  }
105};
106
107template <class T>
108inline SignedLogWeightTpl<T> Plus(const SignedLogWeightTpl<T> &w1,
109                                  const SignedLogWeightTpl<T> &w2) {
110  if (!w1.Member() || !w2.Member())
111    return SignedLogWeightTpl<T>::NoWeight();
112  bool s1 = w1.Value1().Value() > 0.0;
113  bool s2 = w2.Value1().Value() > 0.0;
114  T f1 = w1.Value2().Value();
115  T f2 = w2.Value2().Value();
116  if (f1 == FloatLimits<T>::PosInfinity())
117    return w2;
118  else if (f2 == FloatLimits<T>::PosInfinity())
119    return w1;
120  else if (f1 == f2) {
121    if (s1 == s2)
122      return SignedLogWeightTpl<T>(w1.Value1(), (f2 - log(2.0F)));
123    else
124      return SignedLogWeightTpl<T>::Zero();
125  } else if (f1 > f2) {
126    if (s1 == s2) {
127      return SignedLogWeightTpl<T>(
128        w1.Value1(), (f2 - log(1.0F + exp(f2 - f1))));
129    } else {
130      return SignedLogWeightTpl<T>(
131        w2.Value1(), (f2 - log(1.0F - exp(f2 - f1))));
132    }
133  } else {
134    if (s2 == s1) {
135      return SignedLogWeightTpl<T>(
136        w2.Value1(), (f1 - log(1.0F + exp(f1 - f2))));
137    } else {
138      return SignedLogWeightTpl<T>(
139        w1.Value1(), (f1 - log(1.0F - exp(f1 - f2))));
140    }
141  }
142}
143
144template <class T>
145inline SignedLogWeightTpl<T> Minus(const SignedLogWeightTpl<T> &w1,
146                                   const SignedLogWeightTpl<T> &w2) {
147  SignedLogWeightTpl<T> minus_w2(-w2.Value1().Value(), w2.Value2());
148  return Plus(w1, minus_w2);
149}
150
151template <class T>
152inline SignedLogWeightTpl<T> Times(const SignedLogWeightTpl<T> &w1,
153                                   const SignedLogWeightTpl<T> &w2) {
154  if (!w1.Member() || !w2.Member())
155    return SignedLogWeightTpl<T>::NoWeight();
156  bool s1 = w1.Value1().Value() > 0.0;
157  bool s2 = w2.Value1().Value() > 0.0;
158  T f1 = w1.Value2().Value();
159  T f2 = w2.Value2().Value();
160  if (s1 == s2)
161    return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 + f2));
162  else
163    return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 + f2));
164}
165
166template <class T>
167inline SignedLogWeightTpl<T> Divide(const SignedLogWeightTpl<T> &w1,
168                                    const SignedLogWeightTpl<T> &w2,
169                                    DivideType typ = DIVIDE_ANY) {
170  if (!w1.Member() || !w2.Member())
171    return SignedLogWeightTpl<T>::NoWeight();
172  bool s1 = w1.Value1().Value() > 0.0;
173  bool s2 = w2.Value1().Value() > 0.0;
174  T f1 = w1.Value2().Value();
175  T f2 = w2.Value2().Value();
176  if (f2 == FloatLimits<T>::PosInfinity())
177    return SignedLogWeightTpl<T>(TropicalWeight(1.0),
178      FloatLimits<T>::NumberBad());
179  else if (f1 == FloatLimits<T>::PosInfinity())
180    return SignedLogWeightTpl<T>(TropicalWeight(1.0),
181      FloatLimits<T>::PosInfinity());
182  else if (s1 == s2)
183    return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 - f2));
184  else
185    return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 - f2));
186}
187
188template <class T>
189inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1,
190                        const SignedLogWeightTpl<T> &w2,
191                        float delta = kDelta) {
192  bool s1 = w1.Value1().Value() > 0.0;
193  bool s2 = w2.Value1().Value() > 0.0;
194  if (s1 == s2) {
195    return ApproxEqual(w1.Value2(), w2.Value2(), delta);
196  } else {
197    return w1.Value2() == LogWeightTpl<T>::Zero()
198        && w2.Value2() == LogWeightTpl<T>::Zero();
199  }
200}
201
202template <class T>
203inline bool operator==(const SignedLogWeightTpl<T> &w1,
204                       const SignedLogWeightTpl<T> &w2) {
205  bool s1 = w1.Value1().Value() > 0.0;
206  bool s2 = w2.Value1().Value() > 0.0;
207  if (s1 == s2)
208    return w1.Value2() == w2.Value2();
209  else
210    return (w1.Value2() == LogWeightTpl<T>::Zero()) &&
211           (w2.Value2() == LogWeightTpl<T>::Zero());
212}
213
214
215// Single-precision signed-log weight
216typedef SignedLogWeightTpl<float> SignedLogWeight;
217// Double-precision signed-log weight
218typedef SignedLogWeightTpl<double> SignedLog64Weight;
219
220//
221// WEIGHT CONVERTER SPECIALIZATIONS.
222//
223
224template <class W1, class W2>
225bool SignedLogConvertCheck(W1 w) {
226  if (w.Value1().Value() < 0.0) {
227    FSTERROR() << "WeightConvert: can't convert weight from \""
228               << W1::Type() << "\" to \"" << W2::Type();
229    return false;
230  }
231  return true;
232}
233
234// Convert to tropical
235template <>
236struct WeightConvert<SignedLogWeight, TropicalWeight> {
237  TropicalWeight operator()(SignedLogWeight w) const {
238    if (!SignedLogConvertCheck<SignedLogWeight, TropicalWeight>(w))
239      return TropicalWeight::NoWeight();
240    return w.Value2().Value();
241  }
242};
243
244template <>
245struct WeightConvert<SignedLog64Weight, TropicalWeight> {
246  TropicalWeight operator()(SignedLog64Weight w) const {
247    if (!SignedLogConvertCheck<SignedLog64Weight, TropicalWeight>(w))
248      return TropicalWeight::NoWeight();
249    return w.Value2().Value();
250  }
251};
252
253// Convert to log
254template <>
255struct WeightConvert<SignedLogWeight, LogWeight> {
256  LogWeight operator()(SignedLogWeight w) const {
257    if (!SignedLogConvertCheck<SignedLogWeight, LogWeight>(w))
258      return LogWeight::NoWeight();
259    return w.Value2().Value();
260  }
261};
262
263template <>
264struct WeightConvert<SignedLog64Weight, LogWeight> {
265  LogWeight operator()(SignedLog64Weight w) const {
266    if (!SignedLogConvertCheck<SignedLog64Weight, LogWeight>(w))
267      return LogWeight::NoWeight();
268    return w.Value2().Value();
269  }
270};
271
272// Convert to log64
273template <>
274struct WeightConvert<SignedLogWeight, Log64Weight> {
275  Log64Weight operator()(SignedLogWeight w) const {
276    if (!SignedLogConvertCheck<SignedLogWeight, Log64Weight>(w))
277      return Log64Weight::NoWeight();
278    return w.Value2().Value();
279  }
280};
281
282template <>
283struct WeightConvert<SignedLog64Weight, Log64Weight> {
284  Log64Weight operator()(SignedLog64Weight w) const {
285    if (!SignedLogConvertCheck<SignedLog64Weight, Log64Weight>(w))
286      return Log64Weight::NoWeight();
287    return w.Value2().Value();
288  }
289};
290
291// Convert to signed log
292template <>
293struct WeightConvert<TropicalWeight, SignedLogWeight> {
294  SignedLogWeight operator()(TropicalWeight w) const {
295    TropicalWeight x1 = 1.0;
296    LogWeight x2 = w.Value();
297    return SignedLogWeight(x1, x2);
298  }
299};
300
301template <>
302struct WeightConvert<LogWeight, SignedLogWeight> {
303  SignedLogWeight operator()(LogWeight w) const {
304    TropicalWeight x1 = 1.0;
305    LogWeight x2 = w.Value();
306    return SignedLogWeight(x1, x2);
307  }
308};
309
310template <>
311struct WeightConvert<Log64Weight, SignedLogWeight> {
312  SignedLogWeight operator()(Log64Weight w) const {
313    TropicalWeight x1 = 1.0;
314    LogWeight x2 = w.Value();
315    return SignedLogWeight(x1, x2);
316  }
317};
318
319template <>
320struct WeightConvert<SignedLog64Weight, SignedLogWeight> {
321  SignedLogWeight operator()(SignedLog64Weight w) const {
322    TropicalWeight x1 = w.Value1();
323    LogWeight x2 = w.Value2().Value();
324    return SignedLogWeight(x1, x2);
325  }
326};
327
328// Convert to signed log64
329template <>
330struct WeightConvert<TropicalWeight, SignedLog64Weight> {
331  SignedLog64Weight operator()(TropicalWeight w) const {
332    TropicalWeight x1 = 1.0;
333    Log64Weight x2 = w.Value();
334    return SignedLog64Weight(x1, x2);
335  }
336};
337
338template <>
339struct WeightConvert<LogWeight, SignedLog64Weight> {
340  SignedLog64Weight operator()(LogWeight w) const {
341    TropicalWeight x1 = 1.0;
342    Log64Weight x2 = w.Value();
343    return SignedLog64Weight(x1, x2);
344  }
345};
346
347template <>
348struct WeightConvert<Log64Weight, SignedLog64Weight> {
349  SignedLog64Weight operator()(Log64Weight w) const {
350    TropicalWeight x1 = 1.0;
351    Log64Weight x2 = w.Value();
352    return SignedLog64Weight(x1, x2);
353  }
354};
355
356template <>
357struct WeightConvert<SignedLogWeight, SignedLog64Weight> {
358  SignedLog64Weight operator()(SignedLogWeight w) const {
359    TropicalWeight x1 = w.Value1();
360    Log64Weight x2 = w.Value2().Value();
361    return SignedLog64Weight(x1, x2);
362  }
363};
364
365}  // namespace fst
366
367#endif  // FST_LIB_SIGNED_LOG_WEIGHT_H_
368