rmepsilon.h revision 8fc5a7f51e62cb4ae44a27bdf4176d04adc80ede
1// rmepsilon.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16//
17// \file
18// Functions and classes that implemement epsilon-removal.
19
20#ifndef FST_LIB_RMEPSILON_H__
21#define FST_LIB_RMEPSILON_H__
22
23#include <ext/hash_map>
24using __gnu_cxx::hash_map;
25#include <ext/slist>
26using __gnu_cxx::slist;
27
28#include "fst/lib/arcfilter.h"
29#include "fst/lib/cache.h"
30#include "fst/lib/connect.h"
31#include "fst/lib/factor-weight.h"
32#include "fst/lib/invert.h"
33#include "fst/lib/map.h"
34#include "fst/lib/queue.h"
35#include "fst/lib/shortest-distance.h"
36#include "fst/lib/topsort.h"
37
38namespace fst {
39
40template <class Arc, class Queue>
41struct RmEpsilonOptions
42    : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > {
43  typedef typename Arc::StateId StateId;
44
45  bool connect;  // Connect output
46
47  RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true)
48      : ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> >(
49          q, EpsilonArcFilter<Arc>(), kNoStateId, d), connect(c) {}
50
51};
52
53
54// Computation state of the epsilon-removal algorithm.
55template <class Arc, class Queue>
56class RmEpsilonState {
57 public:
58  typedef typename Arc::Label Label;
59  typedef typename Arc::StateId StateId;
60  typedef typename Arc::Weight Weight;
61
62  RmEpsilonState(const Fst<Arc> &fst,
63                 vector<Weight> *distance,
64                 const RmEpsilonOptions<Arc, Queue> &opts)
65      : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true) {
66  }
67
68  // Compute arcs and final weight for state 's'
69  void Expand(StateId s);
70
71  // Returns arcs of expanded state.
72  vector<Arc> &Arcs() { return arcs_; }
73
74  // Returns final weight of expanded state.
75  const Weight &Final() const { return final_; }
76
77 private:
78  struct Element {
79    Label ilabel;
80    Label olabel;
81    StateId nextstate;
82
83    Element() {}
84
85    Element(Label i, Label o, StateId s)
86        : ilabel(i), olabel(o), nextstate(s) {}
87  };
88
89  class ElementKey {
90   public:
91    size_t operator()(const Element& e) const {
92      return static_cast<size_t>(e.nextstate);
93      return static_cast<size_t>(e.nextstate +
94                                 e.ilabel * kPrime0 +
95                                 e.olabel * kPrime1);
96    }
97
98   private:
99    static const int kPrime0 = 7853;
100    static const int kPrime1 = 7867;
101  };
102
103  class ElementEqual {
104   public:
105    bool operator()(const Element &e1, const Element &e2) const {
106      return (e1.ilabel == e2.ilabel) &&  (e1.olabel == e2.olabel)
107                         && (e1.nextstate == e2.nextstate);
108    }
109  };
110
111 private:
112  typedef hash_map<Element, pair<StateId, ssize_t>,
113                   ElementKey, ElementEqual> ElementMap;
114
115  const Fst<Arc> &fst_;
116  // Distance from state being expanded in epsilon-closure.
117  vector<Weight> *distance_;
118  // Shortest distance algorithm computation state.
119  ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_;
120  // Maps an element 'e' to a pair 'p' corresponding to a position
121  // in the arcs vector of the state being expanded. 'e' corresponds
122  // to the position 'p.second' in the 'arcs_' vector if 'p.first' is
123  // equal to the state being expanded.
124  ElementMap element_map_;
125  EpsilonArcFilter<Arc> eps_filter_;
126  stack<StateId> eps_queue_;      // Queue used to visit the epsilon-closure
127  vector<bool> visited_;          // '[i] = true' if state 'i' has been visited
128  slist<StateId> visited_states_; // List of visited states
129  vector<Arc> arcs_;              // Arcs of state being expanded
130  Weight final_;                  // Final weight of state being expanded
131
132  void operator=(const RmEpsilonState);  // Disallow
133};
134
135
136template <class Arc, class Queue>
137void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) {
138   sd_state_.ShortestDistance(source);
139   eps_queue_.push(source);
140   final_ = Weight::Zero();
141   arcs_.clear();
142
143   while (!eps_queue_.empty()) {
144     StateId state = eps_queue_.top();
145     eps_queue_.pop();
146
147     while ((StateId)visited_.size() <= state) visited_.push_back(false);
148     visited_[state] = true;
149     visited_states_.push_front(state);
150
151     for (ArcIterator< Fst<Arc> > ait(fst_, state);
152          !ait.Done();
153          ait.Next()) {
154       Arc arc = ait.Value();
155       arc.weight = Times((*distance_)[state], arc.weight);
156
157       if (eps_filter_(arc)) {
158         while ((StateId)visited_.size() <= arc.nextstate)
159           visited_.push_back(false);
160         if (!visited_[arc.nextstate])
161           eps_queue_.push(arc.nextstate);
162       } else {
163          Element element(arc.ilabel, arc.olabel, arc.nextstate);
164          typename ElementMap::iterator it = element_map_.find(element);
165          if (it == element_map_.end()) {
166            element_map_.insert(
167                pair<Element, pair<StateId, ssize_t> >
168                (element, pair<StateId, ssize_t>(source, arcs_.size())));
169            arcs_.push_back(arc);
170          } else {
171            if (((*it).second).first == source) {
172              Weight &w = arcs_[((*it).second).second].weight;
173              w = Plus(w, arc.weight);
174            } else {
175              ((*it).second).first = source;
176              ((*it).second).second = arcs_.size();
177              arcs_.push_back(arc);
178            }
179          }
180        }
181     }
182     final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
183   }
184
185   while (!visited_states_.empty()) {
186     visited_[visited_states_.front()] = false;
187     visited_states_.pop_front();
188   }
189}
190
191
192// Removes epsilon-transitions (when both the input and output label
193// are an epsilon) from a transducer. The result will be an equivalent
194// FST that has no such epsilon transitions.  This version modifies
195// its input. It allows fine control via the options argument; see
196// below for a simpler interface.
197//
198// The vector 'distance' will be used to hold the shortest distances
199// during the epsilon-closure computation. The state queue discipline
200// and convergence delta are taken in the options argument.
201template <class Arc, class Queue>
202void RmEpsilon(MutableFst<Arc> *fst,
203               vector<typename Arc::Weight> *distance,
204               const RmEpsilonOptions<Arc, Queue> &opts) {
205  typedef typename Arc::StateId StateId;
206  typedef typename Arc::Weight Weight;
207  typedef typename Arc::Label Label;
208
209  // States sorted in topological order when (acyclic) or generic
210  // topological order (cyclic).
211  vector<StateId> states;
212
213  if (fst->Properties(kTopSorted, false) & kTopSorted) {
214    for (StateId i = 0; i < (StateId)fst->NumStates(); i++)
215      states.push_back(i);
216  } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
217    vector<StateId> order;
218    bool acyclic;
219    TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
220    DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
221    if (!acyclic)
222      LOG(FATAL) << "RmEpsilon: not acyclic though property bit is set";
223    states.resize(order.size());
224    for (StateId i = 0; i < (StateId)order.size(); i++)
225      states[order[i]] = i;
226  } else {
227     uint64 props;
228     vector<StateId> scc;
229     SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props);
230     DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
231     vector<StateId> first(scc.size(), kNoStateId);
232     vector<StateId> next(scc.size(), kNoStateId);
233     for (StateId i = 0; i < (StateId)scc.size(); i++) {
234       if (first[scc[i]] != kNoStateId)
235         next[i] = first[scc[i]];
236       first[scc[i]] = i;
237     }
238     for (StateId i = 0; i < (StateId)first.size(); i++)
239       for (StateId j = first[i]; j != kNoStateId; j = next[j])
240         states.push_back(j);
241  }
242
243  RmEpsilonState<Arc, Queue>
244    rmeps_state(*fst, distance, opts);
245
246  while (!states.empty()) {
247    StateId state = states.back();
248    states.pop_back();
249    rmeps_state.Expand(state);
250    fst->SetFinal(state, rmeps_state.Final());
251    fst->DeleteArcs(state);
252    vector<Arc> &arcs = rmeps_state.Arcs();
253    while (!arcs.empty()) {
254      fst->AddArc(state, arcs.back());
255      arcs.pop_back();
256    }
257  }
258
259  fst->SetProperties(RmEpsilonProperties(
260                         fst->Properties(kFstProperties, false)),
261                     kFstProperties);
262
263  if (opts.connect)
264    Connect(fst);
265}
266
267
268// Removes epsilon-transitions (when both the input and output label
269// are an epsilon) from a transducer. The result will be an equivalent
270// FST that has no such epsilon transitions. This version modifies its
271// input. It has a simplified interface; see above for a version that
272// allows finer control.
273//
274// Complexity:
275// - Time:
276//   - Unweighted: O(V2 + V E)
277//   - Acyclic: O(V2 + V E)
278//   - Tropical semiring: O(V2 log V + V E)
279//   - General: exponential
280// - Space: O(V E)
281// where V = # of states visited, E = # of arcs.
282//
283// References:
284// - Mehryar Mohri. Generic Epsilon-Removal and Input
285//   Epsilon-Normalization Algorithms for Weighted Transducers,
286//   "International Journal of Computer Science", 13(1):129-143 (2002).
287template <class Arc>
288void RmEpsilon(MutableFst<Arc> *fst, bool connect = true) {
289  typedef typename Arc::StateId StateId;
290  typedef typename Arc::Weight Weight;
291  typedef typename Arc::Label Label;
292
293  vector<Weight> distance;
294  AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
295  RmEpsilonOptions<Arc, AutoQueue<StateId> >
296    opts(&state_queue, kDelta, connect);
297
298  RmEpsilon(fst, &distance, opts);
299}
300
301
302struct RmEpsilonFstOptions : CacheOptions {
303  float delta;
304
305  RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
306      : CacheOptions(opts), delta(delta) {}
307
308  explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
309};
310
311
312// Implementation of delayed RmEpsilonFst.
313template <class A>
314class RmEpsilonFstImpl : public CacheImpl<A> {
315 public:
316  using FstImpl<A>::SetType;
317  using FstImpl<A>::SetProperties;
318  using FstImpl<A>::Properties;
319  using FstImpl<A>::SetInputSymbols;
320  using FstImpl<A>::SetOutputSymbols;
321
322  using CacheBaseImpl< CacheState<A> >::HasStart;
323  using CacheBaseImpl< CacheState<A> >::HasFinal;
324  using CacheBaseImpl< CacheState<A> >::HasArcs;
325
326  typedef typename A::Label Label;
327  typedef typename A::Weight Weight;
328  typedef typename A::StateId StateId;
329  typedef CacheState<A> State;
330
331  RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts)
332      : CacheImpl<A>(opts),
333        fst_(fst.Copy()),
334        rmeps_state_(
335            *fst_,
336            &distance_,
337            RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, opts.delta,
338                                                     false)
339            ) {
340    SetType("rmepsilon");
341    uint64 props = fst.Properties(kFstProperties, false);
342    SetProperties(RmEpsilonProperties(props, true), kCopyProperties);
343  }
344
345  ~RmEpsilonFstImpl() {
346    delete fst_;
347  }
348
349  StateId Start() {
350    if (!HasStart()) {
351      SetStart(fst_->Start());
352    }
353    return CacheImpl<A>::Start();
354  }
355
356  Weight Final(StateId s) {
357    if (!HasFinal(s)) {
358      Expand(s);
359    }
360    return CacheImpl<A>::Final(s);
361  }
362
363  size_t NumArcs(StateId s) {
364    if (!HasArcs(s))
365      Expand(s);
366    return CacheImpl<A>::NumArcs(s);
367  }
368
369  size_t NumInputEpsilons(StateId s) {
370    if (!HasArcs(s))
371      Expand(s);
372    return CacheImpl<A>::NumInputEpsilons(s);
373  }
374
375  size_t NumOutputEpsilons(StateId s) {
376    if (!HasArcs(s))
377      Expand(s);
378    return CacheImpl<A>::NumOutputEpsilons(s);
379  }
380
381  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
382    if (!HasArcs(s))
383      Expand(s);
384    CacheImpl<A>::InitArcIterator(s, data);
385  }
386
387  void Expand(StateId s) {
388    rmeps_state_.Expand(s);
389    SetFinal(s, rmeps_state_.Final());
390    vector<A> &arcs = rmeps_state_.Arcs();
391    while (!arcs.empty()) {
392      AddArc(s, arcs.back());
393      arcs.pop_back();
394    }
395    SetArcs(s);
396  }
397
398 private:
399  const Fst<A> *fst_;
400  vector<Weight> distance_;
401  FifoQueue<StateId> queue_;
402  RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_;
403
404  DISALLOW_EVIL_CONSTRUCTORS(RmEpsilonFstImpl);
405};
406
407
408// Removes epsilon-transitions (when both the input and output label
409// are an epsilon) from a transducer. The result will be an equivalent
410// FST that has no such epsilon transitions.  This version is a
411// delayed Fst.
412//
413// Complexity:
414// - Time:
415//   - Unweighted: O(v^2 + v e)
416//   - General: exponential
417// - Space: O(v e)
418// where v = # of states visited, e = # of arcs visited. Constant time
419// to visit an input state or arc is assumed and exclusive of caching.
420//
421// References:
422// - Mehryar Mohri. Generic Epsilon-Removal and Input
423//   Epsilon-Normalization Algorithms for Weighted Transducers,
424//   "International Journal of Computer Science", 13(1):129-143 (2002).
425template <class A>
426class RmEpsilonFst : public Fst<A> {
427 public:
428  friend class ArcIterator< RmEpsilonFst<A> >;
429  friend class CacheStateIterator< RmEpsilonFst<A> >;
430  friend class CacheArcIterator< RmEpsilonFst<A> >;
431
432  typedef A Arc;
433  typedef typename A::Weight Weight;
434  typedef typename A::StateId StateId;
435  typedef CacheState<A> State;
436
437  RmEpsilonFst(const Fst<A> &fst)
438      : impl_(new RmEpsilonFstImpl<A>(fst, RmEpsilonFstOptions())) {}
439
440  RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
441      : impl_(new RmEpsilonFstImpl<A>(fst, opts)) {}
442
443  explicit RmEpsilonFst(const RmEpsilonFst<A> &fst) : impl_(fst.impl_) {
444    impl_->IncrRefCount();
445  }
446
447  virtual ~RmEpsilonFst() { if (!impl_->DecrRefCount()) delete impl_;  }
448
449  virtual StateId Start() const { return impl_->Start(); }
450
451  virtual Weight Final(StateId s) const { return impl_->Final(s); }
452
453  virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
454
455  virtual size_t NumInputEpsilons(StateId s) const {
456    return impl_->NumInputEpsilons(s);
457  }
458
459  virtual size_t NumOutputEpsilons(StateId s) const {
460    return impl_->NumOutputEpsilons(s);
461  }
462
463  virtual uint64 Properties(uint64 mask, bool test) const {
464    if (test) {
465      uint64 known, test = TestProperties(*this, mask, &known);
466      impl_->SetProperties(test, known);
467      return test & mask;
468    } else {
469      return impl_->Properties(mask);
470    }
471  }
472
473  virtual const string& Type() const { return impl_->Type(); }
474
475  virtual RmEpsilonFst<A> *Copy() const {
476    return new RmEpsilonFst<A>(*this);
477  }
478
479  virtual const SymbolTable* InputSymbols() const {
480    return impl_->InputSymbols();
481  }
482
483  virtual const SymbolTable* OutputSymbols() const {
484    return impl_->OutputSymbols();
485  }
486
487  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
488
489  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
490    impl_->InitArcIterator(s, data);
491  }
492
493 protected:
494  RmEpsilonFstImpl<A> *Impl() { return impl_; }
495
496 private:
497  RmEpsilonFstImpl<A> *impl_;
498
499  void operator=(const RmEpsilonFst<A> &fst);  // disallow
500};
501
502
503// Specialization for RmEpsilonFst.
504template<class A>
505class StateIterator< RmEpsilonFst<A> >
506    : public CacheStateIterator< RmEpsilonFst<A> > {
507 public:
508  explicit StateIterator(const RmEpsilonFst<A> &fst)
509      : CacheStateIterator< RmEpsilonFst<A> >(fst) {}
510};
511
512
513// Specialization for RmEpsilonFst.
514template <class A>
515class ArcIterator< RmEpsilonFst<A> >
516    : public CacheArcIterator< RmEpsilonFst<A> > {
517 public:
518  typedef typename A::StateId StateId;
519
520  ArcIterator(const RmEpsilonFst<A> &fst, StateId s)
521      : CacheArcIterator< RmEpsilonFst<A> >(fst, s) {
522    if (!fst.impl_->HasArcs(s))
523      fst.impl_->Expand(s);
524  }
525
526 private:
527  DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
528};
529
530
531template <class A> inline
532void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
533  data->base = new StateIterator< RmEpsilonFst<A> >(*this);
534}
535
536
537// Useful alias when using StdArc.
538typedef RmEpsilonFst<StdArc> StdRmEpsilonFst;
539
540}  // namespace fst
541
542#endif  // FST_LIB_RMEPSILON_H__
543