1// replace-util.h
2
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16// Copyright 2005-2010 Google, Inc.
17// Author: riley@google.com (Michael Riley)
18//
19
20// \file
21// Utility classes for the recursive replacement of Fsts (RTNs).
22
23#ifndef FST_LIB_REPLACE_UTIL_H__
24#define FST_LIB_REPLACE_UTIL_H__
25
26#include <vector>
27using std::vector;
28#include <tr1/unordered_map>
29using std::tr1::unordered_map;
30using std::tr1::unordered_multimap;
31#include <tr1/unordered_set>
32using std::tr1::unordered_set;
33using std::tr1::unordered_multiset;
34#include <map>
35
36#include <fst/connect.h>
37#include <fst/mutable-fst.h>
38#include <fst/topsort.h>
39
40
41namespace fst {
42
43template <class Arc>
44void Replace(const vector<pair<typename Arc::Label, const Fst<Arc>* > >&,
45             MutableFst<Arc> *, typename Arc::Label, bool);
46
47
48// Utility class for the recursive replacement of Fsts (RTNs). The
49// user provides a set of Label, Fst pairs at construction. These are
50// used by methods for testing cyclic dependencies and connectedness
51// and doing RTN connection and specific Fst replacement by label or
52// for various optimization properties. The modified results can be
53// obtained with the GetFstPairs() or GetMutableFstPairs() methods.
54template <class Arc>
55class ReplaceUtil {
56 public:
57  typedef typename Arc::Label Label;
58  typedef typename Arc::Weight Weight;
59  typedef typename Arc::StateId StateId;
60
61  typedef pair<Label, const Fst<Arc>*> FstPair;
62  typedef pair<Label, MutableFst<Arc>*> MutableFstPair;
63  typedef unordered_map<Label, Label> NonTerminalHash;
64
65  // Constructs from mutable Fsts; Fst ownership given to ReplaceUtil.
66  ReplaceUtil(const vector<MutableFstPair> &fst_pairs,
67              Label root_label, bool epsilon_on_replace = false);
68
69  // Constructs from Fsts; Fst ownership retained by caller.
70  ReplaceUtil(const vector<FstPair> &fst_pairs,
71              Label root_label, bool epsilon_on_replace = false);
72
73  // Constructs from ReplaceFst internals; ownership retained by caller.
74  ReplaceUtil(const vector<const Fst<Arc> *> &fst_array,
75              const NonTerminalHash &nonterminal_hash, Label root_fst,
76              bool epsilon_on_replace = false);
77
78  ~ReplaceUtil() {
79    for (Label i = 0; i < fst_array_.size(); ++i)
80      delete fst_array_[i];
81  }
82
83  // True if the non-terminal dependencies are cyclic. Cyclic
84  // dependencies will result in an unexpandable replace fst.
85  bool CyclicDependencies() const {
86    GetDependencies(false);
87    return depprops_ & kCyclic;
88  }
89
90  // Returns true if no useless Fsts, states or transitions.
91  bool Connected() const {
92    GetDependencies(false);
93    uint64 props = kAccessible | kCoAccessible;
94    for (Label i = 0; i < fst_array_.size(); ++i) {
95      if (!fst_array_[i])
96        continue;
97      if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i])
98        return false;
99    }
100    return true;
101  }
102
103  // Removes useless Fsts, states and transitions.
104  void Connect();
105
106  // Replaces Fsts specified by labels.
107  // Does nothing if there are cyclic dependencies.
108  void ReplaceLabels(const vector<Label> &labels);
109
110  // Replaces Fsts that have at most 'nstates' states, 'narcs' arcs and
111  // 'nnonterm' non-terminals (updating in reverse dependency order).
112  // Does nothing if there are cyclic dependencies.
113  void ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms);
114
115  // Replaces singleton Fsts.
116  // Does nothing if there are cyclic dependencies.
117  void ReplaceTrivial() { ReplaceBySize(2, 1, 1); }
118
119  // Replaces non-terminals that have at most 'ninstances' instances
120  // (updating in dependency order).
121  // Does nothing if there are cyclic dependencies.
122  void ReplaceByInstances(size_t ninstances);
123
124  // Replaces non-terminals that have only one instance.
125  // Does nothing if there are cyclic dependencies.
126  void ReplaceUnique() { ReplaceByInstances(1); }
127
128  // Returns Label, Fst pairs; Fst ownership retained by ReplaceUtil.
129  void GetFstPairs(vector<FstPair> *fst_pairs);
130
131  // Returns Label, MutableFst pairs; Fst ownership given to caller.
132  void GetMutableFstPairs(vector<MutableFstPair> *mutable_fst_pairs);
133
134 private:
135  // Per Fst statistics
136  struct ReplaceStats {
137    StateId nstates;    // # of states
138    StateId nfinal;     // # of final states
139    size_t narcs;       // # of arcs
140    Label nnonterms;    // # of non-terminals in Fst
141    size_t nref;        // # of non-terminal instances referring to this Fst
142
143    // # of times that ith Fst references this Fst
144    map<Label, size_t> inref;
145    // # of times that this Fst references the ith Fst
146    map<Label, size_t> outref;
147
148    ReplaceStats()
149        : nstates(0),
150          nfinal(0),
151          narcs(0),
152          nnonterms(0),
153          nref(0) {}
154  };
155
156  // Check Mutable Fsts exist o.w. create them.
157  void CheckMutableFsts();
158
159  // Computes the dependency graph of the replace Fsts.
160  // If 'stats' is true, dependency statistics computed as well.
161  void GetDependencies(bool stats) const;
162
163  void ClearDependencies() const {
164    depfst_.DeleteStates();
165    stats_.clear();
166    depprops_ = 0;
167    have_stats_ = false;
168  }
169
170  // Get topological order of dependencies. Returns false with cyclic input.
171  bool GetTopOrder(const Fst<Arc> &fst, vector<Label> *toporder) const;
172
173  // Update statistics assuming that jth Fst will be replaced.
174  void UpdateStats(Label j);
175
176  Label root_label_;                              // root non-terminal
177  Label root_fst_;                                // root Fst ID
178  bool epsilon_on_replace_;                       // see Replace()
179  vector<const Fst<Arc> *> fst_array_;            // Fst per ID
180  vector<MutableFst<Arc> *> mutable_fst_array_;   // MutableFst per ID
181  vector<Label> nonterminal_array_;               // Fst ID to non-terminal
182  NonTerminalHash nonterminal_hash_;              // non-terminal to Fst ID
183  mutable VectorFst<Arc> depfst_;                 // Fst ID dependencies
184  mutable vector<bool> depaccess_;                // Fst ID accessibility
185  mutable uint64 depprops_;                       // dependency Fst props
186  mutable bool have_stats_;                       // have dependency statistics
187  mutable vector<ReplaceStats> stats_;            // Per Fst statistics
188  DISALLOW_COPY_AND_ASSIGN(ReplaceUtil);
189};
190
191template <class Arc>
192ReplaceUtil<Arc>::ReplaceUtil(
193    const vector<MutableFstPair> &fst_pairs,
194    Label root_label, bool epsilon_on_replace)
195    : root_label_(root_label),
196      epsilon_on_replace_(epsilon_on_replace),
197      depprops_(0),
198      have_stats_(false) {
199  fst_array_.push_back(0);
200  mutable_fst_array_.push_back(0);
201  nonterminal_array_.push_back(kNoLabel);
202  for (Label i = 0; i < fst_pairs.size(); ++i) {
203    Label label = fst_pairs[i].first;
204    MutableFst<Arc> *fst = fst_pairs[i].second;
205    nonterminal_hash_[label] = fst_array_.size();
206    nonterminal_array_.push_back(label);
207    fst_array_.push_back(fst);
208    mutable_fst_array_.push_back(fst);
209  }
210  root_fst_ = nonterminal_hash_[root_label_];
211  if (!root_fst_)
212    FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_;
213}
214
215template <class Arc>
216ReplaceUtil<Arc>::ReplaceUtil(
217    const vector<FstPair> &fst_pairs,
218    Label root_label, bool epsilon_on_replace)
219    : root_label_(root_label),
220      epsilon_on_replace_(epsilon_on_replace),
221      depprops_(0),
222      have_stats_(false) {
223  fst_array_.push_back(0);
224  nonterminal_array_.push_back(kNoLabel);
225  for (Label i = 0; i < fst_pairs.size(); ++i) {
226    Label label = fst_pairs[i].first;
227    const Fst<Arc> *fst = fst_pairs[i].second;
228    nonterminal_hash_[label] = fst_array_.size();
229    nonterminal_array_.push_back(label);
230    fst_array_.push_back(fst->Copy());
231  }
232  root_fst_ = nonterminal_hash_[root_label];
233  if (!root_fst_)
234    FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_;
235}
236
237template <class Arc>
238ReplaceUtil<Arc>::ReplaceUtil(
239    const vector<const Fst<Arc> *> &fst_array,
240    const NonTerminalHash &nonterminal_hash, Label root_fst,
241    bool epsilon_on_replace)
242    : root_fst_(root_fst),
243      epsilon_on_replace_(epsilon_on_replace),
244      nonterminal_array_(fst_array.size()),
245      nonterminal_hash_(nonterminal_hash),
246      depprops_(0),
247      have_stats_(false) {
248  fst_array_.push_back(0);
249  for (Label i = 1; i < fst_array.size(); ++i)
250    fst_array_.push_back(fst_array[i]->Copy());
251  for (typename NonTerminalHash::const_iterator it =
252           nonterminal_hash.begin(); it != nonterminal_hash.end(); ++it)
253    nonterminal_array_[it->second] = it->first;
254  root_label_ = nonterminal_array_[root_fst_];
255}
256
257template <class Arc>
258void ReplaceUtil<Arc>::GetDependencies(bool stats) const {
259  if (depfst_.NumStates() > 0) {
260    if (stats && !have_stats_)
261      ClearDependencies();
262    else
263      return;
264  }
265
266  have_stats_ = stats;
267  if (have_stats_)
268    stats_.reserve(fst_array_.size());
269
270  for (Label i = 0; i < fst_array_.size(); ++i) {
271    depfst_.AddState();
272    depfst_.SetFinal(i, Weight::One());
273    if (have_stats_)
274      stats_.push_back(ReplaceStats());
275  }
276  depfst_.SetStart(root_fst_);
277
278  // An arc from each state (representing the fst) to the
279  // state representing the fst being replaced
280  for (Label i = 0; i < fst_array_.size(); ++i) {
281    const Fst<Arc> *ifst = fst_array_[i];
282    if (!ifst)
283      continue;
284    for (StateIterator<Fst<Arc> > siter(*ifst); !siter.Done(); siter.Next()) {
285      StateId s = siter.Value();
286      if (have_stats_) {
287        ++stats_[i].nstates;
288        if (ifst->Final(s) != Weight::Zero())
289          ++stats_[i].nfinal;
290      }
291      for (ArcIterator<Fst<Arc> > aiter(*ifst, s);
292           !aiter.Done(); aiter.Next()) {
293        if (have_stats_)
294          ++stats_[i].narcs;
295        const Arc& arc = aiter.Value();
296
297        typename NonTerminalHash::const_iterator it =
298            nonterminal_hash_.find(arc.olabel);
299        if (it != nonterminal_hash_.end()) {
300          Label j = it->second;
301          depfst_.AddArc(i, Arc(arc.olabel, arc.olabel, Weight::One(), j));
302          if (have_stats_) {
303            ++stats_[i].nnonterms;
304            ++stats_[j].nref;
305            ++stats_[j].inref[i];
306            ++stats_[i].outref[j];
307          }
308        }
309      }
310    }
311  }
312
313  // Gets accessibility info
314  SccVisitor<Arc> scc_visitor(0, &depaccess_, 0, &depprops_);
315  DfsVisit(depfst_, &scc_visitor);
316}
317
318template <class Arc>
319void ReplaceUtil<Arc>::UpdateStats(Label j) {
320  if (!have_stats_) {
321    FSTERROR() << "ReplaceUtil::UpdateStats: stats not available";
322    return;
323  }
324
325  if (j == root_fst_)  // can't replace root
326    return;
327
328  typedef typename map<Label, size_t>::iterator Iter;
329  for (Iter in = stats_[j].inref.begin();
330       in != stats_[j].inref.end();
331       ++in) {
332    Label i = in->first;
333    size_t ni = in->second;
334    stats_[i].nstates += stats_[j].nstates * ni;
335    stats_[i].narcs += (stats_[j].narcs + 1) * ni;  // narcs - 1 + 2 (eps)
336    stats_[i].nnonterms += (stats_[j].nnonterms - 1) * ni;
337    stats_[i].outref.erase(stats_[i].outref.find(j));
338    for (Iter out = stats_[j].outref.begin();
339         out != stats_[j].outref.end();
340         ++out) {
341      Label k = out->first;
342      size_t nk = out->second;
343      stats_[i].outref[k] += ni * nk;
344    }
345  }
346
347  for (Iter out = stats_[j].outref.begin();
348       out != stats_[j].outref.end();
349       ++out) {
350    Label k = out->first;
351    size_t nk = out->second;
352    stats_[k].nref -= nk;
353    stats_[k].inref.erase(stats_[k].inref.find(j));
354    for (Iter in = stats_[j].inref.begin();
355         in != stats_[j].inref.end();
356         ++in) {
357      Label i = in->first;
358      size_t ni = in->second;
359      stats_[k].inref[i] += ni * nk;
360      stats_[k].nref += ni * nk;
361    }
362  }
363}
364
365template <class Arc>
366void ReplaceUtil<Arc>::CheckMutableFsts() {
367  if (mutable_fst_array_.size() == 0) {
368    for (Label i = 0; i < fst_array_.size(); ++i) {
369      if (!fst_array_[i]) {
370        mutable_fst_array_.push_back(0);
371      } else {
372        mutable_fst_array_.push_back(new VectorFst<Arc>(*fst_array_[i]));
373        delete fst_array_[i];
374        fst_array_[i] = mutable_fst_array_[i];
375      }
376    }
377  }
378}
379
380template <class Arc>
381void ReplaceUtil<Arc>::Connect() {
382  CheckMutableFsts();
383  uint64 props = kAccessible | kCoAccessible;
384  for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
385    if (!mutable_fst_array_[i])
386      continue;
387    if (mutable_fst_array_[i]->Properties(props, false) != props)
388      fst::Connect(mutable_fst_array_[i]);
389  }
390  GetDependencies(false);
391  for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
392    MutableFst<Arc> *fst = mutable_fst_array_[i];
393    if (fst && !depaccess_[i]) {
394      delete fst;
395      fst_array_[i] = 0;
396      mutable_fst_array_[i] = 0;
397    }
398  }
399  ClearDependencies();
400}
401
402template <class Arc>
403bool ReplaceUtil<Arc>::GetTopOrder(const Fst<Arc> &fst,
404                                   vector<Label> *toporder) const {
405  // Finds topological order of dependencies.
406  vector<StateId> order;
407  bool acyclic = false;
408
409  TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
410  DfsVisit(fst, &top_order_visitor);
411  if (!acyclic) {
412    LOG(WARNING) << "ReplaceUtil::GetTopOrder: Cyclical label dependencies";
413    return false;
414  }
415
416  toporder->resize(order.size());
417  for (Label i = 0; i < order.size(); ++i)
418    (*toporder)[order[i]] = i;
419
420  return true;
421}
422
423template <class Arc>
424void ReplaceUtil<Arc>::ReplaceLabels(const vector<Label> &labels) {
425  CheckMutableFsts();
426  unordered_set<Label> label_set;
427  for (Label i = 0; i < labels.size(); ++i)
428    if (labels[i] != root_label_)  // can't replace root
429      label_set.insert(labels[i]);
430
431  // Finds Fst dependencies restricted to the labels requested.
432  GetDependencies(false);
433  VectorFst<Arc> pfst(depfst_);
434  for (StateId i = 0; i < pfst.NumStates(); ++i) {
435    vector<Arc> arcs;
436    for (ArcIterator< VectorFst<Arc> > aiter(pfst, i);
437         !aiter.Done(); aiter.Next()) {
438      const Arc &arc = aiter.Value();
439      Label label = nonterminal_array_[arc.nextstate];
440      if (label_set.count(label) > 0)
441        arcs.push_back(arc);
442    }
443    pfst.DeleteArcs(i);
444    for (size_t j = 0; j < arcs.size(); ++j)
445      pfst.AddArc(i, arcs[j]);
446  }
447
448  vector<Label> toporder;
449  if (!GetTopOrder(pfst, &toporder)) {
450    ClearDependencies();
451    return;
452  }
453
454  // Visits Fsts in reverse topological order of dependencies and
455  // performs replacements.
456  for (Label o = toporder.size() - 1; o >= 0;  --o) {
457    vector<FstPair> fst_pairs;
458    StateId s = toporder[o];
459    for (ArcIterator< VectorFst<Arc> > aiter(pfst, s);
460         !aiter.Done(); aiter.Next()) {
461      const Arc &arc = aiter.Value();
462      Label label = nonterminal_array_[arc.nextstate];
463      const Fst<Arc> *fst = fst_array_[arc.nextstate];
464      fst_pairs.push_back(make_pair(label, fst));
465    }
466    if (fst_pairs.empty())
467        continue;
468    Label label = nonterminal_array_[s];
469    const Fst<Arc> *fst = fst_array_[s];
470    fst_pairs.push_back(make_pair(label, fst));
471
472    Replace(fst_pairs, mutable_fst_array_[s], label, epsilon_on_replace_);
473  }
474  ClearDependencies();
475}
476
477template <class Arc>
478void ReplaceUtil<Arc>::ReplaceBySize(size_t nstates, size_t narcs,
479                                     size_t nnonterms) {
480  vector<Label> labels;
481  GetDependencies(true);
482
483  vector<Label> toporder;
484  if (!GetTopOrder(depfst_, &toporder)) {
485    ClearDependencies();
486    return;
487  }
488
489  for (Label o = toporder.size() - 1; o >= 0; --o) {
490    Label j = toporder[o];
491    if (stats_[j].nstates <= nstates &&
492        stats_[j].narcs <= narcs &&
493        stats_[j].nnonterms <= nnonterms) {
494      labels.push_back(nonterminal_array_[j]);
495      UpdateStats(j);
496    }
497  }
498  ReplaceLabels(labels);
499}
500
501template <class Arc>
502void ReplaceUtil<Arc>::ReplaceByInstances(size_t ninstances) {
503  vector<Label> labels;
504  GetDependencies(true);
505
506  vector<Label> toporder;
507  if (!GetTopOrder(depfst_, &toporder)) {
508    ClearDependencies();
509    return;
510  }
511  for (Label o = 0; o < toporder.size(); ++o) {
512    Label j = toporder[o];
513    if (stats_[j].nref <= ninstances) {
514      labels.push_back(nonterminal_array_[j]);
515      UpdateStats(j);
516    }
517  }
518  ReplaceLabels(labels);
519}
520
521template <class Arc>
522void ReplaceUtil<Arc>::GetFstPairs(vector<FstPair> *fst_pairs) {
523  CheckMutableFsts();
524  fst_pairs->clear();
525  for (Label i = 0; i < fst_array_.size(); ++i) {
526    Label label = nonterminal_array_[i];
527    const Fst<Arc> *fst = fst_array_[i];
528    if (!fst)
529      continue;
530    fst_pairs->push_back(make_pair(label, fst));
531  }
532}
533
534template <class Arc>
535void ReplaceUtil<Arc>::GetMutableFstPairs(
536    vector<MutableFstPair> *mutable_fst_pairs) {
537  CheckMutableFsts();
538  mutable_fst_pairs->clear();
539  for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
540    Label label = nonterminal_array_[i];
541    MutableFst<Arc> *fst = mutable_fst_array_[i];
542    if (!fst)
543      continue;
544    mutable_fst_pairs->push_back(make_pair(label, fst->Copy()));
545  }
546}
547
548}  // namespace fst
549
550#endif  // FST_LIB_REPLACE_UTIL_H__
551