1// shortest-distance.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: allauzen@google.com (Cyril Allauzen)
17//
18// \file
19// Functions and classes to find shortest distance in an FST.
20
21#ifndef FST_LIB_SHORTEST_DISTANCE_H__
22#define FST_LIB_SHORTEST_DISTANCE_H__
23
24#include <deque>
25using std::deque;
26#include <vector>
27using std::vector;
28
29#include <fst/arcfilter.h>
30#include <fst/cache.h>
31#include <fst/queue.h>
32#include <fst/reverse.h>
33#include <fst/test-properties.h>
34
35
36namespace fst {
37
38template <class Arc, class Queue, class ArcFilter>
39struct ShortestDistanceOptions {
40  typedef typename Arc::StateId StateId;
41
42  Queue *state_queue;    // Queue discipline used; owned by caller
43  ArcFilter arc_filter;  // Arc filter (e.g., limit to only epsilon graph)
44  StateId source;        // If kNoStateId, use the Fst's initial state
45  float delta;           // Determines the degree of convergence required
46  bool first_path;       // For a semiring with the path property (o.w.
47                         // undefined), compute the shortest-distances along
48                         // along the first path to a final state found
49                         // by the algorithm. That path is the shortest-path
50                         // only if the FST has a unique final state (or all
51                         // the final states have the same final weight), the
52                         // queue discipline is shortest-first and all the
53                         // weights in the FST are between One() and Zero()
54                         // according to NaturalLess.
55
56  ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src = kNoStateId,
57                          float d = kDelta)
58      : state_queue(q), arc_filter(filt), source(src), delta(d),
59        first_path(false) {}
60};
61
62
63// Computation state of the shortest-distance algorithm. Reusable
64// information is maintained across calls to member function
65// ShortestDistance(source) when 'retain' is true for improved
66// efficiency when calling multiple times from different source states
67// (e.g., in epsilon removal). Contrary to usual conventions, 'fst'
68// may not be freed before this class. Vector 'distance' should not be
69// modified by the user between these calls.
70// The Error() method returns true if an error was encountered.
71template<class Arc, class Queue, class ArcFilter>
72class ShortestDistanceState {
73 public:
74  typedef typename Arc::StateId StateId;
75  typedef typename Arc::Weight Weight;
76
77  ShortestDistanceState(
78      const Fst<Arc> &fst,
79      vector<Weight> *distance,
80      const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts,
81      bool retain)
82      : fst_(fst), distance_(distance), state_queue_(opts.state_queue),
83        arc_filter_(opts.arc_filter), delta_(opts.delta),
84        first_path_(opts.first_path), retain_(retain), source_id_(0),
85        error_(false) {
86    distance_->clear();
87  }
88
89  ~ShortestDistanceState() {}
90
91  void ShortestDistance(StateId source);
92
93  bool Error() const { return error_; }
94
95 private:
96  const Fst<Arc> &fst_;
97  vector<Weight> *distance_;
98  Queue *state_queue_;
99  ArcFilter arc_filter_;
100  float delta_;
101  bool first_path_;
102  bool retain_;               // Retain and reuse information across calls
103
104  vector<Weight> rdistance_;  // Relaxation distance.
105  vector<bool> enqueued_;     // Is state enqueued?
106  vector<StateId> sources_;   // Source ID for ith state in 'distance_',
107                              //  'rdistance_', and 'enqueued_' if retained.
108  StateId source_id_;         // Unique ID characterizing each call to SD
109
110  bool error_;
111};
112
113// Compute the shortest distance. If 'source' is kNoStateId, use
114// the initial state of the Fst.
115template <class Arc, class Queue, class ArcFilter>
116void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance(
117    StateId source) {
118  if (fst_.Start() == kNoStateId) {
119    if (fst_.Properties(kError, false)) error_ = true;
120    return;
121  }
122
123  if (!(Weight::Properties() & kRightSemiring)) {
124    FSTERROR() << "ShortestDistance: Weight needs to be right distributive: "
125               << Weight::Type();
126    error_ = true;
127    return;
128  }
129
130  if (first_path_ && !(Weight::Properties() & kPath)) {
131    FSTERROR() << "ShortestDistance: first_path option disallowed when "
132               << "Weight does not have the path property: "
133               << Weight::Type();
134    error_ = true;
135    return;
136  }
137
138  state_queue_->Clear();
139
140  if (!retain_) {
141    distance_->clear();
142    rdistance_.clear();
143    enqueued_.clear();
144  }
145
146  if (source == kNoStateId)
147    source = fst_.Start();
148
149  while (distance_->size() <= source) {
150    distance_->push_back(Weight::Zero());
151    rdistance_.push_back(Weight::Zero());
152    enqueued_.push_back(false);
153  }
154  if (retain_) {
155    while (sources_.size() <= source)
156      sources_.push_back(kNoStateId);
157    sources_[source] = source_id_;
158  }
159  (*distance_)[source] = Weight::One();
160  rdistance_[source] = Weight::One();
161  enqueued_[source] = true;
162
163  state_queue_->Enqueue(source);
164
165  while (!state_queue_->Empty()) {
166    StateId s = state_queue_->Head();
167    state_queue_->Dequeue();
168    while (distance_->size() <= s) {
169      distance_->push_back(Weight::Zero());
170      rdistance_.push_back(Weight::Zero());
171      enqueued_.push_back(false);
172    }
173    if (first_path_ && (fst_.Final(s) != Weight::Zero()))
174      break;
175    enqueued_[s] = false;
176    Weight r = rdistance_[s];
177    rdistance_[s] = Weight::Zero();
178    for (ArcIterator< Fst<Arc> > aiter(fst_, s);
179         !aiter.Done();
180         aiter.Next()) {
181      const Arc &arc = aiter.Value();
182      if (!arc_filter_(arc))
183        continue;
184      while (distance_->size() <= arc.nextstate) {
185        distance_->push_back(Weight::Zero());
186        rdistance_.push_back(Weight::Zero());
187        enqueued_.push_back(false);
188      }
189      if (retain_) {
190        while (sources_.size() <= arc.nextstate)
191          sources_.push_back(kNoStateId);
192        if (sources_[arc.nextstate] != source_id_) {
193          (*distance_)[arc.nextstate] = Weight::Zero();
194          rdistance_[arc.nextstate] = Weight::Zero();
195          enqueued_[arc.nextstate] = false;
196          sources_[arc.nextstate] = source_id_;
197        }
198      }
199      Weight &nd = (*distance_)[arc.nextstate];
200      Weight &nr = rdistance_[arc.nextstate];
201      Weight w = Times(r, arc.weight);
202      if (!ApproxEqual(nd, Plus(nd, w), delta_)) {
203        nd = Plus(nd, w);
204        nr = Plus(nr, w);
205        if (!nd.Member() || !nr.Member()) {
206          error_ = true;
207          return;
208        }
209        if (!enqueued_[arc.nextstate]) {
210          state_queue_->Enqueue(arc.nextstate);
211          enqueued_[arc.nextstate] = true;
212        } else {
213          state_queue_->Update(arc.nextstate);
214        }
215      }
216    }
217  }
218  ++source_id_;
219  if (fst_.Properties(kError, false)) error_ = true;
220}
221
222
223// Shortest-distance algorithm: this version allows fine control
224// via the options argument. See below for a simpler interface.
225//
226// This computes the shortest distance from the 'opts.source' state to
227// each visited state S and stores the value in the 'distance' vector.
228// An unvisited state S has distance Zero(), which will be stored in
229// the 'distance' vector if S is less than the maximum visited state.
230// The state queue discipline, arc filter, and convergence delta are
231// taken in the options argument.
232// The 'distance' vector will contain a unique element for which
233// Member() is false if an error was encountered.
234//
235// The weights must must be right distributive and k-closed (i.e., 1 +
236// x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
237//
238// The algorithm is from Mohri, "Semiring Framweork and Algorithms for
239// Shortest-Distance Problems", Journal of Automata, Languages and
240// Combinatorics 7(3):321-350, 2002. The complexity of algorithm
241// depends on the properties of the semiring and the queue discipline
242// used. Refer to the paper for more details.
243template<class Arc, class Queue, class ArcFilter>
244void ShortestDistance(
245    const Fst<Arc> &fst,
246    vector<typename Arc::Weight> *distance,
247    const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) {
248
249  ShortestDistanceState<Arc, Queue, ArcFilter>
250    sd_state(fst, distance, opts, false);
251  sd_state.ShortestDistance(opts.source);
252  if (sd_state.Error()) {
253    distance->clear();
254    distance->resize(1, Arc::Weight::NoWeight());
255  }
256}
257
258// Shortest-distance algorithm: simplified interface. See above for a
259// version that allows finer control.
260//
261// If 'reverse' is false, this computes the shortest distance from the
262// initial state to each state S and stores the value in the
263// 'distance' vector. If 'reverse' is true, this computes the shortest
264// distance from each state to the final states.  An unvisited state S
265// has distance Zero(), which will be stored in the 'distance' vector
266// if S is less than the maximum visited state.  The state queue
267// discipline is automatically-selected.
268// The 'distance' vector will contain a unique element for which
269// Member() is false if an error was encountered.
270//
271// The weights must must be right (left) distributive if reverse is
272// false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 +
273// x + x^2 + ... + x^k).
274//
275// The algorithm is from Mohri, "Semiring Framweork and Algorithms for
276// Shortest-Distance Problems", Journal of Automata, Languages and
277// Combinatorics 7(3):321-350, 2002. The complexity of algorithm
278// depends on the properties of the semiring and the queue discipline
279// used. Refer to the paper for more details.
280template<class Arc>
281void ShortestDistance(const Fst<Arc> &fst,
282                      vector<typename Arc::Weight> *distance,
283                      bool reverse = false,
284                      float delta = kDelta) {
285  typedef typename Arc::StateId StateId;
286  typedef typename Arc::Weight Weight;
287
288  if (!reverse) {
289    AnyArcFilter<Arc> arc_filter;
290    AutoQueue<StateId> state_queue(fst, distance, arc_filter);
291    ShortestDistanceOptions< Arc, AutoQueue<StateId>, AnyArcFilter<Arc> >
292      opts(&state_queue, arc_filter);
293    opts.delta = delta;
294    ShortestDistance(fst, distance, opts);
295  } else {
296    typedef ReverseArc<Arc> ReverseArc;
297    typedef typename ReverseArc::Weight ReverseWeight;
298    AnyArcFilter<ReverseArc> rarc_filter;
299    VectorFst<ReverseArc> rfst;
300    Reverse(fst, &rfst);
301    vector<ReverseWeight> rdistance;
302    AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
303    ShortestDistanceOptions< ReverseArc, AutoQueue<StateId>,
304      AnyArcFilter<ReverseArc> >
305      ropts(&state_queue, rarc_filter);
306    ropts.delta = delta;
307    ShortestDistance(rfst, &rdistance, ropts);
308    distance->clear();
309    if (rdistance.size() == 1 && !rdistance[0].Member()) {
310      distance->resize(1, Arc::Weight::NoWeight());
311      return;
312    }
313    while (distance->size() < rdistance.size() - 1)
314      distance->push_back(rdistance[distance->size() + 1].Reverse());
315  }
316}
317
318
319// Return the sum of the weight of all successful paths in an FST, i.e.,
320// the shortest-distance from the initial state to the final states.
321// Returns a weight such that Member() is false if an error was encountered.
322template <class Arc>
323typename Arc::Weight ShortestDistance(const Fst<Arc> &fst, float delta = kDelta) {
324  typedef typename Arc::Weight Weight;
325  typedef typename Arc::StateId StateId;
326  vector<Weight> distance;
327  if (Weight::Properties() & kRightSemiring) {
328    ShortestDistance(fst, &distance, false, delta);
329    if (distance.size() == 1 && !distance[0].Member())
330      return Arc::Weight::NoWeight();
331    Weight sum = Weight::Zero();
332    for (StateId s = 0; s < distance.size(); ++s)
333      sum = Plus(sum, Times(distance[s], fst.Final(s)));
334    return sum;
335  } else {
336    ShortestDistance(fst, &distance, true, delta);
337    StateId s = fst.Start();
338    if (distance.size() == 1 && !distance[0].Member())
339      return Arc::Weight::NoWeight();
340    return s != kNoStateId && s < distance.size() ?
341        distance[s] : Weight::Zero();
342  }
343}
344
345
346}  // namespace fst
347
348#endif  // FST_LIB_SHORTEST_DISTANCE_H__
349