1// replace.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// Functions and classes for the recursive replacement of Fsts.
18//
19
20#ifndef FST_LIB_REPLACE_H__
21#define FST_LIB_REPLACE_H__
22
23#include <unordered_map>
24
25#include "fst/lib/fst.h"
26#include "fst/lib/cache.h"
27#include "fst/lib/test-properties.h"
28
29namespace fst {
30
31// By default ReplaceFst will copy the input label of the 'replace arc'.
32// For acceptors we do not want this behaviour. Instead we need to
33// create an epsilon arc when recursing into the appropriate Fst.
34// The epsilon_on_replace option can be used to toggle this behaviour.
35struct ReplaceFstOptions : CacheOptions {
36  int64 root;    // root rule for expansion
37  bool  epsilon_on_replace;
38
39  ReplaceFstOptions(const CacheOptions &opts, int64 r)
40      : CacheOptions(opts), root(r), epsilon_on_replace(false) {}
41  explicit ReplaceFstOptions(int64 r)
42      : root(r), epsilon_on_replace(false) {}
43  ReplaceFstOptions(int64 r, bool epsilon_replace_arc)
44      : root(r), epsilon_on_replace(epsilon_replace_arc) {}
45  ReplaceFstOptions()
46      : root(kNoLabel), epsilon_on_replace(false) {}
47};
48
49//
50// \class ReplaceFstImpl
51// \brief Implementation class for replace class Fst
52//
53// The replace implementation class supports a dynamic
54// expansion of a recursive transition network represented as Fst
55// with dynamic replacable arcs.
56//
57template <class A>
58class ReplaceFstImpl : public CacheImpl<A> {
59 public:
60  using FstImpl<A>::SetType;
61  using FstImpl<A>::SetProperties;
62  using FstImpl<A>::Properties;
63  using FstImpl<A>::SetInputSymbols;
64  using FstImpl<A>::SetOutputSymbols;
65  using FstImpl<A>::InputSymbols;
66  using FstImpl<A>::OutputSymbols;
67
68  using CacheImpl<A>::HasStart;
69  using CacheImpl<A>::HasArcs;
70  using CacheImpl<A>::SetStart;
71
72  typedef typename A::Label   Label;
73  typedef typename A::Weight  Weight;
74  typedef typename A::StateId StateId;
75  typedef CacheState<A> State;
76  typedef A Arc;
77  typedef std::unordered_map<Label, Label> NonTerminalHash;
78
79
80  // \struct StateTuple
81  // \brief Tuple of information that uniquely defines a state
82  struct StateTuple {
83    typedef int PrefixId;
84
85    StateTuple() {}
86    StateTuple(PrefixId p, StateId f, StateId s) :
87        prefix_id(p), fst_id(f), fst_state(s) {}
88
89    PrefixId prefix_id;  // index in prefix table
90    StateId fst_id;      // current fst being walked
91    StateId fst_state;   // current state in fst being walked, not to be
92                         // confused with the state_id of the combined fst
93  };
94
95  // constructor for replace class implementation.
96  // \param fst_tuples array of label/fst tuples, one for each non-terminal
97  ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples,
98                 const ReplaceFstOptions &opts)
99      : CacheImpl<A>(opts), opts_(opts) {
100    SetType("replace");
101    if (fst_tuples.size() > 0) {
102      SetInputSymbols(fst_tuples[0].second->InputSymbols());
103      SetOutputSymbols(fst_tuples[0].second->OutputSymbols());
104    }
105
106    fst_array_.push_back(0);
107    for (size_t i = 0; i < fst_tuples.size(); ++i)
108      AddFst(fst_tuples[i].first, fst_tuples[i].second);
109
110    SetRoot(opts.root);
111  }
112
113  explicit ReplaceFstImpl(const ReplaceFstOptions &opts)
114      : CacheImpl<A>(opts), opts_(opts), root_(kNoLabel) {
115    fst_array_.push_back(0);
116  }
117
118  ReplaceFstImpl(const ReplaceFstImpl& impl)
119      : opts_(impl.opts_), state_tuples_(impl.state_tuples_),
120        state_hash_(impl.state_hash_),
121        prefix_hash_(impl.prefix_hash_),
122        stackprefix_array_(impl.stackprefix_array_),
123        nonterminal_hash_(impl.nonterminal_hash_),
124        root_(impl.root_) {
125    SetType("replace");
126    SetProperties(impl.Properties(), kCopyProperties);
127    SetInputSymbols(InputSymbols());
128    SetOutputSymbols(OutputSymbols());
129    fst_array_.reserve(impl.fst_array_.size());
130    fst_array_.push_back(0);
131    for (size_t i = 1; i < impl.fst_array_.size(); ++i)
132      fst_array_.push_back(impl.fst_array_[i]->Copy());
133  }
134
135  ~ReplaceFstImpl() {
136    for (size_t i = 1; i < fst_array_.size(); ++i) {
137      delete fst_array_[i];
138    }
139  }
140
141  // Add to Fst array
142  void AddFst(Label label, const Fst<A>* fst) {
143    nonterminal_hash_[label] = fst_array_.size();
144    fst_array_.push_back(fst->Copy());
145    if (fst_array_.size() > 1) {
146      vector<uint64> inprops(fst_array_.size());
147
148      for (size_t i = 1; i < fst_array_.size(); ++i) {
149        inprops[i] = fst_array_[i]->Properties(kCopyProperties, false);
150      }
151      SetProperties(ReplaceProperties(inprops));
152
153      const SymbolTable* isymbols = fst_array_[1]->InputSymbols();
154      const SymbolTable* osymbols = fst_array_[1]->OutputSymbols();
155      for (size_t i = 2; i < fst_array_.size(); ++i) {
156        if (!CompatSymbols(isymbols, fst_array_[i]->InputSymbols())) {
157          LOG(FATAL) << "ReplaceFst::AddFst input symbols of Fst " << i-1
158                     << " does not match input symbols of base Fst (0'th fst)";
159        }
160        if (!CompatSymbols(osymbols, fst_array_[i]->OutputSymbols())) {
161          LOG(FATAL) << "ReplaceFst::AddFst output symbols of Fst " << i-1
162                     << " does not match output symbols of base Fst "
163                     << "(0'th fst)";
164        }
165      }
166    }
167  }
168
169  // Computes the dependency graph of the replace class and returns
170  // true if the dependencies are cyclic. Cyclic dependencies will result
171  // in an un-expandable replace fst.
172  bool CyclicDependencies() const {
173    StdVectorFst depfst;
174
175    // one state for each fst
176    for (size_t i = 1; i < fst_array_.size(); ++i)
177      depfst.AddState();
178
179    // an arc from each state (representing the fst) to the
180    // state representing the fst being replaced
181    for (size_t i = 1; i < fst_array_.size(); ++i) {
182      for (StateIterator<Fst<A> > siter(*(fst_array_[i]));
183           !siter.Done(); siter.Next()) {
184        for (ArcIterator<Fst<A> > aiter(*(fst_array_[i]), siter.Value());
185             !aiter.Done(); aiter.Next()) {
186          const A& arc = aiter.Value();
187
188          typename NonTerminalHash::const_iterator it =
189              nonterminal_hash_.find(arc.olabel);
190          if (it != nonterminal_hash_.end()) {
191            Label j = it->second - 1;
192            depfst.AddArc(i - 1, A(arc.olabel, arc.olabel, Weight::One(), j));
193          }
194        }
195      }
196    }
197
198    depfst.SetStart(root_ - 1);
199    depfst.SetFinal(root_ - 1, Weight::One());
200    return depfst.Properties(kCyclic, true);
201  }
202
203  // set root rule for expansion
204  void SetRoot(Label root) {
205    Label nonterminal = nonterminal_hash_[root];
206    root_ = (nonterminal > 0) ? nonterminal : 1;
207  }
208
209  // Change Fst array
210  void SetFst(Label label, const Fst<A>* fst) {
211    Label nonterminal = nonterminal_hash_[label];
212    delete fst_array_[nonterminal];
213    fst_array_[nonterminal] = fst->Copy();
214  }
215
216  // Return or compute start state of replace fst
217  StateId Start() {
218    if (!HasStart()) {
219      if (fst_array_.size() == 1) {      // no fsts defined for replace
220        SetStart(kNoStateId);
221        return kNoStateId;
222      } else {
223        const Fst<A>* fst = fst_array_[root_];
224        StateId fst_start = fst->Start();
225        if (fst_start == kNoStateId)  // root Fst is empty
226          return kNoStateId;
227
228        int prefix = PrefixId(StackPrefix());
229        StateId start = FindState(StateTuple(prefix, root_, fst_start));
230        SetStart(start);
231        return start;
232      }
233    } else {
234      return CacheImpl<A>::Start();
235    }
236  }
237
238  // return final weight of state (kInfWeight means state is not final)
239  Weight Final(StateId s) {
240    if (!HasFinal(s)) {
241      const StateTuple& tuple  = state_tuples_[s];
242      const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
243      const Fst<A>* fst = fst_array_[tuple.fst_id];
244      StateId fst_state = tuple.fst_state;
245
246      if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0)
247        SetFinal(s, fst->Final(fst_state));
248      else
249        SetFinal(s, Weight::Zero());
250    }
251    return CacheImpl<A>::Final(s);
252  }
253
254  size_t NumArcs(StateId s) {
255    if (!HasArcs(s))
256      Expand(s);
257    return CacheImpl<A>::NumArcs(s);
258  }
259
260  size_t NumInputEpsilons(StateId s) {
261    if (!HasArcs(s))
262      Expand(s);
263    return CacheImpl<A>::NumInputEpsilons(s);
264  }
265
266  size_t NumOutputEpsilons(StateId s) {
267    if (!HasArcs(s))
268      Expand(s);
269    return CacheImpl<A>::NumOutputEpsilons(s);
270  }
271
272  // return the base arc iterator, if arcs have not been computed yet,
273  // extend/recurse for new arcs.
274  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
275    if (!HasArcs(s))
276      Expand(s);
277    CacheImpl<A>::InitArcIterator(s, data);
278  }
279
280  // Find/create an Fst state given a StateTuple.  Only create a new
281  // state if StateTuple is not found in the state hash.
282  StateId FindState(const StateTuple& tuple) {
283    typename StateTupleHash::iterator it = state_hash_.find(tuple);
284    if (it == state_hash_.end()) {
285      StateId new_state_id = state_tuples_.size();
286      state_tuples_.push_back(tuple);
287      state_hash_[tuple] = new_state_id;
288      return new_state_id;
289    } else {
290      return it->second;
291    }
292  }
293
294  // extend current state (walk arcs one level deep)
295  void Expand(StateId s) {
296    StateTuple tuple  = state_tuples_[s];
297    const Fst<A>* fst = fst_array_[tuple.fst_id];
298    StateId fst_state = tuple.fst_state;
299    if (fst_state == kNoStateId) {
300      SetArcs(s);
301      return;
302    }
303
304    // if state is final, pop up stack
305    const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
306    if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) {
307      int prefix_id = PopPrefix(stack);
308      const PrefixTuple& top = stack.Top();
309
310      StateId nextstate =
311        FindState(StateTuple(prefix_id, top.fst_id, top.nextstate));
312      AddArc(s, A(0, 0, fst->Final(fst_state), nextstate));
313    }
314
315    // extend arcs leaving the state
316    for (ArcIterator< Fst<A> > aiter(*fst, fst_state);
317         !aiter.Done(); aiter.Next()) {
318      const Arc& arc = aiter.Value();
319      if (arc.olabel == 0) {  // expand local fst
320        StateId nextstate =
321          FindState(StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate));
322        AddArc(s, A(arc.ilabel, arc.olabel, arc.weight, nextstate));
323      } else {
324        // check for non terminal
325        typename NonTerminalHash::const_iterator it =
326            nonterminal_hash_.find(arc.olabel);
327        if (it != nonterminal_hash_.end()) {  // recurse into non terminal
328          Label nonterminal = it->second;
329          const Fst<A>* nt_fst = fst_array_[nonterminal];
330          int nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id],
331                                     tuple.fst_id, arc.nextstate);
332
333          // if start state is valid replace, else arc is implicitly
334          // deleted
335          StateId nt_start = nt_fst->Start();
336          if (nt_start != kNoStateId) {
337            StateId nt_nextstate = FindState(
338                StateTuple(nt_prefix, nonterminal, nt_start));
339            Label ilabel = (opts_.epsilon_on_replace) ? 0 : arc.ilabel;
340            AddArc(s, A(ilabel, 0, arc.weight, nt_nextstate));
341          }
342        } else {
343          StateId nextstate =
344            FindState(
345                StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate));
346          AddArc(s, A(arc.ilabel, arc.olabel, arc.weight, nextstate));
347        }
348      }
349    }
350
351    SetArcs(s);
352  }
353
354
355  // private helper classes
356 private:
357  static const int kPrime0 = 7853;
358  static const int kPrime1 = 7867;
359
360  // \class StateTupleEqual
361  // \brief Compare two StateTuples for equality
362  class StateTupleEqual {
363   public:
364    bool operator()(const StateTuple& x, const StateTuple& y) const {
365      return ((x.prefix_id == y.prefix_id) && (x.fst_id == y.fst_id) &&
366              (x.fst_state == y.fst_state));
367    }
368  };
369
370  // \class StateTupleKey
371  // \brief Hash function for StateTuple to Fst states
372  class StateTupleKey {
373   public:
374    size_t operator()(const StateTuple& x) const {
375      return static_cast<size_t>(x.prefix_id +
376                                 x.fst_id * kPrime0 +
377                                 x.fst_state * kPrime1);
378    }
379  };
380
381  typedef std::unordered_map<StateTuple, StateId, StateTupleKey, StateTupleEqual>
382  StateTupleHash;
383
384  // \class PrefixTuple
385  // \brief Tuple of fst_id and destination state (entry in stack prefix)
386  struct PrefixTuple {
387    PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {}
388
389    Label   fst_id;
390    StateId nextstate;
391  };
392
393  // \class StackPrefix
394  // \brief Container for stack prefix.
395  class StackPrefix {
396   public:
397    StackPrefix() {}
398
399    // copy constructor
400    StackPrefix(const StackPrefix& x) :
401        prefix_(x.prefix_) {
402    }
403
404    void Push(int fst_id, StateId nextstate) {
405      prefix_.push_back(PrefixTuple(fst_id, nextstate));
406    }
407
408    void Pop() {
409      prefix_.pop_back();
410    }
411
412    const PrefixTuple& Top() const {
413      return prefix_[prefix_.size()-1];
414    }
415
416    size_t Depth() const {
417      return prefix_.size();
418    }
419
420   public:
421    vector<PrefixTuple> prefix_;
422  };
423
424
425  // \class StackPrefixEqual
426  // \brief Compare two stack prefix classes for equality
427  class StackPrefixEqual {
428   public:
429    bool operator()(const StackPrefix& x, const StackPrefix& y) const {
430      if (x.prefix_.size() != y.prefix_.size()) return false;
431      for (size_t i = 0; i < x.prefix_.size(); ++i) {
432        if (x.prefix_[i].fst_id    != y.prefix_[i].fst_id ||
433           x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false;
434      }
435      return true;
436    }
437  };
438
439  //
440  // \class StackPrefixKey
441  // \brief Hash function for stack prefix to prefix id
442  class StackPrefixKey {
443   public:
444    size_t operator()(const StackPrefix& x) const {
445      int sum = 0;
446      for (size_t i = 0; i < x.prefix_.size(); ++i) {
447        sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0;
448      }
449      return (size_t) sum;
450    }
451  };
452
453  typedef std::unordered_map<StackPrefix, int, StackPrefixKey, StackPrefixEqual>
454  StackPrefixHash;
455
456  // private methods
457 private:
458  // hash stack prefix (return unique index into stackprefix array)
459  int PrefixId(const StackPrefix& prefix) {
460    typename StackPrefixHash::iterator it = prefix_hash_.find(prefix);
461    if (it == prefix_hash_.end()) {
462      int prefix_id = stackprefix_array_.size();
463      stackprefix_array_.push_back(prefix);
464      prefix_hash_[prefix] = prefix_id;
465      return prefix_id;
466    } else {
467      return it->second;
468    }
469  }
470
471  // prefix id after a stack pop
472  int PopPrefix(StackPrefix prefix) {
473    prefix.Pop();
474    return PrefixId(prefix);
475  }
476
477  // prefix id after a stack push
478  int PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
479    prefix.Push(fst_id, nextstate);
480    return PrefixId(prefix);
481  }
482
483
484  // private data
485 private:
486  // runtime options
487  ReplaceFstOptions opts_;
488
489  // maps from StateId to StateTuple
490  vector<StateTuple> state_tuples_;
491
492  // hashes from StateTuple to StateId
493  StateTupleHash state_hash_;
494
495  // cross index of unique stack prefix
496  // could potentially have one copy of prefix array
497  StackPrefixHash prefix_hash_;
498  vector<StackPrefix> stackprefix_array_;
499
500  NonTerminalHash nonterminal_hash_;
501  vector<const Fst<A>*> fst_array_;
502
503  Label root_;
504
505  void operator=(const ReplaceFstImpl<A> &);  // disallow
506};
507
508
509//
510// \class ReplaceFst
511// \brief Recursivively replaces arcs in the root Fst with other Fsts.
512// This version is a delayed Fst.
513//
514// ReplaceFst supports dynamic replacement of arcs in one Fst with
515// another Fst. This replacement is recursive.  ReplaceFst can be used
516// to support a variety of delayed constructions such as recursive
517// transition networks, union, or closure.  It is constructed with an
518// array of Fst(s). One Fst represents the root (or topology)
519// machine. The root Fst refers to other Fsts by recursively replacing
520// arcs labeled as non-terminals with the matching non-terminal
521// Fst. Currently the ReplaceFst uses the output symbols of the arcs
522// to determine whether the arc is a non-terminal arc or not. A
523// non-terminal can be any label that is not a non-zero terminal label
524// in the output alphabet.
525//
526// Note that the constructor uses a vector of pair<>. These correspond
527// to the tuple of non-terminal Label and corresponding Fst. For example
528// to implement the closure operation we need 2 Fsts. The first root
529// Fst is a single Arc on the start State that self loops, it references
530// the particular machine for which we are performing the closure operation.
531//
532template <class A>
533class ReplaceFst : public Fst<A> {
534 public:
535  friend class ArcIterator< ReplaceFst<A> >;
536  friend class CacheStateIterator< ReplaceFst<A> >;
537  friend class CacheArcIterator< ReplaceFst<A> >;
538
539  typedef A Arc;
540  typedef typename A::Label   Label;
541  typedef typename A::Weight  Weight;
542  typedef typename A::StateId StateId;
543  typedef CacheState<A> State;
544
545  ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
546             Label root)
547      : impl_(new ReplaceFstImpl<A>(fst_array, ReplaceFstOptions(root))) {}
548
549  ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
550             const ReplaceFstOptions &opts)
551      : impl_(new ReplaceFstImpl<A>(fst_array, opts)) {}
552
553  ReplaceFst(const ReplaceFst<A>& fst) :
554      impl_(new ReplaceFstImpl<A>(*(fst.impl_))) {}
555
556  virtual ~ReplaceFst() {
557    delete impl_;
558  }
559
560  virtual StateId Start() const {
561    return impl_->Start();
562  }
563
564  virtual Weight Final(StateId s) const {
565    return impl_->Final(s);
566  }
567
568  virtual size_t NumArcs(StateId s) const {
569    return impl_->NumArcs(s);
570  }
571
572  virtual size_t NumInputEpsilons(StateId s) const {
573    return impl_->NumInputEpsilons(s);
574  }
575
576  virtual size_t NumOutputEpsilons(StateId s) const {
577    return impl_->NumOutputEpsilons(s);
578  }
579
580  virtual uint64 Properties(uint64 mask, bool test) const {
581    if (test) {
582      uint64 known, test = TestProperties(*this, mask, &known);
583      impl_->SetProperties(test, known);
584      return test & mask;
585    } else {
586      return impl_->Properties(mask);
587    }
588  }
589
590  virtual const string& Type() const {
591    return impl_->Type();
592  }
593
594  virtual ReplaceFst<A>* Copy() const {
595    return new ReplaceFst<A>(*this);
596  }
597
598  virtual const SymbolTable* InputSymbols() const {
599    return impl_->InputSymbols();
600  }
601
602  virtual const SymbolTable* OutputSymbols() const {
603    return impl_->OutputSymbols();
604  }
605
606  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
607
608  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
609    impl_->InitArcIterator(s, data);
610  }
611
612  bool CyclicDependencies() const {
613    return impl_->CyclicDependencies();
614  }
615
616 private:
617  ReplaceFstImpl<A>* impl_;
618};
619
620
621// Specialization for ReplaceFst.
622template<class A>
623class StateIterator< ReplaceFst<A> >
624    : public CacheStateIterator< ReplaceFst<A> > {
625 public:
626  explicit StateIterator(const ReplaceFst<A> &fst)
627      : CacheStateIterator< ReplaceFst<A> >(fst) {}
628
629 private:
630  DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
631};
632
633// Specialization for ReplaceFst.
634template <class A>
635class ArcIterator< ReplaceFst<A> >
636    : public CacheArcIterator< ReplaceFst<A> > {
637 public:
638  typedef typename A::StateId StateId;
639
640  ArcIterator(const ReplaceFst<A> &fst, StateId s)
641      : CacheArcIterator< ReplaceFst<A> >(fst, s) {
642    if (!fst.impl_->HasArcs(s))
643      fst.impl_->Expand(s);
644  }
645
646 private:
647  DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
648};
649
650template <class A> inline
651void ReplaceFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
652  data->base = new StateIterator< ReplaceFst<A> >(*this);
653}
654
655typedef ReplaceFst<StdArc> StdReplaceFst;
656
657
658// // Recursivively replaces arcs in the root Fst with other Fsts.
659// This version writes the result of replacement to an output MutableFst.
660//
661// Replace supports replacement of arcs in one Fst with another
662// Fst. This replacement is recursive.  Replace takes an array of
663// Fst(s). One Fst represents the root (or topology) machine. The root
664// Fst refers to other Fsts by recursively replacing arcs labeled as
665// non-terminals with the matching non-terminal Fst. Currently Replace
666// uses the output symbols of the arcs to determine whether the arc is
667// a non-terminal arc or not. A non-terminal can be any label that is
668// not a non-zero terminal label in the output alphabet.  Note that
669// input argument is a vector of pair<>. These correspond to the tuple
670// of non-terminal Label and corresponding Fst.
671template<class Arc>
672void Replace(const vector<pair<typename Arc::Label,
673             const Fst<Arc>* > >& ifst_array,
674             MutableFst<Arc> *ofst, typename Arc::Label root) {
675  ReplaceFstOptions opts(root);
676  opts.gc_limit = 0;  // Cache only the last state for fastest copy.
677  *ofst = ReplaceFst<Arc>(ifst_array, opts);
678}
679
680}
681
682#endif  // FST_LIB_REPLACE_H__
683