1// connect.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Classes and functions to remove unsuccessful paths from an Fst.
20
21#ifndef FST_LIB_CONNECT_H__
22#define FST_LIB_CONNECT_H__
23
24#include <vector>
25using std::vector;
26
27#include <fst/dfs-visit.h>
28#include <fst/union-find.h>
29#include <fst/mutable-fst.h>
30
31
32namespace fst {
33
34// Finds and returns connected components. Use with Visit().
35template <class A>
36class CcVisitor {
37 public:
38  typedef A Arc;
39  typedef typename Arc::Weight Weight;
40  typedef typename A::StateId StateId;
41
42  // cc[i]: connected component number for state i.
43  CcVisitor(vector<StateId> *cc)
44      : comps_(new UnionFind<StateId>(0, kNoStateId)),
45        cc_(cc),
46        nstates_(0) { }
47
48  // comps: connected components equiv classes.
49  CcVisitor(UnionFind<StateId> *comps)
50      : comps_(comps),
51        cc_(0),
52        nstates_(0) { }
53
54  ~CcVisitor() {
55    if (cc_)  // own comps_?
56      delete comps_;
57  }
58
59  void InitVisit(const Fst<A> &fst) { }
60
61  bool InitState(StateId s, StateId root) {
62    ++nstates_;
63    if (comps_->FindSet(s) == kNoStateId)
64      comps_->MakeSet(s);
65    return true;
66  }
67
68  bool WhiteArc(StateId s, const A &arc) {
69    comps_->MakeSet(arc.nextstate);
70    comps_->Union(s, arc.nextstate);
71    return true;
72  }
73
74  bool GreyArc(StateId s, const A &arc) {
75    comps_->Union(s, arc.nextstate);
76    return true;
77  }
78
79  bool BlackArc(StateId s, const A &arc) {
80    comps_->Union(s, arc.nextstate);
81    return true;
82  }
83
84  void FinishState(StateId s) { }
85
86  void FinishVisit() {
87    if (cc_)
88      GetCcVector(cc_);
89  }
90
91  // cc[i]: connected component number for state i.
92  // Returns number of components.
93  int GetCcVector(vector<StateId> *cc) {
94    cc->clear();
95    cc->resize(nstates_, kNoStateId);
96    StateId ncomp = 0;
97    for (StateId i = 0; i < nstates_; ++i) {
98      StateId rep = comps_->FindSet(i);
99      StateId &comp = (*cc)[rep];
100      if (comp == kNoStateId) {
101        comp = ncomp;
102        ++ncomp;
103      }
104      (*cc)[i] = comp;
105    }
106    return ncomp;
107  }
108
109 private:
110  UnionFind<StateId> *comps_;   // Components
111  vector<StateId> *cc_;         // State's cc number
112  StateId nstates_;             // State count
113};
114
115
116// Finds and returns strongly-connected components, accessible and
117// coaccessible states and related properties. Uses Tarjan's single
118// DFS SCC algorithm (see Aho, et al, "Design and Analysis of Computer
119// Algorithms", 189pp). Use with DfsVisit();
120template <class A>
121class SccVisitor {
122 public:
123  typedef A Arc;
124  typedef typename A::Weight Weight;
125  typedef typename A::StateId StateId;
126
127  // scc[i]: strongly-connected component number for state i.
128  //   SCC numbers will be in topological order for acyclic input.
129  // access[i]: accessibility of state i.
130  // coaccess[i]: coaccessibility of state i.
131  // Any of above can be NULL.
132  // props: related property bits (cyclicity, initial cyclicity,
133  //   accessibility, coaccessibility) set/cleared (o.w. unchanged).
134  SccVisitor(vector<StateId> *scc, vector<bool> *access,
135             vector<bool> *coaccess, uint64 *props)
136      : scc_(scc), access_(access), coaccess_(coaccess), props_(props) {}
137  SccVisitor(uint64 *props)
138      : scc_(0), access_(0), coaccess_(0), props_(props) {}
139
140  void InitVisit(const Fst<A> &fst);
141
142  bool InitState(StateId s, StateId root);
143
144  bool TreeArc(StateId s, const A &arc) { return true; }
145
146  bool BackArc(StateId s, const A &arc) {
147    StateId t = arc.nextstate;
148    if ((*dfnumber_)[t] < (*lowlink_)[s])
149      (*lowlink_)[s] = (*dfnumber_)[t];
150    if ((*coaccess_)[t])
151      (*coaccess_)[s] = true;
152    *props_ |= kCyclic;
153    *props_ &= ~kAcyclic;
154    if (arc.nextstate == start_) {
155      *props_ |= kInitialCyclic;
156      *props_ &= ~kInitialAcyclic;
157    }
158    return true;
159  }
160
161  bool ForwardOrCrossArc(StateId s, const A &arc) {
162    StateId t = arc.nextstate;
163    if ((*dfnumber_)[t] < (*dfnumber_)[s] /* cross edge */ &&
164        (*onstack_)[t] && (*dfnumber_)[t] < (*lowlink_)[s])
165      (*lowlink_)[s] = (*dfnumber_)[t];
166    if ((*coaccess_)[t])
167      (*coaccess_)[s] = true;
168    return true;
169  }
170
171  void FinishState(StateId s, StateId p, const A *);
172
173  void FinishVisit() {
174    // Numbers SCC's in topological order when acyclic.
175    if (scc_)
176      for (StateId i = 0; i < scc_->size(); ++i)
177        (*scc_)[i] = nscc_ - 1 - (*scc_)[i];
178    if (coaccess_internal_)
179      delete coaccess_;
180    delete dfnumber_;
181    delete lowlink_;
182    delete onstack_;
183    delete scc_stack_;
184  }
185
186 private:
187  vector<StateId> *scc_;        // State's scc number
188  vector<bool> *access_;        // State's accessibility
189  vector<bool> *coaccess_;      // State's coaccessibility
190  uint64 *props_;
191  const Fst<A> *fst_;
192  StateId start_;
193  StateId nstates_;             // State count
194  StateId nscc_;                // SCC count
195  bool coaccess_internal_;
196  vector<StateId> *dfnumber_;   // state discovery times
197  vector<StateId> *lowlink_;    // lowlink[s] == dfnumber[s] => SCC root
198  vector<bool> *onstack_;       // is a state on the SCC stack
199  vector<StateId> *scc_stack_;  // SCC stack (w/ random access)
200};
201
202template <class A> inline
203void SccVisitor<A>::InitVisit(const Fst<A> &fst) {
204  if (scc_)
205    scc_->clear();
206  if (access_)
207    access_->clear();
208  if (coaccess_) {
209    coaccess_->clear();
210    coaccess_internal_ = false;
211  } else {
212    coaccess_ = new vector<bool>;
213    coaccess_internal_ = true;
214  }
215  *props_ |= kAcyclic | kInitialAcyclic | kAccessible | kCoAccessible;
216  *props_ &= ~(kCyclic | kInitialCyclic | kNotAccessible | kNotCoAccessible);
217  fst_ = &fst;
218  start_ = fst.Start();
219  nstates_ = 0;
220  nscc_ = 0;
221  dfnumber_ = new vector<StateId>;
222  lowlink_ = new vector<StateId>;
223  onstack_ = new vector<bool>;
224  scc_stack_ = new vector<StateId>;
225}
226
227template <class A> inline
228bool SccVisitor<A>::InitState(StateId s, StateId root) {
229  scc_stack_->push_back(s);
230  while (dfnumber_->size() <= s) {
231    if (scc_)
232      scc_->push_back(-1);
233    if (access_)
234      access_->push_back(false);
235    coaccess_->push_back(false);
236    dfnumber_->push_back(-1);
237    lowlink_->push_back(-1);
238    onstack_->push_back(false);
239  }
240  (*dfnumber_)[s] = nstates_;
241  (*lowlink_)[s] = nstates_;
242  (*onstack_)[s] = true;
243  if (root == start_) {
244    if (access_)
245      (*access_)[s] = true;
246  } else {
247    if (access_)
248      (*access_)[s] = false;
249    *props_ |= kNotAccessible;
250    *props_ &= ~kAccessible;
251  }
252  ++nstates_;
253  return true;
254}
255
256template <class A> inline
257void SccVisitor<A>::FinishState(StateId s, StateId p, const A *) {
258  if (fst_->Final(s) != Weight::Zero())
259    (*coaccess_)[s] = true;
260  if ((*dfnumber_)[s] == (*lowlink_)[s]) {  // root of new SCC
261    bool scc_coaccess = false;
262    size_t i = scc_stack_->size();
263    StateId t;
264    do {
265      t = (*scc_stack_)[--i];
266      if ((*coaccess_)[t])
267        scc_coaccess = true;
268    } while (s != t);
269    do {
270      t = scc_stack_->back();
271      if (scc_)
272        (*scc_)[t] = nscc_;
273      if (scc_coaccess)
274        (*coaccess_)[t] = true;
275      (*onstack_)[t] = false;
276      scc_stack_->pop_back();
277    } while (s != t);
278    if (!scc_coaccess) {
279      *props_ |= kNotCoAccessible;
280      *props_ &= ~kCoAccessible;
281    }
282    ++nscc_;
283  }
284  if (p != kNoStateId) {
285    if ((*coaccess_)[s])
286      (*coaccess_)[p] = true;
287    if ((*lowlink_)[s] < (*lowlink_)[p])
288      (*lowlink_)[p] = (*lowlink_)[s];
289  }
290}
291
292
293// Trims an FST, removing states and arcs that are not on successful
294// paths. This version modifies its input.
295//
296// Complexity:
297// - Time:  O(V + E)
298// - Space: O(V + E)
299// where V = # of states and E = # of arcs.
300template<class Arc>
301void Connect(MutableFst<Arc> *fst) {
302  typedef typename Arc::StateId StateId;
303
304  vector<bool> access;
305  vector<bool> coaccess;
306  uint64 props = 0;
307  SccVisitor<Arc> scc_visitor(0, &access, &coaccess, &props);
308  DfsVisit(*fst, &scc_visitor);
309  vector<StateId> dstates;
310  for (StateId s = 0; s < access.size(); ++s)
311    if (!access[s] || !coaccess[s])
312      dstates.push_back(s);
313  fst->DeleteStates(dstates);
314  fst->SetProperties(kAccessible | kCoAccessible, kAccessible | kCoAccessible);
315}
316
317}  // namespace fst
318
319#endif  // FST_LIB_CONNECT_H__
320