factor-weight.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
1// factor-weight.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: allauzen@google.com (Cyril Allauzen)
17//
18// \file
19// Classes to factor weights in an FST.
20
21#ifndef FST_LIB_FACTOR_WEIGHT_H__
22#define FST_LIB_FACTOR_WEIGHT_H__
23
24#include <algorithm>
25#include <unordered_map>
26using std::tr1::unordered_map;
27using std::tr1::unordered_multimap;
28#include <fst/slist.h>
29#include <string>
30#include <utility>
31using std::pair; using std::make_pair;
32#include <vector>
33using std::vector;
34
35#include <fst/cache.h>
36#include <fst/test-properties.h>
37
38
39namespace fst {
40
41const uint32 kFactorFinalWeights = 0x00000001;
42const uint32 kFactorArcWeights   = 0x00000002;
43
44template <class Arc>
45struct FactorWeightOptions : CacheOptions {
46  typedef typename Arc::Label Label;
47  float delta;
48  uint32 mode;         // factor arc weights and/or final weights
49  Label final_ilabel;  // input label of arc created when factoring final w's
50  Label final_olabel;  // output label of arc created when factoring final w's
51
52  FactorWeightOptions(const CacheOptions &opts, float d,
53                      uint32 m = kFactorArcWeights | kFactorFinalWeights,
54                      Label il = 0, Label ol = 0)
55      : CacheOptions(opts), delta(d), mode(m), final_ilabel(il),
56        final_olabel(ol) {}
57
58  explicit FactorWeightOptions(
59      float d, uint32 m = kFactorArcWeights | kFactorFinalWeights,
60      Label il = 0, Label ol = 0)
61      : delta(d), mode(m), final_ilabel(il), final_olabel(ol) {}
62
63  FactorWeightOptions(uint32 m = kFactorArcWeights | kFactorFinalWeights,
64                      Label il = 0, Label ol = 0)
65      : delta(kDelta), mode(m), final_ilabel(il), final_olabel(ol) {}
66};
67
68
69// A factor iterator takes as argument a weight w and returns a
70// sequence of pairs of weights (xi,yi) such that the sum of the
71// products xi times yi is equal to w. If w is fully factored,
72// the iterator should return nothing.
73//
74// template <class W>
75// class FactorIterator {
76//  public:
77//   FactorIterator(W w);
78//   bool Done() const;
79//   void Next();
80//   pair<W, W> Value() const;
81//   void Reset();
82// }
83
84
85// Factor trivially.
86template <class W>
87class IdentityFactor {
88 public:
89  IdentityFactor(const W &w) {}
90  bool Done() const { return true; }
91  void Next() {}
92  pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused
93  void Reset() {}
94};
95
96
97// Factor a StringWeight w as 'ab' where 'a' is a label.
98template <typename L, StringType S = STRING_LEFT>
99class StringFactor {
100 public:
101  StringFactor(const StringWeight<L, S> &w)
102      : weight_(w), done_(w.Size() <= 1) {}
103
104  bool Done() const { return done_; }
105
106  void Next() { done_ = true; }
107
108  pair< StringWeight<L, S>, StringWeight<L, S> > Value() const {
109    StringWeightIterator<L, S> iter(weight_);
110    StringWeight<L, S> w1(iter.Value());
111    StringWeight<L, S> w2;
112    for (iter.Next(); !iter.Done(); iter.Next())
113      w2.PushBack(iter.Value());
114    return make_pair(w1, w2);
115  }
116
117  void Reset() { done_ = weight_.Size() <= 1; }
118
119 private:
120  StringWeight<L, S> weight_;
121  bool done_;
122};
123
124
125// Factor a GallicWeight using StringFactor.
126template <class L, class W, StringType S = STRING_LEFT>
127class GallicFactor {
128 public:
129  GallicFactor(const GallicWeight<L, W, S> &w)
130      : weight_(w), done_(w.Value1().Size() <= 1) {}
131
132  bool Done() const { return done_; }
133
134  void Next() { done_ = true; }
135
136  pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const {
137    StringFactor<L, S> iter(weight_.Value1());
138    GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2());
139    GallicWeight<L, W, S> w2(iter.Value().second, W::One());
140    return make_pair(w1, w2);
141  }
142
143  void Reset() { done_ = weight_.Value1().Size() <= 1; }
144
145 private:
146  GallicWeight<L, W, S> weight_;
147  bool done_;
148};
149
150
151// Implementation class for FactorWeight
152template <class A, class F>
153class FactorWeightFstImpl
154    : public CacheImpl<A> {
155 public:
156  using FstImpl<A>::SetType;
157  using FstImpl<A>::SetProperties;
158  using FstImpl<A>::SetInputSymbols;
159  using FstImpl<A>::SetOutputSymbols;
160
161  using CacheBaseImpl< CacheState<A> >::PushArc;
162  using CacheBaseImpl< CacheState<A> >::HasStart;
163  using CacheBaseImpl< CacheState<A> >::HasFinal;
164  using CacheBaseImpl< CacheState<A> >::HasArcs;
165  using CacheBaseImpl< CacheState<A> >::SetArcs;
166  using CacheBaseImpl< CacheState<A> >::SetFinal;
167  using CacheBaseImpl< CacheState<A> >::SetStart;
168
169  typedef A Arc;
170  typedef typename A::Label Label;
171  typedef typename A::Weight Weight;
172  typedef typename A::StateId StateId;
173  typedef F FactorIterator;
174
175  struct Element {
176    Element() {}
177
178    Element(StateId s, Weight w) : state(s), weight(w) {}
179
180    StateId state;     // Input state Id
181    Weight weight;     // Residual weight
182  };
183
184  FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions<A> &opts)
185      : CacheImpl<A>(opts),
186        fst_(fst.Copy()),
187        delta_(opts.delta),
188        mode_(opts.mode),
189        final_ilabel_(opts.final_ilabel),
190        final_olabel_(opts.final_olabel) {
191    SetType("factor_weight");
192    uint64 props = fst.Properties(kFstProperties, false);
193    SetProperties(FactorWeightProperties(props), kCopyProperties);
194
195    SetInputSymbols(fst.InputSymbols());
196    SetOutputSymbols(fst.OutputSymbols());
197
198    if (mode_ == 0)
199      LOG(WARNING) << "FactorWeightFst: factor mode is set to 0: "
200                   << "factoring neither arc weights nor final weights.";
201  }
202
203  FactorWeightFstImpl(const FactorWeightFstImpl<A, F> &impl)
204      : CacheImpl<A>(impl),
205        fst_(impl.fst_->Copy(true)),
206        delta_(impl.delta_),
207        mode_(impl.mode_),
208        final_ilabel_(impl.final_ilabel_),
209        final_olabel_(impl.final_olabel_) {
210    SetType("factor_weight");
211    SetProperties(impl.Properties(), kCopyProperties);
212    SetInputSymbols(impl.InputSymbols());
213    SetOutputSymbols(impl.OutputSymbols());
214  }
215
216  ~FactorWeightFstImpl() {
217    delete fst_;
218  }
219
220  StateId Start() {
221    if (!HasStart()) {
222      StateId s = fst_->Start();
223      if (s == kNoStateId)
224        return kNoStateId;
225      StateId start = FindState(Element(fst_->Start(), Weight::One()));
226      SetStart(start);
227    }
228    return CacheImpl<A>::Start();
229  }
230
231  Weight Final(StateId s) {
232    if (!HasFinal(s)) {
233      const Element &e = elements_[s];
234      // TODO: fix so cast is unnecessary
235      Weight w = e.state == kNoStateId
236                 ? e.weight
237                 : (Weight) Times(e.weight, fst_->Final(e.state));
238      FactorIterator f(w);
239      if (!(mode_ & kFactorFinalWeights) || f.Done())
240        SetFinal(s, w);
241      else
242        SetFinal(s, Weight::Zero());
243    }
244    return CacheImpl<A>::Final(s);
245  }
246
247  size_t NumArcs(StateId s) {
248    if (!HasArcs(s))
249      Expand(s);
250    return CacheImpl<A>::NumArcs(s);
251  }
252
253  size_t NumInputEpsilons(StateId s) {
254    if (!HasArcs(s))
255      Expand(s);
256    return CacheImpl<A>::NumInputEpsilons(s);
257  }
258
259  size_t NumOutputEpsilons(StateId s) {
260    if (!HasArcs(s))
261      Expand(s);
262    return CacheImpl<A>::NumOutputEpsilons(s);
263  }
264
265  uint64 Properties() const { return Properties(kFstProperties); }
266
267  // Set error if found; return FST impl properties.
268  uint64 Properties(uint64 mask) const {
269    if ((mask & kError) && fst_->Properties(kError, false))
270      SetProperties(kError, kError);
271    return FstImpl<Arc>::Properties(mask);
272  }
273
274  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
275    if (!HasArcs(s))
276      Expand(s);
277    CacheImpl<A>::InitArcIterator(s, data);
278  }
279
280
281  // Find state corresponding to an element. Create new state
282  // if element not found.
283  StateId FindState(const Element &e) {
284    if (!(mode_ & kFactorArcWeights) && e.weight == Weight::One()) {
285      while (unfactored_.size() <= e.state)
286        unfactored_.push_back(kNoStateId);
287      if (unfactored_[e.state] == kNoStateId) {
288        unfactored_[e.state] = elements_.size();
289        elements_.push_back(e);
290      }
291      return unfactored_[e.state];
292    } else {
293      typename ElementMap::iterator eit = element_map_.find(e);
294      if (eit != element_map_.end()) {
295        return (*eit).second;
296      } else {
297        StateId s = elements_.size();
298        elements_.push_back(e);
299        element_map_.insert(pair<const Element, StateId>(e, s));
300        return s;
301      }
302    }
303  }
304
305  // Computes the outgoing transitions from a state, creating new destination
306  // states as needed.
307  void Expand(StateId s) {
308    Element e = elements_[s];
309    if (e.state != kNoStateId) {
310      for (ArcIterator< Fst<A> > ait(*fst_, e.state);
311           !ait.Done();
312           ait.Next()) {
313        const A &arc = ait.Value();
314        Weight w = Times(e.weight, arc.weight);
315        FactorIterator fit(w);
316        if (!(mode_ & kFactorArcWeights) || fit.Done()) {
317          StateId d = FindState(Element(arc.nextstate, Weight::One()));
318          PushArc(s, Arc(arc.ilabel, arc.olabel, w, d));
319        } else {
320          for (; !fit.Done(); fit.Next()) {
321            const pair<Weight, Weight> &p = fit.Value();
322            StateId d = FindState(Element(arc.nextstate,
323                                          p.second.Quantize(delta_)));
324            PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, d));
325          }
326        }
327      }
328    }
329
330    if ((mode_ & kFactorFinalWeights) &&
331        ((e.state == kNoStateId) ||
332         (fst_->Final(e.state) != Weight::Zero()))) {
333      Weight w = e.state == kNoStateId
334                 ? e.weight
335                 : Times(e.weight, fst_->Final(e.state));
336      for (FactorIterator fit(w);
337           !fit.Done();
338           fit.Next()) {
339        const pair<Weight, Weight> &p = fit.Value();
340        StateId d = FindState(Element(kNoStateId,
341                                      p.second.Quantize(delta_)));
342        PushArc(s, Arc(final_ilabel_, final_olabel_, p.first, d));
343      }
344    }
345    SetArcs(s);
346  }
347
348 private:
349  static const size_t kPrime = 7853;
350
351  // Equality function for Elements, assume weights have been quantized.
352  class ElementEqual {
353   public:
354    bool operator()(const Element &x, const Element &y) const {
355      return x.state == y.state && x.weight == y.weight;
356    }
357  };
358
359  // Hash function for Elements to Fst states.
360  class ElementKey {
361   public:
362    size_t operator()(const Element &x) const {
363      return static_cast<size_t>(x.state * kPrime + x.weight.Hash());
364    }
365   private:
366  };
367
368  typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
369
370  const Fst<A> *fst_;
371  float delta_;
372  uint32 mode_;               // factoring arc and/or final weights
373  Label final_ilabel_;        // ilabel of arc created when factoring final w's
374  Label final_olabel_;        // olabel of arc created when factoring final w's
375  vector<Element> elements_;  // mapping Fst state to Elements
376  ElementMap element_map_;    // mapping Elements to Fst state
377  // mapping between old/new 'StateId' for states that do not need to
378  // be factored when 'mode_' is '0' or 'kFactorFinalWeights'
379  vector<StateId> unfactored_;
380
381  void operator=(const FactorWeightFstImpl<A, F> &);  // disallow
382};
383
384template <class A, class F> const size_t FactorWeightFstImpl<A, F>::kPrime;
385
386
387// FactorWeightFst takes as template parameter a FactorIterator as
388// defined above. The result of weight factoring is a transducer
389// equivalent to the input whose path weights have been factored
390// according to the FactorIterator. States and transitions will be
391// added as necessary. The algorithm is a generalization to arbitrary
392// weights of the second step of the input epsilon-normalization
393// algorithm due to Mohri, "Generic epsilon-removal and input
394// epsilon-normalization algorithms for weighted transducers",
395// International Journal of Computer Science 13(1): 129-143 (2002).
396//
397// This class attaches interface to implementation and handles
398// reference counting, delegating most methods to ImplToFst.
399template <class A, class F>
400class FactorWeightFst : public ImplToFst< FactorWeightFstImpl<A, F> > {
401 public:
402  friend class ArcIterator< FactorWeightFst<A, F> >;
403  friend class StateIterator< FactorWeightFst<A, F> >;
404
405  typedef A Arc;
406  typedef typename A::Weight Weight;
407  typedef typename A::StateId StateId;
408  typedef CacheState<A> State;
409  typedef FactorWeightFstImpl<A, F> Impl;
410
411  FactorWeightFst(const Fst<A> &fst)
412      : ImplToFst<Impl>(new Impl(fst, FactorWeightOptions<A>())) {}
413
414  FactorWeightFst(const Fst<A> &fst,  const FactorWeightOptions<A> &opts)
415      : ImplToFst<Impl>(new Impl(fst, opts)) {}
416
417  // See Fst<>::Copy() for doc.
418  FactorWeightFst(const FactorWeightFst<A, F> &fst, bool copy)
419      : ImplToFst<Impl>(fst, copy) {}
420
421  // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc.
422  virtual FactorWeightFst<A, F> *Copy(bool copy = false) const {
423    return new FactorWeightFst<A, F>(*this, copy);
424  }
425
426  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
427
428  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
429    GetImpl()->InitArcIterator(s, data);
430  }
431
432 private:
433  // Makes visible to friends.
434  Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
435
436  void operator=(const FactorWeightFst<A, F> &fst);  // Disallow
437};
438
439
440// Specialization for FactorWeightFst.
441template<class A, class F>
442class StateIterator< FactorWeightFst<A, F> >
443    : public CacheStateIterator< FactorWeightFst<A, F> > {
444 public:
445  explicit StateIterator(const FactorWeightFst<A, F> &fst)
446      : CacheStateIterator< FactorWeightFst<A, F> >(fst, fst.GetImpl()) {}
447};
448
449
450// Specialization for FactorWeightFst.
451template <class A, class F>
452class ArcIterator< FactorWeightFst<A, F> >
453    : public CacheArcIterator< FactorWeightFst<A, F> > {
454 public:
455  typedef typename A::StateId StateId;
456
457  ArcIterator(const FactorWeightFst<A, F> &fst, StateId s)
458      : CacheArcIterator< FactorWeightFst<A, F> >(fst.GetImpl(), s) {
459    if (!fst.GetImpl()->HasArcs(s))
460      fst.GetImpl()->Expand(s);
461  }
462
463 private:
464  DISALLOW_COPY_AND_ASSIGN(ArcIterator);
465};
466
467template <class A, class F> inline
468void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const
469{
470  data->base = new StateIterator< FactorWeightFst<A, F> >(*this);
471}
472
473
474}  // namespace fst
475
476#endif // FST_LIB_FACTOR_WEIGHT_H__
477