1// minimize.h
2// minimize.h
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16// Copyright 2005-2010 Google, Inc.
17// Author: johans@google.com (Johan Schalkwyk)
18//
19// \file Functions and classes to minimize a finite state acceptor
20//
21
22#ifndef FST_LIB_MINIMIZE_H__
23#define FST_LIB_MINIMIZE_H__
24
25#include <cmath>
26
27#include <algorithm>
28#include <map>
29#include <queue>
30#include <vector>
31using std::vector;
32
33#include <fst/arcsort.h>
34#include <fst/connect.h>
35#include <fst/dfs-visit.h>
36#include <fst/encode.h>
37#include <fst/factor-weight.h>
38#include <fst/fst.h>
39#include <fst/mutable-fst.h>
40#include <fst/partition.h>
41#include <fst/push.h>
42#include <fst/queue.h>
43#include <fst/reverse.h>
44#include <fst/state-map.h>
45
46
47namespace fst {
48
49// comparator for creating partition based on sorting on
50// - states
51// - final weight
52// - out degree,
53// -  (input label, output label, weight, destination_block)
54template <class A>
55class StateComparator {
56 public:
57  typedef typename A::StateId StateId;
58  typedef typename A::Weight Weight;
59
60  static const uint32 kCompareFinal     = 0x00000001;
61  static const uint32 kCompareOutDegree = 0x00000002;
62  static const uint32 kCompareArcs      = 0x00000004;
63  static const uint32 kCompareAll       = 0x00000007;
64
65  StateComparator(const Fst<A>& fst,
66                  const Partition<typename A::StateId>& partition,
67                  uint32 flags = kCompareAll)
68      : fst_(fst), partition_(partition), flags_(flags) {}
69
70  // compare state x with state y based on sort criteria
71  bool operator()(const StateId x, const StateId y) const {
72    // check for final state equivalence
73    if (flags_ & kCompareFinal) {
74      const size_t xfinal = fst_.Final(x).Hash();
75      const size_t yfinal = fst_.Final(y).Hash();
76      if      (xfinal < yfinal) return true;
77      else if (xfinal > yfinal) return false;
78    }
79
80    if (flags_ & kCompareOutDegree) {
81      // check for # arcs
82      if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true;
83      if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false;
84
85      if (flags_ & kCompareArcs) {
86        // # arcs are equal, check for arc match
87        for (ArcIterator<Fst<A> > aiter1(fst_, x), aiter2(fst_, y);
88             !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
89          const A& arc1 = aiter1.Value();
90          const A& arc2 = aiter2.Value();
91          if (arc1.ilabel < arc2.ilabel) return true;
92          if (arc1.ilabel > arc2.ilabel) return false;
93
94          if (partition_.class_id(arc1.nextstate) <
95              partition_.class_id(arc2.nextstate)) return true;
96          if (partition_.class_id(arc1.nextstate) >
97              partition_.class_id(arc2.nextstate)) return false;
98        }
99      }
100    }
101
102    return false;
103  }
104
105 private:
106  const Fst<A>& fst_;
107  const Partition<typename A::StateId>& partition_;
108  const uint32 flags_;
109};
110
111template <class A> const uint32 StateComparator<A>::kCompareFinal;
112template <class A> const uint32 StateComparator<A>::kCompareOutDegree;
113template <class A> const uint32 StateComparator<A>::kCompareArcs;
114template <class A> const uint32 StateComparator<A>::kCompareAll;
115
116
117// Computes equivalence classes for cyclic Fsts. For cyclic minimization
118// we use the classic HopCroft minimization algorithm, which is of
119//
120//   O(E)log(N),
121//
122// where E is the number of edges in the machine and N is number of states.
123//
124// The following paper describes the original algorithm
125//  An N Log N algorithm for minimizing states in a finite automaton
126//  by John HopCroft, January 1971
127//
128template <class A, class Queue>
129class CyclicMinimizer {
130 public:
131  typedef typename A::Label Label;
132  typedef typename A::StateId StateId;
133  typedef typename A::StateId ClassId;
134  typedef typename A::Weight Weight;
135  typedef ReverseArc<A> RevA;
136
137  CyclicMinimizer(const ExpandedFst<A>& fst) {
138    Initialize(fst);
139    Compute(fst);
140  }
141
142  ~CyclicMinimizer() {
143    delete aiter_queue_;
144  }
145
146  const Partition<StateId>& partition() const {
147    return P_;
148  }
149
150  // helper classes
151 private:
152  typedef ArcIterator<Fst<RevA> > ArcIter;
153  class ArcIterCompare {
154   public:
155    ArcIterCompare(const Partition<StateId>& partition)
156        : partition_(partition) {}
157
158    ArcIterCompare(const ArcIterCompare& comp)
159        : partition_(comp.partition_) {}
160
161    // compare two iterators based on there input labels, and proto state
162    // (partition class Ids)
163    bool operator()(const ArcIter* x, const ArcIter* y) const {
164      const RevA& xarc = x->Value();
165      const RevA& yarc = y->Value();
166      return (xarc.ilabel > yarc.ilabel);
167    }
168
169   private:
170    const Partition<StateId>& partition_;
171  };
172
173  typedef priority_queue<ArcIter*, vector<ArcIter*>, ArcIterCompare>
174  ArcIterQueue;
175
176  // helper methods
177 private:
178  // prepartitions the space into equivalence classes with
179  //   same final weight
180  //   same # arcs per state
181  //   same outgoing arcs
182  void PrePartition(const Fst<A>& fst) {
183    VLOG(5) << "PrePartition";
184
185    typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
186    StateComparator<A> comp(fst, P_, StateComparator<A>::kCompareFinal);
187    EquivalenceMap equiv_map(comp);
188
189    StateIterator<Fst<A> > siter(fst);
190    StateId class_id = P_.AddClass();
191    P_.Add(siter.Value(), class_id);
192    equiv_map[siter.Value()] = class_id;
193    L_.Enqueue(class_id);
194    for (siter.Next(); !siter.Done(); siter.Next()) {
195      StateId  s = siter.Value();
196      typename EquivalenceMap::const_iterator it = equiv_map.find(s);
197      if (it == equiv_map.end()) {
198        class_id = P_.AddClass();
199        P_.Add(s, class_id);
200        equiv_map[s] = class_id;
201        L_.Enqueue(class_id);
202      } else {
203        P_.Add(s, it->second);
204        equiv_map[s] = it->second;
205      }
206    }
207
208    VLOG(5) << "Initial Partition: " << P_.num_classes();
209  }
210
211  // - Create inverse transition Tr_ = rev(fst)
212  // - loop over states in fst and split on final, creating two blocks
213  //   in the partition corresponding to final, non-final
214  void Initialize(const Fst<A>& fst) {
215    // construct Tr
216    Reverse(fst, &Tr_);
217    ILabelCompare<RevA> ilabel_comp;
218    ArcSort(&Tr_, ilabel_comp);
219
220    // initial split (F, S - F)
221    P_.Initialize(Tr_.NumStates() - 1);
222
223    // prep partition
224    PrePartition(fst);
225
226    // allocate arc iterator queue
227    ArcIterCompare comp(P_);
228    aiter_queue_ = new ArcIterQueue(comp);
229  }
230
231  // partition all classes with destination C
232  void Split(ClassId C) {
233    // Prep priority queue. Open arc iterator for each state in C, and
234    // insert into priority queue.
235    for (PartitionIterator<StateId> siter(P_, C);
236         !siter.Done(); siter.Next()) {
237      StateId s = siter.Value();
238      if (Tr_.NumArcs(s + 1))
239        aiter_queue_->push(new ArcIterator<Fst<RevA> >(Tr_, s + 1));
240    }
241
242    // Now pop arc iterator from queue, split entering equivalence class
243    // re-insert updated iterator into queue.
244    Label prev_label = -1;
245    while (!aiter_queue_->empty()) {
246      ArcIterator<Fst<RevA> >* aiter = aiter_queue_->top();
247      aiter_queue_->pop();
248      if (aiter->Done()) {
249        delete aiter;
250        continue;
251     }
252
253      const RevA& arc = aiter->Value();
254      StateId from_state = aiter->Value().nextstate - 1;
255      Label   from_label = arc.ilabel;
256      if (prev_label != from_label)
257        P_.FinalizeSplit(&L_);
258
259      StateId from_class = P_.class_id(from_state);
260      if (P_.class_size(from_class) > 1)
261        P_.SplitOn(from_state);
262
263      prev_label = from_label;
264      aiter->Next();
265      if (aiter->Done())
266        delete aiter;
267      else
268        aiter_queue_->push(aiter);
269    }
270    P_.FinalizeSplit(&L_);
271  }
272
273  // Main loop for hopcroft minimization.
274  void Compute(const Fst<A>& fst) {
275    // process active classes (FIFO, or FILO)
276    while (!L_.Empty()) {
277      ClassId C = L_.Head();
278      L_.Dequeue();
279
280      // split on C, all labels in C
281      Split(C);
282    }
283  }
284
285  // helper data
286 private:
287  // Partioning of states into equivalence classes
288  Partition<StateId> P_;
289
290  // L = set of active classes to be processed in partition P
291  Queue L_;
292
293  // reverse transition function
294  VectorFst<RevA> Tr_;
295
296  // Priority queue of open arc iterators for all states in the 'splitter'
297  // equivalence class
298  ArcIterQueue* aiter_queue_;
299};
300
301
302// Computes equivalence classes for acyclic Fsts. The implementation details
303// for this algorithms is documented by the following paper.
304//
305// Minimization of acyclic deterministic automata in linear time
306//  Dominque Revuz
307//
308// Complexity O(|E|)
309//
310template <class A>
311class AcyclicMinimizer {
312 public:
313  typedef typename A::Label Label;
314  typedef typename A::StateId StateId;
315  typedef typename A::StateId ClassId;
316  typedef typename A::Weight Weight;
317
318  AcyclicMinimizer(const ExpandedFst<A>& fst) {
319    Initialize(fst);
320    Refine(fst);
321  }
322
323  const Partition<StateId>& partition() {
324    return partition_;
325  }
326
327  // helper classes
328 private:
329  // DFS visitor to compute the height (distance) to final state.
330  class HeightVisitor {
331   public:
332    HeightVisitor() : max_height_(0), num_states_(0) { }
333
334    // invoked before dfs visit
335    void InitVisit(const Fst<A>& fst) {}
336
337    // invoked when state is discovered (2nd arg is DFS tree root)
338    bool InitState(StateId s, StateId root) {
339      // extend height array and initialize height (distance) to 0
340      for (size_t i = height_.size(); i <= s; ++i)
341        height_.push_back(-1);
342
343      if (s >= num_states_) num_states_ = s + 1;
344      return true;
345    }
346
347    // invoked when tree arc examined (to undiscoverted state)
348    bool TreeArc(StateId s, const A& arc) {
349      return true;
350    }
351
352    // invoked when back arc examined (to unfinished state)
353    bool BackArc(StateId s, const A& arc) {
354      return true;
355    }
356
357    // invoked when forward or cross arc examined (to finished state)
358    bool ForwardOrCrossArc(StateId s, const A& arc) {
359      if (height_[arc.nextstate] + 1 > height_[s])
360        height_[s] = height_[arc.nextstate] + 1;
361      return true;
362    }
363
364    // invoked when state finished (parent is kNoStateId for tree root)
365    void FinishState(StateId s, StateId parent, const A* parent_arc) {
366      if (height_[s] == -1) height_[s] = 0;
367      StateId h = height_[s] +  1;
368      if (parent >= 0) {
369        if (h > height_[parent]) height_[parent] = h;
370        if (h > max_height_)     max_height_ = h;
371      }
372    }
373
374    // invoked after DFS visit
375    void FinishVisit() {}
376
377    size_t max_height() const { return max_height_; }
378
379    const vector<StateId>& height() const { return height_; }
380
381    const size_t num_states() const { return num_states_; }
382
383   private:
384    vector<StateId> height_;
385    size_t max_height_;
386    size_t num_states_;
387  };
388
389  // helper methods
390 private:
391  // cluster states according to height (distance to final state)
392  void Initialize(const Fst<A>& fst) {
393    // compute height (distance to final state)
394    HeightVisitor hvisitor;
395    DfsVisit(fst, &hvisitor);
396
397    // create initial partition based on height
398    partition_.Initialize(hvisitor.num_states());
399    partition_.AllocateClasses(hvisitor.max_height() + 1);
400    const vector<StateId>& hstates = hvisitor.height();
401    for (size_t s = 0; s < hstates.size(); ++s)
402      partition_.Add(s, hstates[s]);
403  }
404
405  // refine states based on arc sort (out degree, arc equivalence)
406  void Refine(const Fst<A>& fst) {
407    typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
408    StateComparator<A> comp(fst, partition_);
409
410    // start with tail (height = 0)
411    size_t height = partition_.num_classes();
412    for (size_t h = 0; h < height; ++h) {
413      EquivalenceMap equiv_classes(comp);
414
415      // sort states within equivalence class
416      PartitionIterator<StateId> siter(partition_, h);
417      equiv_classes[siter.Value()] = h;
418      for (siter.Next(); !siter.Done(); siter.Next()) {
419        const StateId s = siter.Value();
420        typename EquivalenceMap::const_iterator it = equiv_classes.find(s);
421        if (it == equiv_classes.end())
422          equiv_classes[s] = partition_.AddClass();
423        else
424          equiv_classes[s] = it->second;
425      }
426
427      // create refined partition
428      for (siter.Reset(); !siter.Done();) {
429        const StateId s = siter.Value();
430        const StateId old_class = partition_.class_id(s);
431        const StateId new_class = equiv_classes[s];
432
433        // a move operation can invalidate the iterator, so
434        // we first update the iterator to the next element
435        // before we move the current element out of the list
436        siter.Next();
437        if (old_class != new_class)
438          partition_.Move(s, new_class);
439      }
440    }
441  }
442
443 private:
444  Partition<StateId> partition_;
445};
446
447
448// Given a partition and a mutable fst, merge states of Fst inplace
449// (i.e. destructively). Merging works by taking the first state in
450// a class of the partition to be the representative state for the class.
451// Each arc is then reconnected to this state. All states in the class
452// are merged by adding there arcs to the representative state.
453template <class A>
454void MergeStates(
455    const Partition<typename A::StateId>& partition, MutableFst<A>* fst) {
456  typedef typename A::StateId StateId;
457
458  vector<StateId> state_map(partition.num_classes());
459  for (size_t i = 0; i < partition.num_classes(); ++i) {
460    PartitionIterator<StateId> siter(partition, i);
461    state_map[i] = siter.Value();  // first state in partition;
462  }
463
464  // relabel destination states
465  for (size_t c = 0; c < partition.num_classes(); ++c) {
466    for (PartitionIterator<StateId> siter(partition, c);
467         !siter.Done(); siter.Next()) {
468      StateId s = siter.Value();
469      for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
470           !aiter.Done(); aiter.Next()) {
471        A arc = aiter.Value();
472        arc.nextstate = state_map[partition.class_id(arc.nextstate)];
473
474        if (s == state_map[c])  // first state just set destination
475          aiter.SetValue(arc);
476        else
477          fst->AddArc(state_map[c], arc);
478      }
479    }
480  }
481  fst->SetStart(state_map[partition.class_id(fst->Start())]);
482
483  Connect(fst);
484}
485
486template <class A>
487void AcceptorMinimize(MutableFst<A>* fst) {
488  typedef typename A::StateId StateId;
489  if (!(fst->Properties(kAcceptor | kUnweighted, true))) {
490    FSTERROR() << "FST is not an unweighted acceptor";
491    fst->SetProperties(kError, kError);
492    return;
493  }
494
495  // connect fst before minimization, handles disconnected states
496  Connect(fst);
497  if (fst->NumStates() == 0) return;
498
499  if (fst->Properties(kAcyclic, true)) {
500    // Acyclic minimization (revuz)
501    VLOG(2) << "Acyclic Minimization";
502    ArcSort(fst, ILabelCompare<A>());
503    AcyclicMinimizer<A> minimizer(*fst);
504    MergeStates(minimizer.partition(), fst);
505
506  } else {
507    // Cyclic minimizaton (hopcroft)
508    VLOG(2) << "Cyclic Minimization";
509    CyclicMinimizer<A, LifoQueue<StateId> > minimizer(*fst);
510    MergeStates(minimizer.partition(), fst);
511  }
512
513  // Merge in appropriate semiring
514  ArcUniqueMapper<A> mapper(*fst);
515  StateMap(fst, mapper);
516}
517
518
519// In place minimization of deterministic weighted automata and transducers.
520// For transducers, then the 'sfst' argument is not null, the algorithm
521// produces a compact factorization of the minimal transducer.
522//
523// In the acyclic case, we use an algorithm from Dominique Revuz that
524// is linear in the number of arcs (edges) in the machine.
525//  Complexity = O(E)
526//
527// In the cyclic case, we use the classical hopcroft minimization.
528//  Complexity = O(|E|log(|N|)
529//
530template <class A>
531void Minimize(MutableFst<A>* fst,
532              MutableFst<A>* sfst = 0,
533              float delta = kDelta) {
534  uint64 props = fst->Properties(kAcceptor | kIDeterministic|
535                                 kWeighted | kUnweighted, true);
536  if (!(props & kIDeterministic)) {
537    FSTERROR() << "FST is not deterministic";
538    fst->SetProperties(kError, kError);
539    return;
540  }
541
542  if (!(props & kAcceptor)) {  // weighted transducer
543    VectorFst< GallicArc<A, STRING_LEFT> > gfst;
544    ArcMap(*fst, &gfst, ToGallicMapper<A, STRING_LEFT>());
545    fst->DeleteStates();
546    gfst.SetProperties(kAcceptor, kAcceptor);
547    Push(&gfst, REWEIGHT_TO_INITIAL, delta);
548    ArcMap(&gfst, QuantizeMapper< GallicArc<A, STRING_LEFT> >(delta));
549    EncodeMapper< GallicArc<A, STRING_LEFT> >
550      encoder(kEncodeLabels | kEncodeWeights, ENCODE);
551    Encode(&gfst, &encoder);
552    AcceptorMinimize(&gfst);
553    Decode(&gfst, encoder);
554
555    if (sfst == 0) {
556      FactorWeightFst< GallicArc<A, STRING_LEFT>,
557        GallicFactor<typename A::Label,
558        typename A::Weight, STRING_LEFT> > fwfst(gfst);
559      SymbolTable *osyms = fst->OutputSymbols() ?
560          fst->OutputSymbols()->Copy() : 0;
561      ArcMap(fwfst, fst, FromGallicMapper<A, STRING_LEFT>());
562      fst->SetOutputSymbols(osyms);
563      delete osyms;
564    } else {
565      sfst->SetOutputSymbols(fst->OutputSymbols());
566      GallicToNewSymbolsMapper<A, STRING_LEFT> mapper(sfst);
567      ArcMap(gfst, fst, &mapper);
568      fst->SetOutputSymbols(sfst->InputSymbols());
569    }
570  } else if (props & kWeighted) {  // weighted acceptor
571    Push(fst, REWEIGHT_TO_INITIAL, delta);
572    ArcMap(fst, QuantizeMapper<A>(delta));
573    EncodeMapper<A> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
574    Encode(fst, &encoder);
575    AcceptorMinimize(fst);
576    Decode(fst, encoder);
577  } else {  // unweighted acceptor
578    AcceptorMinimize(fst);
579  }
580}
581
582}  // namespace fst
583
584#endif  // FST_LIB_MINIMIZE_H__
585