shortest-path.h revision 8fc5a7f51e62cb4ae44a27bdf4176d04adc80ede
1// shortest-path.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16//
17// \file
18// Functions to find shortest paths in an FST.
19
20#ifndef FST_LIB_SHORTEST_PATH_H__
21#define FST_LIB_SHORTEST_PATH_H__
22
23#include <functional>
24
25#include "fst/lib/cache.h"
26#include "fst/lib/queue.h"
27#include "fst/lib/shortest-distance.h"
28#include "fst/lib/test-properties.h"
29
30namespace fst {
31
32template <class Arc, class Queue, class ArcFilter>
33struct ShortestPathOptions
34    : public ShortestDistanceOptions<Arc, Queue, ArcFilter> {
35  typedef typename Arc::StateId StateId;
36
37  size_t nshortest;      // return n-shortest paths
38  bool unique;           // only return paths with distinct input strings
39  bool has_distance;     // distance vector already contains the
40                         // shortest distance from the initial state
41
42  ShortestPathOptions(Queue *q, ArcFilter filt, size_t n = 1, bool u = false,
43                      bool hasdist = false, float d = kDelta)
44      : ShortestDistanceOptions<Arc, Queue, ArcFilter>(q, filt, kNoStateId, d),
45        nshortest(n), unique(u), has_distance(hasdist)  {}
46};
47
48
49// Shortest-path algorithm: normally not called directly; prefer
50// 'ShortestPath' below with n=1. 'ofst' contains the shortest path in
51// 'ifst'. 'distance' returns the shortest distances from the source
52// state to each state in 'ifst'. 'opts' is used to specify options
53// such as the queue discipline, the arc filter and delta.
54//
55// The shortest path is the lowest weight path w.r.t. the natural
56// semiring order.
57//
58// The weights need to be right distributive and have the path (kPath)
59// property.
60template<class Arc, class Queue, class ArcFilter>
61void SingleShortestPath(const Fst<Arc> &ifst,
62                  MutableFst<Arc> *ofst,
63                  vector<typename Arc::Weight> *distance,
64                  ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
65  typedef typename Arc::StateId StateId;
66  typedef typename Arc::Weight Weight;
67
68  ofst->DeleteStates();
69  ofst->SetInputSymbols(ifst.InputSymbols());
70  ofst->SetOutputSymbols(ifst.OutputSymbols());
71
72  if (ifst.Start() == kNoStateId)
73    return;
74
75  vector<Weight> rdistance;
76  vector<bool> enqueued;
77  vector<StateId> parent;
78  vector<Arc> arc_parent;
79
80  Queue *state_queue = opts.state_queue;
81  StateId source = opts.source == kNoStateId ? ifst.Start() : opts.source;
82  Weight f_distance = Weight::Zero();
83  StateId f_parent = kNoStateId;
84
85  distance->clear();
86  state_queue->Clear();
87  if (opts.nshortest != 1)
88    LOG(FATAL) << "SingleShortestPath: for nshortest > 1, use ShortestPath"
89               << " instead";
90  if ((Weight::Properties() & (kPath | kRightSemiring))
91       != (kPath | kRightSemiring))
92      LOG(FATAL) << "SingleShortestPath: Weight needs to have the path"
93                 << " property and be right distributive: " << Weight::Type();
94
95  while (distance->size() < source) {
96    distance->push_back(Weight::Zero());
97    enqueued.push_back(false);
98    parent.push_back(kNoStateId);
99    arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
100  }
101  distance->push_back(Weight::One());
102  parent.push_back(kNoStateId);
103  arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId));
104  state_queue->Enqueue(source);
105  enqueued.push_back(true);
106
107  while (!state_queue->Empty()) {
108    StateId s = state_queue->Head();
109    state_queue->Dequeue();
110    enqueued[s] = false;
111    Weight sd = (*distance)[s];
112    for (ArcIterator< Fst<Arc> > aiter(ifst, s);
113         !aiter.Done();
114         aiter.Next()) {
115      const Arc &arc = aiter.Value();
116      while (distance->size() <= arc.nextstate) {
117        distance->push_back(Weight::Zero());
118        enqueued.push_back(false);
119        parent.push_back(kNoStateId);
120        arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(),
121                                 kNoStateId));
122      }
123      Weight &nd = (*distance)[arc.nextstate];
124      Weight w = Times(sd, arc.weight);
125      if (nd != Plus(nd, w)) {
126        nd = Plus(nd, w);
127        parent[arc.nextstate] = s;
128        arc_parent[arc.nextstate] = arc;
129        if (!enqueued[arc.nextstate]) {
130          state_queue->Enqueue(arc.nextstate);
131          enqueued[arc.nextstate] = true;
132        } else {
133          state_queue->Update(arc.nextstate);
134        }
135      }
136    }
137    if (ifst.Final(s) != Weight::Zero()) {
138      Weight w = Times(sd, ifst.Final(s));
139      if (f_distance != Plus(f_distance, w)) {
140        f_distance = Plus(f_distance, w);
141        f_parent = s;
142      }
143    }
144  }
145  (*distance)[source] = Weight::One();
146  parent[source] = kNoStateId;
147
148  StateId s_p = kNoStateId, d_p = kNoStateId;
149  for (StateId s = f_parent, d = kNoStateId;
150       s != kNoStateId;
151       d = s, s = parent[s]) {
152    enqueued[s] = true;
153    d_p = s_p;
154    s_p = ofst->AddState();
155    if (d == kNoStateId) {
156      ofst->SetFinal(s_p, ifst.Final(f_parent));
157    } else {
158      arc_parent[d].nextstate = d_p;
159      ofst->AddArc(s_p, arc_parent[d]);
160    }
161  }
162  ofst->SetStart(s_p);
163}
164
165
166template <class S, class W>
167class ShortestPathCompare {
168 public:
169  typedef S StateId;
170  typedef W Weight;
171  typedef pair<StateId, Weight> Pair;
172
173  ShortestPathCompare(const vector<Pair>& pairs,
174                      const vector<Weight>& distance,
175                      StateId sfinal, float d)
176      : pairs_(pairs), distance_(distance), superfinal_(sfinal), delta_(d)  {}
177
178  bool operator()(const StateId x, const StateId y) const {
179    const Pair &px = pairs_[x];
180    const Pair &py = pairs_[y];
181    Weight wx = Times(distance_[px.first], px.second);
182    Weight wy = Times(distance_[py.first], py.second);
183    // Penalize complete paths to ensure correct results with inexact weights.
184    // This forms a strict weak order so long as ApproxEqual(a, b) =>
185    // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b).
186    if (px.first == superfinal_ && py.first != superfinal_) {
187      return less_(wy, wx) || ApproxEqual(wx, wy, delta_);
188    } else if (py.first == superfinal_ && px.first != superfinal_) {
189      return less_(wy, wx) && !ApproxEqual(wx, wy, delta_);
190    } else {
191      return less_(wy, wx);
192    }
193  }
194
195 private:
196  const vector<Pair> &pairs_;
197  const vector<Weight> &distance_;
198  StateId superfinal_;
199  float delta_;
200  NaturalLess<Weight> less_;
201};
202
203
204// N-Shortest-path algorithm:  this version allow fine control
205// via the otpions argument. See below for a simpler interface.
206//
207// 'ofst' contains the n-shortest paths in 'ifst'. 'distance' returns
208// the shortest distances from the source state to each state in
209// 'ifst'. 'opts' is used to specify options such as the number of
210// paths to return, whether they need to have distinct input
211// strings, the queue discipline, the arc filter and the convergence
212// delta.
213//
214// The n-shortest paths are the n-lowest weight paths w.r.t. the
215// natural semiring order. The single path that can be
216// read from the ith of at most n transitions leaving the initial
217// state of 'ofst' is the ith shortest path.
218
219// The weights need to be right distributive and have the path (kPath)
220// property. They need to be left distributive as well for nshortest
221// > 1.
222//
223// The algorithm is from Mohri and Riley, "An Efficient Algorithm for
224// the n-best-strings problem", ICSLP 2002. The algorithm relies on
225// the shortest-distance algorithm. There are some issues with the
226// pseudo-code as written in the paper (viz., line 11).
227template<class Arc, class Queue, class ArcFilter>
228void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
229                  vector<typename Arc::Weight> *distance,
230                  ShortestPathOptions<Arc, Queue, ArcFilter> &opts) {
231  typedef typename Arc::StateId StateId;
232  typedef typename Arc::Weight Weight;
233  typedef pair<StateId, Weight> Pair;
234  typedef ReverseArc<Arc> ReverseArc;
235  typedef typename ReverseArc::Weight ReverseWeight;
236
237  size_t n = opts.nshortest;
238
239  if (n == 1) {
240    SingleShortestPath(ifst, ofst, distance, opts);
241    return;
242  }
243  ofst->DeleteStates();
244  ofst->SetInputSymbols(ifst.InputSymbols());
245  ofst->SetOutputSymbols(ifst.OutputSymbols());
246  if (n <= 0) return;
247  if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring))
248    LOG(FATAL) << "ShortestPath: n-shortest: Weight needs to have the "
249                 << "path property and be distributive: "
250                 << Weight::Type();
251  if (opts.unique)
252    LOG(FATAL) << "ShortestPath: n-shortest-string algorithm not "
253               << "currently implemented";
254
255  // Algorithm works on the reverse of 'fst' : 'rfst' 'distance' is
256  // the distance to the final state in 'rfst' 'ofst' is built as the
257  // reverse of the tree of n-shortest path in 'rfst'.
258
259  if (!opts.has_distance)
260    ShortestDistance(ifst, distance, opts);
261  VectorFst<ReverseArc> rfst;
262  Reverse(ifst, &rfst);
263  distance->insert(distance->begin(), Weight::One());
264  while (distance->size() < rfst.NumStates())
265    distance->push_back(Weight::Zero());
266
267
268  // Each state in 'ofst' corresponds to a path with weight w from the
269  // initial state of 'rfst' to a state s in 'rfst', that can be
270  // characterized by a pair (s,w).  The vector 'pairs' maps each
271  // state in 'ofst' to the corresponding pair maps states in OFST to
272  // the corresponding pair (s,w).
273  vector<Pair> pairs;
274  // 'r[s]', 's' state in 'fst', is the number of states in 'ofst'
275  // which corresponding pair contains 's' ,i.e. , it is number of
276  // paths computed so far to 's'.
277  StateId superfinal = distance->size();  // superfinal must be handled
278  distance->push_back(Weight::One());     // differently when unique=true
279  ShortestPathCompare<StateId, Weight>
280    compare(pairs, *distance, superfinal, opts.delta);
281  vector<StateId> heap;
282  vector<int> r;
283  while (r.size() < distance->size())
284    r.push_back(0);
285  ofst->SetStart(ofst->AddState());
286  StateId final = ofst->AddState();
287  ofst->SetFinal(final, Weight::One());
288  while (pairs.size() <= final)
289    pairs.push_back(Pair(kNoStateId, Weight::Zero()));
290  pairs[final] = Pair(rfst.Start(), Weight::One());
291  heap.push_back(final);
292
293  while (!heap.empty()) {
294    pop_heap(heap.begin(), heap.end(), compare);
295    StateId state = heap.back();
296    Pair p = pairs[state];
297    heap.pop_back();
298
299    ++r[p.first];
300    if (p.first == superfinal)
301      ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state));
302    if ((p.first == superfinal) &&  (r[p.first] == n)) break;
303    if (r[p.first] > n) continue;
304    if (p.first == superfinal)
305      continue;
306
307    for (ArcIterator< Fst<ReverseArc> > aiter(rfst, p.first);
308         !aiter.Done();
309         aiter.Next()) {
310      const ReverseArc &rarc = aiter.Value();
311      Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate);
312      Weight w = Times(p.second, arc.weight);
313      StateId next = ofst->AddState();
314      pairs.push_back(Pair(arc.nextstate, w));
315      arc.nextstate = state;
316      ofst->AddArc(next, arc);
317      heap.push_back(next);
318      push_heap(heap.begin(), heap.end(), compare);
319    }
320
321    Weight finalw = rfst.Final(p.first).Reverse();
322    if (finalw != Weight::Zero()) {
323      Weight w = Times(p.second, finalw);
324      StateId next = ofst->AddState();
325      pairs.push_back(Pair(superfinal, w));
326      ofst->AddArc(next, Arc(0, 0, finalw, state));
327      heap.push_back(next);
328      push_heap(heap.begin(), heap.end(), compare);
329    }
330  }
331  Connect(ofst);
332  distance->erase(distance->begin());
333  distance->pop_back();
334}
335
336// Shortest-path algorithm: simplified interface. See above for a
337// version that allows finer control.
338
339// 'ofst' contains the 'n'-shortest paths in 'ifst'. The queue
340// discipline is automatically selected. When 'unique' == true, only
341// paths with distinct input labels are returned.
342//
343// The n-shortest paths are the n-lowest weight paths w.r.t. the
344// natural semiring order. The single path that can be read from the
345// ith of at most n transitions leaving the initial state of 'ofst' is
346// the ith best path.
347//
348// The weights need to be right distributive and have the path
349// (kPath) property.
350template<class Arc>
351void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
352                  size_t n = 1, bool unique = false) {
353  vector<typename Arc::Weight> distance;
354  AnyArcFilter<Arc> arc_filter;
355  AutoQueue<typename Arc::StateId> state_queue(ifst, &distance, arc_filter);
356  ShortestPathOptions< Arc, AutoQueue<typename Arc::StateId>,
357    AnyArcFilter<Arc> > opts(&state_queue, arc_filter, n, unique);
358  ShortestPath(ifst, ofst, &distance, opts);
359}
360
361}  // namespace fst
362
363#endif  // FST_LIB_SHORTEST_PATH_H__
364