1// connect.h 2 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// Copyright 2005-2010 Google, Inc. 16// Author: riley@google.com (Michael Riley) 17// 18// \file 19// Classes and functions to remove unsuccessful paths from an Fst. 20 21#ifndef FST_LIB_CONNECT_H__ 22#define FST_LIB_CONNECT_H__ 23 24#include <vector> 25using std::vector; 26 27#include <fst/dfs-visit.h> 28#include <fst/union-find.h> 29#include <fst/mutable-fst.h> 30 31 32namespace fst { 33 34// Finds and returns connected components. Use with Visit(). 35template <class A> 36class CcVisitor { 37 public: 38 typedef A Arc; 39 typedef typename Arc::Weight Weight; 40 typedef typename A::StateId StateId; 41 42 // cc[i]: connected component number for state i. 43 CcVisitor(vector<StateId> *cc) 44 : comps_(new UnionFind<StateId>(0, kNoStateId)), 45 cc_(cc), 46 nstates_(0) { } 47 48 // comps: connected components equiv classes. 49 CcVisitor(UnionFind<StateId> *comps) 50 : comps_(comps), 51 cc_(0), 52 nstates_(0) { } 53 54 ~CcVisitor() { 55 if (cc_) // own comps_? 56 delete comps_; 57 } 58 59 void InitVisit(const Fst<A> &fst) { } 60 61 bool InitState(StateId s, StateId root) { 62 ++nstates_; 63 if (comps_->FindSet(s) == kNoStateId) 64 comps_->MakeSet(s); 65 return true; 66 } 67 68 bool WhiteArc(StateId s, const A &arc) { 69 comps_->MakeSet(arc.nextstate); 70 comps_->Union(s, arc.nextstate); 71 return true; 72 } 73 74 bool GreyArc(StateId s, const A &arc) { 75 comps_->Union(s, arc.nextstate); 76 return true; 77 } 78 79 bool BlackArc(StateId s, const A &arc) { 80 comps_->Union(s, arc.nextstate); 81 return true; 82 } 83 84 void FinishState(StateId s) { } 85 86 void FinishVisit() { 87 if (cc_) 88 GetCcVector(cc_); 89 } 90 91 // cc[i]: connected component number for state i. 92 // Returns number of components. 93 int GetCcVector(vector<StateId> *cc) { 94 cc->clear(); 95 cc->resize(nstates_, kNoStateId); 96 StateId ncomp = 0; 97 for (StateId i = 0; i < nstates_; ++i) { 98 StateId rep = comps_->FindSet(i); 99 StateId &comp = (*cc)[rep]; 100 if (comp == kNoStateId) { 101 comp = ncomp; 102 ++ncomp; 103 } 104 (*cc)[i] = comp; 105 } 106 return ncomp; 107 } 108 109 private: 110 UnionFind<StateId> *comps_; // Components 111 vector<StateId> *cc_; // State's cc number 112 StateId nstates_; // State count 113}; 114 115 116// Finds and returns strongly-connected components, accessible and 117// coaccessible states and related properties. Uses Tarjan's single 118// DFS SCC algorithm (see Aho, et al, "Design and Analysis of Computer 119// Algorithms", 189pp). Use with DfsVisit(); 120template <class A> 121class SccVisitor { 122 public: 123 typedef A Arc; 124 typedef typename A::Weight Weight; 125 typedef typename A::StateId StateId; 126 127 // scc[i]: strongly-connected component number for state i. 128 // SCC numbers will be in topological order for acyclic input. 129 // access[i]: accessibility of state i. 130 // coaccess[i]: coaccessibility of state i. 131 // Any of above can be NULL. 132 // props: related property bits (cyclicity, initial cyclicity, 133 // accessibility, coaccessibility) set/cleared (o.w. unchanged). 134 SccVisitor(vector<StateId> *scc, vector<bool> *access, 135 vector<bool> *coaccess, uint64 *props) 136 : scc_(scc), access_(access), coaccess_(coaccess), props_(props) {} 137 SccVisitor(uint64 *props) 138 : scc_(0), access_(0), coaccess_(0), props_(props) {} 139 140 void InitVisit(const Fst<A> &fst); 141 142 bool InitState(StateId s, StateId root); 143 144 bool TreeArc(StateId s, const A &arc) { return true; } 145 146 bool BackArc(StateId s, const A &arc) { 147 StateId t = arc.nextstate; 148 if ((*dfnumber_)[t] < (*lowlink_)[s]) 149 (*lowlink_)[s] = (*dfnumber_)[t]; 150 if ((*coaccess_)[t]) 151 (*coaccess_)[s] = true; 152 *props_ |= kCyclic; 153 *props_ &= ~kAcyclic; 154 if (arc.nextstate == start_) { 155 *props_ |= kInitialCyclic; 156 *props_ &= ~kInitialAcyclic; 157 } 158 return true; 159 } 160 161 bool ForwardOrCrossArc(StateId s, const A &arc) { 162 StateId t = arc.nextstate; 163 if ((*dfnumber_)[t] < (*dfnumber_)[s] /* cross edge */ && 164 (*onstack_)[t] && (*dfnumber_)[t] < (*lowlink_)[s]) 165 (*lowlink_)[s] = (*dfnumber_)[t]; 166 if ((*coaccess_)[t]) 167 (*coaccess_)[s] = true; 168 return true; 169 } 170 171 void FinishState(StateId s, StateId p, const A *); 172 173 void FinishVisit() { 174 // Numbers SCC's in topological order when acyclic. 175 if (scc_) 176 for (StateId i = 0; i < scc_->size(); ++i) 177 (*scc_)[i] = nscc_ - 1 - (*scc_)[i]; 178 if (coaccess_internal_) 179 delete coaccess_; 180 delete dfnumber_; 181 delete lowlink_; 182 delete onstack_; 183 delete scc_stack_; 184 } 185 186 private: 187 vector<StateId> *scc_; // State's scc number 188 vector<bool> *access_; // State's accessibility 189 vector<bool> *coaccess_; // State's coaccessibility 190 uint64 *props_; 191 const Fst<A> *fst_; 192 StateId start_; 193 StateId nstates_; // State count 194 StateId nscc_; // SCC count 195 bool coaccess_internal_; 196 vector<StateId> *dfnumber_; // state discovery times 197 vector<StateId> *lowlink_; // lowlink[s] == dfnumber[s] => SCC root 198 vector<bool> *onstack_; // is a state on the SCC stack 199 vector<StateId> *scc_stack_; // SCC stack (w/ random access) 200}; 201 202template <class A> inline 203void SccVisitor<A>::InitVisit(const Fst<A> &fst) { 204 if (scc_) 205 scc_->clear(); 206 if (access_) 207 access_->clear(); 208 if (coaccess_) { 209 coaccess_->clear(); 210 coaccess_internal_ = false; 211 } else { 212 coaccess_ = new vector<bool>; 213 coaccess_internal_ = true; 214 } 215 *props_ |= kAcyclic | kInitialAcyclic | kAccessible | kCoAccessible; 216 *props_ &= ~(kCyclic | kInitialCyclic | kNotAccessible | kNotCoAccessible); 217 fst_ = &fst; 218 start_ = fst.Start(); 219 nstates_ = 0; 220 nscc_ = 0; 221 dfnumber_ = new vector<StateId>; 222 lowlink_ = new vector<StateId>; 223 onstack_ = new vector<bool>; 224 scc_stack_ = new vector<StateId>; 225} 226 227template <class A> inline 228bool SccVisitor<A>::InitState(StateId s, StateId root) { 229 scc_stack_->push_back(s); 230 while (dfnumber_->size() <= s) { 231 if (scc_) 232 scc_->push_back(-1); 233 if (access_) 234 access_->push_back(false); 235 coaccess_->push_back(false); 236 dfnumber_->push_back(-1); 237 lowlink_->push_back(-1); 238 onstack_->push_back(false); 239 } 240 (*dfnumber_)[s] = nstates_; 241 (*lowlink_)[s] = nstates_; 242 (*onstack_)[s] = true; 243 if (root == start_) { 244 if (access_) 245 (*access_)[s] = true; 246 } else { 247 if (access_) 248 (*access_)[s] = false; 249 *props_ |= kNotAccessible; 250 *props_ &= ~kAccessible; 251 } 252 ++nstates_; 253 return true; 254} 255 256template <class A> inline 257void SccVisitor<A>::FinishState(StateId s, StateId p, const A *) { 258 if (fst_->Final(s) != Weight::Zero()) 259 (*coaccess_)[s] = true; 260 if ((*dfnumber_)[s] == (*lowlink_)[s]) { // root of new SCC 261 bool scc_coaccess = false; 262 size_t i = scc_stack_->size(); 263 StateId t; 264 do { 265 t = (*scc_stack_)[--i]; 266 if ((*coaccess_)[t]) 267 scc_coaccess = true; 268 } while (s != t); 269 do { 270 t = scc_stack_->back(); 271 if (scc_) 272 (*scc_)[t] = nscc_; 273 if (scc_coaccess) 274 (*coaccess_)[t] = true; 275 (*onstack_)[t] = false; 276 scc_stack_->pop_back(); 277 } while (s != t); 278 if (!scc_coaccess) { 279 *props_ |= kNotCoAccessible; 280 *props_ &= ~kCoAccessible; 281 } 282 ++nscc_; 283 } 284 if (p != kNoStateId) { 285 if ((*coaccess_)[s]) 286 (*coaccess_)[p] = true; 287 if ((*lowlink_)[s] < (*lowlink_)[p]) 288 (*lowlink_)[p] = (*lowlink_)[s]; 289 } 290} 291 292 293// Trims an FST, removing states and arcs that are not on successful 294// paths. This version modifies its input. 295// 296// Complexity: 297// - Time: O(V + E) 298// - Space: O(V + E) 299// where V = # of states and E = # of arcs. 300template<class Arc> 301void Connect(MutableFst<Arc> *fst) { 302 typedef typename Arc::StateId StateId; 303 304 vector<bool> access; 305 vector<bool> coaccess; 306 uint64 props = 0; 307 SccVisitor<Arc> scc_visitor(0, &access, &coaccess, &props); 308 DfsVisit(*fst, &scc_visitor); 309 vector<StateId> dstates; 310 for (StateId s = 0; s < access.size(); ++s) 311 if (!access[s] || !coaccess[s]) 312 dstates.push_back(s); 313 fst->DeleteStates(dstates); 314 fst->SetProperties(kAccessible | kCoAccessible, kAccessible | kCoAccessible); 315} 316 317} // namespace fst 318 319#endif // FST_LIB_CONNECT_H__ 320