1// shortest-distance.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16//
17// \file
18// Functions and classes to find shortest distance in an FST.
19
20#ifndef FST_LIB_SHORTEST_DISTANCE_H__
21#define FST_LIB_SHORTEST_DISTANCE_H__
22
23#include <deque>
24
25#include "fst/lib/arcfilter.h"
26#include "fst/lib/cache.h"
27#include "fst/lib/queue.h"
28#include "fst/lib/reverse.h"
29#include "fst/lib/test-properties.h"
30
31namespace fst {
32
33template <class Arc, class Queue, class ArcFilter>
34struct ShortestDistanceOptions {
35  typedef typename Arc::StateId StateId;
36
37  Queue *state_queue;    // Queue discipline used; owned by caller
38  ArcFilter arc_filter;  // Arc filter (e.g., limit to only epsilon graph)
39  StateId source;        // If kNoStateId, use the Fst's initial state
40  float delta;           // Determines the degree of convergence required
41
42  ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src = kNoStateId,
43                          float d = kDelta)
44      : state_queue(q), arc_filter(filt), source(src), delta(d) {}
45};
46
47
48// Computation state of the shortest-distance algorithm. Reusable
49// information is maintained across calls to member function
50// ShortestDistance(source) when 'retain' is true for improved
51// efficiency when calling multiple times from different source states
52// (e.g., in epsilon removal). Vector 'distance' should not be
53// modified by the user between these calls.
54template<class Arc, class Queue, class ArcFilter>
55class ShortestDistanceState {
56 public:
57  typedef typename Arc::StateId StateId;
58  typedef typename Arc::Weight Weight;
59
60  ShortestDistanceState(
61      const Fst<Arc> &fst,
62      vector<Weight> *distance,
63      const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts,
64      bool retain)
65      : fst_(fst.Copy()), distance_(distance), state_queue_(opts.state_queue),
66        arc_filter_(opts.arc_filter),
67        delta_(opts.delta), retain_(retain) {
68    distance_->clear();
69  }
70
71  ~ShortestDistanceState() {
72    delete fst_;
73  }
74
75  void ShortestDistance(StateId source);
76
77 private:
78  const Fst<Arc> *fst_;
79  vector<Weight> *distance_;
80  Queue *state_queue_;
81  ArcFilter arc_filter_;
82  float delta_;
83  bool retain_;                  // Retain and reuse information across calls
84
85  vector<Weight> rdistance_;    // Relaxation distance.
86  vector<bool> enqueued_;       // Is state enqueued?
87  vector<StateId> sources_;     // Source state for ith state in 'distance_',
88                                //  'rdistance_', and 'enqueued_' if retained.
89};
90
91// Compute the shortest distance. If 'source' is kNoStateId, use
92// the initial state of the Fst.
93template <class Arc, class Queue, class ArcFilter>
94void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance(
95    StateId source) {
96  if (fst_->Start() == kNoStateId)
97    return;
98
99  if (!(Weight::Properties() & kRightSemiring))
100    LOG(FATAL) << "ShortestDistance: Weight needs to be right distributive: "
101               << Weight::Type();
102
103  state_queue_->Clear();
104
105  if (!retain_) {
106    distance_->clear();
107    rdistance_.clear();
108    enqueued_.clear();
109  }
110
111  if (source == kNoStateId)
112    source = fst_->Start();
113
114  while ((StateId)distance_->size() <= source) {
115    distance_->push_back(Weight::Zero());
116    rdistance_.push_back(Weight::Zero());
117    enqueued_.push_back(false);
118  }
119  if (retain_) {
120    while ((StateId)sources_.size() <= source)
121      sources_.push_back(kNoStateId);
122    sources_[source] = source;
123  }
124  (*distance_)[source] = Weight::One();
125  rdistance_[source] = Weight::One();
126  enqueued_[source] = true;
127
128  state_queue_->Enqueue(source);
129
130  while (!state_queue_->Empty()) {
131    StateId s = state_queue_->Head();
132    state_queue_->Dequeue();
133    while ((StateId)distance_->size() <= s) {
134      distance_->push_back(Weight::Zero());
135      rdistance_.push_back(Weight::Zero());
136      enqueued_.push_back(false);
137    }
138    enqueued_[s] = false;
139    Weight r = rdistance_[s];
140    rdistance_[s] = Weight::Zero();
141    for (ArcIterator< Fst<Arc> > aiter(*fst_, s);
142         !aiter.Done();
143         aiter.Next()) {
144      const Arc &arc = aiter.Value();
145      if (!arc_filter_(arc) || arc.weight == Weight::Zero())
146        continue;
147      while ((StateId)distance_->size() <= arc.nextstate) {
148        distance_->push_back(Weight::Zero());
149        rdistance_.push_back(Weight::Zero());
150        enqueued_.push_back(false);
151      }
152      if (retain_) {
153        while ((StateId)sources_.size() <= arc.nextstate)
154          sources_.push_back(kNoStateId);
155        if (sources_[arc.nextstate] != source) {
156          (*distance_)[arc.nextstate] = Weight::Zero();
157          rdistance_[arc.nextstate] = Weight::Zero();
158          enqueued_[arc.nextstate] = false;
159          sources_[arc.nextstate] = source;
160        }
161      }
162      Weight &nd = (*distance_)[arc.nextstate];
163      Weight &nr = rdistance_[arc.nextstate];
164      Weight w = Times(r, arc.weight);
165      if (!ApproxEqual(nd, Plus(nd, w), delta_)) {
166        nd = Plus(nd, w);
167        nr = Plus(nr, w);
168        if (!enqueued_[arc.nextstate]) {
169          state_queue_->Enqueue(arc.nextstate);
170          enqueued_[arc.nextstate] = true;
171        } else {
172          state_queue_->Update(arc.nextstate);
173        }
174      }
175    }
176  }
177}
178
179
180// Shortest-distance algorithm: this version allows fine control
181// via the options argument. See below for a simpler interface.
182//
183// This computes the shortest distance from the 'opts.source' state to
184// each visited state S and stores the value in the 'distance' vector.
185// An unvisited state S has distance Zero(), which will be stored in
186// the 'distance' vector if S is less than the maximum visited state.
187// The state queue discipline, arc filter, and convergence delta are
188// taken in the options argument.
189
190// The weights must must be right distributive and k-closed (i.e., 1 +
191// x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
192//
193// The algorithm is from Mohri, "Semiring Framweork and Algorithms for
194// Shortest-Distance Problems", Journal of Automata, Languages and
195// Combinatorics 7(3):321-350, 2002. The complexity of algorithm
196// depends on the properties of the semiring and the queue discipline
197// used. Refer to the paper for more details.
198template<class Arc, class Queue, class ArcFilter>
199void ShortestDistance(
200    const Fst<Arc> &fst,
201    vector<typename Arc::Weight> *distance,
202    const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) {
203
204  ShortestDistanceState<Arc, Queue, ArcFilter>
205    sd_state(fst, distance, opts, false);
206  sd_state.ShortestDistance(opts.source);
207}
208
209// Shortest-distance algorithm: simplified interface. See above for a
210// version that allows finer control.
211//
212// If 'reverse' is false, this computes the shortest distance from the
213// initial state to each state S and stores the value in the
214// 'distance' vector. If 'reverse' is true, this computes the shortest
215// distance from each state to the final states.  An unvisited state S
216// has distance Zero(), which will be stored in the 'distance' vector
217// if S is less than the maximum visited state.  The state queue
218// discipline is automatically-selected.
219//
220// The weights must must be right (left) distributive if reverse is
221// false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 +
222// x + x^2 + ... + x^k).
223//
224// The algorithm is from Mohri, "Semiring Framweork and Algorithms for
225// Shortest-Distance Problems", Journal of Automata, Languages and
226// Combinatorics 7(3):321-350, 2002. The complexity of algorithm
227// depends on the properties of the semiring and the queue discipline
228// used. Refer to the paper for more details.
229template<class Arc>
230void ShortestDistance(const Fst<Arc> &fst,
231                      vector<typename Arc::Weight> *distance,
232                      bool reverse = false) {
233  typedef typename Arc::StateId StateId;
234  typedef typename Arc::Weight Weight;
235
236  if (!reverse) {
237    AnyArcFilter<Arc> arc_filter;
238    AutoQueue<StateId> state_queue(fst, distance, arc_filter);
239    ShortestDistanceOptions< Arc, AutoQueue<StateId>, AnyArcFilter<Arc> >
240      opts(&state_queue, arc_filter);
241    ShortestDistance(fst, distance, opts);
242  } else {
243    typedef ReverseArc<Arc> ReverseArc;
244    typedef typename ReverseArc::Weight ReverseWeight;
245    AnyArcFilter<ReverseArc> rarc_filter;
246    VectorFst<ReverseArc> rfst;
247    Reverse(fst, &rfst);
248    vector<ReverseWeight> rdistance;
249    AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
250    ShortestDistanceOptions< ReverseArc, AutoQueue<StateId>,
251      AnyArcFilter<ReverseArc> >
252      ropts(&state_queue, rarc_filter);
253    ShortestDistance(rfst, &rdistance, ropts);
254    distance->clear();
255    while (distance->size() < rdistance.size() - 1)
256      distance->push_back(rdistance[distance->size() + 1].Reverse());
257  }
258}
259
260}  // namespace fst
261
262#endif  // FST_LIB_SHORTEST_DISTANCE_H__
263