shortest-distance.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
1// shortest-distance.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: allauzen@google.com (Cyril Allauzen)
17//
18// \file
19// Functions and classes to find shortest distance in an FST.
20
21#ifndef FST_LIB_SHORTEST_DISTANCE_H__
22#define FST_LIB_SHORTEST_DISTANCE_H__
23
24#include <deque>
25#include <vector>
26using std::vector;
27
28#include <fst/arcfilter.h>
29#include <fst/cache.h>
30#include <fst/queue.h>
31#include <fst/reverse.h>
32#include <fst/test-properties.h>
33
34
35namespace fst {
36
37template <class Arc, class Queue, class ArcFilter>
38struct ShortestDistanceOptions {
39  typedef typename Arc::StateId StateId;
40
41  Queue *state_queue;    // Queue discipline used; owned by caller
42  ArcFilter arc_filter;  // Arc filter (e.g., limit to only epsilon graph)
43  StateId source;        // If kNoStateId, use the Fst's initial state
44  float delta;           // Determines the degree of convergence required
45  bool first_path;       // For a semiring with the path property (o.w.
46                         // undefined), compute the shortest-distances along
47                         // along the first path to a final state found
48                         // by the algorithm. That path is the shortest-path
49                         // only if the FST has a unique final state (or all
50                         // the final states have the same final weight), the
51                         // queue discipline is shortest-first and all the
52                         // weights in the FST are between One() and Zero()
53                         // according to NaturalLess.
54
55  ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src = kNoStateId,
56                          float d = kDelta)
57      : state_queue(q), arc_filter(filt), source(src), delta(d),
58        first_path(false) {}
59};
60
61
62// Computation state of the shortest-distance algorithm. Reusable
63// information is maintained across calls to member function
64// ShortestDistance(source) when 'retain' is true for improved
65// efficiency when calling multiple times from different source states
66// (e.g., in epsilon removal). Contrary to usual conventions, 'fst'
67// may not be freed before this class. Vector 'distance' should not be
68// modified by the user between these calls.
69// The Error() method returns true if an error was encountered.
70template<class Arc, class Queue, class ArcFilter>
71class ShortestDistanceState {
72 public:
73  typedef typename Arc::StateId StateId;
74  typedef typename Arc::Weight Weight;
75
76  ShortestDistanceState(
77      const Fst<Arc> &fst,
78      vector<Weight> *distance,
79      const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts,
80      bool retain)
81      : fst_(fst), distance_(distance), state_queue_(opts.state_queue),
82        arc_filter_(opts.arc_filter), delta_(opts.delta),
83        first_path_(opts.first_path), retain_(retain), source_id_(0),
84        error_(false) {
85    distance_->clear();
86  }
87
88  ~ShortestDistanceState() {}
89
90  void ShortestDistance(StateId source);
91
92  bool Error() const { return error_; }
93
94 private:
95  const Fst<Arc> &fst_;
96  vector<Weight> *distance_;
97  Queue *state_queue_;
98  ArcFilter arc_filter_;
99  float delta_;
100  bool first_path_;
101  bool retain_;               // Retain and reuse information across calls
102
103  vector<Weight> rdistance_;  // Relaxation distance.
104  vector<bool> enqueued_;     // Is state enqueued?
105  vector<StateId> sources_;   // Source ID for ith state in 'distance_',
106                              //  'rdistance_', and 'enqueued_' if retained.
107  StateId source_id_;         // Unique ID characterizing each call to SD
108
109  bool error_;
110};
111
112// Compute the shortest distance. If 'source' is kNoStateId, use
113// the initial state of the Fst.
114template <class Arc, class Queue, class ArcFilter>
115void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance(
116    StateId source) {
117  if (fst_.Start() == kNoStateId) {
118    if (fst_.Properties(kError, false)) error_ = true;
119    return;
120  }
121
122  if (!(Weight::Properties() & kRightSemiring)) {
123    FSTERROR() << "ShortestDistance: Weight needs to be right distributive: "
124               << Weight::Type();
125    error_ = true;
126    return;
127  }
128
129  if (first_path_ && !(Weight::Properties() & kPath)) {
130    FSTERROR() << "ShortestDistance: first_path option disallowed when "
131               << "Weight does not have the path property: "
132               << Weight::Type();
133    error_ = true;
134    return;
135  }
136
137  state_queue_->Clear();
138
139  if (!retain_) {
140    distance_->clear();
141    rdistance_.clear();
142    enqueued_.clear();
143  }
144
145  if (source == kNoStateId)
146    source = fst_.Start();
147
148  while (distance_->size() <= source) {
149    distance_->push_back(Weight::Zero());
150    rdistance_.push_back(Weight::Zero());
151    enqueued_.push_back(false);
152  }
153  if (retain_) {
154    while (sources_.size() <= source)
155      sources_.push_back(kNoStateId);
156    sources_[source] = source_id_;
157  }
158  (*distance_)[source] = Weight::One();
159  rdistance_[source] = Weight::One();
160  enqueued_[source] = true;
161
162  state_queue_->Enqueue(source);
163
164  while (!state_queue_->Empty()) {
165    StateId s = state_queue_->Head();
166    state_queue_->Dequeue();
167    while (distance_->size() <= s) {
168      distance_->push_back(Weight::Zero());
169      rdistance_.push_back(Weight::Zero());
170      enqueued_.push_back(false);
171    }
172    if (first_path_ && (fst_.Final(s) != Weight::Zero()))
173      break;
174    enqueued_[s] = false;
175    Weight r = rdistance_[s];
176    rdistance_[s] = Weight::Zero();
177    for (ArcIterator< Fst<Arc> > aiter(fst_, s);
178         !aiter.Done();
179         aiter.Next()) {
180      const Arc &arc = aiter.Value();
181      if (!arc_filter_(arc) || arc.weight == Weight::Zero())
182        continue;
183      while (distance_->size() <= arc.nextstate) {
184        distance_->push_back(Weight::Zero());
185        rdistance_.push_back(Weight::Zero());
186        enqueued_.push_back(false);
187      }
188      if (retain_) {
189        while (sources_.size() <= arc.nextstate)
190          sources_.push_back(kNoStateId);
191        if (sources_[arc.nextstate] != source_id_) {
192          (*distance_)[arc.nextstate] = Weight::Zero();
193          rdistance_[arc.nextstate] = Weight::Zero();
194          enqueued_[arc.nextstate] = false;
195          sources_[arc.nextstate] = source_id_;
196        }
197      }
198      Weight &nd = (*distance_)[arc.nextstate];
199      Weight &nr = rdistance_[arc.nextstate];
200      Weight w = Times(r, arc.weight);
201      if (!ApproxEqual(nd, Plus(nd, w), delta_)) {
202        nd = Plus(nd, w);
203        nr = Plus(nr, w);
204        if (!nd.Member() || !nr.Member()) {
205          error_ = true;
206          return;
207        }
208        if (!enqueued_[arc.nextstate]) {
209          state_queue_->Enqueue(arc.nextstate);
210          enqueued_[arc.nextstate] = true;
211        } else {
212          state_queue_->Update(arc.nextstate);
213        }
214      }
215    }
216  }
217  ++source_id_;
218  if (fst_.Properties(kError, false)) error_ = true;
219}
220
221
222// Shortest-distance algorithm: this version allows fine control
223// via the options argument. See below for a simpler interface.
224//
225// This computes the shortest distance from the 'opts.source' state to
226// each visited state S and stores the value in the 'distance' vector.
227// An unvisited state S has distance Zero(), which will be stored in
228// the 'distance' vector if S is less than the maximum visited state.
229// The state queue discipline, arc filter, and convergence delta are
230// taken in the options argument.
231// The 'distance' vector will contain a unique element for which
232// Member() is false if an error was encountered.
233//
234// The weights must must be right distributive and k-closed (i.e., 1 +
235// x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
236//
237// The algorithm is from Mohri, "Semiring Framweork and Algorithms for
238// Shortest-Distance Problems", Journal of Automata, Languages and
239// Combinatorics 7(3):321-350, 2002. The complexity of algorithm
240// depends on the properties of the semiring and the queue discipline
241// used. Refer to the paper for more details.
242template<class Arc, class Queue, class ArcFilter>
243void ShortestDistance(
244    const Fst<Arc> &fst,
245    vector<typename Arc::Weight> *distance,
246    const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) {
247
248  ShortestDistanceState<Arc, Queue, ArcFilter>
249    sd_state(fst, distance, opts, false);
250  sd_state.ShortestDistance(opts.source);
251  if (sd_state.Error()) {
252    distance->clear();
253    distance->resize(1, Arc::Weight::NoWeight());
254  }
255}
256
257// Shortest-distance algorithm: simplified interface. See above for a
258// version that allows finer control.
259//
260// If 'reverse' is false, this computes the shortest distance from the
261// initial state to each state S and stores the value in the
262// 'distance' vector. If 'reverse' is true, this computes the shortest
263// distance from each state to the final states.  An unvisited state S
264// has distance Zero(), which will be stored in the 'distance' vector
265// if S is less than the maximum visited state.  The state queue
266// discipline is automatically-selected.
267// The 'distance' vector will contain a unique element for which
268// Member() is false if an error was encountered.
269//
270// The weights must must be right (left) distributive if reverse is
271// false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 +
272// x + x^2 + ... + x^k).
273//
274// The algorithm is from Mohri, "Semiring Framweork and Algorithms for
275// Shortest-Distance Problems", Journal of Automata, Languages and
276// Combinatorics 7(3):321-350, 2002. The complexity of algorithm
277// depends on the properties of the semiring and the queue discipline
278// used. Refer to the paper for more details.
279template<class Arc>
280void ShortestDistance(const Fst<Arc> &fst,
281                      vector<typename Arc::Weight> *distance,
282                      bool reverse = false,
283                      float delta = kDelta) {
284  typedef typename Arc::StateId StateId;
285  typedef typename Arc::Weight Weight;
286
287  if (!reverse) {
288    AnyArcFilter<Arc> arc_filter;
289    AutoQueue<StateId> state_queue(fst, distance, arc_filter);
290    ShortestDistanceOptions< Arc, AutoQueue<StateId>, AnyArcFilter<Arc> >
291      opts(&state_queue, arc_filter);
292    opts.delta = delta;
293    ShortestDistance(fst, distance, opts);
294  } else {
295    typedef ReverseArc<Arc> ReverseArc;
296    typedef typename ReverseArc::Weight ReverseWeight;
297    AnyArcFilter<ReverseArc> rarc_filter;
298    VectorFst<ReverseArc> rfst;
299    Reverse(fst, &rfst);
300    vector<ReverseWeight> rdistance;
301    AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
302    ShortestDistanceOptions< ReverseArc, AutoQueue<StateId>,
303      AnyArcFilter<ReverseArc> >
304      ropts(&state_queue, rarc_filter);
305    ropts.delta = delta;
306    ShortestDistance(rfst, &rdistance, ropts);
307    distance->clear();
308    if (rdistance.size() == 1 && !rdistance[0].Member()) {
309      distance->resize(1, Arc::Weight::NoWeight());
310      return;
311    }
312    while (distance->size() < rdistance.size() - 1)
313      distance->push_back(rdistance[distance->size() + 1].Reverse());
314  }
315}
316
317
318// Return the sum of the weight of all successful paths in an FST, i.e.,
319// the shortest-distance from the initial state to the final states.
320// Returns a weight such that Member() is false if an error was encountered.
321template <class Arc>
322typename Arc::Weight ShortestDistance(const Fst<Arc> &fst, float delta = kDelta) {
323  typedef typename Arc::Weight Weight;
324  typedef typename Arc::StateId StateId;
325  vector<Weight> distance;
326  if (Weight::Properties() & kRightSemiring) {
327    ShortestDistance(fst, &distance, false, delta);
328    if (distance.size() == 1 && !distance[0].Member())
329      return Arc::Weight::NoWeight();
330    Weight sum = Weight::Zero();
331    for (StateId s = 0; s < distance.size(); ++s)
332      sum = Plus(sum, Times(distance[s], fst.Final(s)));
333    return sum;
334  } else {
335    ShortestDistance(fst, &distance, true, delta);
336    StateId s = fst.Start();
337    if (distance.size() == 1 && !distance[0].Member())
338      return Arc::Weight::NoWeight();
339    return s != kNoStateId && s < distance.size() ?
340        distance[s] : Weight::Zero();
341  }
342}
343
344
345}  // namespace fst
346
347#endif  // FST_LIB_SHORTEST_DISTANCE_H__
348