shortest-path.h revision 5b6dc79427b8f7eeb6a7ff68034ab8548ce670ea
1// shortest-path.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: allauzen@google.com (Cyril Allauzen)
17//
18// \file
19// Functions to find shortest paths in an FST.
20
21#ifndef FST_LIB_SHORTEST_PATH_H__
22#define FST_LIB_SHORTEST_PATH_H__
23
24#include <functional>
25#include <utility>
26using std::pair; using std::make_pair;
27#include <vector>
28using std::vector;
29
30#include <fst/cache.h>
31#include <fst/determinize.h>
32#include <fst/queue.h>
33#include <fst/shortest-distance.h>
34#include <fst/test-properties.h>
35
36
37namespace fst {
38
39template <class Arc, class Queue, class ArcFilter>
40struct ShortestPathOptions
41    : public ShortestDistanceOptions<Arc, Queue, ArcFilter> {
42  typedef typename Arc::StateId StateId;
43  typedef typename Arc::Weight Weight;
44  size_t nshortest;   // return n-shortest paths
45  bool unique;        // only return paths with distinct input strings
46  bool has_distance;  // distance vector already contains the
47                      // shortest distance from the initial state
48  bool first_path;    // Single shortest path stops after finding the first
49                      // path to a final state. That path is the shortest path
50                      // only when using the ShortestFirstQueue and
51                      // only when all the weights in the FST are between
52                      // One() and Zero() according to NaturalLess.
53  Weight weight_threshold;   // pruning weight threshold.
54  StateId state_threshold;   // pruning state threshold.
55
56  ShortestPathOptions(Queue *q, ArcFilter filt, size_t n = 1, bool u = false,
57                      bool hasdist = false, float d = kDelta,
58                      bool fp = false, Weight w = Weight::Zero(),
59                      StateId s = kNoStateId)
60      : ShortestDistanceOptions<Arc, Queue, ArcFilter>(q, filt, kNoStateId, d),
61        nshortest(n), unique(u), has_distance(hasdist), first_path(fp),
62        weight_threshold(w), state_threshold(s) {}
63};
64
65
66// Shortest-path algorithm: normally not called directly; prefer
67// 'ShortestPath' below with n=1. 'ofst' contains the shortest path in
68// 'ifst'. 'distance' returns the shortest distances from the source
69// state to each state in 'ifst'. 'opts' is used to specify options
70// such as the queue discipline, the arc filter and delta.
71//
72// The shortest path is the lowest weight path w.r.t. the natural
73// semiring order.
74//
75// The weights need to be right distributive and have the path (kPath)
76// property.
77template<class Arc, class Queue, class ArcFilter>
78void SingleShortestPath(const Fst<Arc> &ifst,
79                  MutableFst<Arc> *ofst,
80                  vector<typename Arc::Weight> *distance,
81                  ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
82  typedef typename Arc::StateId StateId;
83  typedef typename Arc::Weight Weight;
84
85  ofst->DeleteStates();
86  ofst->SetInputSymbols(ifst.InputSymbols());
87  ofst->SetOutputSymbols(ifst.OutputSymbols());
88
89  if (ifst.Start() == kNoStateId) {
90    if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
91    return;
92  }
93
94  vector<bool> enqueued;
95  vector<StateId> parent;
96  vector<Arc> arc_parent;
97
98  Queue *state_queue = opts.state_queue;
99  StateId source = opts.source == kNoStateId ? ifst.Start() : opts.source;
100  Weight f_distance = Weight::Zero();
101  StateId f_parent = kNoStateId;
102
103  distance->clear();
104  state_queue->Clear();
105  if (opts.nshortest != 1) {
106    FSTERROR() << "SingleShortestPath: for nshortest > 1, use ShortestPath"
107               << " instead";
108    ofst->SetProperties(kError, kError);
109    return;
110  }
111  if (opts.weight_threshold != Weight::Zero() ||
112      opts.state_threshold != kNoStateId) {
113    FSTERROR() <<
114        "SingleShortestPath: weight and state thresholds not applicable";
115    ofst->SetProperties(kError, kError);
116    return;
117  }
118  if ((Weight::Properties() & (kPath | kRightSemiring))
119      != (kPath | kRightSemiring)) {
120    FSTERROR() << "SingleShortestPath: Weight needs to have the path"
121               << " property and be right distributive: " << Weight::Type();
122    ofst->SetProperties(kError, kError);
123    return;
124  }
125  while (distance->size() < source) {
126    distance->push_back(Weight::Zero());
127    enqueued.push_back(false);
128    parent.push_back(kNoStateId);
129    arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
130  }
131  distance->push_back(Weight::One());
132  parent.push_back(kNoStateId);
133  arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
134  state_queue->Enqueue(source);
135  enqueued.push_back(true);
136
137  while (!state_queue->Empty()) {
138    StateId s = state_queue->Head();
139    state_queue->Dequeue();
140    enqueued[s] = false;
141    Weight sd = (*distance)[s];
142    if (ifst.Final(s) != Weight::Zero()) {
143      Weight w = Times(sd, ifst.Final(s));
144      if (f_distance != Plus(f_distance, w)) {
145        f_distance = Plus(f_distance, w);
146        f_parent = s;
147      }
148      if (!f_distance.Member()) {
149        ofst->SetProperties(kError, kError);
150        return;
151      }
152      if (opts.first_path)
153        break;
154    }
155    for (ArcIterator< Fst<Arc> > aiter(ifst, s);
156         !aiter.Done();
157         aiter.Next()) {
158      const Arc &arc = aiter.Value();
159      while (distance->size() <= arc.nextstate) {
160        distance->push_back(Weight::Zero());
161        enqueued.push_back(false);
162        parent.push_back(kNoStateId);
163        arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(),
164                                 kNoStateId));
165      }
166      Weight &nd = (*distance)[arc.nextstate];
167      Weight w = Times(sd, arc.weight);
168      if (nd != Plus(nd, w)) {
169        nd = Plus(nd, w);
170        if (!nd.Member()) {
171          ofst->SetProperties(kError, kError);
172          return;
173        }
174        parent[arc.nextstate] = s;
175        arc_parent[arc.nextstate] = arc;
176        if (!enqueued[arc.nextstate]) {
177          state_queue->Enqueue(arc.nextstate);
178          enqueued[arc.nextstate] = true;
179        } else {
180          state_queue->Update(arc.nextstate);
181        }
182      }
183    }
184  }
185
186  StateId s_p = kNoStateId, d_p = kNoStateId;
187  for (StateId s = f_parent, d = kNoStateId;
188       s != kNoStateId;
189       d = s, s = parent[s]) {
190    d_p = s_p;
191    s_p = ofst->AddState();
192    if (d == kNoStateId) {
193      ofst->SetFinal(s_p, ifst.Final(f_parent));
194    } else {
195      arc_parent[d].nextstate = d_p;
196      ofst->AddArc(s_p, arc_parent[d]);
197    }
198  }
199  ofst->SetStart(s_p);
200  if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
201  ofst->SetProperties(
202      ShortestPathProperties(ofst->Properties(kFstProperties, false)),
203      kFstProperties);
204}
205
206
207template <class S, class W>
208class ShortestPathCompare {
209 public:
210  typedef S StateId;
211  typedef W Weight;
212  typedef pair<StateId, Weight> Pair;
213
214  ShortestPathCompare(const vector<Pair>& pairs,
215                      const vector<Weight>& distance,
216                      StateId sfinal, float d)
217      : pairs_(pairs), distance_(distance), superfinal_(sfinal), delta_(d)  {}
218
219  bool operator()(const StateId x, const StateId y) const {
220    const Pair &px = pairs_[x];
221    const Pair &py = pairs_[y];
222    Weight dx = px.first == superfinal_ ? Weight::One() :
223        px.first < distance_.size() ? distance_[px.first] : Weight::Zero();
224    Weight dy = py.first == superfinal_ ? Weight::One() :
225        py.first < distance_.size() ? distance_[py.first] : Weight::Zero();
226    Weight wx = Times(dx, px.second);
227    Weight wy = Times(dy, py.second);
228    // Penalize complete paths to ensure correct results with inexact weights.
229    // This forms a strict weak order so long as ApproxEqual(a, b) =>
230    // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b).
231    if (px.first == superfinal_ && py.first != superfinal_) {
232      return less_(wy, wx) || ApproxEqual(wx, wy, delta_);
233    } else if (py.first == superfinal_ && px.first != superfinal_) {
234      return less_(wy, wx) && !ApproxEqual(wx, wy, delta_);
235    } else {
236      return less_(wy, wx);
237    }
238  }
239
240 private:
241  const vector<Pair> &pairs_;
242  const vector<Weight> &distance_;
243  StateId superfinal_;
244  float delta_;
245  NaturalLess<Weight> less_;
246};
247
248
249// N-Shortest-path algorithm: implements the core n-shortest path
250// algorithm. The output is built REVERSED. See below for versions with
251// more options and not reversed.
252//
253// 'ofst' contains the REVERSE of 'n'-shortest paths in 'ifst'.
254// 'distance' must contain the shortest distance from each state to a final
255// state in 'ifst'. 'delta' is the convergence delta.
256//
257// The n-shortest paths are the n-lowest weight paths w.r.t. the
258// natural semiring order. The single path that can be read from the
259// ith of at most n transitions leaving the initial state of 'ofst' is
260// the ith shortest path. Disregarding the initial state and initial
261// transitions, the n-shortest paths, in fact, form a tree rooted at
262// the single final state.
263//
264// The weights need to be left and right distributive (kSemiring) and
265// have the path (kPath) property.
266//
267// The algorithm is from Mohri and Riley, "An Efficient Algorithm for
268// the n-best-strings problem", ICSLP 2002. The algorithm relies on
269// the shortest-distance algorithm. There are some issues with the
270// pseudo-code as written in the paper (viz., line 11).
271//
272// IMPLEMENTATION NOTE: The input fst 'ifst' can be a delayed fst and
273// and at any state in its expansion the values of distance vector need only
274// be defined at that time for the states that are known to exist.
275template<class Arc, class RevArc>
276void NShortestPath(const Fst<RevArc> &ifst,
277                   MutableFst<Arc> *ofst,
278                   const vector<typename Arc::Weight> &distance,
279                   size_t n,
280                   float delta = kDelta,
281                   typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
282                   typename Arc::StateId state_threshold = kNoStateId) {
283  typedef typename Arc::StateId StateId;
284  typedef typename Arc::Weight Weight;
285  typedef pair<StateId, Weight> Pair;
286  typedef typename RevArc::Weight RevWeight;
287
288  if (n <= 0) return;
289  if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) {
290    FSTERROR() << "NShortestPath: Weight needs to have the "
291                 << "path property and be distributive: "
292                 << Weight::Type();
293    ofst->SetProperties(kError, kError);
294    return;
295  }
296  ofst->DeleteStates();
297  ofst->SetInputSymbols(ifst.InputSymbols());
298  ofst->SetOutputSymbols(ifst.OutputSymbols());
299  // Each state in 'ofst' corresponds to a path with weight w from the
300  // initial state of 'ifst' to a state s in 'ifst', that can be
301  // characterized by a pair (s,w).  The vector 'pairs' maps each
302  // state in 'ofst' to the corresponding pair maps states in OFST to
303  // the corresponding pair (s,w).
304  vector<Pair> pairs;
305  // The supefinal state is denoted by -1, 'compare' knows that the
306  // distance from 'superfinal' to the final state is 'Weight::One()',
307  // hence 'distance[superfinal]' is not needed.
308  StateId superfinal = -1;
309  ShortestPathCompare<StateId, Weight>
310    compare(pairs, distance, superfinal, delta);
311  vector<StateId> heap;
312  // 'r[s + 1]', 's' state in 'fst', is the number of states in 'ofst'
313  // which corresponding pair contains 's' ,i.e. , it is number of
314  // paths computed so far to 's'. Valid for 's == -1' (superfinal).
315  vector<int> r;
316  NaturalLess<Weight> less;
317  if (ifst.Start() == kNoStateId ||
318      distance.size() <= ifst.Start() ||
319      distance[ifst.Start()] == Weight::Zero() ||
320      less(weight_threshold, Weight::One()) ||
321      state_threshold == 0) {
322    if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
323    return;
324  }
325  ofst->SetStart(ofst->AddState());
326  StateId final = ofst->AddState();
327  ofst->SetFinal(final, Weight::One());
328  while (pairs.size() <= final)
329    pairs.push_back(Pair(kNoStateId, Weight::Zero()));
330  pairs[final] = Pair(ifst.Start(), Weight::One());
331  heap.push_back(final);
332  Weight limit = Times(distance[ifst.Start()], weight_threshold);
333
334  while (!heap.empty()) {
335    pop_heap(heap.begin(), heap.end(), compare);
336    StateId state = heap.back();
337    Pair p = pairs[state];
338    heap.pop_back();
339    Weight d = p.first == superfinal ? Weight::One() :
340        p.first < distance.size() ? distance[p.first] : Weight::Zero();
341
342    if (less(limit, Times(d, p.second)) ||
343        (state_threshold != kNoStateId &&
344         ofst->NumStates() >= state_threshold))
345      continue;
346
347    while (r.size() <= p.first + 1) r.push_back(0);
348    ++r[p.first + 1];
349    if (p.first == superfinal)
350      ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state));
351    if ((p.first == superfinal) && (r[p.first + 1] == n)) break;
352    if (r[p.first + 1] > n) continue;
353    if (p.first == superfinal) continue;
354
355    for (ArcIterator< Fst<RevArc> > aiter(ifst, p.first);
356         !aiter.Done();
357         aiter.Next()) {
358      const RevArc &rarc = aiter.Value();
359      Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate);
360      Weight w = Times(p.second, arc.weight);
361      StateId next = ofst->AddState();
362      pairs.push_back(Pair(arc.nextstate, w));
363      arc.nextstate = state;
364      ofst->AddArc(next, arc);
365      heap.push_back(next);
366      push_heap(heap.begin(), heap.end(), compare);
367    }
368
369    Weight finalw = ifst.Final(p.first).Reverse();
370    if (finalw != Weight::Zero()) {
371      Weight w = Times(p.second, finalw);
372      StateId next = ofst->AddState();
373      pairs.push_back(Pair(superfinal, w));
374      ofst->AddArc(next, Arc(0, 0, finalw, state));
375      heap.push_back(next);
376      push_heap(heap.begin(), heap.end(), compare);
377    }
378  }
379  Connect(ofst);
380  if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError);
381  ofst->SetProperties(
382      ShortestPathProperties(ofst->Properties(kFstProperties, false)),
383      kFstProperties);
384}
385
386
387// N-Shortest-path algorithm:  this version allow fine control
388// via the options argument. See below for a simpler interface.
389//
390// 'ofst' contains the n-shortest paths in 'ifst'. 'distance' returns
391// the shortest distances from the source state to each state in
392// 'ifst'. 'opts' is used to specify options such as the number of
393// paths to return, whether they need to have distinct input
394// strings, the queue discipline, the arc filter and the convergence
395// delta.
396//
397// The n-shortest paths are the n-lowest weight paths w.r.t. the
398// natural semiring order. The single path that can be read from the
399// ith of at most n transitions leaving the initial state of 'ofst' is
400// the ith shortest path. Disregarding the initial state and initial
401// transitions, The n-shortest paths, in fact, form a tree rooted at
402// the single final state.
403
404// The weights need to be right distributive and have the path (kPath)
405// property. They need to be left distributive as well for nshortest
406// > 1.
407//
408// The algorithm is from Mohri and Riley, "An Efficient Algorithm for
409// the n-best-strings problem", ICSLP 2002. The algorithm relies on
410// the shortest-distance algorithm. There are some issues with the
411// pseudo-code as written in the paper (viz., line 11).
412template<class Arc, class Queue, class ArcFilter>
413void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
414                  vector<typename Arc::Weight> *distance,
415                  ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
416  typedef typename Arc::StateId StateId;
417  typedef typename Arc::Weight Weight;
418  typedef ReverseArc<Arc> ReverseArc;
419
420  size_t n = opts.nshortest;
421  if (n == 1) {
422    SingleShortestPath(ifst, ofst, distance, opts);
423    return;
424  }
425  if (n <= 0) return;
426  if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) {
427    FSTERROR() << "ShortestPath: n-shortest: Weight needs to have the "
428               << "path property and be distributive: "
429               << Weight::Type();
430    ofst->SetProperties(kError, kError);
431    return;
432  }
433  if (!opts.has_distance) {
434    ShortestDistance(ifst, distance, opts);
435    if (distance->size() == 1 && !(*distance)[0].Member()) {
436      ofst->SetProperties(kError, kError);
437      return;
438    }
439  }
440  // Algorithm works on the reverse of 'fst' : 'rfst', 'distance' is
441  // the distance to the final state in 'rfst', 'ofst' is built as the
442  // reverse of the tree of n-shortest path in 'rfst'.
443  VectorFst<ReverseArc> rfst;
444  Reverse(ifst, &rfst);
445  Weight d = Weight::Zero();
446  for (ArcIterator< VectorFst<ReverseArc> > aiter(rfst, 0);
447       !aiter.Done(); aiter.Next()) {
448    const ReverseArc &arc = aiter.Value();
449    StateId s = arc.nextstate - 1;
450    if (s < distance->size())
451      d = Plus(d, Times(arc.weight.Reverse(), (*distance)[s]));
452  }
453  distance->insert(distance->begin(), d);
454
455  if (!opts.unique) {
456    NShortestPath(rfst, ofst, *distance, n, opts.delta,
457                  opts.weight_threshold, opts.state_threshold);
458  } else {
459    vector<Weight> ddistance;
460    DeterminizeFstOptions<ReverseArc> dopts(opts.delta);
461    DeterminizeFst<ReverseArc> dfst(rfst, distance, &ddistance, dopts);
462    NShortestPath(dfst, ofst, ddistance, n, opts.delta,
463                  opts.weight_threshold, opts.state_threshold);
464  }
465  distance->erase(distance->begin());
466}
467
468
469// Shortest-path algorithm: simplified interface. See above for a
470// version that allows finer control.
471//
472// 'ofst' contains the 'n'-shortest paths in 'ifst'. The queue
473// discipline is automatically selected. When 'unique' == true, only
474// paths with distinct input labels are returned.
475//
476// The n-shortest paths are the n-lowest weight paths w.r.t. the
477// natural semiring order. The single path that can be read from the
478// ith of at most n transitions leaving the initial state of 'ofst' is
479// the ith best path.
480//
481// The weights need to be right distributive and have the path
482// (kPath) property.
483template<class Arc>
484void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
485                  size_t n = 1, bool unique = false,
486                  bool first_path = false,
487                  typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
488                  typename Arc::StateId state_threshold = kNoStateId) {
489  vector<typename Arc::Weight> distance;
490  AnyArcFilter<Arc> arc_filter;
491  AutoQueue<typename Arc::StateId> state_queue(ifst, &distance, arc_filter);
492  ShortestPathOptions< Arc, AutoQueue<typename Arc::StateId>,
493      AnyArcFilter<Arc> > opts(&state_queue, arc_filter, n, unique, false,
494                               kDelta, first_path, weight_threshold,
495                               state_threshold);
496  ShortestPath(ifst, ofst, &distance, opts);
497}
498
499}  // namespace fst
500
501#endif  // FST_LIB_SHORTEST_PATH_H__
502