1// shortest-path.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Functions to find shortest paths in a PDT.
20
21#ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
22#define FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
23
24#include <fst/shortest-path.h>
25#include <fst/extensions/pdt/paren.h>
26#include <fst/extensions/pdt/pdt.h>
27
28#include <tr1/unordered_map>
29using std::tr1::unordered_map;
30using std::tr1::unordered_multimap;
31#include <tr1/unordered_set>
32using std::tr1::unordered_set;
33using std::tr1::unordered_multiset;
34#include <stack>
35#include <vector>
36using std::vector;
37
38namespace fst {
39
40template <class Arc, class Queue>
41struct PdtShortestPathOptions {
42  bool keep_parentheses;
43  bool path_gc;
44
45  PdtShortestPathOptions(bool kp = false, bool gc = true)
46      : keep_parentheses(kp), path_gc(gc) {}
47};
48
49
50// Class to store PDT shortest path results. Stores shortest path
51// tree info 'Distance()', Parent(), and ArcParent() information keyed
52// on two types:
53// (1) By SearchState: This is a usual node in a shortest path tree but:
54//    (a) is w.r.t a PDT search state - a pair of a PDT state and
55//        a 'start' state, which is either the PDT start state or
56//        the destination state of an open parenthesis.
57//    (b) the Distance() is from this 'start' state to the search state.
58//    (c) Parent().state is kNoLabel for the 'start' state.
59//
60// (2) By ParenSpec: This connects shortest path trees depending on the
61// the parenthesis taken. Given the parenthesis spec:
62//    (a) the Distance() is from the Parent() 'start' state to the
63//     parenthesis destination state.
64//    (b) the ArcParent() is the parenthesis arc.
65template <class Arc>
66class PdtShortestPathData {
67 public:
68  static const uint8 kFinal;
69
70  typedef typename Arc::StateId StateId;
71  typedef typename Arc::Weight Weight;
72  typedef typename Arc::Label Label;
73
74  struct SearchState {
75    SearchState() : state(kNoStateId), start(kNoStateId) {}
76
77    SearchState(StateId s, StateId t) : state(s), start(t) {}
78
79    bool operator==(const SearchState &s) const {
80      if (&s == this)
81        return true;
82      return s.state == this->state && s.start == this->start;
83    }
84
85    StateId state;  // PDT state
86    StateId start;  // PDT paren 'source' state
87  };
88
89
90  // Specifies paren id, source and dest 'start' states of a paren.
91  // These are the 'start' states of the respective sub-graphs.
92  struct ParenSpec {
93    ParenSpec()
94        : paren_id(kNoLabel), src_start(kNoStateId), dest_start(kNoStateId) {}
95
96    ParenSpec(Label id, StateId s, StateId d)
97        : paren_id(id), src_start(s), dest_start(d) {}
98
99    Label paren_id;        // Id of parenthesis
100    StateId src_start;     // sub-graph 'start' state for paren source.
101    StateId dest_start;    // sub-graph 'start' state for paren dest.
102
103    bool operator==(const ParenSpec &x) const {
104      if (&x == this)
105        return true;
106      return x.paren_id == this->paren_id &&
107          x.src_start == this->src_start &&
108          x.dest_start == this->dest_start;
109    }
110  };
111
112  struct SearchData {
113    SearchData() : distance(Weight::Zero()),
114                   parent(kNoStateId, kNoStateId),
115                   paren_id(kNoLabel),
116                   flags(0) {}
117
118    Weight distance;     // Distance to this state from PDT 'start' state
119    SearchState parent;  // Parent state in shortest path tree
120    int16 paren_id;      // If parent arc has paren, paren ID, o.w. kNoLabel
121    uint8 flags;         // First byte reserved for PdtShortestPathData use
122  };
123
124  PdtShortestPathData(bool gc)
125      : state_(kNoStateId, kNoStateId),
126        paren_(kNoLabel, kNoStateId, kNoStateId),
127        gc_(gc),
128        nstates_(0),
129        ngc_(0),
130        finished_(false) {}
131
132  ~PdtShortestPathData() {
133    VLOG(1) << "opm size: " << paren_map_.size();
134    VLOG(1) << "# of search states: " << nstates_;
135    if (gc_)
136      VLOG(1) << "# of GC'd search states: " << ngc_;
137  }
138
139  void Clear() {
140    search_map_.clear();
141    search_multimap_.clear();
142    paren_map_.clear();
143    state_ = SearchState(kNoStateId, kNoStateId);
144    nstates_ = 0;
145    ngc_ = 0;
146  }
147
148  Weight Distance(SearchState s) const {
149    SearchData *data = GetSearchData(s);
150    return data->distance;
151  }
152
153  Weight Distance(const ParenSpec &paren) const {
154    SearchData *data = GetSearchData(paren);
155    return data->distance;
156  }
157
158  SearchState Parent(SearchState s) const {
159    SearchData *data = GetSearchData(s);
160    return data->parent;
161  }
162
163  SearchState Parent(const ParenSpec &paren) const {
164    SearchData *data = GetSearchData(paren);
165    return data->parent;
166  }
167
168  Label ParenId(SearchState s) const {
169    SearchData *data = GetSearchData(s);
170    return data->paren_id;
171  }
172
173  uint8 Flags(SearchState s) const {
174    SearchData *data = GetSearchData(s);
175    return data->flags;
176  }
177
178  void SetDistance(SearchState s, Weight w) {
179    SearchData *data = GetSearchData(s);
180    data->distance = w;
181  }
182
183  void SetDistance(const ParenSpec &paren, Weight w) {
184    SearchData *data = GetSearchData(paren);
185    data->distance = w;
186  }
187
188  void SetParent(SearchState s, SearchState p) {
189    SearchData *data = GetSearchData(s);
190    data->parent = p;
191  }
192
193  void SetParent(const ParenSpec &paren, SearchState p) {
194    SearchData *data = GetSearchData(paren);
195    data->parent = p;
196  }
197
198  void SetParenId(SearchState s, Label p) {
199    if (p >= 32768)
200      FSTERROR() << "PdtShortestPathData: Paren ID does not fits in an int16";
201    SearchData *data = GetSearchData(s);
202    data->paren_id = p;
203  }
204
205  void SetFlags(SearchState s, uint8 f, uint8 mask) {
206    SearchData *data = GetSearchData(s);
207    data->flags &= ~mask;
208    data->flags |= f & mask;
209  }
210
211  void GC(StateId s);
212
213  void Finish() { finished_ = true; }
214
215 private:
216  static const Arc kNoArc;
217  static const size_t kPrime0;
218  static const size_t kPrime1;
219  static const uint8 kInited;
220  static const uint8 kMarked;
221
222  // Hash for search state
223  struct SearchStateHash {
224    size_t operator()(const SearchState &s) const {
225      return s.state + s.start * kPrime0;
226    }
227  };
228
229  // Hash for paren map
230  struct ParenHash {
231    size_t operator()(const ParenSpec &paren) const {
232      return paren.paren_id + paren.src_start * kPrime0 +
233          paren.dest_start * kPrime1;
234    }
235  };
236
237  typedef unordered_map<SearchState, SearchData, SearchStateHash> SearchMap;
238
239  typedef unordered_multimap<StateId, StateId> SearchMultimap;
240
241  // Hash map from paren spec to open paren data
242  typedef unordered_map<ParenSpec, SearchData, ParenHash> ParenMap;
243
244  SearchData *GetSearchData(SearchState s) const {
245    if (s == state_)
246      return state_data_;
247    if (finished_) {
248      typename SearchMap::iterator it = search_map_.find(s);
249      if (it == search_map_.end())
250        return &null_search_data_;
251      state_ = s;
252      return state_data_ = &(it->second);
253    } else {
254      state_ = s;
255      state_data_ = &search_map_[s];
256      if (!(state_data_->flags & kInited)) {
257        ++nstates_;
258        if (gc_)
259          search_multimap_.insert(make_pair(s.start, s.state));
260        state_data_->flags = kInited;
261      }
262      return state_data_;
263    }
264  }
265
266  SearchData *GetSearchData(ParenSpec paren) const {
267    if (paren == paren_)
268      return paren_data_;
269    if (finished_) {
270      typename ParenMap::iterator it = paren_map_.find(paren);
271      if (it == paren_map_.end())
272        return &null_search_data_;
273      paren_ = paren;
274      return state_data_ = &(it->second);
275    } else {
276      paren_ = paren;
277      return paren_data_ = &paren_map_[paren];
278    }
279  }
280
281  mutable SearchMap search_map_;            // Maps from search state to data
282  mutable SearchMultimap search_multimap_;  // Maps from 'start' to subgraph
283  mutable ParenMap paren_map_;              // Maps paren spec to search data
284  mutable SearchState state_;               // Last state accessed
285  mutable SearchData *state_data_;          // Last state data accessed
286  mutable ParenSpec paren_;                 // Last paren spec accessed
287  mutable SearchData *paren_data_;          // Last paren data accessed
288  bool gc_;                                 // Allow GC?
289  mutable size_t nstates_;                  // Total number of search states
290  size_t ngc_;                              // Number of GC'd search states
291  mutable SearchData null_search_data_;     // Null search data
292  bool finished_;                           // Read-only access when true
293
294  DISALLOW_COPY_AND_ASSIGN(PdtShortestPathData);
295};
296
297// Deletes inaccessible search data from a given 'start' (open paren dest)
298// state. Assumes 'final' (close paren source or PDT final) states have
299// been flagged 'kFinal'.
300template<class Arc>
301void  PdtShortestPathData<Arc>::GC(StateId start) {
302  if (!gc_)
303    return;
304  vector<StateId> final;
305  for (typename SearchMultimap::iterator mmit = search_multimap_.find(start);
306       mmit != search_multimap_.end() && mmit->first == start;
307       ++mmit) {
308    SearchState s(mmit->second, start);
309    const SearchData &data = search_map_[s];
310    if (data.flags & kFinal)
311      final.push_back(s.state);
312  }
313
314  // Mark phase
315  for (size_t i = 0; i < final.size(); ++i) {
316    SearchState s(final[i], start);
317    while (s.state != kNoLabel) {
318      SearchData *sdata = &search_map_[s];
319      if (sdata->flags & kMarked)
320        break;
321      sdata->flags |= kMarked;
322      SearchState p = sdata->parent;
323      if (p.start != start && p.start != kNoLabel) {  // entering sub-subgraph
324        ParenSpec paren(sdata->paren_id, s.start, p.start);
325        SearchData *pdata = &paren_map_[paren];
326        s = pdata->parent;
327      } else {
328        s = p;
329      }
330    }
331  }
332
333  // Sweep phase
334  typename SearchMultimap::iterator mmit = search_multimap_.find(start);
335  while (mmit != search_multimap_.end() && mmit->first == start) {
336    SearchState s(mmit->second, start);
337    typename SearchMap::iterator mit = search_map_.find(s);
338    const SearchData &data = mit->second;
339    if (!(data.flags & kMarked)) {
340      search_map_.erase(mit);
341      ++ngc_;
342    }
343    search_multimap_.erase(mmit++);
344  }
345}
346
347template<class Arc> const Arc PdtShortestPathData<Arc>::kNoArc
348    = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId);
349
350template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime0 = 7853;
351
352template<class Arc> const size_t PdtShortestPathData<Arc>::kPrime1 = 7867;
353
354template<class Arc> const uint8 PdtShortestPathData<Arc>::kInited = 0x01;
355
356template<class Arc> const uint8 PdtShortestPathData<Arc>::kFinal =  0x02;
357
358template<class Arc> const uint8 PdtShortestPathData<Arc>::kMarked = 0x04;
359
360
361// This computes the single source shortest (balanced) path (SSSP)
362// through a weighted PDT that has a bounded stack (i.e. is expandable
363// as an FST). It is a generalization of the classic SSSP graph
364// algorithm that removes a state s from a queue (defined by a
365// user-provided queue type) and relaxes the destination states of
366// transitions leaving s. In this PDT version, states that have
367// entering open parentheses are treated as source states for a
368// sub-graph SSSP problem with the shortest path up to the open
369// parenthesis being first saved. When a close parenthesis is then
370// encountered any balancing open parenthesis is examined for this
371// saved information and multiplied back. In this way, each sub-graph
372// is entered only once rather than repeatedly.  If every state in the
373// input PDT has the property that there is a unique 'start' state for
374// it with entering open parentheses, then this algorithm is quite
375// straight-forward. In general, this will not be the case, so the
376// algorithm (implicitly) creates a new graph where each state is a
377// pair of an original state and a possible parenthesis 'start' state
378// for that state.
379template<class Arc, class Queue>
380class PdtShortestPath {
381 public:
382  typedef typename Arc::StateId StateId;
383  typedef typename Arc::Weight Weight;
384  typedef typename Arc::Label Label;
385
386  typedef PdtShortestPathData<Arc> SpData;
387  typedef typename SpData::SearchState SearchState;
388  typedef typename SpData::ParenSpec ParenSpec;
389
390  typedef typename PdtBalanceData<Arc>::SetIterator CloseSourceIterator;
391
392  PdtShortestPath(const Fst<Arc> &ifst,
393                  const vector<pair<Label, Label> > &parens,
394                  const PdtShortestPathOptions<Arc, Queue> &opts)
395      : kFinal(SpData::kFinal),
396        ifst_(ifst.Copy()),
397        parens_(parens),
398        keep_parens_(opts.keep_parentheses),
399        start_(ifst.Start()),
400        sp_data_(opts.path_gc),
401        error_(false) {
402
403    if ((Weight::Properties() & (kPath | kRightSemiring))
404        != (kPath | kRightSemiring)) {
405      FSTERROR() << "PdtShortestPath: Weight needs to have the path"
406                 << " property and be right distributive: " << Weight::Type();
407      error_ = true;
408    }
409
410    for (Label i = 0; i < parens.size(); ++i) {
411      const pair<Label, Label>  &p = parens[i];
412      paren_id_map_[p.first] = i;
413      paren_id_map_[p.second] = i;
414    }
415  };
416
417  ~PdtShortestPath() {
418    VLOG(1) << "# of input states: " << CountStates(*ifst_);
419    VLOG(1) << "# of enqueued: " << nenqueued_;
420    VLOG(1) << "cpmm size: " << close_paren_multimap_.size();
421    delete ifst_;
422  }
423
424  void ShortestPath(MutableFst<Arc> *ofst) {
425    Init(ofst);
426    GetDistance(start_);
427    GetPath();
428    sp_data_.Finish();
429    if (error_) ofst->SetProperties(kError, kError);
430  }
431
432  const PdtShortestPathData<Arc> &GetShortestPathData() const {
433    return sp_data_;
434  }
435
436  PdtBalanceData<Arc> *GetBalanceData() { return &balance_data_; }
437
438 private:
439  static const Arc kNoArc;
440  static const uint8 kEnqueued;
441  static const uint8 kExpanded;
442  static const uint8 kFinished;
443  const uint8 kFinal;
444
445 public:
446  // Hash multimap from close paren label to an paren arc.
447  typedef unordered_multimap<ParenState<Arc>, Arc,
448                        typename ParenState<Arc>::Hash> CloseParenMultimap;
449
450  const CloseParenMultimap &GetCloseParenMultimap() const {
451    return close_paren_multimap_;
452  }
453
454 private:
455  void Init(MutableFst<Arc> *ofst);
456  void GetDistance(StateId start);
457  void ProcFinal(SearchState s);
458  void ProcArcs(SearchState s);
459  void ProcOpenParen(Label paren_id, SearchState s, Arc arc, Weight w);
460  void ProcCloseParen(Label paren_id, SearchState s, const Arc &arc, Weight w);
461  void ProcNonParen(SearchState s, const Arc &arc, Weight w);
462  void Relax(SearchState s, SearchState t, Arc arc, Weight w, Label paren_id);
463  void Enqueue(SearchState d);
464  void GetPath();
465  Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open);
466
467  Fst<Arc> *ifst_;
468  MutableFst<Arc> *ofst_;
469  const vector<pair<Label, Label> > &parens_;
470  bool keep_parens_;
471  Queue *state_queue_;                   // current state queue
472  StateId start_;
473  Weight f_distance_;
474  SearchState f_parent_;
475  SpData sp_data_;
476  unordered_map<Label, Label> paren_id_map_;
477  CloseParenMultimap close_paren_multimap_;
478  PdtBalanceData<Arc> balance_data_;
479  ssize_t nenqueued_;
480  bool error_;
481
482  DISALLOW_COPY_AND_ASSIGN(PdtShortestPath);
483};
484
485template<class Arc, class Queue>
486void PdtShortestPath<Arc, Queue>::Init(MutableFst<Arc> *ofst) {
487  ofst_ = ofst;
488  ofst->DeleteStates();
489  ofst->SetInputSymbols(ifst_->InputSymbols());
490  ofst->SetOutputSymbols(ifst_->OutputSymbols());
491
492  if (ifst_->Start() == kNoStateId)
493    return;
494
495  f_distance_ = Weight::Zero();
496  f_parent_ = SearchState(kNoStateId, kNoStateId);
497
498  sp_data_.Clear();
499  close_paren_multimap_.clear();
500  balance_data_.Clear();
501  nenqueued_ = 0;
502
503  // Find open parens per destination state and close parens per source state.
504  for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) {
505    StateId s = siter.Value();
506    for (ArcIterator<Fst<Arc> > aiter(*ifst_, s);
507         !aiter.Done(); aiter.Next()) {
508      const Arc &arc = aiter.Value();
509      typename unordered_map<Label, Label>::const_iterator pit
510          = paren_id_map_.find(arc.ilabel);
511      if (pit != paren_id_map_.end()) {               // Is a paren?
512        Label paren_id = pit->second;
513        if (arc.ilabel == parens_[paren_id].first) {  // Open paren
514          balance_data_.OpenInsert(paren_id, arc.nextstate);
515        } else {                                      // Close paren
516          ParenState<Arc> paren_state(paren_id, s);
517          close_paren_multimap_.insert(make_pair(paren_state, arc));
518        }
519      }
520    }
521  }
522}
523
524// Computes the shortest distance stored in a recursive way. Each
525// sub-graph (i.e. different paren 'start' state) begins with weight One().
526template<class Arc, class Queue>
527void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) {
528  if (start == kNoStateId)
529    return;
530
531  Queue state_queue;
532  state_queue_ = &state_queue;
533  SearchState q(start, start);
534  Enqueue(q);
535  sp_data_.SetDistance(q, Weight::One());
536
537  while (!state_queue_->Empty()) {
538    StateId state = state_queue_->Head();
539    state_queue_->Dequeue();
540    SearchState s(state, start);
541    sp_data_.SetFlags(s, 0, kEnqueued);
542    ProcFinal(s);
543    ProcArcs(s);
544    sp_data_.SetFlags(s, kExpanded, kExpanded);
545  }
546  sp_data_.SetFlags(q, kFinished, kFinished);
547  balance_data_.FinishInsert(start);
548  sp_data_.GC(start);
549}
550
551// Updates best complete path.
552template<class Arc, class Queue>
553void PdtShortestPath<Arc, Queue>::ProcFinal(SearchState s) {
554  if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) {
555    Weight w = Times(sp_data_.Distance(s),
556                     ifst_->Final(s.state));
557    if (f_distance_ != Plus(f_distance_, w)) {
558      if (f_parent_.state != kNoStateId)
559        sp_data_.SetFlags(f_parent_, 0, kFinal);
560      sp_data_.SetFlags(s, kFinal, kFinal);
561
562      f_distance_ = Plus(f_distance_, w);
563      f_parent_ = s;
564    }
565  }
566}
567
568// Processes all arcs leaving the state s.
569template<class Arc, class Queue>
570void PdtShortestPath<Arc, Queue>::ProcArcs(SearchState s) {
571  for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state);
572       !aiter.Done();
573       aiter.Next()) {
574    Arc arc = aiter.Value();
575    Weight w = Times(sp_data_.Distance(s), arc.weight);
576
577    typename unordered_map<Label, Label>::const_iterator pit
578        = paren_id_map_.find(arc.ilabel);
579    if (pit != paren_id_map_.end()) {  // Is a paren?
580      Label paren_id = pit->second;
581      if (arc.ilabel == parens_[paren_id].first)
582        ProcOpenParen(paren_id, s, arc, w);
583      else
584        ProcCloseParen(paren_id, s, arc, w);
585    } else {
586      ProcNonParen(s, arc, w);
587    }
588  }
589}
590
591// Saves the shortest path info for reaching this parenthesis
592// and starts a new SSSP in the sub-graph pointed to by the parenthesis
593// if previously unvisited. Otherwise it finds any previously encountered
594// closing parentheses and relaxes them using the recursively stored
595// shortest distance to them.
596template<class Arc, class Queue> inline
597void PdtShortestPath<Arc, Queue>::ProcOpenParen(
598    Label paren_id, SearchState s, Arc arc, Weight w) {
599
600  SearchState d(arc.nextstate, arc.nextstate);
601  ParenSpec paren(paren_id, s.start, d.start);
602  Weight pdist = sp_data_.Distance(paren);
603  if (pdist != Plus(pdist, w)) {
604    sp_data_.SetDistance(paren, w);
605    sp_data_.SetParent(paren, s);
606    Weight dist = sp_data_.Distance(d);
607    if (dist == Weight::Zero()) {
608      Queue *state_queue = state_queue_;
609      GetDistance(d.start);
610      state_queue_ = state_queue;
611    } else if (!(sp_data_.Flags(d) & kFinished)) {
612      FSTERROR() << "PdtShortestPath: open parenthesis recursion: not bounded stack";
613      error_ = true;
614    }
615
616    for (CloseSourceIterator set_iter =
617             balance_data_.Find(paren_id, arc.nextstate);
618         !set_iter.Done(); set_iter.Next()) {
619      SearchState cpstate(set_iter.Element(), d.start);
620      ParenState<Arc> paren_state(paren_id, cpstate.state);
621      for (typename CloseParenMultimap::const_iterator cpit =
622               close_paren_multimap_.find(paren_state);
623           cpit != close_paren_multimap_.end() && paren_state == cpit->first;
624           ++cpit) {
625        const Arc &cparc = cpit->second;
626        Weight cpw = Times(w, Times(sp_data_.Distance(cpstate),
627                                    cparc.weight));
628        Relax(cpstate, s, cparc, cpw, paren_id);
629      }
630    }
631  }
632}
633
634// Saves the correspondence between each closing parenthesis and its
635// balancing open parenthesis info. Relaxes any close parenthesis
636// destination state that has a balancing previously encountered open
637// parenthesis.
638template<class Arc, class Queue> inline
639void PdtShortestPath<Arc, Queue>::ProcCloseParen(
640    Label paren_id, SearchState s, const Arc &arc, Weight w) {
641  ParenState<Arc> paren_state(paren_id, s.start);
642  if (!(sp_data_.Flags(s) & kExpanded)) {
643    balance_data_.CloseInsert(paren_id, s.start, s.state);
644    sp_data_.SetFlags(s, kFinal, kFinal);
645  }
646}
647
648// For non-parentheses, classical relaxation.
649template<class Arc, class Queue> inline
650void PdtShortestPath<Arc, Queue>::ProcNonParen(
651    SearchState s, const Arc &arc, Weight w) {
652  Relax(s, s, arc, w, kNoLabel);
653}
654
655// Classical relaxation on the search graph for 'arc' from state 's'.
656// State 't' is in the same sub-graph as the nextstate should be (i.e.
657// has the same paren 'start'.
658template<class Arc, class Queue> inline
659void PdtShortestPath<Arc, Queue>::Relax(
660    SearchState s, SearchState t, Arc arc, Weight w, Label paren_id) {
661  SearchState d(arc.nextstate, t.start);
662  Weight dist = sp_data_.Distance(d);
663  if (dist != Plus(dist, w)) {
664    sp_data_.SetParent(d, s);
665    sp_data_.SetParenId(d, paren_id);
666    sp_data_.SetDistance(d, Plus(dist, w));
667    Enqueue(d);
668  }
669}
670
671template<class Arc, class Queue> inline
672void PdtShortestPath<Arc, Queue>::Enqueue(SearchState s) {
673  if (!(sp_data_.Flags(s) & kEnqueued)) {
674    state_queue_->Enqueue(s.state);
675    sp_data_.SetFlags(s, kEnqueued, kEnqueued);
676    ++nenqueued_;
677  } else {
678    state_queue_->Update(s.state);
679  }
680}
681
682// Follows parent pointers to find the shortest path. Uses a stack
683// since the shortest distance is stored recursively.
684template<class Arc, class Queue>
685void PdtShortestPath<Arc, Queue>::GetPath() {
686  SearchState s = f_parent_, d = SearchState(kNoStateId, kNoStateId);
687  StateId s_p = kNoStateId, d_p = kNoStateId;
688  Arc arc(kNoArc);
689  Label paren_id = kNoLabel;
690  stack<ParenSpec> paren_stack;
691  while (s.state != kNoStateId) {
692    d_p = s_p;
693    s_p = ofst_->AddState();
694    if (d.state == kNoStateId) {
695      ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state));
696    } else {
697      if (paren_id != kNoLabel) {                     // paren?
698        if (arc.ilabel == parens_[paren_id].first) {  // open paren
699          paren_stack.pop();
700        } else {                                      // close paren
701          ParenSpec paren(paren_id, d.start, s.start);
702          paren_stack.push(paren);
703        }
704        if (!keep_parens_)
705          arc.ilabel = arc.olabel = 0;
706      }
707      arc.nextstate = d_p;
708      ofst_->AddArc(s_p, arc);
709    }
710    d = s;
711    s = sp_data_.Parent(d);
712    paren_id = sp_data_.ParenId(d);
713    if (s.state != kNoStateId) {
714      arc = GetPathArc(s, d, paren_id, false);
715    } else if (!paren_stack.empty()) {
716      ParenSpec paren = paren_stack.top();
717      s = sp_data_.Parent(paren);
718      paren_id = paren.paren_id;
719      arc = GetPathArc(s, d, paren_id, true);
720    }
721  }
722  ofst_->SetStart(s_p);
723  ofst_->SetProperties(
724      ShortestPathProperties(ofst_->Properties(kFstProperties, false)),
725      kFstProperties);
726}
727
728
729// Finds transition with least weight between two states with label matching
730// paren_id and open/close paren type or a non-paren if kNoLabel.
731template<class Arc, class Queue>
732Arc PdtShortestPath<Arc, Queue>::GetPathArc(
733    SearchState s, SearchState d, Label paren_id, bool open_paren) {
734  Arc path_arc = kNoArc;
735  for (ArcIterator< Fst<Arc> > aiter(*ifst_, s.state);
736       !aiter.Done();
737       aiter.Next()) {
738    const Arc &arc = aiter.Value();
739    if (arc.nextstate != d.state)
740      continue;
741    Label arc_paren_id = kNoLabel;
742    typename unordered_map<Label, Label>::const_iterator pit
743        = paren_id_map_.find(arc.ilabel);
744    if (pit != paren_id_map_.end()) {
745      arc_paren_id = pit->second;
746      bool arc_open_paren = arc.ilabel == parens_[arc_paren_id].first;
747      if (arc_open_paren != open_paren)
748        continue;
749    }
750    if (arc_paren_id != paren_id)
751      continue;
752    if (arc.weight == Plus(arc.weight, path_arc.weight))
753      path_arc = arc;
754  }
755  if (path_arc.nextstate == kNoStateId) {
756    FSTERROR() << "PdtShortestPath::GetPathArc failed to find arc";
757    error_ = true;
758  }
759  return path_arc;
760}
761
762template<class Arc, class Queue>
763const Arc PdtShortestPath<Arc, Queue>::kNoArc
764    = Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId);
765
766template<class Arc, class Queue>
767const uint8 PdtShortestPath<Arc, Queue>::kEnqueued = 0x10;
768
769template<class Arc, class Queue>
770const uint8 PdtShortestPath<Arc, Queue>::kExpanded = 0x20;
771
772template<class Arc, class Queue>
773const uint8 PdtShortestPath<Arc, Queue>::kFinished = 0x40;
774
775template<class Arc, class Queue>
776void ShortestPath(const Fst<Arc> &ifst,
777                  const vector<pair<typename Arc::Label,
778                                    typename Arc::Label> > &parens,
779                  MutableFst<Arc> *ofst,
780                  const PdtShortestPathOptions<Arc, Queue> &opts) {
781  PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
782  psp.ShortestPath(ofst);
783}
784
785template<class Arc>
786void ShortestPath(const Fst<Arc> &ifst,
787                  const vector<pair<typename Arc::Label,
788                                    typename Arc::Label> > &parens,
789                  MutableFst<Arc> *ofst) {
790  typedef FifoQueue<typename Arc::StateId> Queue;
791  PdtShortestPathOptions<Arc, Queue> opts;
792  PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
793  psp.ShortestPath(ofst);
794}
795
796}  // namespace fst
797
798#endif  // FST_EXTENSIONS_PDT_SHORTEST_PATH_H__
799