expand.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
1// expand.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Expand a PDT to an FST.
20
21#ifndef FST_EXTENSIONS_PDT_EXPAND_H__
22#define FST_EXTENSIONS_PDT_EXPAND_H__
23
24#include <vector>
25using std::vector;
26
27#include <fst/extensions/pdt/pdt.h>
28#include <fst/extensions/pdt/paren.h>
29#include <fst/extensions/pdt/shortest-path.h>
30#include <fst/extensions/pdt/reverse.h>
31#include <fst/cache.h>
32#include <fst/mutable-fst.h>
33#include <fst/queue.h>
34#include <fst/state-table.h>
35#include <fst/test-properties.h>
36
37namespace fst {
38
39template <class Arc>
40struct ExpandFstOptions : public CacheOptions {
41  bool keep_parentheses;
42  PdtStack<typename Arc::StateId, typename Arc::Label> *stack;
43  PdtStateTable<typename Arc::StateId, typename Arc::StateId> *state_table;
44
45  ExpandFstOptions(
46      const CacheOptions &opts = CacheOptions(),
47      bool kp = false,
48      PdtStack<typename Arc::StateId, typename Arc::Label> *s = 0,
49      PdtStateTable<typename Arc::StateId, typename Arc::StateId> *st = 0)
50      : CacheOptions(opts), keep_parentheses(kp), stack(s), state_table(st) {}
51};
52
53// Properties for an expanded PDT.
54inline uint64 ExpandProperties(uint64 inprops) {
55  return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted);
56}
57
58
59// Implementation class for ExpandFst
60template <class A>
61class ExpandFstImpl
62    : public CacheImpl<A> {
63 public:
64  using FstImpl<A>::SetType;
65  using FstImpl<A>::SetProperties;
66  using FstImpl<A>::Properties;
67  using FstImpl<A>::SetInputSymbols;
68  using FstImpl<A>::SetOutputSymbols;
69
70  using CacheBaseImpl< CacheState<A> >::PushArc;
71  using CacheBaseImpl< CacheState<A> >::HasArcs;
72  using CacheBaseImpl< CacheState<A> >::HasFinal;
73  using CacheBaseImpl< CacheState<A> >::HasStart;
74  using CacheBaseImpl< CacheState<A> >::SetArcs;
75  using CacheBaseImpl< CacheState<A> >::SetFinal;
76  using CacheBaseImpl< CacheState<A> >::SetStart;
77
78  typedef A Arc;
79  typedef typename A::Label Label;
80  typedef typename A::Weight Weight;
81  typedef typename A::StateId StateId;
82  typedef StateId StackId;
83  typedef PdtStateTuple<StateId, StackId> StateTuple;
84
85  ExpandFstImpl(const Fst<A> &fst,
86                const vector<pair<typename Arc::Label,
87                                  typename Arc::Label> > &parens,
88                const ExpandFstOptions<A> &opts)
89      : CacheImpl<A>(opts), fst_(fst.Copy()),
90        stack_(opts.stack ? opts.stack: new PdtStack<StateId, Label>(parens)),
91        state_table_(opts.state_table ? opts.state_table :
92                     new PdtStateTable<StateId, StackId>()),
93        own_stack_(opts.stack == 0), own_state_table_(opts.state_table == 0),
94        keep_parentheses_(opts.keep_parentheses) {
95    SetType("expand");
96
97    uint64 props = fst.Properties(kFstProperties, false);
98    SetProperties(ExpandProperties(props), kCopyProperties);
99
100    SetInputSymbols(fst.InputSymbols());
101    SetOutputSymbols(fst.OutputSymbols());
102  }
103
104  ExpandFstImpl(const ExpandFstImpl &impl)
105      : CacheImpl<A>(impl),
106        fst_(impl.fst_->Copy(true)),
107        stack_(new PdtStack<StateId, Label>(*impl.stack_)),
108        state_table_(new PdtStateTable<StateId, StackId>()),
109        own_stack_(true), own_state_table_(true),
110        keep_parentheses_(impl.keep_parentheses_) {
111    SetType("expand");
112    SetProperties(impl.Properties(), kCopyProperties);
113    SetInputSymbols(impl.InputSymbols());
114    SetOutputSymbols(impl.OutputSymbols());
115  }
116
117  ~ExpandFstImpl() {
118    delete fst_;
119    if (own_stack_)
120      delete stack_;
121    if (own_state_table_)
122      delete state_table_;
123  }
124
125  StateId Start() {
126    if (!HasStart()) {
127      StateId s = fst_->Start();
128      if (s == kNoStateId)
129        return kNoStateId;
130      StateTuple tuple(s, 0);
131      StateId start = state_table_->FindState(tuple);
132      SetStart(start);
133    }
134    return CacheImpl<A>::Start();
135  }
136
137  Weight Final(StateId s) {
138    if (!HasFinal(s)) {
139      const StateTuple &tuple = state_table_->Tuple(s);
140      Weight w = fst_->Final(tuple.state_id);
141      if (w != Weight::Zero() && tuple.stack_id == 0)
142        SetFinal(s, w);
143      else
144        SetFinal(s, Weight::Zero());
145    }
146    return CacheImpl<A>::Final(s);
147  }
148
149  size_t NumArcs(StateId s) {
150    if (!HasArcs(s)) {
151      ExpandState(s);
152    }
153    return CacheImpl<A>::NumArcs(s);
154  }
155
156  size_t NumInputEpsilons(StateId s) {
157    if (!HasArcs(s))
158      ExpandState(s);
159    return CacheImpl<A>::NumInputEpsilons(s);
160  }
161
162  size_t NumOutputEpsilons(StateId s) {
163    if (!HasArcs(s))
164      ExpandState(s);
165    return CacheImpl<A>::NumOutputEpsilons(s);
166  }
167
168  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
169    if (!HasArcs(s))
170      ExpandState(s);
171    CacheImpl<A>::InitArcIterator(s, data);
172  }
173
174  // Computes the outgoing transitions from a state, creating new destination
175  // states as needed.
176  void ExpandState(StateId s) {
177    StateTuple tuple = state_table_->Tuple(s);
178    for (ArcIterator< Fst<A> > aiter(*fst_, tuple.state_id);
179         !aiter.Done(); aiter.Next()) {
180      Arc arc = aiter.Value();
181      StackId stack_id = stack_->Find(tuple.stack_id, arc.ilabel);
182      if (stack_id == -1) {
183        // Non-matching close parenthesis
184        continue;
185      } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) {
186        // Stack push/pop
187        arc.ilabel = arc.olabel = 0;
188      }
189
190      StateTuple ntuple(arc.nextstate, stack_id);
191      arc.nextstate = state_table_->FindState(ntuple);
192      PushArc(s, arc);
193    }
194    SetArcs(s);
195  }
196
197  const PdtStack<StackId, Label> &GetStack() const { return *stack_; }
198
199  const PdtStateTable<StateId, StackId> &GetStateTable() const {
200    return *state_table_;
201  }
202
203 private:
204  const Fst<A> *fst_;
205
206  PdtStack<StackId, Label> *stack_;
207  PdtStateTable<StateId, StackId> *state_table_;
208  bool own_stack_;
209  bool own_state_table_;
210  bool keep_parentheses_;
211
212  void operator=(const ExpandFstImpl<A> &);  // disallow
213};
214
215// Expands a pushdown transducer (PDT) encoded as an FST into an FST.
216// This version is a delayed Fst.  In the PDT, some transitions are
217// labeled with open or close parentheses. To be interpreted as a PDT,
218// the parens must balance on a path. The open-close parenthesis label
219// pairs are passed in 'parens'. The expansion enforces the
220// parenthesis constraints. The PDT must be expandable as an FST.
221//
222// This class attaches interface to implementation and handles
223// reference counting, delegating most methods to ImplToFst.
224template <class A>
225class ExpandFst : public ImplToFst< ExpandFstImpl<A> > {
226 public:
227  friend class ArcIterator< ExpandFst<A> >;
228  friend class StateIterator< ExpandFst<A> >;
229
230  typedef A Arc;
231  typedef typename A::Label Label;
232  typedef typename A::Weight Weight;
233  typedef typename A::StateId StateId;
234  typedef StateId StackId;
235  typedef CacheState<A> State;
236  typedef ExpandFstImpl<A> Impl;
237
238  ExpandFst(const Fst<A> &fst,
239            const vector<pair<typename Arc::Label,
240                              typename Arc::Label> > &parens)
241      : ImplToFst<Impl>(new Impl(fst, parens, ExpandFstOptions<A>())) {}
242
243  ExpandFst(const Fst<A> &fst,
244            const vector<pair<typename Arc::Label,
245                              typename Arc::Label> > &parens,
246            const ExpandFstOptions<A> &opts)
247      : ImplToFst<Impl>(new Impl(fst, parens, opts)) {}
248
249  // See Fst<>::Copy() for doc.
250  ExpandFst(const ExpandFst<A> &fst, bool safe = false)
251      : ImplToFst<Impl>(fst, safe) {}
252
253  // Get a copy of this ExpandFst. See Fst<>::Copy() for further doc.
254  virtual ExpandFst<A> *Copy(bool safe = false) const {
255    return new ExpandFst<A>(*this, safe);
256  }
257
258  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
259
260  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
261    GetImpl()->InitArcIterator(s, data);
262  }
263
264  const PdtStack<StackId, Label> &GetStack() const {
265    return GetImpl()->GetStack();
266  }
267
268  const PdtStateTable<StateId, StackId> &GetStateTable() const {
269    return GetImpl()->GetStateTable();
270  }
271
272 private:
273  // Makes visible to friends.
274  Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
275
276  void operator=(const ExpandFst<A> &fst);  // Disallow
277};
278
279
280// Specialization for ExpandFst.
281template<class A>
282class StateIterator< ExpandFst<A> >
283    : public CacheStateIterator< ExpandFst<A> > {
284 public:
285  explicit StateIterator(const ExpandFst<A> &fst)
286      : CacheStateIterator< ExpandFst<A> >(fst, fst.GetImpl()) {}
287};
288
289
290// Specialization for ExpandFst.
291template <class A>
292class ArcIterator< ExpandFst<A> >
293    : public CacheArcIterator< ExpandFst<A> > {
294 public:
295  typedef typename A::StateId StateId;
296
297  ArcIterator(const ExpandFst<A> &fst, StateId s)
298      : CacheArcIterator< ExpandFst<A> >(fst.GetImpl(), s) {
299    if (!fst.GetImpl()->HasArcs(s))
300      fst.GetImpl()->ExpandState(s);
301  }
302
303 private:
304  DISALLOW_COPY_AND_ASSIGN(ArcIterator);
305};
306
307
308template <class A> inline
309void ExpandFst<A>::InitStateIterator(StateIteratorData<A> *data) const
310{
311  data->base = new StateIterator< ExpandFst<A> >(*this);
312}
313
314//
315// PrunedExpand Class
316//
317
318// Prunes the delayed expansion of a pushdown transducer (PDT) encoded
319// as an FST into an FST.  In the PDT, some transitions are labeled
320// with open or close parentheses. To be interpreted as a PDT, the
321// parens must balance on a path. The open-close parenthesis label
322// pairs are passed in 'parens'. The expansion enforces the
323// parenthesis constraints.
324//
325// The algorithm works by visiting the delayed ExpandFst using a
326// shortest-stack first queue discipline and relies on the
327// shortest-distance information computed using a reverse
328// shortest-path call to perform the pruning.
329//
330// The algorithm maintains the same state ordering between the ExpandFst
331// being visited 'efst_' and the result of pruning written into the
332// MutableFst 'ofst_' to improve readability of the code.
333//
334template <class A>
335class PrunedExpand {
336 public:
337  typedef A Arc;
338  typedef typename A::Label Label;
339  typedef typename A::StateId StateId;
340  typedef typename A::Weight Weight;
341  typedef StateId StackId;
342  typedef PdtStack<StackId, Label> Stack;
343  typedef PdtStateTable<StateId, StackId> StateTable;
344  typedef typename PdtBalanceData<Arc>::SetIterator SetIterator;
345
346  // Constructor taking as input a PDT specified by 'ifst' and 'parens'.
347  // 'keep_parentheses' specifies whether parentheses are replaced by
348  // epsilons or not during the expansion. 'opts' is the cache options
349  // used to instantiate the underlying ExpandFst.
350  PrunedExpand(const Fst<A> &ifst,
351               const vector<pair<Label, Label> > &parens,
352               bool keep_parentheses = false,
353               const CacheOptions &opts = CacheOptions())
354      : ifst_(ifst.Copy()),
355        keep_parentheses_(keep_parentheses),
356        stack_(parens),
357        efst_(ifst, parens,
358              ExpandFstOptions<Arc>(opts, true, &stack_, &state_table_)),
359        queue_(state_table_, stack_, stack_length_, distance_, fdistance_) {
360    Reverse(*ifst_, parens, &rfst_);
361    VectorFst<Arc> path;
362    reverse_shortest_path_ = new SP(
363        rfst_, parens,
364        PdtShortestPathOptions<A, FifoQueue<StateId> >(true, false));
365    reverse_shortest_path_->ShortestPath(&path);
366    balance_data_ = reverse_shortest_path_->GetBalanceData()->Reverse(
367        rfst_.NumStates(), 10, -1);
368
369    InitCloseParenMultimap(parens);
370  }
371
372  ~PrunedExpand() {
373    delete ifst_;
374    delete reverse_shortest_path_;
375    delete balance_data_;
376  }
377
378  // Expands and prunes with weight threshold 'threshold' the input PDT.
379  // Writes the result in 'ofst'.
380  void Expand(MutableFst<A> *ofst, const Weight &threshold);
381
382 private:
383  static const uint8 kEnqueued;
384  static const uint8 kExpanded;
385  static const uint8 kSourceState;
386
387  // Comparison functor used by the queue:
388  // 1. states corresponding to shortest stack first,
389  // 2. among stacks of the same length, reverse lexicographic order is used,
390  // 3. among states with the same stack, shortest-first order is used.
391  class StackCompare {
392   public:
393    StackCompare(const StateTable &st,
394                 const Stack &s, const vector<StackId> &sl,
395                 const vector<Weight> &d, const vector<Weight> &fd)
396        : state_table_(st), stack_(s), stack_length_(sl),
397          distance_(d), fdistance_(fd) {}
398
399    bool operator()(StateId s1, StateId s2) const {
400      StackId si1 = state_table_.Tuple(s1).stack_id;
401      StackId si2 = state_table_.Tuple(s2).stack_id;
402      if (stack_length_[si1] < stack_length_[si2])
403        return true;
404      if  (stack_length_[si1] > stack_length_[si2])
405        return false;
406      // If stack id equal, use A*
407      if (si1 == si2) {
408        Weight w1 = (s1 < distance_.size()) && (s1 < fdistance_.size()) ?
409            Times(distance_[s1], fdistance_[s1]) : Weight::Zero();
410        Weight w2 = (s2 < distance_.size()) && (s2 < fdistance_.size()) ?
411            Times(distance_[s2], fdistance_[s2]) : Weight::Zero();
412        return less_(w1, w2);
413      }
414      // If lenghts are equal, use reverse lexico.
415      for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) {
416        if (stack_.Top(si1) < stack_.Top(si2)) return true;
417        if (stack_.Top(si1) > stack_.Top(si2)) return false;
418      }
419      return false;
420    }
421
422   private:
423    const StateTable &state_table_;
424    const Stack &stack_;
425    const vector<StackId> &stack_length_;
426    const vector<Weight> &distance_;
427    const vector<Weight> &fdistance_;
428    NaturalLess<Weight> less_;
429  };
430
431  class ShortestStackFirstQueue
432      : public ShortestFirstQueue<StateId, StackCompare> {
433   public:
434    ShortestStackFirstQueue(
435        const PdtStateTable<StateId, StackId> &st,
436        const Stack &s,
437        const vector<StackId> &sl,
438        const vector<Weight> &d, const vector<Weight> &fd)
439        : ShortestFirstQueue<StateId, StackCompare>(
440            StackCompare(st, s, sl, d, fd)) {}
441  };
442
443
444  void InitCloseParenMultimap(const vector<pair<Label, Label> > &parens);
445  Weight DistanceToDest(StateId state, StateId source) const;
446  uint8 Flags(StateId s) const;
447  void SetFlags(StateId s, uint8 flags, uint8 mask);
448  Weight Distance(StateId s) const;
449  void SetDistance(StateId s, Weight w);
450  Weight FinalDistance(StateId s) const;
451  void SetFinalDistance(StateId s, Weight w);
452  StateId SourceState(StateId s) const;
453  void SetSourceState(StateId s, StateId p);
454  void AddStateAndEnqueue(StateId s);
455  void Relax(StateId s, const A &arc, Weight w);
456  bool PruneArc(StateId s, const A &arc);
457  void ProcStart();
458  void ProcFinal(StateId s);
459  bool ProcNonParen(StateId s, const A &arc, bool add_arc);
460  bool ProcOpenParen(StateId s, const A &arc, StackId si, StackId nsi);
461  bool ProcCloseParen(StateId s, const A &arc);
462  void ProcDestStates(StateId s, StackId si);
463
464  Fst<A> *ifst_;                   // Input PDT
465  VectorFst<Arc> rfst_;            // Reversed PDT
466  bool keep_parentheses_;          // Keep parentheses in ofst?
467  StateTable state_table_;         // State table for efst_
468  Stack stack_;                    // Stack trie
469  ExpandFst<Arc> efst_;            // Expanded PDT
470  vector<StackId> stack_length_;   // Length of stack for given stack id
471  vector<Weight> distance_;        // Distance from initial state in efst_/ofst
472  vector<Weight> fdistance_;       // Distance to final states in efst_/ofst
473  ShortestStackFirstQueue queue_;  // Queue used to visit efst_
474  vector<uint8> flags_;            // Status flags for states in efst_/ofst
475  vector<StateId> sources_;        // PDT source state for each expanded state
476
477  typedef PdtShortestPath<Arc, FifoQueue<StateId> > SP;
478  typedef typename SP::CloseParenMultimap ParenMultimap;
479  SP *reverse_shortest_path_;  // Shortest path for rfst_
480  PdtBalanceData<Arc> *balance_data_;   // Not owned by shortest_path_
481  ParenMultimap close_paren_multimap_;  // Maps open paren arcs to
482  // balancing close paren arcs.
483
484  MutableFst<Arc> *ofst_;  // Output fst
485  Weight limit_;           // Weight limit
486
487  typedef unordered_map<StateId, Weight> DestMap;
488  DestMap dest_map_;
489  StackId current_stack_id_;
490  // 'current_stack_id_' is the stack id of the states currently at the top
491  // of queue, i.e., the states currently being popped and processed.
492  // 'dest_map_' maps a state 's' in 'ifst_' that is the source
493  // of a close parentheses matching the top of 'current_stack_id_; to
494  // the shortest-distance from '(s, current_stack_id_)' to the final
495  // states in 'efst_'.
496  ssize_t current_paren_id_;  // Paren id at top of current stack
497  ssize_t cached_stack_id_;
498  StateId cached_source_;
499  slist<pair<StateId, Weight> > cached_dest_list_;
500  // 'cached_dest_list_' contains the set of pair of destination
501  // states and weight to final states for source state
502  // 'cached_source_' and paren id 'cached_paren_id': the set of
503  // source state of a close parenthesis with paren id
504  // 'cached_paren_id' balancing an incoming open parenthesis with
505  // paren id 'cached_paren_id' in state 'cached_source_'.
506
507  NaturalLess<Weight> less_;
508};
509
510template <class A> const uint8 PrunedExpand<A>::kEnqueued = 0x01;
511template <class A> const uint8 PrunedExpand<A>::kExpanded = 0x02;
512template <class A> const uint8 PrunedExpand<A>::kSourceState = 0x04;
513
514
515// Initializes close paren multimap, mapping pairs (s,paren_id) to
516// all the arcs out of s labeled with close parenthese for paren_id.
517template <class A>
518void PrunedExpand<A>::InitCloseParenMultimap(
519    const vector<pair<Label, Label> > &parens) {
520  unordered_map<Label, Label> paren_id_map;
521  for (Label i = 0; i < parens.size(); ++i) {
522    const pair<Label, Label>  &p = parens[i];
523    paren_id_map[p.first] = i;
524    paren_id_map[p.second] = i;
525  }
526
527  for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) {
528    StateId s = siter.Value();
529    for (ArcIterator<Fst<Arc> > aiter(*ifst_, s);
530         !aiter.Done(); aiter.Next()) {
531      const Arc &arc = aiter.Value();
532      typename unordered_map<Label, Label>::const_iterator pit
533          = paren_id_map.find(arc.ilabel);
534      if (pit == paren_id_map.end()) continue;
535      if (arc.ilabel == parens[pit->second].second) {  // Close paren
536        ParenState<Arc> paren_state(pit->second, s);
537        close_paren_multimap_.insert(make_pair(paren_state, arc));
538      }
539    }
540  }
541}
542
543
544// Returns the weight of the shortest balanced path from 'source' to 'dest'
545// in 'ifst_', 'dest' must be the source state of a close paren arc.
546template <class A>
547typename A::Weight PrunedExpand<A>::DistanceToDest(StateId source,
548                                                   StateId dest) const {
549  typename SP::SearchState s(source + 1, dest + 1);
550  VLOG(2) << "D(" << source << ", " << dest << ") ="
551            << reverse_shortest_path_->GetShortestPathData().Distance(s);
552  return reverse_shortest_path_->GetShortestPathData().Distance(s);
553}
554
555// Returns the flags for state 's' in 'ofst_'.
556template <class A>
557uint8 PrunedExpand<A>::Flags(StateId s) const {
558  return s < flags_.size() ? flags_[s] : 0;
559}
560
561// Modifies the flags for state 's' in 'ofst_'.
562template <class A>
563void PrunedExpand<A>::SetFlags(StateId s, uint8 flags, uint8 mask) {
564  while (flags_.size() <= s) flags_.push_back(0);
565  flags_[s] &= ~mask;
566  flags_[s] |= flags & mask;
567}
568
569
570// Returns the shortest distance from the initial state to 's' in 'ofst_'.
571template <class A>
572typename A::Weight PrunedExpand<A>::Distance(StateId s) const {
573  return s < distance_.size() ? distance_[s] : Weight::Zero();
574}
575
576// Sets the shortest distance from the initial state to 's' in 'ofst_' to 'w'.
577template <class A>
578void PrunedExpand<A>::SetDistance(StateId s, Weight w) {
579  while (distance_.size() <= s ) distance_.push_back(Weight::Zero());
580  distance_[s] = w;
581}
582
583
584// Returns the shortest distance from 's' to the final states in 'ofst_'.
585template <class A>
586typename A::Weight PrunedExpand<A>::FinalDistance(StateId s) const {
587  return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
588}
589
590// Sets the shortest distance from 's' to the final states in 'ofst_' to 'w'.
591template <class A>
592void PrunedExpand<A>::SetFinalDistance(StateId s, Weight w) {
593  while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero());
594  fdistance_[s] = w;
595}
596
597// Returns the PDT "source" state of state 's' in 'ofst_'.
598template <class A>
599typename A::StateId PrunedExpand<A>::SourceState(StateId s) const {
600  return s < sources_.size() ? sources_[s] : kNoStateId;
601}
602
603// Sets the PDT "source" state of state 's' in 'ofst_' to state 'p' in 'ifst_'.
604template <class A>
605void PrunedExpand<A>::SetSourceState(StateId s, StateId p) {
606  while (sources_.size() <= s) sources_.push_back(kNoStateId);
607  sources_[s] = p;
608}
609
610// Adds state 's' of 'efst_' to 'ofst_' and inserts it in the queue,
611// modifying the flags for 's' accordingly.
612template <class A>
613void PrunedExpand<A>::AddStateAndEnqueue(StateId s) {
614  if (!(Flags(s) & (kEnqueued | kExpanded))) {
615    while (ofst_->NumStates() <= s) ofst_->AddState();
616    queue_.Enqueue(s);
617    SetFlags(s, kEnqueued, kEnqueued);
618  } else if (Flags(s) & kEnqueued) {
619    queue_.Update(s);
620  }
621  // TODO(allauzen): Check everything is fine when kExpanded?
622}
623
624// Relaxes arc 'arc' out of state 's' in 'ofst_':
625// * if the distance to 's' times the weight of 'arc' is smaller than
626//   the currently stored distance for 'arc.nextstate',
627//   updates 'Distance(arc.nextstate)' with new estimate;
628// * if 'fd' is less than the currently stored distance from 'arc.nextstate'
629//   to the final state, updates with new estimate.
630template <class A>
631void PrunedExpand<A>::Relax(StateId s, const A &arc, Weight fd) {
632  Weight nd = Times(Distance(s), arc.weight);
633  if (less_(nd, Distance(arc.nextstate))) {
634    SetDistance(arc.nextstate, nd);
635    SetSourceState(arc.nextstate, SourceState(s));
636  }
637  if (less_(fd, FinalDistance(arc.nextstate)))
638    SetFinalDistance(arc.nextstate, fd);
639  VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to "
640            << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate)
641            << ", nd = " << nd;
642}
643
644// Returns 'true' if the arc 'arc' out of state 's' in 'efst_' needs to
645// be pruned.
646template <class A>
647bool PrunedExpand<A>::PruneArc(StateId s, const A &arc) {
648  VLOG(2) << "Prune ?";
649  Weight fd = Weight::Zero();
650
651  if ((cached_source_ != SourceState(s)) ||
652      (cached_stack_id_ != current_stack_id_)) {
653    cached_source_ = SourceState(s);
654    cached_stack_id_ = current_stack_id_;
655    cached_dest_list_.clear();
656    if (cached_source_ != ifst_->Start()) {
657      for (SetIterator set_iter =
658               balance_data_->Find(current_paren_id_, cached_source_);
659           !set_iter.Done(); set_iter.Next()) {
660        StateId dest = set_iter.Element();
661        typename DestMap::const_iterator iter = dest_map_.find(dest);
662        cached_dest_list_.push_front(*iter);
663      }
664    } else {
665      // TODO(allauzen): queue discipline should prevent this never
666      // from happening; replace by a check.
667      cached_dest_list_.push_front(
668          make_pair(rfst_.Start() -1, Weight::One()));
669    }
670  }
671
672  for (typename slist<pair<StateId, Weight> >::const_iterator iter =
673           cached_dest_list_.begin();
674       iter != cached_dest_list_.end();
675       ++iter) {
676    fd = Plus(fd,
677              Times(DistanceToDest(state_table_.Tuple(arc.nextstate).state_id,
678                                   iter->first),
679                    iter->second));
680  }
681  Relax(s, arc, fd);
682  Weight w = Times(Distance(s), Times(arc.weight, fd));
683  return less_(limit_, w);
684}
685
686// Adds start state of 'efst_' to 'ofst_', enqueues it and initializes
687// the distance data structures.
688template <class A>
689void PrunedExpand<A>::ProcStart() {
690  StateId s = efst_.Start();
691  AddStateAndEnqueue(s);
692  ofst_->SetStart(s);
693  SetSourceState(s, ifst_->Start());
694
695  current_stack_id_ = 0;
696  current_paren_id_ = -1;
697  stack_length_.push_back(0);
698  dest_map_[rfst_.Start() - 1] = Weight::One(); // not needed
699
700  cached_source_ = ifst_->Start();
701  cached_stack_id_ = 0;
702  cached_dest_list_.push_front(
703          make_pair(rfst_.Start() -1, Weight::One()));
704
705  PdtStateTuple<StateId, StackId> tuple(rfst_.Start() - 1, 0);
706  SetFinalDistance(state_table_.FindState(tuple), Weight::One());
707  SetDistance(s, Weight::One());
708  SetFinalDistance(s, DistanceToDest(ifst_->Start(), rfst_.Start() - 1));
709  VLOG(2) << DistanceToDest(ifst_->Start(), rfst_.Start() - 1);
710}
711
712// Makes 's' final in 'ofst_' if shortest accepting path ending in 's'
713// is below threshold.
714template <class A>
715void PrunedExpand<A>::ProcFinal(StateId s) {
716  Weight final = efst_.Final(s);
717  if ((final == Weight::Zero()) || less_(limit_, Times(Distance(s), final)))
718    return;
719  ofst_->SetFinal(s, final);
720}
721
722// Returns true when arc (or meta-arc) 'arc' out of 's' in 'efst_' is
723// below the threshold.  When 'add_arc' is true, 'arc' is added to 'ofst_'.
724template <class A>
725bool PrunedExpand<A>::ProcNonParen(StateId s, const A &arc, bool add_arc) {
726  VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate
727          << ", " << arc.ilabel << ":" << arc.olabel << " / " << arc.weight
728          << ", add_arc = " << (add_arc ? "true" : "false");
729  if (PruneArc(s, arc)) return false;
730  if(add_arc) ofst_->AddArc(s, arc);
731  AddStateAndEnqueue(arc.nextstate);
732  return true;
733}
734
735// Processes an open paren arc 'arc' out of state 's' in 'ofst_'.
736// When 'arc' is labeled with an open paren,
737// 1. considers each (shortest) balanced path starting in 's' by
738//    taking 'arc' and ending by a close paren balancing the open
739//    paren of 'arc' as a meta-arc, processes and prunes each meta-arc
740//    as a non-paren arc, inserting its destination to the queue;
741// 2. if at least one of these meta-arcs has not been pruned,
742//    adds the destination of 'arc' to 'ofst_' as a new source state
743//    for the stack id 'nsi' and inserts it in the queue.
744template <class A>
745bool PrunedExpand<A>::ProcOpenParen(StateId s, const A &arc, StackId si,
746                                    StackId nsi) {
747  // Update the stack lenght when needed: |nsi| = |si| + 1.
748  while (stack_length_.size() <= nsi) stack_length_.push_back(-1);
749  if (stack_length_[nsi] == -1)
750    stack_length_[nsi] = stack_length_[si] + 1;
751
752  StateId ns = arc.nextstate;
753  VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id
754            << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")";
755  bool proc_arc = false;
756  Weight fd = Weight::Zero();
757  ssize_t paren_id = stack_.ParenId(arc.ilabel);
758  slist<StateId> sources;
759  for (SetIterator set_iter =
760           balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id);
761       !set_iter.Done(); set_iter.Next()) {
762    sources.push_front(set_iter.Element());
763  }
764  for (typename slist<StateId>::const_iterator sources_iter = sources.begin();
765       sources_iter != sources.end();
766       ++ sources_iter) {
767    StateId source = *sources_iter;
768    VLOG(2) << "Close paren source: " << source;
769    ParenState<Arc> paren_state(paren_id, source);
770    for (typename ParenMultimap::const_iterator iter =
771             close_paren_multimap_.find(paren_state);
772         iter != close_paren_multimap_.end() && paren_state == iter->first;
773         ++iter) {
774      Arc meta_arc = iter->second;
775      PdtStateTuple<StateId, StackId> tuple(meta_arc.nextstate, si);
776      meta_arc.nextstate =  state_table_.FindState(tuple);
777      VLOG(2) << state_table_.Tuple(ns).state_id << ", " << source;
778      VLOG(2) << "Meta arc weight = " << arc.weight << " Times "
779                << DistanceToDest(state_table_.Tuple(ns).state_id, source)
780                << " Times " << meta_arc.weight;
781      meta_arc.weight = Times(
782          arc.weight,
783          Times(DistanceToDest(state_table_.Tuple(ns).state_id, source),
784                meta_arc.weight));
785      proc_arc |= ProcNonParen(s, meta_arc, false);
786      fd = Plus(fd, Times(
787          Times(
788              DistanceToDest(state_table_.Tuple(ns).state_id, source),
789              iter->second.weight),
790          FinalDistance(meta_arc.nextstate)));
791    }
792  }
793  if (proc_arc) {
794    VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate;
795    ofst_->AddArc(
796      s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
797    AddStateAndEnqueue(arc.nextstate);
798    Weight nd = Times(Distance(s), arc.weight);
799    if(less_(nd, Distance(arc.nextstate)))
800      SetDistance(arc.nextstate, nd);
801    // FinalDistance not necessary for source state since pruning
802    // decided using the meta-arcs above.  But this is a problem with
803    // A*, hence:
804    if (less_(fd, FinalDistance(arc.nextstate)))
805      SetFinalDistance(arc.nextstate, fd);
806    SetFlags(arc.nextstate, kSourceState, kSourceState);
807  }
808  return proc_arc;
809}
810
811// Checks that shortest path through close paren arc in 'efst_' is
812// below threshold, if so adds it to 'ofst_'.
813template <class A>
814bool PrunedExpand<A>::ProcCloseParen(StateId s, const A &arc) {
815  Weight w = Times(Distance(s),
816                   Times(arc.weight, FinalDistance(arc.nextstate)));
817  if (less_(limit_, w))
818    return false;
819  ofst_->AddArc(
820      s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
821  return true;
822}
823
824// When 's' in 'ofst_' is a source state for stack id 'si', identifies
825// all the corresponding possible destination states, that is, all the
826// states in 'ifst_' that have an outgoing close paren arc balancing
827// the incoming open paren taken to get to 's', and for each such
828// state 't', computes the shortest distance from (t, si) to the final
829// states in 'ofst_'. Stores this information in 'dest_map_'.
830template <class A>
831void PrunedExpand<A>::ProcDestStates(StateId s, StackId si) {
832  if (!(Flags(s) & kSourceState)) return;
833  if (si != current_stack_id_) {
834    dest_map_.clear();
835    current_stack_id_ = si;
836    current_paren_id_ = stack_.Top(current_stack_id_);
837    VLOG(2) << "StackID " << si << " dequeued for first time";
838  }
839  // TODO(allauzen): clean up source state business; rename current function to
840  // ProcSourceState.
841  SetSourceState(s, state_table_.Tuple(s).state_id);
842
843  ssize_t paren_id = stack_.Top(si);
844  for (SetIterator set_iter =
845           balance_data_->Find(paren_id, state_table_.Tuple(s).state_id);
846       !set_iter.Done(); set_iter.Next()) {
847    StateId dest_state = set_iter.Element();
848    if (dest_map_.find(dest_state) != dest_map_.end())
849      continue;
850    Weight dest_weight = Weight::Zero();
851    ParenState<Arc> paren_state(paren_id, dest_state);
852    for (typename ParenMultimap::const_iterator iter =
853             close_paren_multimap_.find(paren_state);
854         iter != close_paren_multimap_.end() && paren_state == iter->first;
855         ++iter) {
856      const Arc &arc = iter->second;
857      PdtStateTuple<StateId, StackId> tuple(arc.nextstate, stack_.Pop(si));
858      dest_weight = Plus(dest_weight,
859                         Times(arc.weight,
860                               FinalDistance(state_table_.FindState(tuple))));
861    }
862    dest_map_[dest_state] = dest_weight;
863    VLOG(2) << "State " << dest_state << " is a dest state for stack id "
864              << si << " with weight " << dest_weight;
865  }
866}
867
868// Expands and prunes with weight threshold 'threshold' the input PDT.
869// Writes the result in 'ofst'.
870template <class A>
871void PrunedExpand<A>::Expand(
872    MutableFst<A> *ofst, const typename A::Weight &threshold) {
873  ofst_ = ofst;
874  ofst_->DeleteStates();
875  ofst_->SetInputSymbols(ifst_->InputSymbols());
876  ofst_->SetOutputSymbols(ifst_->OutputSymbols());
877
878  limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold);
879  flags_.clear();
880
881  ProcStart();
882
883  while (!queue_.Empty()) {
884    StateId s = queue_.Head();
885    queue_.Dequeue();
886    SetFlags(s, kExpanded, kExpanded | kEnqueued);
887    VLOG(2) << s << " dequeued!";
888
889    ProcFinal(s);
890    StackId stack_id = state_table_.Tuple(s).stack_id;
891    ProcDestStates(s, stack_id);
892
893    for (ArcIterator<ExpandFst<Arc> > aiter(efst_, s);
894         !aiter.Done();
895         aiter.Next()) {
896      Arc arc = aiter.Value();
897      StackId nextstack_id = state_table_.Tuple(arc.nextstate).stack_id;
898      if (stack_id == nextstack_id)
899        ProcNonParen(s, arc, true);
900      else if (stack_id == stack_.Pop(nextstack_id))
901        ProcOpenParen(s, arc, stack_id, nextstack_id);
902      else
903        ProcCloseParen(s, arc);
904    }
905    VLOG(2) << "d[" << s << "] = " << Distance(s)
906            << ", fd[" << s << "] = " << FinalDistance(s);
907  }
908}
909
910//
911// Expand() Functions
912//
913
914template <class Arc>
915struct ExpandOptions {
916  bool connect;
917  bool keep_parentheses;
918  typename Arc::Weight weight_threshold;
919
920  ExpandOptions(bool c  = true, bool k = false,
921                typename Arc::Weight w = Arc::Weight::Zero())
922      : connect(c), keep_parentheses(k), weight_threshold(w) {}
923};
924
925// Expands a pushdown transducer (PDT) encoded as an FST into an FST.
926// This version writes the expanded PDT result to a MutableFst.
927// In the PDT, some transitions are labeled with open or close
928// parentheses. To be interpreted as a PDT, the parens must balance on
929// a path. The open-close parenthesis label pairs are passed in
930// 'parens'. The expansion enforces the parenthesis constraints. The
931// PDT must be expandable as an FST.
932template <class Arc>
933void Expand(
934    const Fst<Arc> &ifst,
935    const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
936    MutableFst<Arc> *ofst,
937    const ExpandOptions<Arc> &opts) {
938  typedef typename Arc::Label Label;
939  typedef typename Arc::StateId StateId;
940  typedef typename Arc::Weight Weight;
941  typedef typename ExpandFst<Arc>::StackId StackId;
942
943  ExpandFstOptions<Arc> eopts;
944  eopts.gc_limit = 0;
945  if (opts.weight_threshold == Weight::Zero()) {
946    eopts.keep_parentheses = opts.keep_parentheses;
947    *ofst = ExpandFst<Arc>(ifst, parens, eopts);
948  } else {
949    PrunedExpand<Arc> pruned_expand(ifst, parens, opts.keep_parentheses);
950    pruned_expand.Expand(ofst, opts.weight_threshold);
951  }
952
953  if (opts.connect)
954    Connect(ofst);
955}
956
957// Expands a pushdown transducer (PDT) encoded as an FST into an FST.
958// This version writes the expanded PDT result to a MutableFst.
959// In the PDT, some transitions are labeled with open or close
960// parentheses. To be interpreted as a PDT, the parens must balance on
961// a path. The open-close parenthesis label pairs are passed in
962// 'parens'. The expansion enforces the parenthesis constraints. The
963// PDT must be expandable as an FST.
964template<class Arc>
965void Expand(
966    const Fst<Arc> &ifst,
967    const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
968    MutableFst<Arc> *ofst,
969    bool connect = true, bool keep_parentheses = false) {
970  Expand(ifst, parens, ofst, ExpandOptions<Arc>(connect, keep_parentheses));
971}
972
973}  // namespace fst
974
975#endif  // FST_EXTENSIONS_PDT_EXPAND_H__
976