1// replace.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Recursively replace Fst arcs with other Fst(s) returning a PDT.
20
21#ifndef FST_EXTENSIONS_PDT_REPLACE_H__
22#define FST_EXTENSIONS_PDT_REPLACE_H__
23
24#include <tr1/unordered_map>
25using std::tr1::unordered_map;
26using std::tr1::unordered_multimap;
27
28#include <fst/replace.h>
29
30namespace fst {
31
32// Hash to paren IDs
33template <typename S>
34struct ReplaceParenHash {
35  size_t operator()(const pair<size_t, S> &p) const {
36    return p.first + p.second * kPrime;
37  }
38 private:
39  static const size_t kPrime = 7853;
40};
41
42template <typename S> const size_t ReplaceParenHash<S>::kPrime;
43
44// Builds a pushdown transducer (PDT) from an RTN specification
45// identical to that in fst/lib/replace.h. The result is a PDT
46// encoded as the FST 'ofst' where some transitions are labeled with
47// open or close parentheses. To be interpreted as a PDT, the parens
48// must balance on a path (see PdtExpand()). The open/close
49// parenthesis label pairs are returned in 'parens'.
50template <class Arc>
51void Replace(const vector<pair<typename Arc::Label,
52             const Fst<Arc>* > >& ifst_array,
53             MutableFst<Arc> *ofst,
54             vector<pair<typename Arc::Label,
55             typename Arc::Label> > *parens,
56             typename Arc::Label root) {
57  typedef typename Arc::Label Label;
58  typedef typename Arc::StateId StateId;
59  typedef typename Arc::Weight Weight;
60
61  ofst->DeleteStates();
62  parens->clear();
63
64  unordered_map<Label, size_t> label2id;
65  for (size_t i = 0; i < ifst_array.size(); ++i)
66    label2id[ifst_array[i].first] = i;
67
68  Label max_label = kNoLabel;
69  size_t max_non_term_count = 0;
70
71  // Queue of non-terminals to replace
72  deque<size_t> non_term_queue;
73  // Map of non-terminals to replace to count
74  unordered_map<Label, size_t> non_term_map;
75  non_term_queue.push_back(root);
76  non_term_map[root] = 1;;
77
78  // PDT state corr. to ith replace FST start state.
79  vector<StateId> fst_start(ifst_array.size(), kNoLabel);
80  // PDT state, weight pairs corr. to ith replace FST final state & weights.
81  vector< vector<pair<StateId, Weight> > > fst_final(ifst_array.size());
82
83  // Builds single Fst combining all referenced input Fsts. Leaves in the
84  // non-termnals for now.  Tabulate the PDT states that correspond to
85  // the start and final states of the input Fsts.
86  for (StateId soff = 0; !non_term_queue.empty(); soff = ofst->NumStates()) {
87    Label label = non_term_queue.front();
88    non_term_queue.pop_front();
89    size_t fst_id = label2id[label];
90
91    const Fst<Arc> *ifst = ifst_array[fst_id].second;
92    for (StateIterator< Fst<Arc> > siter(*ifst);
93         !siter.Done(); siter.Next()) {
94      StateId is = siter.Value();
95      StateId os = ofst->AddState();
96      if (is == ifst->Start()) {
97        fst_start[fst_id] = os;
98        if (label == root)
99          ofst->SetStart(os);
100      }
101      if (ifst->Final(is) != Weight::Zero()) {
102        if (label == root)
103          ofst->SetFinal(os, ifst->Final(is));
104        fst_final[fst_id].push_back(make_pair(os, ifst->Final(is)));
105      }
106      for (ArcIterator< Fst<Arc> > aiter(*ifst, is);
107           !aiter.Done(); aiter.Next()) {
108        Arc arc = aiter.Value();
109        if (max_label == kNoLabel || arc.olabel > max_label)
110          max_label = arc.olabel;
111        typename unordered_map<Label, size_t>::const_iterator it =
112            label2id.find(arc.olabel);
113        if (it != label2id.end()) {
114          size_t nfst_id = it->second;
115          if (ifst_array[nfst_id].second->Start() == -1)
116            continue;
117          size_t count = non_term_map[arc.olabel]++;
118          if (count == 0)
119            non_term_queue.push_back(arc.olabel);
120          if (count > max_non_term_count)
121            max_non_term_count = count;
122        }
123        arc.nextstate += soff;
124        ofst->AddArc(os, arc);
125      }
126    }
127  }
128
129  // Changes each non-terminal transition to an open parenthesis
130  // transition redirected to the PDT state that corresponds to the
131  // start state of the input FST for the non-terminal. Adds close parenthesis
132  // transitions from the PDT states corr. to the final states of the
133  // input FST for the non-terminal to the former destination state of the
134  // non-terminal transition.
135
136  typedef MutableArcIterator< MutableFst<Arc> > MIter;
137  typedef unordered_map<pair<size_t, StateId >, size_t,
138                   ReplaceParenHash<StateId> > ParenMap;
139
140  // Parenthesis pair ID per fst, state pair.
141  ParenMap paren_map;
142  // # of parenthesis pairs per fst.
143  vector<size_t> nparens(ifst_array.size(), 0);
144  // Initial open parenthesis label
145  Label first_open_paren = max_label + 1;
146  Label first_close_paren = max_label + max_non_term_count + 1;
147
148  for (StateIterator< Fst<Arc> > siter(*ofst);
149       !siter.Done(); siter.Next()) {
150    StateId os = siter.Value();
151    MIter *aiter = new MIter(ofst, os);
152    for (size_t n = 0; !aiter->Done(); aiter->Next(), ++n) {
153      Arc arc = aiter->Value();
154      typename unordered_map<Label, size_t>::const_iterator lit =
155          label2id.find(arc.olabel);
156      if (lit != label2id.end()) {
157        size_t nfst_id = lit->second;
158
159        // Get parentheses. Ensures distinct parenthesis pair per
160        // non-terminal and destination state but otherwise reuses them.
161        Label open_paren = kNoLabel, close_paren = kNoLabel;
162        pair<size_t, StateId> paren_key(nfst_id, arc.nextstate);
163        typename ParenMap::const_iterator pit = paren_map.find(paren_key);
164        if (pit != paren_map.end()) {
165          size_t paren_id = pit->second;
166          open_paren = (*parens)[paren_id].first;
167          close_paren = (*parens)[paren_id].second;
168        } else {
169          size_t paren_id = nparens[nfst_id]++;
170          open_paren = first_open_paren + paren_id;
171          close_paren = first_close_paren + paren_id;
172          paren_map[paren_key] = paren_id;
173          if (paren_id >= parens->size())
174            parens->push_back(make_pair(open_paren, close_paren));
175        }
176
177        // Sets open parenthesis.
178        Arc sarc(open_paren, open_paren, arc.weight, fst_start[nfst_id]);
179        aiter->SetValue(sarc);
180
181        // Adds close parentheses.
182        for (size_t i = 0; i < fst_final[nfst_id].size(); ++i) {
183          pair<StateId, Weight> &p = fst_final[nfst_id][i];
184          Arc farc(close_paren, close_paren, p.second, arc.nextstate);
185
186          ofst->AddArc(p.first, farc);
187          if (os == p.first) {  // Invalidated iterator
188            delete aiter;
189            aiter = new MIter(ofst, os);
190            aiter->Seek(n);
191          }
192        }
193      }
194    }
195    delete aiter;
196  }
197}
198
199}  // namespace fst
200
201#endif  // FST_EXTENSIONS_PDT_REPLACE_H__
202