1// rmepsilon.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: allauzen@google.com (Cyril Allauzen)
17//
18// \file
19// Functions and classes that implemement epsilon-removal.
20
21#ifndef FST_LIB_RMEPSILON_H__
22#define FST_LIB_RMEPSILON_H__
23
24#include <tr1/unordered_map>
25using std::tr1::unordered_map;
26using std::tr1::unordered_multimap;
27#include <fst/slist.h>
28#include <stack>
29#include <string>
30#include <utility>
31using std::pair; using std::make_pair;
32#include <vector>
33using std::vector;
34
35#include <fst/arcfilter.h>
36#include <fst/cache.h>
37#include <fst/connect.h>
38#include <fst/factor-weight.h>
39#include <fst/invert.h>
40#include <fst/prune.h>
41#include <fst/queue.h>
42#include <fst/shortest-distance.h>
43#include <fst/topsort.h>
44
45
46namespace fst {
47
48template <class Arc, class Queue>
49class RmEpsilonOptions
50    : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > {
51 public:
52  typedef typename Arc::StateId StateId;
53  typedef typename Arc::Weight Weight;
54
55  bool connect;              // Connect output
56  Weight weight_threshold;   // Pruning weight threshold.
57  StateId state_threshold;   // Pruning state threshold.
58
59  explicit RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true,
60                            Weight w = Weight::Zero(),
61                            StateId n = kNoStateId)
62      : ShortestDistanceOptions< Arc, Queue, EpsilonArcFilter<Arc> >(
63          q, EpsilonArcFilter<Arc>(), kNoStateId, d),
64        connect(c), weight_threshold(w), state_threshold(n) {}
65 private:
66  RmEpsilonOptions();  // disallow
67};
68
69// Computation state of the epsilon-removal algorithm.
70template <class Arc, class Queue>
71class RmEpsilonState {
72 public:
73  typedef typename Arc::Label Label;
74  typedef typename Arc::StateId StateId;
75  typedef typename Arc::Weight Weight;
76
77  RmEpsilonState(const Fst<Arc> &fst,
78                 vector<Weight> *distance,
79                 const RmEpsilonOptions<Arc, Queue> &opts)
80      : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true),
81        expand_id_(0) {}
82
83  // Compute arcs and final weight for state 's'
84  void Expand(StateId s);
85
86  // Returns arcs of expanded state.
87  vector<Arc> &Arcs() { return arcs_; }
88
89  // Returns final weight of expanded state.
90  const Weight &Final() const { return final_; }
91
92  // Return true if an error has occured.
93  bool Error() const { return sd_state_.Error(); }
94
95 private:
96  static const size_t kPrime0 = 7853;
97  static const size_t kPrime1 = 7867;
98
99  struct Element {
100    Label ilabel;
101    Label olabel;
102    StateId nextstate;
103
104    Element() {}
105
106    Element(Label i, Label o, StateId s)
107        : ilabel(i), olabel(o), nextstate(s) {}
108  };
109
110  class ElementKey {
111   public:
112    size_t operator()(const Element& e) const {
113      return static_cast<size_t>(e.nextstate +
114                                 e.ilabel * kPrime0 +
115                                 e.olabel * kPrime1);
116    }
117
118   private:
119  };
120
121  class ElementEqual {
122   public:
123    bool operator()(const Element &e1, const Element &e2) const {
124      return (e1.ilabel == e2.ilabel) &&  (e1.olabel == e2.olabel)
125                         && (e1.nextstate == e2.nextstate);
126    }
127  };
128
129  typedef unordered_map<Element, pair<StateId, size_t>,
130                   ElementKey, ElementEqual> ElementMap;
131
132  const Fst<Arc> &fst_;
133  // Distance from state being expanded in epsilon-closure.
134  vector<Weight> *distance_;
135  // Shortest distance algorithm computation state.
136  ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_;
137  // Maps an element 'e' to a pair 'p' corresponding to a position
138  // in the arcs vector of the state being expanded. 'e' corresponds
139  // to the position 'p.second' in the 'arcs_' vector if 'p.first' is
140  // equal to the state being expanded.
141  ElementMap element_map_;
142  EpsilonArcFilter<Arc> eps_filter_;
143  stack<StateId> eps_queue_;      // Queue used to visit the epsilon-closure
144  vector<bool> visited_;          // '[i] = true' if state 'i' has been visited
145  slist<StateId> visited_states_; // List of visited states
146  vector<Arc> arcs_;              // Arcs of state being expanded
147  Weight final_;                  // Final weight of state being expanded
148  StateId expand_id_;             // Unique ID for each call to Expand
149
150  DISALLOW_COPY_AND_ASSIGN(RmEpsilonState);
151};
152
153template <class Arc, class Queue>
154const size_t RmEpsilonState<Arc, Queue>::kPrime0;
155template <class Arc, class Queue>
156const size_t RmEpsilonState<Arc, Queue>::kPrime1;
157
158
159template <class Arc, class Queue>
160void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) {
161   final_ = Weight::Zero();
162   arcs_.clear();
163   sd_state_.ShortestDistance(source);
164   if (sd_state_.Error())
165     return;
166   eps_queue_.push(source);
167
168   while (!eps_queue_.empty()) {
169     StateId state = eps_queue_.top();
170     eps_queue_.pop();
171
172     while (visited_.size() <= state) visited_.push_back(false);
173     if (visited_[state]) continue;
174     visited_[state] = true;
175     visited_states_.push_front(state);
176
177     for (ArcIterator< Fst<Arc> > ait(fst_, state);
178          !ait.Done();
179          ait.Next()) {
180       Arc arc = ait.Value();
181       arc.weight = Times((*distance_)[state], arc.weight);
182
183       if (eps_filter_(arc)) {
184         while (visited_.size() <= arc.nextstate)
185           visited_.push_back(false);
186         if (!visited_[arc.nextstate])
187           eps_queue_.push(arc.nextstate);
188       } else {
189          Element element(arc.ilabel, arc.olabel, arc.nextstate);
190          typename ElementMap::iterator it = element_map_.find(element);
191          if (it == element_map_.end()) {
192            element_map_.insert(
193                pair<Element, pair<StateId, size_t> >
194                (element, pair<StateId, size_t>(expand_id_, arcs_.size())));
195            arcs_.push_back(arc);
196          } else {
197            if (((*it).second).first == expand_id_) {
198              Weight &w = arcs_[((*it).second).second].weight;
199              w = Plus(w, arc.weight);
200            } else {
201              ((*it).second).first = expand_id_;
202              ((*it).second).second = arcs_.size();
203              arcs_.push_back(arc);
204            }
205          }
206        }
207     }
208     final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
209   }
210
211   while (!visited_states_.empty()) {
212     visited_[visited_states_.front()] = false;
213     visited_states_.pop_front();
214   }
215   ++expand_id_;
216}
217
218// Removes epsilon-transitions (when both the input and output label
219// are an epsilon) from a transducer. The result will be an equivalent
220// FST that has no such epsilon transitions.  This version modifies
221// its input. It allows fine control via the options argument; see
222// below for a simpler interface.
223//
224// The vector 'distance' will be used to hold the shortest distances
225// during the epsilon-closure computation. The state queue discipline
226// and convergence delta are taken in the options argument.
227template <class Arc, class Queue>
228void RmEpsilon(MutableFst<Arc> *fst,
229               vector<typename Arc::Weight> *distance,
230               const RmEpsilonOptions<Arc, Queue> &opts) {
231  typedef typename Arc::StateId StateId;
232  typedef typename Arc::Weight Weight;
233  typedef typename Arc::Label Label;
234
235  if (fst->Start() == kNoStateId) {
236    return;
237  }
238
239  // 'noneps_in[s]' will be set to true iff 's' admits a non-epsilon
240  // incoming transition or is the start state.
241  vector<bool> noneps_in(fst->NumStates(), false);
242  noneps_in[fst->Start()] = true;
243  for (StateId i = 0; i < fst->NumStates(); ++i) {
244    for (ArcIterator<Fst<Arc> > aiter(*fst, i);
245         !aiter.Done();
246         aiter.Next()) {
247      if (aiter.Value().ilabel != 0 || aiter.Value().olabel != 0)
248        noneps_in[aiter.Value().nextstate] = true;
249    }
250  }
251
252  // States sorted in topological order when (acyclic) or generic
253  // topological order (cyclic).
254  vector<StateId> states;
255  states.reserve(fst->NumStates());
256
257  if (fst->Properties(kTopSorted, false) & kTopSorted) {
258    for (StateId i = 0; i < fst->NumStates(); i++)
259      states.push_back(i);
260  } else if (fst->Properties(kAcyclic, false) & kAcyclic) {
261    vector<StateId> order;
262    bool acyclic;
263    TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
264    DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
265    // Sanity check: should be acyclic if property bit is set.
266    if(!acyclic) {
267      FSTERROR() << "RmEpsilon: inconsistent acyclic property bit";
268      fst->SetProperties(kError, kError);
269      return;
270    }
271    states.resize(order.size());
272    for (StateId i = 0; i < order.size(); i++)
273      states[order[i]] = i;
274  } else {
275     uint64 props;
276     vector<StateId> scc;
277     SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props);
278     DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
279     vector<StateId> first(scc.size(), kNoStateId);
280     vector<StateId> next(scc.size(), kNoStateId);
281     for (StateId i = 0; i < scc.size(); i++) {
282       if (first[scc[i]] != kNoStateId)
283         next[i] = first[scc[i]];
284       first[scc[i]] = i;
285     }
286     for (StateId i = 0; i < first.size(); i++)
287       for (StateId j = first[i]; j != kNoStateId; j = next[j])
288         states.push_back(j);
289  }
290
291  RmEpsilonState<Arc, Queue>
292    rmeps_state(*fst, distance, opts);
293
294  while (!states.empty()) {
295    StateId state = states.back();
296    states.pop_back();
297    if (!noneps_in[state])
298      continue;
299    rmeps_state.Expand(state);
300    fst->SetFinal(state, rmeps_state.Final());
301    fst->DeleteArcs(state);
302    vector<Arc> &arcs = rmeps_state.Arcs();
303    fst->ReserveArcs(state, arcs.size());
304    while (!arcs.empty()) {
305      fst->AddArc(state, arcs.back());
306      arcs.pop_back();
307    }
308  }
309
310  for (StateId s = 0; s < fst->NumStates(); ++s) {
311    if (!noneps_in[s])
312      fst->DeleteArcs(s);
313  }
314
315  if(rmeps_state.Error())
316    fst->SetProperties(kError, kError);
317  fst->SetProperties(
318      RmEpsilonProperties(fst->Properties(kFstProperties, false)),
319      kFstProperties);
320
321  if (opts.weight_threshold != Weight::Zero() ||
322      opts.state_threshold != kNoStateId)
323    Prune(fst, opts.weight_threshold, opts.state_threshold);
324  if (opts.connect && (opts.weight_threshold == Weight::Zero() ||
325                       opts.state_threshold != kNoStateId))
326    Connect(fst);
327}
328
329// Removes epsilon-transitions (when both the input and output label
330// are an epsilon) from a transducer. The result will be an equivalent
331// FST that has no such epsilon transitions. This version modifies its
332// input. It has a simplified interface; see above for a version that
333// allows finer control.
334//
335// Complexity:
336// - Time:
337//   - Unweighted: O(V2 + V E)
338//   - Acyclic: O(V2 + V E)
339//   - Tropical semiring: O(V2 log V + V E)
340//   - General: exponential
341// - Space: O(V E)
342// where V = # of states visited, E = # of arcs.
343//
344// References:
345// - Mehryar Mohri. Generic Epsilon-Removal and Input
346//   Epsilon-Normalization Algorithms for Weighted Transducers,
347//   "International Journal of Computer Science", 13(1):129-143 (2002).
348template <class Arc>
349void RmEpsilon(MutableFst<Arc> *fst,
350               bool connect = true,
351               typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
352               typename Arc::StateId state_threshold = kNoStateId,
353               float delta = kDelta) {
354  typedef typename Arc::StateId StateId;
355  typedef typename Arc::Weight Weight;
356  typedef typename Arc::Label Label;
357
358  vector<Weight> distance;
359  AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
360  RmEpsilonOptions<Arc, AutoQueue<StateId> >
361      opts(&state_queue, delta, connect, weight_threshold, state_threshold);
362
363  RmEpsilon(fst, &distance, opts);
364}
365
366
367struct RmEpsilonFstOptions : CacheOptions {
368  float delta;
369
370  RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
371      : CacheOptions(opts), delta(delta) {}
372
373  explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
374};
375
376
377// Implementation of delayed RmEpsilonFst.
378template <class A>
379class RmEpsilonFstImpl : public CacheImpl<A> {
380 public:
381  using FstImpl<A>::SetType;
382  using FstImpl<A>::SetProperties;
383  using FstImpl<A>::SetInputSymbols;
384  using FstImpl<A>::SetOutputSymbols;
385
386  using CacheBaseImpl< CacheState<A> >::PushArc;
387  using CacheBaseImpl< CacheState<A> >::HasArcs;
388  using CacheBaseImpl< CacheState<A> >::HasFinal;
389  using CacheBaseImpl< CacheState<A> >::HasStart;
390  using CacheBaseImpl< CacheState<A> >::SetArcs;
391  using CacheBaseImpl< CacheState<A> >::SetFinal;
392  using CacheBaseImpl< CacheState<A> >::SetStart;
393
394  typedef typename A::Label Label;
395  typedef typename A::Weight Weight;
396  typedef typename A::StateId StateId;
397  typedef CacheState<A> State;
398
399  RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts)
400      : CacheImpl<A>(opts),
401        fst_(fst.Copy()),
402        delta_(opts.delta),
403        rmeps_state_(
404            *fst_,
405            &distance_,
406            RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) {
407    SetType("rmepsilon");
408    uint64 props = fst.Properties(kFstProperties, false);
409    SetProperties(RmEpsilonProperties(props, true), kCopyProperties);
410    SetInputSymbols(fst.InputSymbols());
411    SetOutputSymbols(fst.OutputSymbols());
412  }
413
414  RmEpsilonFstImpl(const RmEpsilonFstImpl &impl)
415      : CacheImpl<A>(impl),
416        fst_(impl.fst_->Copy(true)),
417        delta_(impl.delta_),
418        rmeps_state_(
419            *fst_,
420            &distance_,
421            RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) {
422    SetType("rmepsilon");
423    SetProperties(impl.Properties(), kCopyProperties);
424    SetInputSymbols(impl.InputSymbols());
425    SetOutputSymbols(impl.OutputSymbols());
426  }
427
428  ~RmEpsilonFstImpl() {
429    delete fst_;
430  }
431
432  StateId Start() {
433    if (!HasStart()) {
434      SetStart(fst_->Start());
435    }
436    return CacheImpl<A>::Start();
437  }
438
439  Weight Final(StateId s) {
440    if (!HasFinal(s)) {
441      Expand(s);
442    }
443    return CacheImpl<A>::Final(s);
444  }
445
446  size_t NumArcs(StateId s) {
447    if (!HasArcs(s))
448      Expand(s);
449    return CacheImpl<A>::NumArcs(s);
450  }
451
452  size_t NumInputEpsilons(StateId s) {
453    if (!HasArcs(s))
454      Expand(s);
455    return CacheImpl<A>::NumInputEpsilons(s);
456  }
457
458  size_t NumOutputEpsilons(StateId s) {
459    if (!HasArcs(s))
460      Expand(s);
461    return CacheImpl<A>::NumOutputEpsilons(s);
462  }
463
464  uint64 Properties() const { return Properties(kFstProperties); }
465
466  // Set error if found; return FST impl properties.
467  uint64 Properties(uint64 mask) const {
468    if ((mask & kError) &&
469        (fst_->Properties(kError, false) || rmeps_state_.Error()))
470      SetProperties(kError, kError);
471    return FstImpl<A>::Properties(mask);
472  }
473
474  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
475    if (!HasArcs(s))
476      Expand(s);
477    CacheImpl<A>::InitArcIterator(s, data);
478  }
479
480  void Expand(StateId s) {
481    rmeps_state_.Expand(s);
482    SetFinal(s, rmeps_state_.Final());
483    vector<A> &arcs = rmeps_state_.Arcs();
484    while (!arcs.empty()) {
485      PushArc(s, arcs.back());
486      arcs.pop_back();
487    }
488    SetArcs(s);
489  }
490
491 private:
492  const Fst<A> *fst_;
493  float delta_;
494  vector<Weight> distance_;
495  FifoQueue<StateId> queue_;
496  RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_;
497
498  void operator=(const RmEpsilonFstImpl<A> &);  // disallow
499};
500
501
502// Removes epsilon-transitions (when both the input and output label
503// are an epsilon) from a transducer. The result will be an equivalent
504// FST that has no such epsilon transitions.  This version is a
505// delayed Fst.
506//
507// Complexity:
508// - Time:
509//   - Unweighted: O(v^2 + v e)
510//   - General: exponential
511// - Space: O(v e)
512// where v = # of states visited, e = # of arcs visited. Constant time
513// to visit an input state or arc is assumed and exclusive of caching.
514//
515// References:
516// - Mehryar Mohri. Generic Epsilon-Removal and Input
517//   Epsilon-Normalization Algorithms for Weighted Transducers,
518//   "International Journal of Computer Science", 13(1):129-143 (2002).
519//
520// This class attaches interface to implementation and handles
521// reference counting, delegating most methods to ImplToFst.
522template <class A>
523class RmEpsilonFst : public ImplToFst< RmEpsilonFstImpl<A> > {
524 public:
525  friend class ArcIterator< RmEpsilonFst<A> >;
526  friend class StateIterator< RmEpsilonFst<A> >;
527
528  typedef A Arc;
529  typedef typename A::StateId StateId;
530  typedef CacheState<A> State;
531  typedef RmEpsilonFstImpl<A> Impl;
532
533  RmEpsilonFst(const Fst<A> &fst)
534      : ImplToFst<Impl>(new Impl(fst, RmEpsilonFstOptions())) {}
535
536  RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
537      : ImplToFst<Impl>(new Impl(fst, opts)) {}
538
539  // See Fst<>::Copy() for doc.
540  RmEpsilonFst(const RmEpsilonFst<A> &fst, bool safe = false)
541      : ImplToFst<Impl>(fst, safe) {}
542
543  // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc.
544  virtual RmEpsilonFst<A> *Copy(bool safe = false) const {
545    return new RmEpsilonFst<A>(*this, safe);
546  }
547
548  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
549
550  virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
551    GetImpl()->InitArcIterator(s, data);
552  }
553
554 private:
555  // Makes visible to friends.
556  Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
557
558  void operator=(const RmEpsilonFst<A> &fst);  // disallow
559};
560
561// Specialization for RmEpsilonFst.
562template<class A>
563class StateIterator< RmEpsilonFst<A> >
564    : public CacheStateIterator< RmEpsilonFst<A> > {
565 public:
566  explicit StateIterator(const RmEpsilonFst<A> &fst)
567      : CacheStateIterator< RmEpsilonFst<A> >(fst, fst.GetImpl()) {}
568};
569
570
571// Specialization for RmEpsilonFst.
572template <class A>
573class ArcIterator< RmEpsilonFst<A> >
574    : public CacheArcIterator< RmEpsilonFst<A> > {
575 public:
576  typedef typename A::StateId StateId;
577
578  ArcIterator(const RmEpsilonFst<A> &fst, StateId s)
579      : CacheArcIterator< RmEpsilonFst<A> >(fst.GetImpl(), s) {
580    if (!fst.GetImpl()->HasArcs(s))
581      fst.GetImpl()->Expand(s);
582  }
583
584 private:
585  DISALLOW_COPY_AND_ASSIGN(ArcIterator);
586};
587
588
589template <class A> inline
590void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
591  data->base = new StateIterator< RmEpsilonFst<A> >(*this);
592}
593
594
595// Useful alias when using StdArc.
596typedef RmEpsilonFst<StdArc> StdRmEpsilonFst;
597
598}  // namespace fst
599
600#endif  // FST_LIB_RMEPSILON_H__
601