1// shortest-distance.h 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// Author: allauzen@cs.nyu.edu (Cyril Allauzen) 16// 17// \file 18// Functions and classes to find shortest distance in an FST. 19 20#ifndef FST_LIB_SHORTEST_DISTANCE_H__ 21#define FST_LIB_SHORTEST_DISTANCE_H__ 22 23#include <deque> 24 25#include "fst/lib/arcfilter.h" 26#include "fst/lib/cache.h" 27#include "fst/lib/queue.h" 28#include "fst/lib/reverse.h" 29#include "fst/lib/test-properties.h" 30 31namespace fst { 32 33template <class Arc, class Queue, class ArcFilter> 34struct ShortestDistanceOptions { 35 typedef typename Arc::StateId StateId; 36 37 Queue *state_queue; // Queue discipline used; owned by caller 38 ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph) 39 StateId source; // If kNoStateId, use the Fst's initial state 40 float delta; // Determines the degree of convergence required 41 42 ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src = kNoStateId, 43 float d = kDelta) 44 : state_queue(q), arc_filter(filt), source(src), delta(d) {} 45}; 46 47 48// Computation state of the shortest-distance algorithm. Reusable 49// information is maintained across calls to member function 50// ShortestDistance(source) when 'retain' is true for improved 51// efficiency when calling multiple times from different source states 52// (e.g., in epsilon removal). Vector 'distance' should not be 53// modified by the user between these calls. 54template<class Arc, class Queue, class ArcFilter> 55class ShortestDistanceState { 56 public: 57 typedef typename Arc::StateId StateId; 58 typedef typename Arc::Weight Weight; 59 60 ShortestDistanceState( 61 const Fst<Arc> &fst, 62 vector<Weight> *distance, 63 const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, 64 bool retain) 65 : fst_(fst.Copy()), distance_(distance), state_queue_(opts.state_queue), 66 arc_filter_(opts.arc_filter), 67 delta_(opts.delta), retain_(retain) { 68 distance_->clear(); 69 } 70 71 ~ShortestDistanceState() { 72 delete fst_; 73 } 74 75 void ShortestDistance(StateId source); 76 77 private: 78 const Fst<Arc> *fst_; 79 vector<Weight> *distance_; 80 Queue *state_queue_; 81 ArcFilter arc_filter_; 82 float delta_; 83 bool retain_; // Retain and reuse information across calls 84 85 vector<Weight> rdistance_; // Relaxation distance. 86 vector<bool> enqueued_; // Is state enqueued? 87 vector<StateId> sources_; // Source state for ith state in 'distance_', 88 // 'rdistance_', and 'enqueued_' if retained. 89}; 90 91// Compute the shortest distance. If 'source' is kNoStateId, use 92// the initial state of the Fst. 93template <class Arc, class Queue, class ArcFilter> 94void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance( 95 StateId source) { 96 if (fst_->Start() == kNoStateId) 97 return; 98 99 if (!(Weight::Properties() & kRightSemiring)) 100 LOG(FATAL) << "ShortestDistance: Weight needs to be right distributive: " 101 << Weight::Type(); 102 103 state_queue_->Clear(); 104 105 if (!retain_) { 106 distance_->clear(); 107 rdistance_.clear(); 108 enqueued_.clear(); 109 } 110 111 if (source == kNoStateId) 112 source = fst_->Start(); 113 114 while ((StateId)distance_->size() <= source) { 115 distance_->push_back(Weight::Zero()); 116 rdistance_.push_back(Weight::Zero()); 117 enqueued_.push_back(false); 118 } 119 if (retain_) { 120 while ((StateId)sources_.size() <= source) 121 sources_.push_back(kNoStateId); 122 sources_[source] = source; 123 } 124 (*distance_)[source] = Weight::One(); 125 rdistance_[source] = Weight::One(); 126 enqueued_[source] = true; 127 128 state_queue_->Enqueue(source); 129 130 while (!state_queue_->Empty()) { 131 StateId s = state_queue_->Head(); 132 state_queue_->Dequeue(); 133 while ((StateId)distance_->size() <= s) { 134 distance_->push_back(Weight::Zero()); 135 rdistance_.push_back(Weight::Zero()); 136 enqueued_.push_back(false); 137 } 138 enqueued_[s] = false; 139 Weight r = rdistance_[s]; 140 rdistance_[s] = Weight::Zero(); 141 for (ArcIterator< Fst<Arc> > aiter(*fst_, s); 142 !aiter.Done(); 143 aiter.Next()) { 144 const Arc &arc = aiter.Value(); 145 if (!arc_filter_(arc) || arc.weight == Weight::Zero()) 146 continue; 147 while ((StateId)distance_->size() <= arc.nextstate) { 148 distance_->push_back(Weight::Zero()); 149 rdistance_.push_back(Weight::Zero()); 150 enqueued_.push_back(false); 151 } 152 if (retain_) { 153 while ((StateId)sources_.size() <= arc.nextstate) 154 sources_.push_back(kNoStateId); 155 if (sources_[arc.nextstate] != source) { 156 (*distance_)[arc.nextstate] = Weight::Zero(); 157 rdistance_[arc.nextstate] = Weight::Zero(); 158 enqueued_[arc.nextstate] = false; 159 sources_[arc.nextstate] = source; 160 } 161 } 162 Weight &nd = (*distance_)[arc.nextstate]; 163 Weight &nr = rdistance_[arc.nextstate]; 164 Weight w = Times(r, arc.weight); 165 if (!ApproxEqual(nd, Plus(nd, w), delta_)) { 166 nd = Plus(nd, w); 167 nr = Plus(nr, w); 168 if (!enqueued_[arc.nextstate]) { 169 state_queue_->Enqueue(arc.nextstate); 170 enqueued_[arc.nextstate] = true; 171 } else { 172 state_queue_->Update(arc.nextstate); 173 } 174 } 175 } 176 } 177} 178 179 180// Shortest-distance algorithm: this version allows fine control 181// via the options argument. See below for a simpler interface. 182// 183// This computes the shortest distance from the 'opts.source' state to 184// each visited state S and stores the value in the 'distance' vector. 185// An unvisited state S has distance Zero(), which will be stored in 186// the 'distance' vector if S is less than the maximum visited state. 187// The state queue discipline, arc filter, and convergence delta are 188// taken in the options argument. 189 190// The weights must must be right distributive and k-closed (i.e., 1 + 191// x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k). 192// 193// The algorithm is from Mohri, "Semiring Framweork and Algorithms for 194// Shortest-Distance Problems", Journal of Automata, Languages and 195// Combinatorics 7(3):321-350, 2002. The complexity of algorithm 196// depends on the properties of the semiring and the queue discipline 197// used. Refer to the paper for more details. 198template<class Arc, class Queue, class ArcFilter> 199void ShortestDistance( 200 const Fst<Arc> &fst, 201 vector<typename Arc::Weight> *distance, 202 const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) { 203 204 ShortestDistanceState<Arc, Queue, ArcFilter> 205 sd_state(fst, distance, opts, false); 206 sd_state.ShortestDistance(opts.source); 207} 208 209// Shortest-distance algorithm: simplified interface. See above for a 210// version that allows finer control. 211// 212// If 'reverse' is false, this computes the shortest distance from the 213// initial state to each state S and stores the value in the 214// 'distance' vector. If 'reverse' is true, this computes the shortest 215// distance from each state to the final states. An unvisited state S 216// has distance Zero(), which will be stored in the 'distance' vector 217// if S is less than the maximum visited state. The state queue 218// discipline is automatically-selected. 219// 220// The weights must must be right (left) distributive if reverse is 221// false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + 222// x + x^2 + ... + x^k). 223// 224// The algorithm is from Mohri, "Semiring Framweork and Algorithms for 225// Shortest-Distance Problems", Journal of Automata, Languages and 226// Combinatorics 7(3):321-350, 2002. The complexity of algorithm 227// depends on the properties of the semiring and the queue discipline 228// used. Refer to the paper for more details. 229template<class Arc> 230void ShortestDistance(const Fst<Arc> &fst, 231 vector<typename Arc::Weight> *distance, 232 bool reverse = false) { 233 typedef typename Arc::StateId StateId; 234 typedef typename Arc::Weight Weight; 235 236 if (!reverse) { 237 AnyArcFilter<Arc> arc_filter; 238 AutoQueue<StateId> state_queue(fst, distance, arc_filter); 239 ShortestDistanceOptions< Arc, AutoQueue<StateId>, AnyArcFilter<Arc> > 240 opts(&state_queue, arc_filter); 241 ShortestDistance(fst, distance, opts); 242 } else { 243 typedef ReverseArc<Arc> ReverseArc; 244 typedef typename ReverseArc::Weight ReverseWeight; 245 AnyArcFilter<ReverseArc> rarc_filter; 246 VectorFst<ReverseArc> rfst; 247 Reverse(fst, &rfst); 248 vector<ReverseWeight> rdistance; 249 AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter); 250 ShortestDistanceOptions< ReverseArc, AutoQueue<StateId>, 251 AnyArcFilter<ReverseArc> > 252 ropts(&state_queue, rarc_filter); 253 ShortestDistance(rfst, &rdistance, ropts); 254 distance->clear(); 255 while (distance->size() < rdistance.size() - 1) 256 distance->push_back(rdistance[distance->size() + 1].Reverse()); 257 } 258} 259 260} // namespace fst 261 262#endif // FST_LIB_SHORTEST_DISTANCE_H__ 263