1// factor-weight.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: allauzen@google.com (Cyril Allauzen)
17//
18// \file
19// Classes to factor weights in an FST.
20
21#ifndef FST_LIB_FACTOR_WEIGHT_H__
22#define FST_LIB_FACTOR_WEIGHT_H__
23
24#include <algorithm>
25#include <tr1/unordered_map>
26using std::tr1::unordered_map;
27using std::tr1::unordered_multimap;
28#include <string>
29#include <utility>
30using std::pair; using std::make_pair;
31#include <vector>
32using std::vector;
33
34#include <fst/cache.h>
35#include <fst/test-properties.h>
36
37
38namespace fst {
39
40const uint32 kFactorFinalWeights = 0x00000001;
41const uint32 kFactorArcWeights   = 0x00000002;
42
43template <class Arc>
44struct FactorWeightOptions : CacheOptions {
45  typedef typename Arc::Label Label;
46  float delta;
47  uint32 mode;         // factor arc weights and/or final weights
48  Label final_ilabel;  // input label of arc created when factoring final w's
49  Label final_olabel;  // output label of arc created when factoring final w's
50
51  FactorWeightOptions(const CacheOptions &opts, float d,
52                      uint32 m = kFactorArcWeights | kFactorFinalWeights,
53                      Label il = 0, Label ol = 0)
54      : CacheOptions(opts), delta(d), mode(m), final_ilabel(il),
55        final_olabel(ol) {}
56
57  explicit FactorWeightOptions(
58      float d, uint32 m = kFactorArcWeights | kFactorFinalWeights,
59      Label il = 0, Label ol = 0)
60      : delta(d), mode(m), final_ilabel(il), final_olabel(ol) {}
61
62  FactorWeightOptions(uint32 m = kFactorArcWeights | kFactorFinalWeights,
63                      Label il = 0, Label ol = 0)
64      : delta(kDelta), mode(m), final_ilabel(il), final_olabel(ol) {}
65};
66
67
68// A factor iterator takes as argument a weight w and returns a
69// sequence of pairs of weights (xi,yi) such that the sum of the
70// products xi times yi is equal to w. If w is fully factored,
71// the iterator should return nothing.
72//
73// template <class W>
74// class FactorIterator {
75//  public:
76//   FactorIterator(W w);
77//   bool Done() const;
78//   void Next();
79//   pair<W, W> Value() const;
80//   void Reset();
81// }
82
83
84// Factor trivially.
85template <class W>
86class IdentityFactor {
87 public:
88  IdentityFactor(const W &w) {}
89  bool Done() const { return true; }
90  void Next() {}
91  pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
92  void Reset() {}
93};
94
95
96// Factor a StringWeight w as 'ab' where 'a' is a label.
97template <typename L, StringType S = STRING_LEFT>
98class StringFactor {
99 public:
100  StringFactor(const StringWeight<L, S> &w)
101      : weight_(w), done_(w.Size() <= 1) {}
102
103  bool Done() const { return done_; }
104
105  void Next() { done_ = true; }
106
107  pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
108    StringWeightIterator<L, S> iter(weight_);
109    StringWeight<L, S> w1(iter.Value());
110    StringWeight<L, S> w2;
111    for (iter.Next(); !iter.Done(); iter.Next())
112      w2.PushBack(iter.Value());
113    return make_pair(w1, w2);
114  }
115
116  void Reset() { done_ = weight_.Size() <= 1; }
117
118 private:
119  StringWeight<L, S> weight_;
120  bool done_;
121};
122
123
124// Factor a GallicWeight using StringFactor.
125template <class L, class W, StringType S = STRING_LEFT>
126class GallicFactor {
127 public:
128  GallicFactor(const GallicWeight<L, W, S> &w)
129      : weight_(w), done_(w.Value1().Size() <= 1) {}
130
131  bool Done() const { return done_; }
132
133  void Next() { done_ = true; }
134
135  pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
136    StringFactor<L, S> iter(weight_.Value1());
137    GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
138    GallicWeight<L, W, S> w2(iter.Value().second, W::One());
139    return make_pair(w1, w2);
140  }
141
142  void Reset() { done_ = weight_.Value1().Size() <= 1; }
143
144 private:
145  GallicWeight<L, W, S> weight_;
146  bool done_;
147};
148
149
150// Implementation class for FactorWeight
151template <class A, class F>
152class FactorWeightFstImpl
153    : public CacheImpl<A> {
154 public:
155  using FstImpl<A>::SetType;
156  using FstImpl<A>::SetProperties;
157  using FstImpl<A>::SetInputSymbols;
158  using FstImpl<A>::SetOutputSymbols;
159
160  using CacheBaseImpl< CacheState<A> >::PushArc;
161  using CacheBaseImpl< CacheState<A> >::HasStart;
162  using CacheBaseImpl< CacheState<A> >::HasFinal;
163  using CacheBaseImpl< CacheState<A> >::HasArcs;
164  using CacheBaseImpl< CacheState<A> >::SetArcs;
165  using CacheBaseImpl< CacheState<A> >::SetFinal;
166  using CacheBaseImpl< CacheState<A> >::SetStart;
167
168  typedef A Arc;
169  typedef typename A::Label Label;
170  typedef typename A::Weight Weight;
171  typedef typename A::StateId StateId;
172  typedef F FactorIterator;
173
174  struct Element {
175    Element() {}
176
177    Element(StateId s, Weight w) : state(s), weight(w) {}
178
179    StateId state;     // Input state Id
180    Weight weight;     // Residual weight
181  };
182
183  FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions<A> &opts)
184      : CacheImpl<A>(opts),
185        fst_(fst.Copy()),
186        delta_(opts.delta),
187        mode_(opts.mode),
188        final_ilabel_(opts.final_ilabel),
189        final_olabel_(opts.final_olabel) {
190    SetType("factor_weight");
191    uint64 props = fst.Properties(kFstProperties, false);
192    SetProperties(FactorWeightProperties(props), kCopyProperties);
193
194    SetInputSymbols(fst.InputSymbols());
195    SetOutputSymbols(fst.OutputSymbols());
196
197    if (mode_ == 0)
198      LOG(WARNING) << "FactorWeightFst: factor mode is set to 0: "
199                   << "factoring neither arc weights nor final weights.";
200  }
201
202  FactorWeightFstImpl(const FactorWeightFstImpl<A, F> &impl)
203      : CacheImpl<A>(impl),
204        fst_(impl.fst_->Copy(true)),
205        delta_(impl.delta_),
206        mode_(impl.mode_),
207        final_ilabel_(impl.final_ilabel_),
208        final_olabel_(impl.final_olabel_) {
209    SetType("factor_weight");
210    SetProperties(impl.Properties(), kCopyProperties);
211    SetInputSymbols(impl.InputSymbols());
212    SetOutputSymbols(impl.OutputSymbols());
213  }
214
215  ~FactorWeightFstImpl() {
216    delete fst_;
217  }
218
219  StateId Start() {
220    if (!HasStart()) {
221      StateId s = fst_->Start();
222      if (s == kNoStateId)
223        return kNoStateId;
224      StateId start = FindState(Element(fst_->Start(), Weight::One()));
225      SetStart(start);
226    }
227    return CacheImpl<A>::Start();
228  }
229
230  Weight Final(StateId s) {
231    if (!HasFinal(s)) {
232      const Element &e = elements_[s];
233      // TODO: fix so cast is unnecessary
234      Weight w = e.state == kNoStateId
235                 ? e.weight
236                 : (Weight) Times(e.weight, fst_->Final(e.state));
237      FactorIterator f(w);
238      if (!(mode_ & kFactorFinalWeights) || f.Done())
239        SetFinal(s, w);
240      else
241        SetFinal(s, Weight::Zero());
242    }
243    return CacheImpl<A>::Final(s);
244  }
245
246  size_t NumArcs(StateId s) {
247    if (!HasArcs(s))
248      Expand(s);
249    return CacheImpl<A>::NumArcs(s);
250  }
251
252  size_t NumInputEpsilons(StateId s) {
253    if (!HasArcs(s))
254      Expand(s);
255    return CacheImpl<A>::NumInputEpsilons(s);
256  }
257
258  size_t NumOutputEpsilons(StateId s) {
259    if (!HasArcs(s))
260      Expand(s);
261    return CacheImpl<A>::NumOutputEpsilons(s);
262  }
263
264  uint64 Properties() const { return Properties(kFstProperties); }
265
266  // Set error if found; return FST impl properties.
267  uint64 Properties(uint64 mask) const {
268    if ((mask & kError) && fst_->Properties(kError, false))
269      SetProperties(kError, kError);
270    return FstImpl<Arc>::Properties(mask);
271  }
272
273  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
274    if (!HasArcs(s))
275      Expand(s);
276    CacheImpl<A>::InitArcIterator(s, data);
277  }
278
279
280  // Find state corresponding to an element. Create new state
281  // if element not found.
282  StateId FindState(const Element &e) {
283    if (!(mode_ & kFactorArcWeights) && e.weight == Weight::One()) {
284      while (unfactored_.size() <= e.state)
285        unfactored_.push_back(kNoStateId);
286      if (unfactored_[e.state] == kNoStateId) {
287        unfactored_[e.state] = elements_.size();
288        elements_.push_back(e);
289      }
290      return unfactored_[e.state];
291    } else {
292      typename ElementMap::iterator eit = element_map_.find(e);
293      if (eit != element_map_.end()) {
294        return (*eit).second;
295      } else {
296        StateId s = elements_.size();
297        elements_.push_back(e);
298        element_map_.insert(pair<const Element, StateId>(e, s));
299        return s;
300      }
301    }
302  }
303
304  // Computes the outgoing transitions from a state, creating new destination
305  // states as needed.
306  void Expand(StateId s) {
307    Element e = elements_[s];
308    if (e.state != kNoStateId) {
309      for (ArcIterator< Fst<A> > ait(*fst_, e.state);
310           !ait.Done();
311           ait.Next()) {
312        const A &arc = ait.Value();
313        Weight w = Times(e.weight, arc.weight);
314        FactorIterator fit(w);
315        if (!(mode_ & kFactorArcWeights) || fit.Done()) {
316          StateId d = FindState(Element(arc.nextstate, Weight::One()));
317          PushArc(s, Arc(arc.ilabel, arc.olabel, w, d));
318        } else {
319          for (; !fit.Done(); fit.Next()) {
320            const pair<Weight, Weight> &p = fit.Value();
321            StateId d = FindState(Element(arc.nextstate,
322                                          p.second.Quantize(delta_)));
323            PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, d));
324          }
325        }
326      }
327    }
328
329    if ((mode_ & kFactorFinalWeights) &&
330        ((e.state == kNoStateId) ||
331         (fst_->Final(e.state) != Weight::Zero()))) {
332      Weight w = e.state == kNoStateId
333                 ? e.weight
334                 : Times(e.weight, fst_->Final(e.state));
335      for (FactorIterator fit(w);
336           !fit.Done();
337           fit.Next()) {
338        const pair<Weight, Weight> &p = fit.Value();
339        StateId d = FindState(Element(kNoStateId,
340                                      p.second.Quantize(delta_)));
341        PushArc(s, Arc(final_ilabel_, final_olabel_, p.first, d));
342      }
343    }
344    SetArcs(s);
345  }
346
347 private:
348  static const size_t kPrime = 7853;
349
350  // Equality function for Elements, assume weights have been quantized.
351  class ElementEqual {
352   public:
353    bool operator()(const Element &x, const Element &y) const {
354      return x.state == y.state && x.weight == y.weight;
355    }
356  };
357
358  // Hash function for Elements to Fst states.
359  class ElementKey {
360   public:
361    size_t operator()(const Element &x) const {
362      return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
363    }
364   private:
365  };
366
367  typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
368
369  const Fst<A> *fst_;
370  float delta_;
371  uint32 mode_;               // factoring arc and/or final weights
372  Label final_ilabel_;        // ilabel of arc created when factoring final w's
373  Label final_olabel_;        // olabel of arc created when factoring final w's
374  vector<Element> elements_;  // mapping Fst state to Elements
375  ElementMap element_map_;    // mapping Elements to Fst state
376  // mapping between old/new 'StateId' for states that do not need to
377  // be factored when 'mode_' is '0' or 'kFactorFinalWeights'
378  vector<StateId> unfactored_;
379
380  void operator=(const FactorWeightFstImpl<A, F> &);  // disallow
381};
382
383template <class A, class F> const size_t FactorWeightFstImpl<A, F>::kPrime;
384
385
386// FactorWeightFst takes as template parameter a FactorIterator as
387// defined above. The result of weight factoring is a transducer
388// equivalent to the input whose path weights have been factored
389// according to the FactorIterator. States and transitions will be
390// added as necessary. The algorithm is a generalization to arbitrary
391// weights of the second step of the input epsilon-normalization
392// algorithm due to Mohri, "Generic epsilon-removal and input
393// epsilon-normalization algorithms for weighted transducers",
394// International Journal of Computer Science 13(1): 129-143 (2002).
395//
396// This class attaches interface to implementation and handles
397// reference counting, delegating most methods to ImplToFst.
398template <class A, class F>
399class FactorWeightFst : public ImplToFst< FactorWeightFstImpl<A, F> > {
400 public:
401  friend class ArcIterator< FactorWeightFst<A, F> >;
402  friend class StateIterator< FactorWeightFst<A, F> >;
403
404  typedef A Arc;
405  typedef typename A::Weight Weight;
406  typedef typename A::StateId StateId;
407  typedef CacheState<A> State;
408  typedef FactorWeightFstImpl<A, F> Impl;
409
410  FactorWeightFst(const Fst<A> &fst)
411      : ImplToFst<Impl>(new Impl(fst, FactorWeightOptions<A>())) {}
412
413  FactorWeightFst(const Fst<A> &fst,  const FactorWeightOptions<A> &opts)
414      : ImplToFst<Impl>(new Impl(fst, opts)) {}
415
416  // See Fst<>::Copy() for doc.
417  FactorWeightFst(const FactorWeightFst<A, F> &fst, bool copy)
418      : ImplToFst<Impl>(fst, copy) {}
419
420  // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc.
421  virtual FactorWeightFst<A, F> *Copy(bool copy = false) const {
422    return new FactorWeightFst<A, F>(*this, copy);
423  }
424
425  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
426
427  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
428    GetImpl()->InitArcIterator(s, data);
429  }
430
431 private:
432  // Makes visible to friends.
433  Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
434
435  void operator=(const FactorWeightFst<A, F> &fst);  // Disallow
436};
437
438
439// Specialization for FactorWeightFst.
440template<class A, class F>
441class StateIterator< FactorWeightFst<A, F> >
442    : public CacheStateIterator< FactorWeightFst<A, F> > {
443 public:
444  explicit StateIterator(const FactorWeightFst<A, F> &fst)
445      : CacheStateIterator< FactorWeightFst<A, F> >(fst, fst.GetImpl()) {}
446};
447
448
449// Specialization for FactorWeightFst.
450template <class A, class F>
451class ArcIterator< FactorWeightFst<A, F> >
452    : public CacheArcIterator< FactorWeightFst<A, F> > {
453 public:
454  typedef typename A::StateId StateId;
455
456  ArcIterator(const FactorWeightFst<A, F> &fst, StateId s)
457      : CacheArcIterator< FactorWeightFst<A, F> >(fst.GetImpl(), s) {
458    if (!fst.GetImpl()->HasArcs(s))
459      fst.GetImpl()->Expand(s);
460  }
461
462 private:
463  DISALLOW_COPY_AND_ASSIGN(ArcIterator);
464};
465
466template <class A, class F> inline
467void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const
468{
469  data->base = new StateIterator< FactorWeightFst<A, F> >(*this);
470}
471
472
473}  // namespace fst
474
475#endif // FST_LIB_FACTOR_WEIGHT_H__
476