pdtscript.h revision 5b6dc79427b8f7eeb6a7ff68034ab8548ce670ea
1
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13//
14// Copyright 2005-2010 Google, Inc.
15// Author: jpr@google.com (Jake Ratkiewicz)
16// Convenience file for including all PDT operations at once, and/or
17// registering them for new arc types.
18
19#ifndef FST_EXTENSIONS_PDT_PDTSCRIPT_H_
20#define FST_EXTENSIONS_PDT_PDTSCRIPT_H_
21
22#include <utility>
23using std::pair; using std::make_pair;
24#include <vector>
25using std::vector;
26
27#include <fst/compose.h>  // for ComposeOptions
28#include <fst/util.h>
29
30#include <fst/script/fst-class.h>
31#include <fst/script/arg-packs.h>
32#include <fst/script/shortest-path.h>
33
34#include <fst/extensions/pdt/compose.h>
35#include <fst/extensions/pdt/expand.h>
36#include <fst/extensions/pdt/info.h>
37#include <fst/extensions/pdt/replace.h>
38#include <fst/extensions/pdt/reverse.h>
39#include <fst/extensions/pdt/shortest-path.h>
40
41
42namespace fst {
43namespace script {
44
45// PDT COMPOSE
46
47typedef args::Package<const FstClass &,
48                      const FstClass &,
49                      const vector<pair<int64, int64> >&,
50                      MutableFstClass *,
51                      const PdtComposeOptions &,
52                      bool> PdtComposeArgs;
53
54template<class Arc>
55void PdtCompose(PdtComposeArgs *args) {
56  const Fst<Arc> &ifst1 = *(args->arg1.GetFst<Arc>());
57  const Fst<Arc> &ifst2 = *(args->arg2.GetFst<Arc>());
58  MutableFst<Arc> *ofst = args->arg4->GetMutableFst<Arc>();
59
60  vector<pair<typename Arc::Label, typename Arc::Label> > parens(
61      args->arg3.size());
62
63  for (size_t i = 0; i < parens.size(); ++i) {
64    parens[i].first = args->arg3[i].first;
65    parens[i].second = args->arg3[i].second;
66  }
67
68  if (args->arg6) {
69    Compose(ifst1, parens, ifst2, ofst, args->arg5);
70  } else {
71    Compose(ifst1, ifst2, parens, ofst, args->arg5);
72  }
73}
74
75void PdtCompose(const FstClass & ifst1,
76                const FstClass & ifst2,
77                const vector<pair<int64, int64> > &parens,
78                MutableFstClass *ofst,
79                const PdtComposeOptions &copts,
80                bool left_pdt);
81
82// PDT EXPAND
83
84struct PdtExpandOptions {
85  bool connect;
86  bool keep_parentheses;
87  WeightClass weight_threshold;
88
89  PdtExpandOptions(bool c = true, bool k = false,
90                   WeightClass w = WeightClass::Zero())
91      : connect(c), keep_parentheses(k), weight_threshold(w) {}
92};
93
94typedef args::Package<const FstClass &,
95                      const vector<pair<int64, int64> >&,
96                      MutableFstClass *, PdtExpandOptions> PdtExpandArgs;
97
98template<class Arc>
99void PdtExpand(PdtExpandArgs *args) {
100  const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
101  MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
102
103  vector<pair<typename Arc::Label, typename Arc::Label> > parens(
104      args->arg2.size());
105  for (size_t i = 0; i < parens.size(); ++i) {
106    parens[i].first = args->arg2[i].first;
107    parens[i].second = args->arg2[i].second;
108  }
109  Expand(fst, parens, ofst,
110         ExpandOptions<Arc>(
111             args->arg4.connect, args->arg4.keep_parentheses,
112             *(args->arg4.weight_threshold.GetWeight<typename Arc::Weight>())));
113}
114
115void PdtExpand(const FstClass &ifst,
116               const vector<pair<int64, int64> > &parens,
117               MutableFstClass *ofst, const PdtExpandOptions &opts);
118
119void PdtExpand(const FstClass &ifst,
120               const vector<pair<int64, int64> > &parens,
121               MutableFstClass *ofst, bool connect);
122
123// PDT REPLACE
124
125typedef args::Package<const vector<pair<int64, const FstClass*> > &,
126                      MutableFstClass *,
127                      vector<pair<int64, int64> > *,
128                      const int64 &> PdtReplaceArgs;
129template<class Arc>
130void PdtReplace(PdtReplaceArgs *args) {
131  vector<pair<typename Arc::Label, const Fst<Arc> *> > tuples(
132      args->arg1.size());
133  for (size_t i = 0; i < tuples.size(); ++i) {
134    tuples[i].first = args->arg1[i].first;
135    tuples[i].second = (args->arg1[i].second)->GetFst<Arc>();
136  }
137  MutableFst<Arc> *ofst = args->arg2->GetMutableFst<Arc>();
138  vector<pair<typename Arc::Label, typename Arc::Label> > parens(
139      args->arg3->size());
140
141  for (size_t i = 0; i < parens.size(); ++i) {
142    parens[i].first = args->arg3->at(i).first;
143    parens[i].second = args->arg3->at(i).second;
144  }
145  Replace(tuples, ofst, &parens, args->arg4);
146
147  // now copy parens back
148  args->arg3->resize(parens.size());
149  for (size_t i = 0; i < parens.size(); ++i) {
150    (*args->arg3)[i].first = parens[i].first;
151    (*args->arg3)[i].second = parens[i].second;
152  }
153}
154
155void PdtReplace(const vector<pair<int64, const FstClass*> > &fst_tuples,
156                MutableFstClass *ofst,
157                vector<pair<int64, int64> > *parens,
158                const int64 &root);
159
160// PDT REVERSE
161
162typedef args::Package<const FstClass &,
163                      const vector<pair<int64, int64> >&,
164                      MutableFstClass *> PdtReverseArgs;
165
166template<class Arc>
167void PdtReverse(PdtReverseArgs *args) {
168  const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
169  MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
170
171  vector<pair<typename Arc::Label, typename Arc::Label> > parens(
172      args->arg2.size());
173  for (size_t i = 0; i < parens.size(); ++i) {
174    parens[i].first = args->arg2[i].first;
175    parens[i].second = args->arg2[i].second;
176  }
177  Reverse(fst, parens, ofst);
178}
179
180void PdtReverse(const FstClass &ifst,
181                const vector<pair<int64, int64> > &parens,
182                MutableFstClass *ofst);
183
184
185// PDT SHORTESTPATH
186
187struct PdtShortestPathOptions {
188  QueueType queue_type;
189  bool keep_parentheses;
190  bool path_gc;
191
192  PdtShortestPathOptions(QueueType qt = FIFO_QUEUE,
193                         bool kp = false, bool gc = true)
194      : queue_type(qt), keep_parentheses(kp), path_gc(gc) {}
195};
196
197typedef args::Package<const FstClass &,
198                      const vector<pair<int64, int64> >&,
199                      MutableFstClass *,
200                      const PdtShortestPathOptions &> PdtShortestPathArgs;
201
202template<class Arc>
203void PdtShortestPath(PdtShortestPathArgs *args) {
204  typedef typename Arc::StateId StateId;
205  typedef typename Arc::Label Label;
206  typedef typename Arc::Weight Weight;
207
208  const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
209  MutableFst<Arc> *ofst = args->arg3->GetMutableFst<Arc>();
210  const PdtShortestPathOptions &opts = args->arg4;
211
212
213  vector<pair<Label, Label> > parens(args->arg2.size());
214  for (size_t i = 0; i < parens.size(); ++i) {
215    parens[i].first = args->arg2[i].first;
216    parens[i].second = args->arg2[i].second;
217  }
218
219  switch (opts.queue_type) {
220    default:
221      FSTERROR() << "Unknown queue type: " << opts.queue_type;
222    case FIFO_QUEUE: {
223      typedef FifoQueue<StateId> Queue;
224      fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
225                                                         opts.path_gc);
226      ShortestPath(fst, parens, ofst, spopts);
227      return;
228    }
229    case LIFO_QUEUE: {
230      typedef LifoQueue<StateId> Queue;
231      fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
232                                                         opts.path_gc);
233      ShortestPath(fst, parens, ofst, spopts);
234      return;
235    }
236    case STATE_ORDER_QUEUE: {
237      typedef StateOrderQueue<StateId> Queue;
238      fst::PdtShortestPathOptions<Arc, Queue> spopts(opts.keep_parentheses,
239                                                         opts.path_gc);
240      ShortestPath(fst, parens, ofst, spopts);
241      return;
242    }
243  }
244}
245
246void PdtShortestPath(const FstClass &ifst,
247                     const vector<pair<int64, int64> > &parens,
248                     MutableFstClass *ofst,
249                     const PdtShortestPathOptions &opts =
250                     PdtShortestPathOptions());
251
252// PRINT INFO
253
254typedef args::Package<const FstClass &,
255                      const vector<pair<int64, int64> > &> PrintPdtInfoArgs;
256
257template<class Arc>
258void PrintPdtInfo(PrintPdtInfoArgs *args) {
259  const Fst<Arc> &fst = *(args->arg1.GetFst<Arc>());
260  vector<pair<typename Arc::Label, typename Arc::Label> > parens(
261      args->arg2.size());
262  for (size_t i = 0; i < parens.size(); ++i) {
263    parens[i].first = args->arg2[i].first;
264    parens[i].second = args->arg2[i].second;
265  }
266  PdtInfo<Arc> pdtinfo(fst, parens);
267  PrintPdtInfo(pdtinfo);
268}
269
270void PrintPdtInfo(const FstClass &ifst,
271                  const vector<pair<int64, int64> > &parens);
272
273}  // namespace script
274}  // namespace fst
275
276
277#define REGISTER_FST_PDT_OPERATIONS(ArcType)                                \
278  REGISTER_FST_OPERATION(PdtCompose, ArcType, PdtComposeArgs);              \
279  REGISTER_FST_OPERATION(PdtExpand, ArcType, PdtExpandArgs);                \
280  REGISTER_FST_OPERATION(PdtReplace, ArcType, PdtReplaceArgs);              \
281  REGISTER_FST_OPERATION(PdtReverse, ArcType, PdtReverseArgs);              \
282  REGISTER_FST_OPERATION(PdtShortestPath, ArcType, PdtShortestPathArgs);    \
283  REGISTER_FST_OPERATION(PrintPdtInfo, ArcType, PrintPdtInfoArgs)
284#endif  // FST_EXTENSIONS_PDT_PDTSCRIPT_H_
285