pdt.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
1// pdt.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Common classes for PDT expansion/traversal.
20
21#ifndef FST_EXTENSIONS_PDT_PDT_H__
22#define FST_EXTENSIONS_PDT_PDT_H__
23
24#include <unordered_map>
25using std::tr1::unordered_map;
26using std::tr1::unordered_multimap;
27#include <map>
28#include <set>
29
30#include <fst/state-table.h>
31#include <fst/fst.h>
32
33namespace fst {
34
35// Provides bijection between parenthesis stacks and signed integral
36// stack IDs. Each stack ID is unique to each distinct stack.  The
37// open-close parenthesis label pairs are passed in 'parens'.
38template <typename K, typename L>
39class PdtStack {
40 public:
41  typedef K StackId;
42  typedef L Label;
43
44  // The stacks are stored in a tree. The nodes are stored in vector
45  // 'nodes_'. Each node represents the top of some stack and is
46  // ID'ed by its position in the vector. Its parent node represents
47  // the stack with the top 'popped' and its children are stored in
48  // 'child_map_' accessed by stack_id and label. The paren_id is
49  // the position in 'parens' of the parenthesis for that node.
50  struct StackNode {
51    StackId parent_id;
52    size_t paren_id;
53
54    StackNode(StackId p, size_t i) : parent_id(p), paren_id(i) {}
55  };
56
57  PdtStack(const vector<pair<Label, Label> > &parens)
58      : parens_(parens), min_paren_(kNoLabel), max_paren_(kNoLabel) {
59    for (size_t i = 0; i < parens.size(); ++i) {
60      const pair<Label, Label>  &p = parens[i];
61      paren_map_[p.first] = i;
62      paren_map_[p.second] = i;
63
64      if (min_paren_ == kNoLabel || p.first < min_paren_)
65        min_paren_ = p.first;
66      if (p.second < min_paren_)
67        min_paren_ = p.second;
68
69      if (max_paren_ == kNoLabel || p.first > max_paren_)
70        max_paren_ = p.first;
71      if (p.second > max_paren_)
72        max_paren_ = p.second;
73    }
74    nodes_.push_back(StackNode(-1, -1));  // Tree root.
75  }
76
77  // Returns stack ID given the current stack ID (0 if empty) and
78  // label read. 'Pushes' onto a stack if the label is an open
79  // parenthesis, returning the new stack ID. 'Pops' the stack if the
80  // label is a close parenthesis that matches the top of the stack,
81  // returning the parent stack ID. Returns -1 if label is an
82  // unmatched close parenthesis. Otherwise, returns the current stack
83  // ID.
84  StackId Find(StackId stack_id, Label label) {
85    if (min_paren_ == kNoLabel || label < min_paren_ || label > max_paren_)
86      return stack_id;                       // Non-paren.
87
88    typename unordered_map<Label, size_t>::const_iterator pit
89        = paren_map_.find(label);
90    if (pit == paren_map_.end())             // Non-paren.
91      return stack_id;
92    ssize_t paren_id = pit->second;
93
94    if (label == parens_[paren_id].first) {  // Open paren.
95      StackId &child_id = child_map_[make_pair(stack_id, label)];
96      if (child_id == 0) {                   // Child not found, push label.
97        child_id = nodes_.size();
98        nodes_.push_back(StackNode(stack_id, paren_id));
99      }
100      return child_id;
101    }
102
103    const StackNode &node = nodes_[stack_id];
104    if (paren_id == node.paren_id)           // Matching close paren.
105      return node.parent_id;
106
107    return -1;                               // Non-matching close paren.
108  }
109
110  // Returns the stack ID obtained by "popping" the label at the top
111  // of the current stack ID.
112  StackId Pop(StackId stack_id) const {
113    return nodes_[stack_id].parent_id;
114  }
115
116  // Returns the paren ID at the top of the stack for 'stack_id'
117  ssize_t Top(StackId stack_id) const {
118    return nodes_[stack_id].paren_id;
119  }
120
121  ssize_t ParenId(Label label) const {
122    typename unordered_map<Label, size_t>::const_iterator pit
123        = paren_map_.find(label);
124    if (pit == paren_map_.end())  // Non-paren.
125      return -1;
126    return pit->second;
127  }
128
129 private:
130  struct ChildHash {
131    size_t operator()(const pair<StackId, Label> &p) const {
132      return p.first + p.second * kPrime;
133    }
134  };
135
136  static const size_t kPrime;
137
138  vector<pair<Label, Label> > parens_;
139  vector<StackNode> nodes_;
140  unordered_map<Label, size_t> paren_map_;
141  unordered_map<pair<StackId, Label>,
142           StackId, ChildHash> child_map_;   // Child of stack node wrt label
143  Label min_paren_;                          // For faster paren. check
144  Label max_paren_;                          // For faster paren. check
145};
146
147template <typename T, typename L>
148const size_t PdtStack<T, L>::kPrime = 7853;
149
150
151// State tuple for PDT expansion
152template <typename S, typename K>
153struct PdtStateTuple {
154  typedef S StateId;
155  typedef K StackId;
156
157  StateId state_id;
158  StackId stack_id;
159
160  PdtStateTuple()
161      : state_id(kNoStateId), stack_id(-1) {}
162
163  PdtStateTuple(StateId fs, StackId ss)
164      : state_id(fs), stack_id(ss) {}
165};
166
167// Equality of PDT state tuples.
168template <typename S, typename K>
169inline bool operator==(const PdtStateTuple<S, K>& x,
170                       const PdtStateTuple<S, K>& y) {
171  if (&x == &y)
172    return true;
173  return x.state_id == y.state_id && x.stack_id == y.stack_id;
174}
175
176
177// Hash function object for PDT state tuples
178template <class T>
179class PdtStateHash {
180 public:
181  size_t operator()(const T &tuple) const {
182    return tuple.state_id + tuple.stack_id * kPrime;
183  }
184
185 private:
186  static const size_t kPrime;
187};
188
189template <typename T>
190const size_t PdtStateHash<T>::kPrime = 7853;
191
192
193// Tuple to PDT state bijection.
194template <class S, class K>
195class PdtStateTable
196    : public CompactHashStateTable<PdtStateTuple<S, K>,
197                                   PdtStateHash<PdtStateTuple<S, K> > > {
198 public:
199  typedef S StateId;
200  typedef K StackId;
201
202  PdtStateTable() {}
203
204  PdtStateTable(const PdtStateTable<S, K> &table) {}
205
206 private:
207  void operator=(const PdtStateTable<S, K> &table);  // disallow
208};
209
210}  // namespace fst
211
212#endif  // FST_EXTENSIONS_PDT_PDT_H__
213