1// state-reachable.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Class to determine whether a given (final) state can be reached from some
20// other given state.
21
22#ifndef FST_LIB_STATE_REACHABLE_H__
23#define FST_LIB_STATE_REACHABLE_H__
24
25#include <vector>
26using std::vector;
27
28#include <fst/dfs-visit.h>
29#include <fst/fst.h>
30#include <fst/interval-set.h>
31
32
33namespace fst {
34
35// Computes the (final) states reachable from a given state in an FST.
36// After this visitor has been called, a final state f can be reached
37// from a state s iff (*isets)[s].Member(state2index[f]) is true, where
38// (*isets[s]) is a set of half-open inteval of final state indices
39// and state2index[f] maps from a final state to its index.
40//
41// If state2index is empty, it is filled-in with suitable indices.
42// If it is non-empty, those indices are used; in this case, the
43// final states must have out-degree 0.
44template <class A, typename I = typename A::StateId>
45class IntervalReachVisitor {
46 public:
47  typedef typename A::StateId StateId;
48  typedef typename A::Label Label;
49  typedef typename A::Weight Weight;
50  typedef typename IntervalSet<I>::Interval Interval;
51
52  IntervalReachVisitor(const Fst<A> &fst,
53                       vector< IntervalSet<I> > *isets,
54                       vector<I> *state2index)
55      : fst_(fst),
56        isets_(isets),
57        state2index_(state2index),
58        index_(state2index->empty() ? 1 : -1),
59        error_(false) {
60    isets_->clear();
61  }
62
63  void InitVisit(const Fst<A> &fst) { error_ = false; }
64
65  bool InitState(StateId s, StateId r) {
66    while (isets_->size() <= s)
67      isets_->push_back(IntervalSet<Label>());
68    while (state2index_->size() <= s)
69      state2index_->push_back(-1);
70
71    if (fst_.Final(s) != Weight::Zero()) {
72      // Create tree interval
73      vector<Interval> *intervals = (*isets_)[s].Intervals();
74      if (index_ < 0) {  // Use state2index_ map to set index
75        if (fst_.NumArcs(s) > 0) {
76          FSTERROR() << "IntervalReachVisitor: state2index map must be empty "
77                     << "for this FST";
78          error_ = true;
79          return false;
80        }
81        I index = (*state2index_)[s];
82        if (index < 0) {
83          FSTERROR() << "IntervalReachVisitor: state2index map incomplete";
84          error_ = true;
85          return false;
86        }
87        intervals->push_back(Interval(index, index + 1));
88      } else {           // Use pre-order index
89        intervals->push_back(Interval(index_, index_ + 1));
90        (*state2index_)[s] = index_++;
91      }
92    }
93    return true;
94  }
95
96  bool TreeArc(StateId s, const A &arc) {
97    return true;
98  }
99
100  bool BackArc(StateId s, const A &arc) {
101    FSTERROR() << "IntervalReachVisitor: cyclic input";
102    error_ = true;
103    return false;
104  }
105
106  bool ForwardOrCrossArc(StateId s, const A &arc) {
107    // Non-tree interval
108    (*isets_)[s].Union((*isets_)[arc.nextstate]);
109    return true;
110  }
111
112  void FinishState(StateId s, StateId p, const A *arc) {
113    if (index_ >= 0 && fst_.Final(s) != Weight::Zero()) {
114      vector<Interval> *intervals = (*isets_)[s].Intervals();
115      (*intervals)[0].end = index_;      // Update tree interval end
116    }
117    (*isets_)[s].Normalize();
118    if (p != kNoStateId)
119      (*isets_)[p].Union((*isets_)[s]);  // Propagate intervals to parent
120  }
121
122  void FinishVisit() {}
123
124  bool Error() const { return error_; }
125
126 private:
127  const Fst<A> &fst_;
128  vector< IntervalSet<I> > *isets_;
129  vector<I> *state2index_;
130  I index_;
131  bool error_;
132};
133
134
135// Tests reachability of final states from a given state. To test for
136// reachability from a state s, first do SetState(s). Then a final
137// state f can be reached from state s of FST iff Reach(f) is true.
138template <class A, typename I = typename A::StateId>
139class StateReachable {
140 public:
141  typedef A Arc;
142  typedef I Index;
143  typedef typename A::StateId StateId;
144  typedef typename A::Label Label;
145  typedef typename A::Weight Weight;
146  typedef typename IntervalSet<I>::Interval Interval;
147
148  StateReachable(const Fst<A> &fst)
149      : error_(false) {
150    IntervalReachVisitor<Arc> reach_visitor(fst, &isets_, &state2index_);
151    DfsVisit(fst, &reach_visitor);
152    if (reach_visitor.Error()) error_ = true;
153  }
154
155  StateReachable(const StateReachable<A> &reachable) {
156    FSTERROR() << "Copy constructor for state reachable class "
157               << "not yet implemented.";
158    error_ = true;
159  }
160
161  // Set current state.
162  void SetState(StateId s) { s_ = s; }
163
164  // Can reach this label from current state?
165  bool Reach(StateId s) {
166    if (s >= state2index_.size())
167      return false;
168
169    I i =  state2index_[s];
170    if (i < 0) {
171      FSTERROR() << "StateReachable: state non-final: " << s;
172      error_ = true;
173      return false;
174    }
175    return isets_[s_].Member(i);
176  }
177
178  // Access to the state-to-index mapping. Unassigned states have index -1.
179  vector<I> &State2Index() { return state2index_; }
180
181  // Access to the interval sets. These specify the reachability
182  // to the final states as intervals of the final state indices.
183  const vector< IntervalSet<I> > &IntervalSets() { return isets_; }
184
185  bool Error() const { return error_; }
186
187 private:
188  StateId s_;                                 // Current state
189  vector< IntervalSet<I> > isets_;            // Interval sets per state
190  vector<I> state2index_;                     // Finds index for a final state
191  bool error_;
192
193  void operator=(const StateReachable<A> &);  // Disallow
194};
195
196}  // namespace fst
197
198#endif  // FST_LIB_STATE_REACHABLE_H__
199