shortest-distance.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
1// shortest-distance.h 2 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// Copyright 2005-2010 Google, Inc. 16// Author: allauzen@google.com (Cyril Allauzen) 17// 18// \file 19// Functions and classes to find shortest distance in an FST. 20 21#ifndef FST_LIB_SHORTEST_DISTANCE_H__ 22#define FST_LIB_SHORTEST_DISTANCE_H__ 23 24#include <deque> 25#include <vector> 26using std::vector; 27 28#include <fst/arcfilter.h> 29#include <fst/cache.h> 30#include <fst/queue.h> 31#include <fst/reverse.h> 32#include <fst/test-properties.h> 33 34 35namespace fst { 36 37template <class Arc, class Queue, class ArcFilter> 38struct ShortestDistanceOptions { 39 typedef typename Arc::StateId StateId; 40 41 Queue *state_queue; // Queue discipline used; owned by caller 42 ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph) 43 StateId source; // If kNoStateId, use the Fst's initial state 44 float delta; // Determines the degree of convergence required 45 bool first_path; // For a semiring with the path property (o.w. 46 // undefined), compute the shortest-distances along 47 // along the first path to a final state found 48 // by the algorithm. That path is the shortest-path 49 // only if the FST has a unique final state (or all 50 // the final states have the same final weight), the 51 // queue discipline is shortest-first and all the 52 // weights in the FST are between One() and Zero() 53 // according to NaturalLess. 54 55 ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src = kNoStateId, 56 float d = kDelta) 57 : state_queue(q), arc_filter(filt), source(src), delta(d), 58 first_path(false) {} 59}; 60 61 62// Computation state of the shortest-distance algorithm. Reusable 63// information is maintained across calls to member function 64// ShortestDistance(source) when 'retain' is true for improved 65// efficiency when calling multiple times from different source states 66// (e.g., in epsilon removal). Contrary to usual conventions, 'fst' 67// may not be freed before this class. Vector 'distance' should not be 68// modified by the user between these calls. 69// The Error() method returns true if an error was encountered. 70template<class Arc, class Queue, class ArcFilter> 71class ShortestDistanceState { 72 public: 73 typedef typename Arc::StateId StateId; 74 typedef typename Arc::Weight Weight; 75 76 ShortestDistanceState( 77 const Fst<Arc> &fst, 78 vector<Weight> *distance, 79 const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, 80 bool retain) 81 : fst_(fst), distance_(distance), state_queue_(opts.state_queue), 82 arc_filter_(opts.arc_filter), delta_(opts.delta), 83 first_path_(opts.first_path), retain_(retain), source_id_(0), 84 error_(false) { 85 distance_->clear(); 86 } 87 88 ~ShortestDistanceState() {} 89 90 void ShortestDistance(StateId source); 91 92 bool Error() const { return error_; } 93 94 private: 95 const Fst<Arc> &fst_; 96 vector<Weight> *distance_; 97 Queue *state_queue_; 98 ArcFilter arc_filter_; 99 float delta_; 100 bool first_path_; 101 bool retain_; // Retain and reuse information across calls 102 103 vector<Weight> rdistance_; // Relaxation distance. 104 vector<bool> enqueued_; // Is state enqueued? 105 vector<StateId> sources_; // Source ID for ith state in 'distance_', 106 // 'rdistance_', and 'enqueued_' if retained. 107 StateId source_id_; // Unique ID characterizing each call to SD 108 109 bool error_; 110}; 111 112// Compute the shortest distance. If 'source' is kNoStateId, use 113// the initial state of the Fst. 114template <class Arc, class Queue, class ArcFilter> 115void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance( 116 StateId source) { 117 if (fst_.Start() == kNoStateId) { 118 if (fst_.Properties(kError, false)) error_ = true; 119 return; 120 } 121 122 if (!(Weight::Properties() & kRightSemiring)) { 123 FSTERROR() << "ShortestDistance: Weight needs to be right distributive: " 124 << Weight::Type(); 125 error_ = true; 126 return; 127 } 128 129 if (first_path_ && !(Weight::Properties() & kPath)) { 130 FSTERROR() << "ShortestDistance: first_path option disallowed when " 131 << "Weight does not have the path property: " 132 << Weight::Type(); 133 error_ = true; 134 return; 135 } 136 137 state_queue_->Clear(); 138 139 if (!retain_) { 140 distance_->clear(); 141 rdistance_.clear(); 142 enqueued_.clear(); 143 } 144 145 if (source == kNoStateId) 146 source = fst_.Start(); 147 148 while (distance_->size() <= source) { 149 distance_->push_back(Weight::Zero()); 150 rdistance_.push_back(Weight::Zero()); 151 enqueued_.push_back(false); 152 } 153 if (retain_) { 154 while (sources_.size() <= source) 155 sources_.push_back(kNoStateId); 156 sources_[source] = source_id_; 157 } 158 (*distance_)[source] = Weight::One(); 159 rdistance_[source] = Weight::One(); 160 enqueued_[source] = true; 161 162 state_queue_->Enqueue(source); 163 164 while (!state_queue_->Empty()) { 165 StateId s = state_queue_->Head(); 166 state_queue_->Dequeue(); 167 while (distance_->size() <= s) { 168 distance_->push_back(Weight::Zero()); 169 rdistance_.push_back(Weight::Zero()); 170 enqueued_.push_back(false); 171 } 172 if (first_path_ && (fst_.Final(s) != Weight::Zero())) 173 break; 174 enqueued_[s] = false; 175 Weight r = rdistance_[s]; 176 rdistance_[s] = Weight::Zero(); 177 for (ArcIterator< Fst<Arc> > aiter(fst_, s); 178 !aiter.Done(); 179 aiter.Next()) { 180 const Arc &arc = aiter.Value(); 181 if (!arc_filter_(arc) || arc.weight == Weight::Zero()) 182 continue; 183 while (distance_->size() <= arc.nextstate) { 184 distance_->push_back(Weight::Zero()); 185 rdistance_.push_back(Weight::Zero()); 186 enqueued_.push_back(false); 187 } 188 if (retain_) { 189 while (sources_.size() <= arc.nextstate) 190 sources_.push_back(kNoStateId); 191 if (sources_[arc.nextstate] != source_id_) { 192 (*distance_)[arc.nextstate] = Weight::Zero(); 193 rdistance_[arc.nextstate] = Weight::Zero(); 194 enqueued_[arc.nextstate] = false; 195 sources_[arc.nextstate] = source_id_; 196 } 197 } 198 Weight &nd = (*distance_)[arc.nextstate]; 199 Weight &nr = rdistance_[arc.nextstate]; 200 Weight w = Times(r, arc.weight); 201 if (!ApproxEqual(nd, Plus(nd, w), delta_)) { 202 nd = Plus(nd, w); 203 nr = Plus(nr, w); 204 if (!nd.Member() || !nr.Member()) { 205 error_ = true; 206 return; 207 } 208 if (!enqueued_[arc.nextstate]) { 209 state_queue_->Enqueue(arc.nextstate); 210 enqueued_[arc.nextstate] = true; 211 } else { 212 state_queue_->Update(arc.nextstate); 213 } 214 } 215 } 216 } 217 ++source_id_; 218 if (fst_.Properties(kError, false)) error_ = true; 219} 220 221 222// Shortest-distance algorithm: this version allows fine control 223// via the options argument. See below for a simpler interface. 224// 225// This computes the shortest distance from the 'opts.source' state to 226// each visited state S and stores the value in the 'distance' vector. 227// An unvisited state S has distance Zero(), which will be stored in 228// the 'distance' vector if S is less than the maximum visited state. 229// The state queue discipline, arc filter, and convergence delta are 230// taken in the options argument. 231// The 'distance' vector will contain a unique element for which 232// Member() is false if an error was encountered. 233// 234// The weights must must be right distributive and k-closed (i.e., 1 + 235// x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k). 236// 237// The algorithm is from Mohri, "Semiring Framweork and Algorithms for 238// Shortest-Distance Problems", Journal of Automata, Languages and 239// Combinatorics 7(3):321-350, 2002. The complexity of algorithm 240// depends on the properties of the semiring and the queue discipline 241// used. Refer to the paper for more details. 242template<class Arc, class Queue, class ArcFilter> 243void ShortestDistance( 244 const Fst<Arc> &fst, 245 vector<typename Arc::Weight> *distance, 246 const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) { 247 248 ShortestDistanceState<Arc, Queue, ArcFilter> 249 sd_state(fst, distance, opts, false); 250 sd_state.ShortestDistance(opts.source); 251 if (sd_state.Error()) { 252 distance->clear(); 253 distance->resize(1, Arc::Weight::NoWeight()); 254 } 255} 256 257// Shortest-distance algorithm: simplified interface. See above for a 258// version that allows finer control. 259// 260// If 'reverse' is false, this computes the shortest distance from the 261// initial state to each state S and stores the value in the 262// 'distance' vector. If 'reverse' is true, this computes the shortest 263// distance from each state to the final states. An unvisited state S 264// has distance Zero(), which will be stored in the 'distance' vector 265// if S is less than the maximum visited state. The state queue 266// discipline is automatically-selected. 267// The 'distance' vector will contain a unique element for which 268// Member() is false if an error was encountered. 269// 270// The weights must must be right (left) distributive if reverse is 271// false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + 272// x + x^2 + ... + x^k). 273// 274// The algorithm is from Mohri, "Semiring Framweork and Algorithms for 275// Shortest-Distance Problems", Journal of Automata, Languages and 276// Combinatorics 7(3):321-350, 2002. The complexity of algorithm 277// depends on the properties of the semiring and the queue discipline 278// used. Refer to the paper for more details. 279template<class Arc> 280void ShortestDistance(const Fst<Arc> &fst, 281 vector<typename Arc::Weight> *distance, 282 bool reverse = false, 283 float delta = kDelta) { 284 typedef typename Arc::StateId StateId; 285 typedef typename Arc::Weight Weight; 286 287 if (!reverse) { 288 AnyArcFilter<Arc> arc_filter; 289 AutoQueue<StateId> state_queue(fst, distance, arc_filter); 290 ShortestDistanceOptions< Arc, AutoQueue<StateId>, AnyArcFilter<Arc> > 291 opts(&state_queue, arc_filter); 292 opts.delta = delta; 293 ShortestDistance(fst, distance, opts); 294 } else { 295 typedef ReverseArc<Arc> ReverseArc; 296 typedef typename ReverseArc::Weight ReverseWeight; 297 AnyArcFilter<ReverseArc> rarc_filter; 298 VectorFst<ReverseArc> rfst; 299 Reverse(fst, &rfst); 300 vector<ReverseWeight> rdistance; 301 AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter); 302 ShortestDistanceOptions< ReverseArc, AutoQueue<StateId>, 303 AnyArcFilter<ReverseArc> > 304 ropts(&state_queue, rarc_filter); 305 ropts.delta = delta; 306 ShortestDistance(rfst, &rdistance, ropts); 307 distance->clear(); 308 if (rdistance.size() == 1 && !rdistance[0].Member()) { 309 distance->resize(1, Arc::Weight::NoWeight()); 310 return; 311 } 312 while (distance->size() < rdistance.size() - 1) 313 distance->push_back(rdistance[distance->size() + 1].Reverse()); 314 } 315} 316 317 318// Return the sum of the weight of all successful paths in an FST, i.e., 319// the shortest-distance from the initial state to the final states. 320// Returns a weight such that Member() is false if an error was encountered. 321template <class Arc> 322typename Arc::Weight ShortestDistance(const Fst<Arc> &fst, float delta = kDelta) { 323 typedef typename Arc::Weight Weight; 324 typedef typename Arc::StateId StateId; 325 vector<Weight> distance; 326 if (Weight::Properties() & kRightSemiring) { 327 ShortestDistance(fst, &distance, false, delta); 328 if (distance.size() == 1 && !distance[0].Member()) 329 return Arc::Weight::NoWeight(); 330 Weight sum = Weight::Zero(); 331 for (StateId s = 0; s < distance.size(); ++s) 332 sum = Plus(sum, Times(distance[s], fst.Final(s))); 333 return sum; 334 } else { 335 ShortestDistance(fst, &distance, true, delta); 336 StateId s = fst.Start(); 337 if (distance.size() == 1 && !distance[0].Member()) 338 return Arc::Weight::NoWeight(); 339 return s != kNoStateId && s < distance.size() ? 340 distance[s] : Weight::Zero(); 341 } 342} 343 344 345} // namespace fst 346 347#endif // FST_LIB_SHORTEST_DISTANCE_H__ 348