rmepsilon.h revision 73018b4a1d088cdda0e7bd059fddf1f308a8195a
1// rmepsilon.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16//
17// \file
18// Functions and classes that implemement epsilon-removal.
19
20#ifndef FST_LIB_RMEPSILON_H__
21#define FST_LIB_RMEPSILON_H__
22
23#include <unordered_map>
24#include <forward_list>
25
26#include "fst/lib/arcfilter.h"
27#include "fst/lib/cache.h"
28#include "fst/lib/connect.h"
29#include "fst/lib/factor-weight.h"
30#include "fst/lib/invert.h"
31#include "fst/lib/map.h"
32#include "fst/lib/queue.h"
33#include "fst/lib/shortest-distance.h"
34#include "fst/lib/topsort.h"
35
36namespace fst {
37
38template <class Arc, class Queue>
39struct RmEpsilonOptions
40    : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > {
41  typedef typename Arc::StateId StateId;
42
43  bool connect;  // Connect output
44
45  RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true)
46      : ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> >(
47          q, EpsilonArcFilter<Arc>(), kNoStateId, d), connect(c) {}
48
49};
50
51
52// Computation state of the epsilon-removal algorithm.
53template <class Arc, class Queue>
54class RmEpsilonState {
55 public:
56  typedef typename Arc::Label Label;
57  typedef typename Arc::StateId StateId;
58  typedef typename Arc::Weight Weight;
59
60  RmEpsilonState(const Fst<Arc> &fst,
61                 vector<Weight> *distance,
62                 const RmEpsilonOptions<Arc, Queue> &opts)
63      : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true) {
64  }
65
66  // Compute arcs and final weight for state 's'
67  void Expand(StateId s);
68
69  // Returns arcs of expanded state.
70  vector<Arc> &Arcs() { return arcs_; }
71
72  // Returns final weight of expanded state.
73  const Weight &Final() const { return final_; }
74
75 private:
76  struct Element {
77    Label ilabel;
78    Label olabel;
79    StateId nextstate;
80
81    Element() {}
82
83    Element(Label i, Label o, StateId s)
84        : ilabel(i), olabel(o), nextstate(s) {}
85  };
86
87  class ElementKey {
88   public:
89    size_t operator()(const Element& e) const {
90      return static_cast<size_t>(e.nextstate);
91      return static_cast<size_t>(e.nextstate +
92                                 e.ilabel * kPrime0 +
93                                 e.olabel * kPrime1);
94    }
95
96   private:
97    static const int kPrime0 = 7853;
98    static const int kPrime1 = 7867;
99  };
100
101  class ElementEqual {
102   public:
103    bool operator()(const Element &e1, const Element &e2) const {
104      return (e1.ilabel == e2.ilabel) &&  (e1.olabel == e2.olabel)
105                         && (e1.nextstate == e2.nextstate);
106    }
107  };
108
109 private:
110  typedef std::unordered_map<Element, pair<StateId, ssize_t>,
111                             ElementKey, ElementEqual> ElementMap;
112
113  const Fst<Arc> &fst_;
114  // Distance from state being expanded in epsilon-closure.
115  vector<Weight> *distance_;
116  // Shortest distance algorithm computation state.
117  ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_;
118  // Maps an element 'e' to a pair 'p' corresponding to a position
119  // in the arcs vector of the state being expanded. 'e' corresponds
120  // to the position 'p.second' in the 'arcs_' vector if 'p.first' is
121  // equal to the state being expanded.
122  ElementMap element_map_;
123  EpsilonArcFilter<Arc> eps_filter_;
124  stack<StateId> eps_queue_;      // Queue used to visit the epsilon-closure
125  vector<bool> visited_;          // '[i] = true' if state 'i' has been visited
126  std::forward_list<StateId> visited_states_; // List of visited states
127  vector<Arc> arcs_;              // Arcs of state being expanded
128  Weight final_;                  // Final weight of state being expanded
129
130  void operator=(const RmEpsilonState);  // Disallow
131};
132
133
134template <class Arc, class Queue>
135void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) {
136   sd_state_.ShortestDistance(source);
137   eps_queue_.push(source);
138   final_ = Weight::Zero();
139   arcs_.clear();
140
141   while (!eps_queue_.empty()) {
142     StateId state = eps_queue_.top();
143     eps_queue_.pop();
144
145     while ((StateId)visited_.size() <= state) visited_.push_back(false);
146     visited_[state] = true;
147     visited_states_.push_front(state);
148
149     for (ArcIterator< Fst<Arc> > ait(fst_, state);
150          !ait.Done();
151          ait.Next()) {
152       Arc arc = ait.Value();
153       arc.weight = Times((*distance_)[state], arc.weight);
154
155       if (eps_filter_(arc)) {
156         while ((StateId)visited_.size() <= arc.nextstate)
157           visited_.push_back(false);
158         if (!visited_[arc.nextstate])
159           eps_queue_.push(arc.nextstate);
160       } else {
161          Element element(arc.ilabel, arc.olabel, arc.nextstate);
162          typename ElementMap::iterator it = element_map_.find(element);
163          if (it == element_map_.end()) {
164            element_map_.insert(
165                pair<Element, pair<StateId, ssize_t> >
166                (element, pair<StateId, ssize_t>(source, arcs_.size())));
167            arcs_.push_back(arc);
168          } else {
169            if (((*it).second).first == source) {
170              Weight &w = arcs_[((*it).second).second].weight;
171              w = Plus(w, arc.weight);
172            } else {
173              ((*it).second).first = source;
174              ((*it).second).second = arcs_.size();
175              arcs_.push_back(arc);
176            }
177          }
178        }
179     }
180     final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
181   }
182
183   while (!visited_states_.empty()) {
184     visited_[visited_states_.front()] = false;
185     visited_states_.pop_front();
186   }
187}
188
189
190// Removes epsilon-transitions (when both the input and output label
191// are an epsilon) from a transducer. The result will be an equivalent
192// FST that has no such epsilon transitions.  This version modifies
193// its input. It allows fine control via the options argument; see
194// below for a simpler interface.
195//
196// The vector 'distance' will be used to hold the shortest distances
197// during the epsilon-closure computation. The state queue discipline
198// and convergence delta are taken in the options argument.
199template <class Arc, class Queue>
200void RmEpsilon(MutableFst<Arc> *fst,
201               vector<typename Arc::Weight> *distance,
202               const RmEpsilonOptions<Arc, Queue> &opts) {
203  typedef typename Arc::StateId StateId;
204  typedef typename Arc::Weight Weight;
205  typedef typename Arc::Label Label;
206
207  // States sorted in topological order when (acyclic) or generic
208  // topological order (cyclic).
209  vector<StateId> states;
210
211  if (fst->Properties(kTopSorted, false) & kTopSorted) {
212    for (StateId i = 0; i < (StateId)fst->NumStates(); i++)
213      states.push_back(i);
214  } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
215    vector<StateId> order;
216    bool acyclic;
217    TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
218    DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
219    if (!acyclic)
220      LOG(FATAL) << "RmEpsilon: not acyclic though property bit is set";
221    states.resize(order.size());
222    for (StateId i = 0; i < (StateId)order.size(); i++)
223      states[order[i]] = i;
224  } else {
225     uint64 props;
226     vector<StateId> scc;
227     SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props);
228     DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
229     vector<StateId> first(scc.size(), kNoStateId);
230     vector<StateId> next(scc.size(), kNoStateId);
231     for (StateId i = 0; i < (StateId)scc.size(); i++) {
232       if (first[scc[i]] != kNoStateId)
233         next[i] = first[scc[i]];
234       first[scc[i]] = i;
235     }
236     for (StateId i = 0; i < (StateId)first.size(); i++)
237       for (StateId j = first[i]; j != kNoStateId; j = next[j])
238         states.push_back(j);
239  }
240
241  RmEpsilonState<Arc, Queue>
242    rmeps_state(*fst, distance, opts);
243
244  while (!states.empty()) {
245    StateId state = states.back();
246    states.pop_back();
247    rmeps_state.Expand(state);
248    fst->SetFinal(state, rmeps_state.Final());
249    fst->DeleteArcs(state);
250    vector<Arc> &arcs = rmeps_state.Arcs();
251    while (!arcs.empty()) {
252      fst->AddArc(state, arcs.back());
253      arcs.pop_back();
254    }
255  }
256
257  fst->SetProperties(RmEpsilonProperties(
258                         fst->Properties(kFstProperties, false)),
259                     kFstProperties);
260
261  if (opts.connect)
262    Connect(fst);
263}
264
265
266// Removes epsilon-transitions (when both the input and output label
267// are an epsilon) from a transducer. The result will be an equivalent
268// FST that has no such epsilon transitions. This version modifies its
269// input. It has a simplified interface; see above for a version that
270// allows finer control.
271//
272// Complexity:
273// - Time:
274//   - Unweighted: O(V2 + V E)
275//   - Acyclic: O(V2 + V E)
276//   - Tropical semiring: O(V2 log V + V E)
277//   - General: exponential
278// - Space: O(V E)
279// where V = # of states visited, E = # of arcs.
280//
281// References:
282// - Mehryar Mohri. Generic Epsilon-Removal and Input
283//   Epsilon-Normalization Algorithms for Weighted Transducers,
284//   "International Journal of Computer Science", 13(1):129-143 (2002).
285template <class Arc>
286void RmEpsilon(MutableFst<Arc> *fst, bool connect = true) {
287  typedef typename Arc::StateId StateId;
288  typedef typename Arc::Weight Weight;
289  typedef typename Arc::Label Label;
290
291  vector<Weight> distance;
292  AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
293  RmEpsilonOptions<Arc, AutoQueue<StateId> >
294    opts(&state_queue, kDelta, connect);
295
296  RmEpsilon(fst, &distance, opts);
297}
298
299
300struct RmEpsilonFstOptions : CacheOptions {
301  float delta;
302
303  RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
304      : CacheOptions(opts), delta(delta) {}
305
306  explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
307};
308
309
310// Implementation of delayed RmEpsilonFst.
311template <class A>
312class RmEpsilonFstImpl : public CacheImpl<A> {
313 public:
314  using FstImpl<A>::SetType;
315  using FstImpl<A>::SetProperties;
316  using FstImpl<A>::Properties;
317  using FstImpl<A>::SetInputSymbols;
318  using FstImpl<A>::SetOutputSymbols;
319
320  using CacheBaseImpl< CacheState<A> >::HasStart;
321  using CacheBaseImpl< CacheState<A> >::HasFinal;
322  using CacheBaseImpl< CacheState<A> >::HasArcs;
323
324  typedef typename A::Label Label;
325  typedef typename A::Weight Weight;
326  typedef typename A::StateId StateId;
327  typedef CacheState<A> State;
328
329  RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts)
330      : CacheImpl<A>(opts),
331        fst_(fst.Copy()),
332        rmeps_state_(
333            *fst_,
334            &distance_,
335            RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, opts.delta,
336                                                     false)
337            ) {
338    SetType("rmepsilon");
339    uint64 props = fst.Properties(kFstProperties, false);
340    SetProperties(RmEpsilonProperties(props, true), kCopyProperties);
341  }
342
343  ~RmEpsilonFstImpl() {
344    delete fst_;
345  }
346
347  StateId Start() {
348    if (!HasStart()) {
349      SetStart(fst_->Start());
350    }
351    return CacheImpl<A>::Start();
352  }
353
354  Weight Final(StateId s) {
355    if (!HasFinal(s)) {
356      Expand(s);
357    }
358    return CacheImpl<A>::Final(s);
359  }
360
361  size_t NumArcs(StateId s) {
362    if (!HasArcs(s))
363      Expand(s);
364    return CacheImpl<A>::NumArcs(s);
365  }
366
367  size_t NumInputEpsilons(StateId s) {
368    if (!HasArcs(s))
369      Expand(s);
370    return CacheImpl<A>::NumInputEpsilons(s);
371  }
372
373  size_t NumOutputEpsilons(StateId s) {
374    if (!HasArcs(s))
375      Expand(s);
376    return CacheImpl<A>::NumOutputEpsilons(s);
377  }
378
379  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
380    if (!HasArcs(s))
381      Expand(s);
382    CacheImpl<A>::InitArcIterator(s, data);
383  }
384
385  void Expand(StateId s) {
386    rmeps_state_.Expand(s);
387    SetFinal(s, rmeps_state_.Final());
388    vector<A> &arcs = rmeps_state_.Arcs();
389    while (!arcs.empty()) {
390      AddArc(s, arcs.back());
391      arcs.pop_back();
392    }
393    SetArcs(s);
394  }
395
396 private:
397  const Fst<A> *fst_;
398  vector<Weight> distance_;
399  FifoQueue<StateId> queue_;
400  RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_;
401
402  DISALLOW_EVIL_CONSTRUCTORS(RmEpsilonFstImpl);
403};
404
405
406// Removes epsilon-transitions (when both the input and output label
407// are an epsilon) from a transducer. The result will be an equivalent
408// FST that has no such epsilon transitions.  This version is a
409// delayed Fst.
410//
411// Complexity:
412// - Time:
413//   - Unweighted: O(v^2 + v e)
414//   - General: exponential
415// - Space: O(v e)
416// where v = # of states visited, e = # of arcs visited. Constant time
417// to visit an input state or arc is assumed and exclusive of caching.
418//
419// References:
420// - Mehryar Mohri. Generic Epsilon-Removal and Input
421//   Epsilon-Normalization Algorithms for Weighted Transducers,
422//   "International Journal of Computer Science", 13(1):129-143 (2002).
423template <class A>
424class RmEpsilonFst : public Fst<A> {
425 public:
426  friend class ArcIterator< RmEpsilonFst<A> >;
427  friend class CacheStateIterator< RmEpsilonFst<A> >;
428  friend class CacheArcIterator< RmEpsilonFst<A> >;
429
430  typedef A Arc;
431  typedef typename A::Weight Weight;
432  typedef typename A::StateId StateId;
433  typedef CacheState<A> State;
434
435  RmEpsilonFst(const Fst<A> &fst)
436      : impl_(new RmEpsilonFstImpl<A>(fst, RmEpsilonFstOptions())) {}
437
438  RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
439      : impl_(new RmEpsilonFstImpl<A>(fst, opts)) {}
440
441  explicit RmEpsilonFst(const RmEpsilonFst<A> &fst) : impl_(fst.impl_) {
442    impl_->IncrRefCount();
443  }
444
445  virtual ~RmEpsilonFst() { if (!impl_->DecrRefCount()) delete impl_;  }
446
447  virtual StateId Start() const { return impl_->Start(); }
448
449  virtual Weight Final(StateId s) const { return impl_->Final(s); }
450
451  virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
452
453  virtual size_t NumInputEpsilons(StateId s) const {
454    return impl_->NumInputEpsilons(s);
455  }
456
457  virtual size_t NumOutputEpsilons(StateId s) const {
458    return impl_->NumOutputEpsilons(s);
459  }
460
461  virtual uint64 Properties(uint64 mask, bool test) const {
462    if (test) {
463      uint64 known, test = TestProperties(*this, mask, &known);
464      impl_->SetProperties(test, known);
465      return test & mask;
466    } else {
467      return impl_->Properties(mask);
468    }
469  }
470
471  virtual const string& Type() const { return impl_->Type(); }
472
473  virtual RmEpsilonFst<A> *Copy() const {
474    return new RmEpsilonFst<A>(*this);
475  }
476
477  virtual const SymbolTable* InputSymbols() const {
478    return impl_->InputSymbols();
479  }
480
481  virtual const SymbolTable* OutputSymbols() const {
482    return impl_->OutputSymbols();
483  }
484
485  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
486
487  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
488    impl_->InitArcIterator(s, data);
489  }
490
491 protected:
492  RmEpsilonFstImpl<A> *Impl() { return impl_; }
493
494 private:
495  RmEpsilonFstImpl<A> *impl_;
496
497  void operator=(const RmEpsilonFst<A> &fst);  // disallow
498};
499
500
501// Specialization for RmEpsilonFst.
502template<class A>
503class StateIterator< RmEpsilonFst<A> >
504    : public CacheStateIterator< RmEpsilonFst<A> > {
505 public:
506  explicit StateIterator(const RmEpsilonFst<A> &fst)
507      : CacheStateIterator< RmEpsilonFst<A> >(fst) {}
508};
509
510
511// Specialization for RmEpsilonFst.
512template <class A>
513class ArcIterator< RmEpsilonFst<A> >
514    : public CacheArcIterator< RmEpsilonFst<A> > {
515 public:
516  typedef typename A::StateId StateId;
517
518  ArcIterator(const RmEpsilonFst<A> &fst, StateId s)
519      : CacheArcIterator< RmEpsilonFst<A> >(fst, s) {
520    if (!fst.impl_->HasArcs(s))
521      fst.impl_->Expand(s);
522  }
523
524 private:
525  DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
526};
527
528
529template <class A> inline
530void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
531  data->base = new StateIterator< RmEpsilonFst<A> >(*this);
532}
533
534
535// Useful alias when using StdArc.
536typedef RmEpsilonFst<StdArc> StdRmEpsilonFst;
537
538}  // namespace fst
539
540#endif  // FST_LIB_RMEPSILON_H__
541