1// minimize.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file Functions and classes to minimize a finite state acceptor
17
18#ifndef FST_LIB_MINIMIZE_H__
19#define FST_LIB_MINIMIZE_H__
20
21#include <algorithm>
22#include <map>
23#include <queue>
24
25#include "fst/lib/arcsort.h"
26#include "fst/lib/arcsum.h"
27#include "fst/lib/connect.h"
28#include "fst/lib/dfs-visit.h"
29#include "fst/lib/encode.h"
30#include "fst/lib/factor-weight.h"
31#include "fst/lib/fst.h"
32#include "fst/lib/mutable-fst.h"
33#include "fst/lib/partition.h"
34#include "fst/lib/push.h"
35#include "fst/lib/queue.h"
36#include "fst/lib/reverse.h"
37
38namespace fst {
39
40// comparator for creating partition based on sorting on
41// - states
42// - final weight
43// - out degree,
44// -  (input label, output label, weight, destination_block)
45template <class A>
46class StateComparator {
47 public:
48  typedef typename A::StateId StateId;
49  typedef typename A::Weight Weight;
50
51  static const int32 kCompareFinal     = 0x0000001;
52  static const int32 kCompareOutDegree = 0x0000002;
53  static const int32 kCompareArcs      = 0x0000004;
54  static const int32 kCompareAll       = (kCompareFinal |
55                                          kCompareOutDegree |
56                                          kCompareArcs);
57
58  StateComparator(const Fst<A>& fst,
59                  const Partition<typename A::StateId>& partition,
60                  int32 flags = kCompareAll)
61      : fst_(fst), partition_(partition), flags_(flags) {}
62
63  // compare state x with state y based on sort criteria
64  bool operator()(const StateId x, const StateId y) const {
65    // check for final state equivalence
66    if (flags_ & kCompareFinal) {
67      const ssize_t xfinal = fst_.Final(x).Hash();
68      const ssize_t yfinal = fst_.Final(y).Hash();
69      if      (xfinal < yfinal) return true;
70      else if (xfinal > yfinal) return false;
71    }
72
73    if (flags_ & kCompareOutDegree) {
74      // check for # arcs
75      if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true;
76      if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false;
77
78      if (flags_ & kCompareArcs) {
79        // # arcs are equal, check for arc match
80        for (ArcIterator<Fst<A> > aiter1(fst_, x), aiter2(fst_, y);
81             !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
82          const A& arc1 = aiter1.Value();
83          const A& arc2 = aiter2.Value();
84          if (arc1.ilabel < arc2.ilabel) return true;
85          if (arc1.ilabel > arc2.ilabel) return false;
86
87          if (partition_.class_id(arc1.nextstate) <
88              partition_.class_id(arc2.nextstate)) return true;
89          if (partition_.class_id(arc1.nextstate) >
90              partition_.class_id(arc2.nextstate)) return false;
91        }
92      }
93    }
94
95    return false;
96  }
97
98 private:
99  const Fst<A>& fst_;
100  const Partition<typename A::StateId>& partition_;
101  const int32 flags_;
102};
103
104// Computes equivalence classes for cyclic Fsts. For cyclic minimization
105// we use the classic HopCroft minimization algorithm, which is of
106//
107//   O(E)log(N),
108//
109// where E is the number of edges in the machine and N is number of states.
110//
111// The following paper describes the original algorithm
112//  An N Log N algorithm for minimizing states in a finite automaton
113//  by John HopCroft, January 1971
114//
115template <class A, class Queue>
116class CyclicMinimizer {
117 public:
118  typedef typename A::Label Label;
119  typedef typename A::StateId StateId;
120  typedef typename A::StateId ClassId;
121  typedef typename A::Weight Weight;
122  typedef ReverseArc<A> RevA;
123
124  CyclicMinimizer(const ExpandedFst<A>& fst) {
125    Initialize(fst);
126    Compute(fst);
127  }
128
129  ~CyclicMinimizer() {
130    delete aiter_queue_;
131  }
132
133  const Partition<StateId>& partition() const {
134    return P_;
135  }
136
137  // helper classes
138 private:
139  typedef ArcIterator<Fst<RevA> > ArcIter;
140  class ArcIterCompare {
141   public:
142    ArcIterCompare(const Partition<StateId>& partition)
143        : partition_(partition) {}
144
145    ArcIterCompare(const ArcIterCompare& comp)
146        : partition_(comp.partition_) {}
147
148    // compare two iterators based on there input labels, and proto state
149    // (partition class Ids)
150    bool operator()(const ArcIter* x, const ArcIter* y) const {
151      const RevA& xarc = x->Value();
152      const RevA& yarc = y->Value();
153      return (xarc.ilabel > yarc.ilabel);
154    }
155
156   private:
157    const Partition<StateId>& partition_;
158  };
159
160  typedef priority_queue<ArcIter*, vector<ArcIter*>, ArcIterCompare>
161  ArcIterQueue;
162
163  // helper methods
164 private:
165  // prepartitions the space into equivalence classes with
166  //   same final weight
167  //   same # arcs per state
168  //   same outgoing arcs
169  void PrePartition(const Fst<A>& fst) {
170    VLOG(5) << "PrePartition";
171
172    typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
173    StateComparator<A> comp(fst, P_, StateComparator<A>::kCompareFinal);
174    EquivalenceMap equiv_map(comp);
175
176    StateIterator<Fst<A> > siter(fst);
177    StateId class_id = P_.AddClass();
178    P_.Add(siter.Value(), class_id);
179    equiv_map[siter.Value()] = class_id;
180    L_.Enqueue(class_id);
181    for (siter.Next(); !siter.Done(); siter.Next()) {
182      StateId  s = siter.Value();
183      typename EquivalenceMap::const_iterator it = equiv_map.find(s);
184      if (it == equiv_map.end()) {
185        class_id = P_.AddClass();
186        P_.Add(s, class_id);
187        equiv_map[s] = class_id;
188        L_.Enqueue(class_id);
189      } else {
190        P_.Add(s, it->second);
191        equiv_map[s] = it->second;
192      }
193    }
194
195    VLOG(5) << "Initial Partition: " << P_.num_classes();
196  }
197
198  // - Create inverse transition Tr_ = rev(fst)
199  // - loop over states in fst and split on final, creating two blocks
200  //   in the partition corresponding to final, non-final
201  void Initialize(const Fst<A>& fst) {
202    // construct Tr
203    Reverse(fst, &Tr_);
204    ILabelCompare<RevA> ilabel_comp;
205    ArcSort(&Tr_, ilabel_comp);
206
207    // initial split (F, S - F)
208    P_.Initialize(Tr_.NumStates() - 1);
209
210    // prep partition
211    PrePartition(fst);
212
213    // allocate arc iterator queue
214    ArcIterCompare comp(P_);
215    aiter_queue_ = new ArcIterQueue(comp);
216  }
217
218  // partition all classes with destination C
219  void Split(ClassId C) {
220    // Prep priority queue. Open arc iterator for each state in C, and
221    // insert into priority queue.
222    for (PartitionIterator<StateId> siter(P_, C);
223         !siter.Done(); siter.Next()) {
224      StateId s = siter.Value();
225      if (Tr_.NumArcs(s + 1))
226        aiter_queue_->push(new ArcIterator<Fst<RevA> >(Tr_, s + 1));
227    }
228
229    // Now pop arc iterator from queue, split entering equivalence class
230    // re-insert updated iterator into queue.
231    Label prev_label = -1;
232    while (!aiter_queue_->empty()) {
233      ArcIterator<Fst<RevA> >* aiter = aiter_queue_->top();
234      aiter_queue_->pop();
235      if (aiter->Done()) {
236        delete aiter;
237        continue;
238     }
239
240      const RevA& arc = aiter->Value();
241      StateId from_state = aiter->Value().nextstate - 1;
242      Label   from_label = arc.ilabel;
243      if (prev_label != from_label)
244        P_.FinalizeSplit(&L_);
245
246      StateId from_class = P_.class_id(from_state);
247      if (P_.class_size(from_class) > 1)
248        P_.SplitOn(from_state);
249
250      prev_label = from_label;
251      aiter->Next();
252      if (aiter->Done())
253        delete aiter;
254      else
255        aiter_queue_->push(aiter);
256    }
257    P_.FinalizeSplit(&L_);
258  }
259
260  // Main loop for hopcroft minimization.
261  void Compute(const Fst<A>& fst) {
262    // process active classes (FIFO, or FILO)
263    while (!L_.Empty()) {
264      ClassId C = L_.Head();
265      L_.Dequeue();
266
267      // split on C, all labels in C
268      Split(C);
269    }
270  }
271
272  // helper data
273 private:
274  // Partioning of states into equivalence classes
275  Partition<StateId> P_;
276
277  // L = set of active classes to be processed in partition P
278  Queue L_;
279
280  // reverse transition function
281  VectorFst<RevA> Tr_;
282
283  // Priority queue of open arc iterators for all states in the 'splitter'
284  // equivalence class
285  ArcIterQueue* aiter_queue_;
286};
287
288
289// Computes equivalence classes for acyclic Fsts. The implementation details
290// for this algorithms is documented by the following paper.
291//
292// Minimization of acyclic deterministic automata in linear time
293//  Dominque Revuz
294//
295// Complexity O(|E|)
296//
297template <class A>
298class AcyclicMinimizer {
299 public:
300  typedef typename A::Label Label;
301  typedef typename A::StateId StateId;
302  typedef typename A::StateId ClassId;
303  typedef typename A::Weight Weight;
304
305  AcyclicMinimizer(const ExpandedFst<A>& fst) {
306    Initialize(fst);
307    Refine(fst);
308  }
309
310  const Partition<StateId>& partition() {
311    return partition_;
312  }
313
314  // helper classes
315 private:
316  // DFS visitor to compute the height (distance) to final state.
317  class HeightVisitor {
318   public:
319    HeightVisitor() : max_height_(0), num_states_(0) { }
320
321    // invoked before dfs visit
322    void InitVisit(const Fst<A>& fst) {}
323
324    // invoked when state is discovered (2nd arg is DFS tree root)
325    bool InitState(StateId s, StateId root) {
326      // extend height array and initialize height (distance) to 0
327      for (size_t i = height_.size(); i <= (size_t)s; ++i)
328        height_.push_back(-1);
329
330      if (s >= (StateId)num_states_) num_states_ = s + 1;
331      return true;
332    }
333
334    // invoked when tree arc examined (to undiscoverted state)
335    bool TreeArc(StateId s, const A& arc) {
336      return true;
337    }
338
339    // invoked when back arc examined (to unfinished state)
340    bool BackArc(StateId s, const A& arc) {
341      return true;
342    }
343
344    // invoked when forward or cross arc examined (to finished state)
345    bool ForwardOrCrossArc(StateId s, const A& arc) {
346      if (height_[arc.nextstate] + 1 > height_[s])
347        height_[s] = height_[arc.nextstate] + 1;
348      return true;
349    }
350
351    // invoked when state finished (parent is kNoStateId for tree root)
352    void FinishState(StateId s, StateId parent, const A* parent_arc) {
353      if (height_[s] == -1) height_[s] = 0;
354      StateId h = height_[s] +  1;
355      if (parent >= 0) {
356        if (h > height_[parent]) height_[parent] = h;
357        if (h > (StateId)max_height_)     max_height_ = h;
358      }
359    }
360
361    // invoked after DFS visit
362    void FinishVisit() {}
363
364    size_t max_height() const { return max_height_; }
365
366    const vector<StateId>& height() const { return height_; }
367
368    size_t num_states() const { return num_states_; }
369
370   private:
371    vector<StateId> height_;
372    size_t max_height_;
373    size_t num_states_;
374  };
375
376  // helper methods
377 private:
378  // cluster states according to height (distance to final state)
379  void Initialize(const Fst<A>& fst) {
380    // compute height (distance to final state)
381    HeightVisitor hvisitor;
382    DfsVisit(fst, &hvisitor);
383
384    // create initial partition based on height
385    partition_.Initialize(hvisitor.num_states());
386    partition_.AllocateClasses(hvisitor.max_height() + 1);
387    const vector<StateId>& hstates = hvisitor.height();
388    for (size_t s = 0; s < hstates.size(); ++s)
389      partition_.Add(s, hstates[s]);
390  }
391
392  // refine states based on arc sort (out degree, arc equivalence)
393  void Refine(const Fst<A>& fst) {
394    typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
395    StateComparator<A> comp(fst, partition_);
396
397    // start with tail (height = 0)
398    size_t height = partition_.num_classes();
399    for (size_t h = 0; h < height; ++h) {
400      EquivalenceMap equiv_classes(comp);
401
402      // sort states within equivalence class
403      PartitionIterator<StateId> siter(partition_, h);
404      equiv_classes[siter.Value()] = h;
405      for (siter.Next(); !siter.Done(); siter.Next()) {
406        const StateId s = siter.Value();
407        typename EquivalenceMap::const_iterator it = equiv_classes.find(s);
408        if (it == equiv_classes.end())
409          equiv_classes[s] = partition_.AddClass();
410        else
411          equiv_classes[s] = it->second;
412      }
413
414      // create refined partition
415      for (siter.Reset(); !siter.Done();) {
416        const StateId s = siter.Value();
417        const StateId old_class = partition_.class_id(s);
418        const StateId new_class = equiv_classes[s];
419
420        // a move operation can invalidate the iterator, so
421        // we first update the iterator to the next element
422        // before we move the current element out of the list
423        siter.Next();
424        if (old_class != new_class)
425          partition_.Move(s, new_class);
426      }
427    }
428  }
429
430 private:
431  Partition<StateId> partition_;
432};
433
434
435// Given a partition and a mutable fst, merge states of Fst inplace
436// (i.e. destructively). Merging works by taking the first state in
437// a class of the partition to be the representative state for the class.
438// Each arc is then reconnected to this state. All states in the class
439// are merged by adding there arcs to the representative state.
440template <class A>
441void MergeStates(
442    const Partition<typename A::StateId>& partition, MutableFst<A>* fst) {
443  typedef typename A::StateId StateId;
444
445  vector<StateId> state_map(partition.num_classes());
446  for (size_t i = 0; i < (size_t)partition.num_classes(); ++i) {
447    PartitionIterator<StateId> siter(partition, i);
448    state_map[i] = siter.Value();  // first state in partition;
449  }
450
451  // relabel destination states
452  for (size_t c = 0; c < (size_t)partition.num_classes(); ++c) {
453    for (PartitionIterator<StateId> siter(partition, c);
454         !siter.Done(); siter.Next()) {
455      StateId s = siter.Value();
456      for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
457           !aiter.Done(); aiter.Next()) {
458        A arc = aiter.Value();
459        arc.nextstate = state_map[partition.class_id(arc.nextstate)];
460
461        if (s == state_map[c])  // first state just set destination
462          aiter.SetValue(arc);
463        else
464          fst->AddArc(state_map[c], arc);
465      }
466    }
467  }
468  fst->SetStart(state_map[partition.class_id(fst->Start())]);
469
470  Connect(fst);
471}
472
473template <class A>
474void AcceptorMinimize(MutableFst<A>* fst) {
475  typedef typename A::StateId StateId;
476  if (!(fst->Properties(kAcceptor | kUnweighted, true)))
477    LOG(FATAL) << "Input Fst is not an unweighted acceptor";
478
479  // connect fst before minimization, handles disconnected states
480  Connect(fst);
481  if (fst->NumStates() == 0) return;
482
483  if (fst->Properties(kAcyclic, true)) {
484    // Acyclic minimization (revuz)
485    VLOG(2) << "Acyclic Minimization";
486    AcyclicMinimizer<A> minimizer(*fst);
487    MergeStates(minimizer.partition(), fst);
488
489  } else {
490    // Cyclic minimizaton (hopcroft)
491    VLOG(2) << "Cyclic Minimization";
492    CyclicMinimizer<A, LifoQueue<StateId> > minimizer(*fst);
493    MergeStates(minimizer.partition(), fst);
494  }
495
496  // sort arcs before summing
497  ArcSort(fst, ILabelCompare<A>());
498
499  // sum in appropriate semiring
500  ArcSum(fst);
501}
502
503
504// In place minimization of unweighted, deterministic acceptors
505//
506// For acyclic automata we use an algorithm from Dominique Revuz that is
507// linear in the number of arcs (edges) in the machine.
508//  Complexity = O(E)
509//
510// For cyclic automata we use the classical hopcroft minimization.
511//  Complexity = O(|E|log(|N|)
512//
513template <class A>
514void Minimize(MutableFst<A>* fst, MutableFst<A>* sfst = 0) {
515  uint64 props = fst->Properties(kAcceptor | kIDeterministic|
516                                 kWeighted | kUnweighted, true);
517  if (!(props & kIDeterministic))
518    LOG(FATAL) << "Input Fst is not deterministic";
519
520  if (!(props & kAcceptor)) {  // weighted transducer
521    VectorFst< GallicArc<A, STRING_LEFT> > gfst;
522    Map(*fst, &gfst, ToGallicMapper<A, STRING_LEFT>());
523    fst->DeleteStates();
524    gfst.SetProperties(kAcceptor, kAcceptor);
525    Push(&gfst, REWEIGHT_TO_INITIAL);
526    Map(&gfst, QuantizeMapper< GallicArc<A, STRING_LEFT> >());
527    EncodeMapper< GallicArc<A, STRING_LEFT> >
528      encoder(kEncodeLabels | kEncodeWeights, ENCODE);
529    Encode(&gfst, &encoder);
530    AcceptorMinimize(&gfst);
531    Decode(&gfst, encoder);
532
533    if (sfst == 0) {
534      FactorWeightFst< GallicArc<A, STRING_LEFT>,
535        GallicFactor<typename A::Label,
536        typename A::Weight, STRING_LEFT> > fwfst(gfst);
537      Map(fwfst, fst, FromGallicMapper<A, STRING_LEFT>());
538    } else {
539      sfst->SetOutputSymbols(fst->OutputSymbols());
540      GallicToNewSymbolsMapper<A, STRING_LEFT> mapper(sfst);
541      Map(gfst, fst, mapper);
542      fst->SetOutputSymbols(sfst->InputSymbols());
543    }
544  } else if (props & kWeighted) {  // weighted acceptor
545    Push(fst, REWEIGHT_TO_INITIAL);
546    Map(fst, QuantizeMapper<A>());
547    EncodeMapper<A> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
548    Encode(fst, &encoder);
549    AcceptorMinimize(fst);
550    Decode(fst, encoder);
551  } else {  // unweighted acceptor
552    AcceptorMinimize(fst);
553  }
554}
555
556}  // namespace fst
557
558#endif  // FST_LIB_MINIMIZE_H__
559