1// pdt.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Common classes for PDT expansion/traversal.
20
21#ifndef FST_EXTENSIONS_PDT_PDT_H__
22#define FST_EXTENSIONS_PDT_PDT_H__
23
24#include <tr1/unordered_map>
25using std::tr1::unordered_map;
26using std::tr1::unordered_multimap;
27#include <map>
28#include <set>
29
30#include <fst/compat.h>
31#include <fst/state-table.h>
32#include <fst/fst.h>
33
34namespace fst {
35
36// Provides bijection between parenthesis stacks and signed integral
37// stack IDs. Each stack ID is unique to each distinct stack.  The
38// open-close parenthesis label pairs are passed in 'parens'.
39template <typename K, typename L>
40class PdtStack {
41 public:
42  typedef K StackId;
43  typedef L Label;
44
45  // The stacks are stored in a tree. The nodes are stored in vector
46  // 'nodes_'. Each node represents the top of some stack and is
47  // ID'ed by its position in the vector. Its parent node represents
48  // the stack with the top 'popped' and its children are stored in
49  // 'child_map_' accessed by stack_id and label. The paren_id is
50  // the position in 'parens' of the parenthesis for that node.
51  struct StackNode {
52    StackId parent_id;
53    size_t paren_id;
54
55    StackNode(StackId p, size_t i) : parent_id(p), paren_id(i) {}
56  };
57
58  PdtStack(const vector<pair<Label, Label> > &parens)
59      : parens_(parens), min_paren_(kNoLabel), max_paren_(kNoLabel) {
60    for (size_t i = 0; i < parens.size(); ++i) {
61      const pair<Label, Label>  &p = parens[i];
62      paren_map_[p.first] = i;
63      paren_map_[p.second] = i;
64
65      if (min_paren_ == kNoLabel || p.first < min_paren_)
66        min_paren_ = p.first;
67      if (p.second < min_paren_)
68        min_paren_ = p.second;
69
70      if (max_paren_ == kNoLabel || p.first > max_paren_)
71        max_paren_ = p.first;
72      if (p.second > max_paren_)
73        max_paren_ = p.second;
74    }
75    nodes_.push_back(StackNode(-1, -1));  // Tree root.
76  }
77
78  // Returns stack ID given the current stack ID (0 if empty) and
79  // label read. 'Pushes' onto a stack if the label is an open
80  // parenthesis, returning the new stack ID. 'Pops' the stack if the
81  // label is a close parenthesis that matches the top of the stack,
82  // returning the parent stack ID. Returns -1 if label is an
83  // unmatched close parenthesis. Otherwise, returns the current stack
84  // ID.
85  StackId Find(StackId stack_id, Label label) {
86    if (min_paren_ == kNoLabel || label < min_paren_ || label > max_paren_)
87      return stack_id;                       // Non-paren.
88
89    typename unordered_map<Label, size_t>::const_iterator pit
90        = paren_map_.find(label);
91    if (pit == paren_map_.end())             // Non-paren.
92      return stack_id;
93    ssize_t paren_id = pit->second;
94
95    if (label == parens_[paren_id].first) {  // Open paren.
96      StackId &child_id = child_map_[make_pair(stack_id, label)];
97      if (child_id == 0) {                   // Child not found, push label.
98        child_id = nodes_.size();
99        nodes_.push_back(StackNode(stack_id, paren_id));
100      }
101      return child_id;
102    }
103
104    const StackNode &node = nodes_[stack_id];
105    if (paren_id == node.paren_id)           // Matching close paren.
106      return node.parent_id;
107
108    return -1;                               // Non-matching close paren.
109  }
110
111  // Returns the stack ID obtained by "popping" the label at the top
112  // of the current stack ID.
113  StackId Pop(StackId stack_id) const {
114    return nodes_[stack_id].parent_id;
115  }
116
117  // Returns the paren ID at the top of the stack for 'stack_id'
118  ssize_t Top(StackId stack_id) const {
119    return nodes_[stack_id].paren_id;
120  }
121
122  ssize_t ParenId(Label label) const {
123    typename unordered_map<Label, size_t>::const_iterator pit
124        = paren_map_.find(label);
125    if (pit == paren_map_.end())  // Non-paren.
126      return -1;
127    return pit->second;
128  }
129
130 private:
131  struct ChildHash {
132    size_t operator()(const pair<StackId, Label> &p) const {
133      return p.first + p.second * kPrime;
134    }
135  };
136
137  static const size_t kPrime;
138
139  vector<pair<Label, Label> > parens_;
140  vector<StackNode> nodes_;
141  unordered_map<Label, size_t> paren_map_;
142  unordered_map<pair<StackId, Label>,
143           StackId, ChildHash> child_map_;   // Child of stack node wrt label
144  Label min_paren_;                          // For faster paren. check
145  Label max_paren_;                          // For faster paren. check
146};
147
148template <typename T, typename L>
149const size_t PdtStack<T, L>::kPrime = 7853;
150
151
152// State tuple for PDT expansion
153template <typename S, typename K>
154struct PdtStateTuple {
155  typedef S StateId;
156  typedef K StackId;
157
158  StateId state_id;
159  StackId stack_id;
160
161  PdtStateTuple()
162      : state_id(kNoStateId), stack_id(-1) {}
163
164  PdtStateTuple(StateId fs, StackId ss)
165      : state_id(fs), stack_id(ss) {}
166};
167
168// Equality of PDT state tuples.
169template <typename S, typename K>
170inline bool operator==(const PdtStateTuple<S, K>& x,
171                       const PdtStateTuple<S, K>& y) {
172  if (&x == &y)
173    return true;
174  return x.state_id == y.state_id && x.stack_id == y.stack_id;
175}
176
177
178// Hash function object for PDT state tuples
179template <class T>
180class PdtStateHash {
181 public:
182  size_t operator()(const T &tuple) const {
183    return tuple.state_id + tuple.stack_id * kPrime;
184  }
185
186 private:
187  static const size_t kPrime;
188};
189
190template <typename T>
191const size_t PdtStateHash<T>::kPrime = 7853;
192
193
194// Tuple to PDT state bijection.
195template <class S, class K>
196class PdtStateTable
197    : public CompactHashStateTable<PdtStateTuple<S, K>,
198                                   PdtStateHash<PdtStateTuple<S, K> > > {
199 public:
200  typedef S StateId;
201  typedef K StackId;
202
203  PdtStateTable() {}
204
205  PdtStateTable(const PdtStateTable<S, K> &table) {}
206
207 private:
208  void operator=(const PdtStateTable<S, K> &table);  // disallow
209};
210
211}  // namespace fst
212
213#endif  // FST_EXTENSIONS_PDT_PDT_H__
214