factor-weight.h revision 4a68b3365c8c50aa93505e99ead2565ab73dcdb0
1// factor-weight.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16//
17// \file
18// Classes to factor weights in an FST.
19
20#ifndef FST_LIB_FACTOR_WEIGHT_H__
21#define FST_LIB_FACTOR_WEIGHT_H__
22
23#include <algorithm>
24
25#include <ext/hash_map>
26using __gnu_cxx::hash_map;
27#include <ext/slist>
28using __gnu_cxx::slist;
29
30#include "fst/lib/cache.h"
31#include "fst/lib/test-properties.h"
32
33namespace fst {
34
35struct FactorWeightOptions : CacheOptions {
36  float delta;
37  bool final_only;  // only factor final weights when true
38
39  FactorWeightOptions(const CacheOptions &opts, float d, bool of)
40      : CacheOptions(opts), delta(d), final_only(of) {}
41
42  explicit FactorWeightOptions(float d, bool of = false)
43      : delta(d), final_only(of) {}
44
45  FactorWeightOptions(bool of = false)
46      : delta(kDelta), final_only(of) {}
47};
48
49
50// A factor iterator takes as argument a weight w and returns a
51// sequence of pairs of weights (xi,yi) such that the sum of the
52// products xi times yi is equal to w. If w is fully factored,
53// the iterator should return nothing.
54//
55// template <class W>
56// class FactorIterator {
57//  public:
58//   FactorIterator(W w);
59//   bool Done() const;
60//   void Next();
61//   pair<W, W> Value() const;
62//   void Reset();
63// }
64
65
66// Factor trivially.
67template <class W>
68class IdentityFactor {
69 public:
70  IdentityFactor(const W &w) {}
71  bool Done() const { return true; }
72  void Next() {}
73  pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
74  void Reset() {}
75};
76
77
78// Factor a StringWeight w as 'ab' where 'a' is a label.
79template <typename L, StringType S = STRING_LEFT>
80class StringFactor {
81 public:
82  StringFactor(const StringWeight<L, S> &w)
83      : weight_(w), done_(w.Size() <= 1) {}
84
85  bool Done() const { return done_; }
86
87  void Next() { done_ = true; }
88
89  pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
90    StringWeightIterator<L, S> iter(weight_);
91    StringWeight<L, S> w1(iter.Value());
92    StringWeight<L, S> w2;
93    for (iter.Next(); !iter.Done(); iter.Next())
94      w2.PushBack(iter.Value());
95    return make_pair(w1, w2);
96  }
97
98  void Reset() { done_ = weight_.Size() <= 1; }
99
100 private:
101  StringWeight<L, S> weight_;
102  bool done_;
103};
104
105
106// Factor a GallicWeight using StringFactor.
107template <class L, class W, StringType S = STRING_LEFT>
108class GallicFactor {
109 public:
110  GallicFactor(const GallicWeight<L, W, S> &w)
111      : weight_(w), done_(w.Value1().Size() <= 1) {}
112
113  bool Done() const { return done_; }
114
115  void Next() { done_ = true; }
116
117  pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
118    StringFactor<L, S> iter(weight_.Value1());
119    GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
120    GallicWeight<L, W, S> w2(iter.Value().second, W::One());
121    return make_pair(w1, w2);
122  }
123
124  void Reset() { done_ = weight_.Value1().Size() <= 1; }
125
126 private:
127  GallicWeight<L, W, S> weight_;
128  bool done_;
129};
130
131
132// Implementation class for FactorWeight
133template <class A, class F>
134class FactorWeightFstImpl
135    : public CacheImpl<A> {
136 public:
137  using FstImpl<A>::SetType;
138  using FstImpl<A>::SetProperties;
139  using FstImpl<A>::Properties;
140  using FstImpl<A>::SetInputSymbols;
141  using FstImpl<A>::SetOutputSymbols;
142
143  using CacheBaseImpl< CacheState<A> >::HasStart;
144  using CacheBaseImpl< CacheState<A> >::HasFinal;
145  using CacheBaseImpl< CacheState<A> >::HasArcs;
146
147  typedef A Arc;
148  typedef typename A::Label Label;
149  typedef typename A::Weight Weight;
150  typedef typename A::StateId StateId;
151  typedef F FactorIterator;
152
153  struct Element {
154    Element() {}
155
156    Element(StateId s, Weight w) : state(s), weight(w) {}
157
158    StateId state;     // Input state Id
159    Weight weight;     // Residual weight
160  };
161
162  FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions &opts)
163      : CacheImpl<A>(opts), fst_(fst.Copy()), delta_(opts.delta),
164        final_only_(opts.final_only) {
165    SetType("factor-weight");
166    uint64 props = fst.Properties(kFstProperties, false);
167    SetProperties(FactorWeightProperties(props), kCopyProperties);
168
169    SetInputSymbols(fst.InputSymbols());
170    SetOutputSymbols(fst.OutputSymbols());
171  }
172
173  ~FactorWeightFstImpl() {
174    delete fst_;
175  }
176
177  StateId Start() {
178    if (!HasStart()) {
179      StateId s = fst_->Start();
180      if (s == kNoStateId)
181        return kNoStateId;
182      StateId start = FindState(Element(fst_->Start(), Weight::One()));
183      SetStart(start);
184    }
185    return CacheImpl<A>::Start();
186  }
187
188  Weight Final(StateId s) {
189    if (!HasFinal(s)) {
190      const Element &e = elements_[s];
191      // TODO: fix so cast is unnecessary
192      Weight w = e.state == kNoStateId
193                 ? e.weight
194                 : (Weight) Times(e.weight, fst_->Final(e.state));
195      FactorIterator f(w);
196      if (w != Weight::Zero() && f.Done())
197        SetFinal(s, w);
198      else
199        SetFinal(s, Weight::Zero());
200    }
201    return CacheImpl<A>::Final(s);
202  }
203
204  size_t NumArcs(StateId s) {
205    if (!HasArcs(s))
206      Expand(s);
207    return CacheImpl<A>::NumArcs(s);
208  }
209
210  size_t NumInputEpsilons(StateId s) {
211    if (!HasArcs(s))
212      Expand(s);
213    return CacheImpl<A>::NumInputEpsilons(s);
214  }
215
216  size_t NumOutputEpsilons(StateId s) {
217    if (!HasArcs(s))
218      Expand(s);
219    return CacheImpl<A>::NumOutputEpsilons(s);
220  }
221
222  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
223    if (!HasArcs(s))
224      Expand(s);
225    CacheImpl<A>::InitArcIterator(s, data);
226  }
227
228
229  // Find state corresponding to an element. Create new state
230  // if element not found.
231  StateId FindState(const Element &e) {
232    if (final_only_ && e.weight == Weight::One()) {
233      while (unfactored_.size() <= (unsigned int)e.state)
234        unfactored_.push_back(kNoStateId);
235      if (unfactored_[e.state] == kNoStateId) {
236        unfactored_[e.state] = elements_.size();
237        elements_.push_back(e);
238      }
239      return unfactored_[e.state];
240    } else {
241      typename ElementMap::iterator eit = element_map_.find(e);
242      if (eit != element_map_.end()) {
243        return (*eit).second;
244      } else {
245        StateId s = elements_.size();
246        elements_.push_back(e);
247        element_map_.insert(pair<const Element, StateId>(e, s));
248        return s;
249      }
250    }
251  }
252
253  // Computes the outgoing transitions from a state, creating new destination
254  // states as needed.
255  void Expand(StateId s) {
256    Element e = elements_[s];
257    if (e.state != kNoStateId) {
258      for (ArcIterator< Fst<A> > ait(*fst_, e.state);
259           !ait.Done();
260           ait.Next()) {
261        const A &arc = ait.Value();
262        Weight w = Times(e.weight, arc.weight);
263        FactorIterator fit(w);
264        if (final_only_ || fit.Done()) {
265          StateId d = FindState(Element(arc.nextstate, Weight::One()));
266          AddArc(s, Arc(arc.ilabel, arc.olabel, w, d));
267        } else {
268          for (; !fit.Done(); fit.Next()) {
269            const pair<Weight, Weight> &p = fit.Value();
270            StateId d = FindState(Element(arc.nextstate,
271                                          p.second.Quantize(delta_)));
272            AddArc(s, Arc(arc.ilabel, arc.olabel, p.first, d));
273          }
274        }
275      }
276    }
277    if ((e.state == kNoStateId) ||
278        (fst_->Final(e.state) != Weight::Zero())) {
279      Weight w = e.state == kNoStateId
280                 ? e.weight
281                 : Times(e.weight, fst_->Final(e.state));
282      for (FactorIterator fit(w);
283           !fit.Done();
284           fit.Next()) {
285        const pair<Weight, Weight> &p = fit.Value();
286        StateId d = FindState(Element(kNoStateId,
287                                      p.second.Quantize(delta_)));
288        AddArc(s, Arc(0, 0, p.first, d));
289      }
290    }
291    SetArcs(s);
292  }
293
294 private:
295  // Equality function for Elements, assume weights have been quantized.
296  class ElementEqual {
297   public:
298    bool operator()(const Element &x, const Element &y) const {
299      return x.state == y.state && x.weight == y.weight;
300    }
301  };
302
303  // Hash function for Elements to Fst states.
304  class ElementKey {
305   public:
306    size_t operator()(const Element &x) const {
307      return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
308    }
309   private:
310    static const int kPrime = 7853;
311  };
312
313  typedef hash_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
314
315  const Fst<A> *fst_;
316  float delta_;
317  bool final_only_;
318  vector<Element> elements_;  // mapping Fst state to Elements
319  ElementMap element_map_;    // mapping Elements to Fst state
320  // mapping between old/new 'StateId' for states that do not need to
321  // be factored when 'final_only_' is true
322  vector<StateId> unfactored_;
323
324  DISALLOW_EVIL_CONSTRUCTORS(FactorWeightFstImpl);
325};
326
327
328// FactorWeightFst takes as template parameter a FactorIterator as
329// defined above. The result of weight factoring is a transducer
330// equivalent to the input whose path weights have been factored
331// according to the FactorIterator. States and transitions will be
332// added as necessary. The algorithm is a generalization to arbitrary
333// weights of the second step of the input epsilon-normalization
334// algorithm due to Mohri, "Generic epsilon-removal and input
335// epsilon-normalization algorithms for weighted transducers",
336// International Journal of Computer Science 13(1): 129-143 (2002).
337template <class A, class F>
338class FactorWeightFst : public Fst<A> {
339 public:
340  friend class ArcIterator< FactorWeightFst<A, F> >;
341  friend class CacheStateIterator< FactorWeightFst<A, F> >;
342  friend class CacheArcIterator< FactorWeightFst<A, F> >;
343
344  typedef A Arc;
345  typedef typename A::Weight Weight;
346  typedef typename A::StateId StateId;
347  typedef CacheState<A> State;
348
349  FactorWeightFst(const Fst<A> &fst)
350      : impl_(new FactorWeightFstImpl<A, F>(fst, FactorWeightOptions())) {}
351
352  FactorWeightFst(const Fst<A> &fst,  const FactorWeightOptions &opts)
353      : impl_(new FactorWeightFstImpl<A, F>(fst, opts)) {}
354  FactorWeightFst(const FactorWeightFst<A, F> &fst) : Fst<A>(fst), impl_(fst.impl_) {
355    impl_->IncrRefCount();
356  }
357
358  virtual ~FactorWeightFst() { if (!impl_->DecrRefCount()) delete impl_;  }
359
360  virtual StateId Start() const { return impl_->Start(); }
361
362  virtual Weight Final(StateId s) const { return impl_->Final(s); }
363
364  virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
365
366  virtual size_t NumInputEpsilons(StateId s) const {
367    return impl_->NumInputEpsilons(s);
368  }
369
370  virtual size_t NumOutputEpsilons(StateId s) const {
371    return impl_->NumOutputEpsilons(s);
372  }
373
374  virtual uint64 Properties(uint64 mask, bool test) const {
375    if (test) {
376      uint64 known, test = TestProperties(*this, mask, &known);
377      impl_->SetProperties(test, known);
378      return test & mask;
379    } else {
380      return impl_->Properties(mask);
381    }
382  }
383
384  virtual const string& Type() const { return impl_->Type(); }
385
386  virtual FactorWeightFst<A, F> *Copy() const {
387    return new FactorWeightFst<A, F>(*this);
388  }
389
390  virtual const SymbolTable* InputSymbols() const {
391    return impl_->InputSymbols();
392  }
393
394  virtual const SymbolTable* OutputSymbols() const {
395    return impl_->OutputSymbols();
396  }
397
398  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
399
400  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
401    impl_->InitArcIterator(s, data);
402  }
403
404 private:
405  FactorWeightFstImpl<A, F> *Impl() { return impl_; }
406
407  FactorWeightFstImpl<A, F> *impl_;
408
409  void operator=(const FactorWeightFst<A, F> &fst);  // Disallow
410};
411
412
413// Specialization for FactorWeightFst.
414template<class A, class F>
415class StateIterator< FactorWeightFst<A, F> >
416    : public CacheStateIterator< FactorWeightFst<A, F> > {
417 public:
418  explicit StateIterator(const FactorWeightFst<A, F> &fst)
419      : CacheStateIterator< FactorWeightFst<A, F> >(fst) {}
420};
421
422
423// Specialization for FactorWeightFst.
424template <class A, class F>
425class ArcIterator< FactorWeightFst<A, F> >
426    : public CacheArcIterator< FactorWeightFst<A, F> > {
427 public:
428  typedef typename A::StateId StateId;
429
430  ArcIterator(const FactorWeightFst<A, F> &fst, StateId s)
431      : CacheArcIterator< FactorWeightFst<A, F> >(fst, s) {
432    if (!fst.impl_->HasArcs(s))
433      fst.impl_->Expand(s);
434  }
435
436 private:
437  DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
438};
439
440template <class A, class F> inline
441void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const
442{
443  data->base = new StateIterator< FactorWeightFst<A, F> >(*this);
444}
445
446
447}  // namespace fst
448
449#endif // FST_LIB_FACTOR_WEIGHT_H__
450