paren.h revision 3da1eb108d36da35333b2d655202791af854996b
1// paren.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// Common classes for PDT parentheses
19
20// \file
21
22#ifndef FST_EXTENSIONS_PDT_PAREN_H_
23#define FST_EXTENSIONS_PDT_PAREN_H_
24
25#include <algorithm>
26#include <tr1/unordered_map>
27using std::tr1::unordered_map;
28using std::tr1::unordered_multimap;
29#include <tr1/unordered_set>
30using std::tr1::unordered_set;
31using std::tr1::unordered_multiset;
32#include <set>
33
34#include <fst/extensions/pdt/pdt.h>
35#include <fst/extensions/pdt/collection.h>
36#include <fst/fst.h>
37#include <fst/dfs-visit.h>
38
39
40namespace fst {
41
42//
43// ParenState: Pair of an open (close) parenthesis and
44// its destination (source) state.
45//
46
47template <class A>
48class ParenState {
49 public:
50  typedef typename A::Label Label;
51  typedef typename A::StateId StateId;
52
53  struct Hash {
54    size_t operator()(const ParenState<A> &p) const {
55      return p.paren_id + p.state_id * kPrime;
56    }
57  };
58
59  Label paren_id;     // ID of open (close) paren
60  StateId state_id;   // destination (source) state of open (close) paren
61
62  ParenState() : paren_id(kNoLabel), state_id(kNoStateId) {}
63
64  ParenState(Label p, StateId s) : paren_id(p), state_id(s) {}
65
66  bool operator==(const ParenState<A> &p) const {
67    if (&p == this)
68      return true;
69    return p.paren_id == this->paren_id && p.state_id == this->state_id;
70  }
71
72  bool operator!=(const ParenState<A> &p) const { return !(p == *this); }
73
74  bool operator<(const ParenState<A> &p) const {
75    return paren_id < this->paren.id ||
76        (p.paren_id == this->paren.id && p.state_id < this->state_id);
77  }
78
79 private:
80  static const size_t kPrime;
81};
82
83template <class A>
84const size_t ParenState<A>::kPrime = 7853;
85
86
87// Creates an FST-style iterator from STL map and iterator.
88template <class M>
89class MapIterator {
90 public:
91  typedef typename M::const_iterator StlIterator;
92  typedef typename M::value_type PairType;
93  typedef typename PairType::second_type ValueType;
94
95  MapIterator(const M &m, StlIterator iter)
96      : map_(m), begin_(iter), iter_(iter) {}
97
98  bool Done() const {
99    return iter_ == map_.end() || iter_->first != begin_->first;
100  }
101
102  ValueType Value() const { return iter_->second; }
103  void Next() { ++iter_; }
104  void Reset() { iter_ = begin_; }
105
106 private:
107  const M &map_;
108  StlIterator begin_;
109  StlIterator iter_;
110};
111
112//
113// PdtParenReachable: Provides various parenthesis reachability information
114// on a PDT.
115//
116
117template <class A>
118class PdtParenReachable {
119 public:
120  typedef typename A::StateId StateId;
121  typedef typename A::Label Label;
122 public:
123  // Maps from state ID to reachable paren IDs from (to) that state.
124  typedef unordered_multimap<StateId, Label> ParenMultiMap;
125
126  // Maps from paren ID and state ID to reachable state set ID
127  typedef unordered_map<ParenState<A>, ssize_t,
128                   typename ParenState<A>::Hash> StateSetMap;
129
130  // Maps from paren ID and state ID to arcs exiting that state with that
131  // Label.
132  typedef unordered_multimap<ParenState<A>, A,
133                        typename ParenState<A>::Hash> ParenArcMultiMap;
134
135  typedef MapIterator<ParenMultiMap> ParenIterator;
136
137  typedef MapIterator<ParenArcMultiMap> ParenArcIterator;
138
139  typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
140
141  // Computes close (open) parenthesis reachabilty information for
142  // a PDT with bounded stack.
143  PdtParenReachable(const Fst<A> &fst,
144                    const vector<pair<Label, Label> > &parens, bool close)
145      : fst_(fst),
146        parens_(parens),
147        close_(close),
148        error_(false) {
149    for (Label i = 0; i < parens.size(); ++i) {
150      const pair<Label, Label>  &p = parens[i];
151      paren_id_map_[p.first] = i;
152      paren_id_map_[p.second] = i;
153    }
154
155    if (close_) {
156      StateId start = fst.Start();
157      if (start == kNoStateId)
158        return;
159      if (!DFSearch(start)) {
160        FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
161        error_ = true;
162      }
163    } else {
164      FSTERROR() << "PdtParenReachable: open paren info not implemented";
165      error_ = true;
166    }
167  }
168
169  bool const Error() { return error_; }
170
171  // Given a state ID, returns an iterator over paren IDs
172  // for close (open) parens reachable from that state along balanced
173  // paths.
174  ParenIterator FindParens(StateId s) const {
175    return ParenIterator(paren_multimap_, paren_multimap_.find(s));
176  }
177
178  // Given a paren ID and a state ID s, returns an iterator over
179  // states that can be reached along balanced paths from (to) s that
180  // have have close (open) parentheses matching the paren ID exiting
181  // (entering) those states.
182  SetIterator FindStates(Label paren_id, StateId s) const {
183    ParenState<A> paren_state(paren_id, s);
184    typename StateSetMap::const_iterator id_it = set_map_.find(paren_state);
185    if (id_it == set_map_.end()) {
186      return state_sets_.FindSet(-1);
187    } else {
188      return state_sets_.FindSet(id_it->second);
189    }
190  }
191
192  // Given a paren Id and a state ID s, return an iterator over
193  // arcs that exit (enter) s and are labeled with a close (open)
194  // parenthesis matching the paren ID.
195  ParenArcIterator FindParenArcs(Label paren_id, StateId s) const {
196    ParenState<A> paren_state(paren_id, s);
197    return ParenArcIterator(paren_arc_multimap_,
198                            paren_arc_multimap_.find(paren_state));
199  }
200
201 private:
202  // DFS that gathers paren and state set information.
203  // Bool returns false when cycle detected.
204  bool DFSearch(StateId s);
205
206  // Unions state sets together gathered by the DFS.
207  void ComputeStateSet(StateId s);
208
209  // Gather state set(s) from state 'nexts'.
210  void UpdateStateSet(StateId nexts, set<Label> *paren_set,
211                      vector< set<StateId> > *state_sets) const;
212
213  const Fst<A> &fst_;
214  const vector<pair<Label, Label> > &parens_;         // Paren ID -> Labels
215  bool close_;                                        // Close/open paren info?
216  unordered_map<Label, Label> paren_id_map_;               // Paren labels -> ID
217  ParenMultiMap paren_multimap_;                      // Paren reachability
218  ParenArcMultiMap paren_arc_multimap_;               // Paren Arcs
219  vector<char> state_color_;                          // DFS state
220  mutable Collection<ssize_t, StateId> state_sets_;   // Reachable states -> ID
221  StateSetMap set_map_;                               // ID -> Reachable states
222  bool error_;
223  DISALLOW_COPY_AND_ASSIGN(PdtParenReachable);
224};
225
226// DFS that gathers paren and state set information.
227template <class A>
228bool PdtParenReachable<A>::DFSearch(StateId s) {
229  if (s >= state_color_.size())
230    state_color_.resize(s + 1, kDfsWhite);
231
232  if (state_color_[s] == kDfsBlack)
233    return true;
234
235  if (state_color_[s] == kDfsGrey)
236    return false;
237
238  state_color_[s] = kDfsGrey;
239
240  for (ArcIterator<Fst<A> > aiter(fst_, s);
241       !aiter.Done();
242       aiter.Next()) {
243    const A &arc = aiter.Value();
244
245    typename unordered_map<Label, Label>::const_iterator pit
246        = paren_id_map_.find(arc.ilabel);
247    if (pit != paren_id_map_.end()) {               // paren?
248      Label paren_id = pit->second;
249      if (arc.ilabel == parens_[paren_id].first) {  // open paren
250        if (!DFSearch(arc.nextstate))
251          return false;
252        for (SetIterator set_iter = FindStates(paren_id, arc.nextstate);
253             !set_iter.Done(); set_iter.Next()) {
254          for (ParenArcIterator paren_arc_iter =
255                   FindParenArcs(paren_id, set_iter.Element());
256               !paren_arc_iter.Done();
257               paren_arc_iter.Next()) {
258            const A &cparc = paren_arc_iter.Value();
259            if (!DFSearch(cparc.nextstate))
260              return false;
261          }
262        }
263      }
264    } else {                                       // non-paren
265      if(!DFSearch(arc.nextstate))
266        return false;
267    }
268  }
269  ComputeStateSet(s);
270  state_color_[s] = kDfsBlack;
271  return true;
272}
273
274// Unions state sets together gathered by the DFS.
275template <class A>
276void PdtParenReachable<A>::ComputeStateSet(StateId s) {
277  set<Label> paren_set;
278  vector< set<StateId> > state_sets(parens_.size());
279  for (ArcIterator< Fst<A> > aiter(fst_, s);
280       !aiter.Done();
281       aiter.Next()) {
282    const A &arc = aiter.Value();
283
284    typename unordered_map<Label, Label>::const_iterator pit
285        = paren_id_map_.find(arc.ilabel);
286    if (pit != paren_id_map_.end()) {               // paren?
287      Label paren_id = pit->second;
288      if (arc.ilabel == parens_[paren_id].first) {  // open paren
289        for (SetIterator set_iter =
290                 FindStates(paren_id, arc.nextstate);
291             !set_iter.Done(); set_iter.Next()) {
292          for (ParenArcIterator paren_arc_iter =
293                   FindParenArcs(paren_id, set_iter.Element());
294               !paren_arc_iter.Done();
295               paren_arc_iter.Next()) {
296            const A &cparc = paren_arc_iter.Value();
297            UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
298          }
299        }
300      } else {                                      // close paren
301        paren_set.insert(paren_id);
302        state_sets[paren_id].insert(s);
303        ParenState<A> paren_state(paren_id, s);
304        paren_arc_multimap_.insert(make_pair(paren_state, arc));
305      }
306    } else {                                        // non-paren
307      UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
308    }
309  }
310
311  vector<StateId> state_set;
312  for (typename set<Label>::iterator paren_iter = paren_set.begin();
313       paren_iter != paren_set.end(); ++paren_iter) {
314    state_set.clear();
315    Label paren_id = *paren_iter;
316    paren_multimap_.insert(make_pair(s, paren_id));
317    for (typename set<StateId>::iterator state_iter
318             = state_sets[paren_id].begin();
319         state_iter != state_sets[paren_id].end();
320         ++state_iter) {
321      state_set.push_back(*state_iter);
322    }
323    ParenState<A> paren_state(paren_id, s);
324    set_map_[paren_state] = state_sets_.FindId(state_set);
325  }
326}
327
328// Gather state set(s) from state 'nexts'.
329template <class A>
330void PdtParenReachable<A>::UpdateStateSet(
331    StateId nexts, set<Label> *paren_set,
332    vector< set<StateId> > *state_sets) const {
333  for(ParenIterator paren_iter = FindParens(nexts);
334      !paren_iter.Done(); paren_iter.Next()) {
335    Label paren_id = paren_iter.Value();
336    paren_set->insert(paren_id);
337    for (SetIterator set_iter = FindStates(paren_id, nexts);
338         !set_iter.Done(); set_iter.Next()) {
339      (*state_sets)[paren_id].insert(set_iter.Element());
340    }
341  }
342}
343
344
345// Store balancing parenthesis data for a PDT. Allows on-the-fly
346// construction (e.g. in PdtShortestPath) unlike PdtParenReachable above.
347template <class A>
348class PdtBalanceData {
349 public:
350  typedef typename A::StateId StateId;
351  typedef typename A::Label Label;
352
353  // Hash set for open parens
354  typedef unordered_set<ParenState<A>, typename ParenState<A>::Hash> OpenParenSet;
355
356  // Maps from open paren destination state to parenthesis ID.
357  typedef unordered_multimap<StateId, Label> OpenParenMap;
358
359  // Maps from open paren state to source states of matching close parens
360  typedef unordered_multimap<ParenState<A>, StateId,
361                        typename ParenState<A>::Hash> CloseParenMap;
362
363  // Maps from open paren state to close source set ID
364  typedef unordered_map<ParenState<A>, ssize_t,
365                   typename ParenState<A>::Hash> CloseSourceMap;
366
367  typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
368
369  PdtBalanceData() {}
370
371  void Clear() {
372    open_paren_map_.clear();
373    close_paren_map_.clear();
374  }
375
376  // Adds an open parenthesis with destination state 'open_dest'.
377  void OpenInsert(Label paren_id, StateId open_dest) {
378    ParenState<A> key(paren_id, open_dest);
379    if (!open_paren_set_.count(key)) {
380      open_paren_set_.insert(key);
381      open_paren_map_.insert(make_pair(open_dest, paren_id));
382    }
383  }
384
385  // Adds a matching closing parenthesis with source state
386  // 'close_source' that balances an open_parenthesis with destination
387  // state 'open_dest' if OpenInsert() previously called
388  // (o.w. CloseInsert() does nothing).
389  void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) {
390    ParenState<A> key(paren_id, open_dest);
391    if (open_paren_set_.count(key))
392      close_paren_map_.insert(make_pair(key, close_source));
393  }
394
395  // Find close paren source states matching an open parenthesis.
396  // Methods that follow, iterate through those matching states.
397  // Should be called only after FinishInsert(open_dest).
398  SetIterator Find(Label paren_id, StateId open_dest) {
399    ParenState<A> close_key(paren_id, open_dest);
400    typename CloseSourceMap::const_iterator id_it =
401        close_source_map_.find(close_key);
402    if (id_it == close_source_map_.end()) {
403      return close_source_sets_.FindSet(-1);
404    } else {
405      return close_source_sets_.FindSet(id_it->second);
406    }
407  }
408
409  // Call when all open and close parenthesis insertions wrt open
410  // parentheses entering 'open_dest' are finished. Must be called
411  // before Find(open_dest). Stores close paren source state sets
412  // efficiently.
413  void FinishInsert(StateId open_dest) {
414    vector<StateId> close_sources;
415    for (typename OpenParenMap::iterator oit = open_paren_map_.find(open_dest);
416         oit != open_paren_map_.end() && oit->first == open_dest;) {
417      Label paren_id = oit->second;
418      close_sources.clear();
419      ParenState<A> okey(paren_id, open_dest);
420      open_paren_set_.erase(open_paren_set_.find(okey));
421      for (typename CloseParenMap::iterator cit = close_paren_map_.find(okey);
422           cit != close_paren_map_.end() && cit->first == okey;) {
423        close_sources.push_back(cit->second);
424        close_paren_map_.erase(cit++);
425      }
426      sort(close_sources.begin(), close_sources.end());
427      typename vector<StateId>::iterator unique_end =
428          unique(close_sources.begin(), close_sources.end());
429      close_sources.resize(unique_end - close_sources.begin());
430
431      if (!close_sources.empty())
432        close_source_map_[okey] = close_source_sets_.FindId(close_sources);
433      open_paren_map_.erase(oit++);
434    }
435  }
436
437  // Return a new balance data object representing the reversed balance
438  // information.
439  PdtBalanceData<A> *Reverse(StateId num_states,
440                               StateId num_split,
441                               StateId state_id_shift) const;
442
443 private:
444  OpenParenSet open_paren_set_;                      // open par. at dest?
445
446  OpenParenMap open_paren_map_;                      // open parens per state
447  ParenState<A> open_dest_;                          // cur open dest. state
448  typename OpenParenMap::const_iterator open_iter_;  // cur open parens/state
449
450  CloseParenMap close_paren_map_;                    // close states/open
451                                                     //  paren and state
452
453  CloseSourceMap close_source_map_;                  // paren, state to set ID
454  mutable Collection<ssize_t, StateId> close_source_sets_;
455};
456
457// Return a new balance data object representing the reversed balance
458// information.
459template <class A>
460PdtBalanceData<A> *PdtBalanceData<A>::Reverse(
461    StateId num_states,
462    StateId num_split,
463    StateId state_id_shift) const {
464  PdtBalanceData<A> *bd = new PdtBalanceData<A>;
465  unordered_set<StateId> close_sources;
466  StateId split_size = num_states / num_split;
467
468  for (StateId i = 0; i < num_states; i+= split_size) {
469    close_sources.clear();
470
471    for (typename CloseSourceMap::const_iterator
472             sit = close_source_map_.begin();
473         sit != close_source_map_.end();
474         ++sit) {
475      ParenState<A> okey = sit->first;
476      StateId open_dest = okey.state_id;
477      Label paren_id = okey.paren_id;
478      for (SetIterator set_iter = close_source_sets_.FindSet(sit->second);
479           !set_iter.Done(); set_iter.Next()) {
480        StateId close_source = set_iter.Element();
481        if ((close_source < i) || (close_source >= i + split_size))
482          continue;
483        close_sources.insert(close_source + state_id_shift);
484        bd->OpenInsert(paren_id, close_source + state_id_shift);
485        bd->CloseInsert(paren_id, close_source + state_id_shift,
486                        open_dest + state_id_shift);
487      }
488    }
489
490    for (typename unordered_set<StateId>::const_iterator it
491             = close_sources.begin();
492         it != close_sources.end();
493         ++it) {
494      bd->FinishInsert(*it);
495    }
496
497  }
498  return bd;
499}
500
501
502}  // namespace fst
503
504#endif  // FST_EXTENSIONS_PDT_PAREN_H_
505