1// factor-weight.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16//
17// \file
18// Classes to factor weights in an FST.
19
20#ifndef FST_LIB_FACTOR_WEIGHT_H__
21#define FST_LIB_FACTOR_WEIGHT_H__
22
23#include <algorithm>
24
25#include <unordered_map>
26#include <forward_list>
27
28#include "fst/lib/cache.h"
29#include "fst/lib/test-properties.h"
30
31namespace fst {
32
33struct FactorWeightOptions : CacheOptions {
34  float delta;
35  bool final_only;  // only factor final weights when true
36
37  FactorWeightOptions(const CacheOptions &opts, float d, bool of)
38      : CacheOptions(opts), delta(d), final_only(of) {}
39
40  explicit FactorWeightOptions(float d, bool of = false)
41      : delta(d), final_only(of) {}
42
43  FactorWeightOptions(bool of = false)
44      : delta(kDelta), final_only(of) {}
45};
46
47
48// A factor iterator takes as argument a weight w and returns a
49// sequence of pairs of weights (xi,yi) such that the sum of the
50// products xi times yi is equal to w. If w is fully factored,
51// the iterator should return nothing.
52//
53// template <class W>
54// class FactorIterator {
55//  public:
56//   FactorIterator(W w);
57//   bool Done() const;
58//   void Next();
59//   pair<W, W> Value() const;
60//   void Reset();
61// }
62
63
64// Factor trivially.
65template <class W>
66class IdentityFactor {
67 public:
68  IdentityFactor(const W &w) {}
69  bool Done() const { return true; }
70  void Next() {}
71  pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
72  void Reset() {}
73};
74
75
76// Factor a StringWeight w as 'ab' where 'a' is a label.
77template <typename L, StringType S = STRING_LEFT>
78class StringFactor {
79 public:
80  StringFactor(const StringWeight<L, S> &w)
81      : weight_(w), done_(w.Size() <= 1) {}
82
83  bool Done() const { return done_; }
84
85  void Next() { done_ = true; }
86
87  pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
88    StringWeightIterator<L, S> iter(weight_);
89    StringWeight<L, S> w1(iter.Value());
90    StringWeight<L, S> w2;
91    for (iter.Next(); !iter.Done(); iter.Next())
92      w2.PushBack(iter.Value());
93    return make_pair(w1, w2);
94  }
95
96  void Reset() { done_ = weight_.Size() <= 1; }
97
98 private:
99  StringWeight<L, S> weight_;
100  bool done_;
101};
102
103
104// Factor a GallicWeight using StringFactor.
105template <class L, class W, StringType S = STRING_LEFT>
106class GallicFactor {
107 public:
108  GallicFactor(const GallicWeight<L, W, S> &w)
109      : weight_(w), done_(w.Value1().Size() <= 1) {}
110
111  bool Done() const { return done_; }
112
113  void Next() { done_ = true; }
114
115  pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
116    StringFactor<L, S> iter(weight_.Value1());
117    GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
118    GallicWeight<L, W, S> w2(iter.Value().second, W::One());
119    return make_pair(w1, w2);
120  }
121
122  void Reset() { done_ = weight_.Value1().Size() <= 1; }
123
124 private:
125  GallicWeight<L, W, S> weight_;
126  bool done_;
127};
128
129
130// Implementation class for FactorWeight
131template <class A, class F>
132class FactorWeightFstImpl
133    : public CacheImpl<A> {
134 public:
135  using FstImpl<A>::SetType;
136  using FstImpl<A>::SetProperties;
137  using FstImpl<A>::Properties;
138  using FstImpl<A>::SetInputSymbols;
139  using FstImpl<A>::SetOutputSymbols;
140
141  using CacheBaseImpl< CacheState<A> >::HasStart;
142  using CacheBaseImpl< CacheState<A> >::HasFinal;
143  using CacheBaseImpl< CacheState<A> >::HasArcs;
144
145  typedef A Arc;
146  typedef typename A::Label Label;
147  typedef typename A::Weight Weight;
148  typedef typename A::StateId StateId;
149  typedef F FactorIterator;
150
151  struct Element {
152    Element() {}
153
154    Element(StateId s, Weight w) : state(s), weight(w) {}
155
156    StateId state;     // Input state Id
157    Weight weight;     // Residual weight
158  };
159
160  FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions &opts)
161      : CacheImpl<A>(opts), fst_(fst.Copy()), delta_(opts.delta),
162        final_only_(opts.final_only) {
163    SetType("factor-weight");
164    uint64 props = fst.Properties(kFstProperties, false);
165    SetProperties(FactorWeightProperties(props), kCopyProperties);
166
167    SetInputSymbols(fst.InputSymbols());
168    SetOutputSymbols(fst.OutputSymbols());
169  }
170
171  ~FactorWeightFstImpl() {
172    delete fst_;
173  }
174
175  StateId Start() {
176    if (!HasStart()) {
177      StateId s = fst_->Start();
178      if (s == kNoStateId)
179        return kNoStateId;
180      StateId start = FindState(Element(fst_->Start(), Weight::One()));
181      this->SetStart(start);
182    }
183    return CacheImpl<A>::Start();
184  }
185
186  Weight Final(StateId s) {
187    if (!HasFinal(s)) {
188      const Element &e = elements_[s];
189      // TODO: fix so cast is unnecessary
190      Weight w = e.state == kNoStateId
191                 ? e.weight
192                 : (Weight) Times(e.weight, fst_->Final(e.state));
193      FactorIterator f(w);
194      if (w != Weight::Zero() && f.Done())
195        this->SetFinal(s, w);
196      else
197        this->SetFinal(s, Weight::Zero());
198    }
199    return CacheImpl<A>::Final(s);
200  }
201
202  size_t NumArcs(StateId s) {
203    if (!HasArcs(s))
204      Expand(s);
205    return CacheImpl<A>::NumArcs(s);
206  }
207
208  size_t NumInputEpsilons(StateId s) {
209    if (!HasArcs(s))
210      Expand(s);
211    return CacheImpl<A>::NumInputEpsilons(s);
212  }
213
214  size_t NumOutputEpsilons(StateId s) {
215    if (!HasArcs(s))
216      Expand(s);
217    return CacheImpl<A>::NumOutputEpsilons(s);
218  }
219
220  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
221    if (!HasArcs(s))
222      Expand(s);
223    CacheImpl<A>::InitArcIterator(s, data);
224  }
225
226
227  // Find state corresponding to an element. Create new state
228  // if element not found.
229  StateId FindState(const Element &e) {
230    if (final_only_ && e.weight == Weight::One()) {
231      while (unfactored_.size() <= (unsigned int)e.state)
232        unfactored_.push_back(kNoStateId);
233      if (unfactored_[e.state] == kNoStateId) {
234        unfactored_[e.state] = elements_.size();
235        elements_.push_back(e);
236      }
237      return unfactored_[e.state];
238    } else {
239      typename ElementMap::iterator eit = element_map_.find(e);
240      if (eit != element_map_.end()) {
241        return (*eit).second;
242      } else {
243        StateId s = elements_.size();
244        elements_.push_back(e);
245        element_map_.insert(pair<const Element, StateId>(e, s));
246        return s;
247      }
248    }
249  }
250
251  // Computes the outgoing transitions from a state, creating new destination
252  // states as needed.
253  void Expand(StateId s) {
254    Element e = elements_[s];
255    if (e.state != kNoStateId) {
256      for (ArcIterator< Fst<A> > ait(*fst_, e.state);
257           !ait.Done();
258           ait.Next()) {
259        const A &arc = ait.Value();
260        Weight w = Times(e.weight, arc.weight);
261        FactorIterator fit(w);
262        if (final_only_ || fit.Done()) {
263          StateId d = FindState(Element(arc.nextstate, Weight::One()));
264          this->AddArc(s, Arc(arc.ilabel, arc.olabel, w, d));
265        } else {
266          for (; !fit.Done(); fit.Next()) {
267            const pair<Weight, Weight> &p = fit.Value();
268            StateId d = FindState(Element(arc.nextstate,
269                                          p.second.Quantize(delta_)));
270            this->AddArc(s, Arc(arc.ilabel, arc.olabel, p.first, d));
271          }
272        }
273      }
274    }
275    if ((e.state == kNoStateId) ||
276        (fst_->Final(e.state) != Weight::Zero())) {
277      Weight w = e.state == kNoStateId
278                 ? e.weight
279                 : Times(e.weight, fst_->Final(e.state));
280      for (FactorIterator fit(w);
281           !fit.Done();
282           fit.Next()) {
283        const pair<Weight, Weight> &p = fit.Value();
284        StateId d = FindState(Element(kNoStateId,
285                                      p.second.Quantize(delta_)));
286        this->AddArc(s, Arc(0, 0, p.first, d));
287      }
288    }
289    this->SetArcs(s);
290  }
291
292 private:
293  // Equality function for Elements, assume weights have been quantized.
294  class ElementEqual {
295   public:
296    bool operator()(const Element &x, const Element &y) const {
297      return x.state == y.state && x.weight == y.weight;
298    }
299  };
300
301  // Hash function for Elements to Fst states.
302  class ElementKey {
303   public:
304    size_t operator()(const Element &x) const {
305      return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
306    }
307   private:
308    static const int kPrime = 7853;
309  };
310
311  typedef std::unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
312
313  const Fst<A> *fst_;
314  float delta_;
315  bool final_only_;
316  vector<Element> elements_;  // mapping Fst state to Elements
317  ElementMap element_map_;    // mapping Elements to Fst state
318  // mapping between old/new 'StateId' for states that do not need to
319  // be factored when 'final_only_' is true
320  vector<StateId> unfactored_;
321
322  DISALLOW_EVIL_CONSTRUCTORS(FactorWeightFstImpl);
323};
324
325
326// FactorWeightFst takes as template parameter a FactorIterator as
327// defined above. The result of weight factoring is a transducer
328// equivalent to the input whose path weights have been factored
329// according to the FactorIterator. States and transitions will be
330// added as necessary. The algorithm is a generalization to arbitrary
331// weights of the second step of the input epsilon-normalization
332// algorithm due to Mohri, "Generic epsilon-removal and input
333// epsilon-normalization algorithms for weighted transducers",
334// International Journal of Computer Science 13(1): 129-143 (2002).
335template <class A, class F>
336class FactorWeightFst : public Fst<A> {
337 public:
338  friend class ArcIterator< FactorWeightFst<A, F> >;
339  friend class CacheStateIterator< FactorWeightFst<A, F> >;
340  friend class CacheArcIterator< FactorWeightFst<A, F> >;
341
342  typedef A Arc;
343  typedef typename A::Weight Weight;
344  typedef typename A::StateId StateId;
345  typedef CacheState<A> State;
346
347  FactorWeightFst(const Fst<A> &fst)
348      : impl_(new FactorWeightFstImpl<A, F>(fst, FactorWeightOptions())) {}
349
350  FactorWeightFst(const Fst<A> &fst,  const FactorWeightOptions &opts)
351      : impl_(new FactorWeightFstImpl<A, F>(fst, opts)) {}
352  FactorWeightFst(const FactorWeightFst<A, F> &fst) : Fst<A>(fst), impl_(fst.impl_) {
353    impl_->IncrRefCount();
354  }
355
356  virtual ~FactorWeightFst() { if (!impl_->DecrRefCount()) delete impl_;  }
357
358  virtual StateId Start() const { return impl_->Start(); }
359
360  virtual Weight Final(StateId s) const { return impl_->Final(s); }
361
362  virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
363
364  virtual size_t NumInputEpsilons(StateId s) const {
365    return impl_->NumInputEpsilons(s);
366  }
367
368  virtual size_t NumOutputEpsilons(StateId s) const {
369    return impl_->NumOutputEpsilons(s);
370  }
371
372  virtual uint64 Properties(uint64 mask, bool test) const {
373    if (test) {
374      uint64 known, test = TestProperties(*this, mask, &known);
375      impl_->SetProperties(test, known);
376      return test & mask;
377    } else {
378      return impl_->Properties(mask);
379    }
380  }
381
382  virtual const string& Type() const { return impl_->Type(); }
383
384  virtual FactorWeightFst<A, F> *Copy() const {
385    return new FactorWeightFst<A, F>(*this);
386  }
387
388  virtual const SymbolTable* InputSymbols() const {
389    return impl_->InputSymbols();
390  }
391
392  virtual const SymbolTable* OutputSymbols() const {
393    return impl_->OutputSymbols();
394  }
395
396  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
397
398  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
399    impl_->InitArcIterator(s, data);
400  }
401
402 private:
403  FactorWeightFstImpl<A, F> *Impl() { return impl_; }
404
405  FactorWeightFstImpl<A, F> *impl_;
406
407  void operator=(const FactorWeightFst<A, F> &fst);  // Disallow
408};
409
410
411// Specialization for FactorWeightFst.
412template<class A, class F>
413class StateIterator< FactorWeightFst<A, F> >
414    : public CacheStateIterator< FactorWeightFst<A, F> > {
415 public:
416  explicit StateIterator(const FactorWeightFst<A, F> &fst)
417      : CacheStateIterator< FactorWeightFst<A, F> >(fst) {}
418};
419
420
421// Specialization for FactorWeightFst.
422template <class A, class F>
423class ArcIterator< FactorWeightFst<A, F> >
424    : public CacheArcIterator< FactorWeightFst<A, F> > {
425 public:
426  typedef typename A::StateId StateId;
427
428  ArcIterator(const FactorWeightFst<A, F> &fst, StateId s)
429      : CacheArcIterator< FactorWeightFst<A, F> >(fst, s) {
430    if (!fst.impl_->HasArcs(s))
431      fst.impl_->Expand(s);
432  }
433
434 private:
435  DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
436};
437
438template <class A, class F> inline
439void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const
440{
441  data->base = new StateIterator< FactorWeightFst<A, F> >(*this);
442}
443
444
445}  // namespace fst
446
447#endif // FST_LIB_FACTOR_WEIGHT_H__
448