compose.h revision 5b6dc79427b8f7eeb6a7ff68034ab8548ce670ea
1baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner// compose.h
2baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner
3baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner// Licensed under the Apache License, Version 2.0 (the "License");
4baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner// you may not use this file except in compliance with the License.
5baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner// You may obtain a copy of the License at
6baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner//
7baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner//     http://www.apache.org/licenses/LICENSE-2.0
8baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner//
9baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner// Unless required by applicable law or agreed to in writing, software
10baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner// distributed under the License is distributed on an "AS IS" BASIS,
11baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner// See the License for the specific language governing permissions and
13baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner// limitations under the License.
14baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner//
15baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner// Copyright 2005-2010 Google, Inc.
16baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner// Author: riley@google.com (Michael Riley)
17baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner//
18baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner// \file
19baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner// Compose a PDT and an FST.
20baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner
21baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner#ifndef FST_EXTENSIONS_PDT_COMPOSE_H__
22baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner#define FST_EXTENSIONS_PDT_COMPOSE_H__
23baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner
24baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner#include <list>
25baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner
26baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner#include <fst/extensions/pdt/pdt.h>
27baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner#include <fst/compose.h>
28baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner
29baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turnernamespace fst {
30baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner
31baf99eed1bde1af072679f77930bc9210fba267bDavid 'Digit' Turner// Return paren arcs for Find(kNoLabel).
32const uint32 kParenList =  0x00000001;
33
34// Return a kNolabel loop for Find(paren).
35const uint32 kParenLoop =  0x00000002;
36
37// This class is a matcher that treats parens as multi-epsilon labels.
38// It is most efficient if the parens are in a range non-overlapping with
39// the non-paren labels.
40template <class F>
41class ParenMatcher {
42 public:
43  typedef SortedMatcher<F> M;
44  typedef typename M::FST FST;
45  typedef typename M::Arc Arc;
46  typedef typename Arc::StateId StateId;
47  typedef typename Arc::Label Label;
48  typedef typename Arc::Weight Weight;
49
50  ParenMatcher(const FST &fst, MatchType match_type,
51               uint32 flags = (kParenLoop | kParenList))
52      : matcher_(fst, match_type),
53        match_type_(match_type),
54        flags_(flags) {
55    if (match_type == MATCH_INPUT) {
56      loop_.ilabel = kNoLabel;
57      loop_.olabel = 0;
58    } else {
59      loop_.ilabel = 0;
60      loop_.olabel = kNoLabel;
61    }
62    loop_.weight = Weight::One();
63    loop_.nextstate = kNoStateId;
64  }
65
66  ParenMatcher(const ParenMatcher<F> &matcher, bool safe = false)
67      : matcher_(matcher.matcher_, safe),
68        match_type_(matcher.match_type_),
69        flags_(matcher.flags_),
70        open_parens_(matcher.open_parens_),
71        close_parens_(matcher.close_parens_),
72        loop_(matcher.loop_) {
73    loop_.nextstate = kNoStateId;
74  }
75
76  ParenMatcher<F> *Copy(bool safe = false) const {
77    return new ParenMatcher<F>(*this, safe);
78  }
79
80  MatchType Type(bool test) const { return matcher_.Type(test); }
81
82  void SetState(StateId s) {
83    matcher_.SetState(s);
84    loop_.nextstate = s;
85  }
86
87  bool Find(Label match_label);
88
89  bool Done() const {
90    return done_;
91  }
92
93  const Arc& Value() const {
94    return paren_loop_ ? loop_ : matcher_.Value();
95  }
96
97  void Next();
98
99  const FST &GetFst() const { return matcher_.GetFst(); }
100
101  uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
102
103  uint32 Flags() const { return matcher_.Flags(); }
104
105  void AddOpenParen(Label label) {
106    if (label == 0) {
107      FSTERROR() << "ParenMatcher: Bad open paren label: 0";
108    } else {
109      open_parens_.Insert(label);
110    }
111  }
112
113  void AddCloseParen(Label label) {
114    if (label == 0) {
115      FSTERROR() << "ParenMatcher: Bad close paren label: 0";
116    } else {
117      close_parens_.Insert(label);
118    }
119  }
120
121  void RemoveOpenParen(Label label) {
122    if (label == 0) {
123      FSTERROR() << "ParenMatcher: Bad open paren label: 0";
124    } else {
125      open_parens_.Erase(label);
126    }
127  }
128
129  void RemoveCloseParen(Label label) {
130    if (label == 0) {
131      FSTERROR() << "ParenMatcher: Bad close paren label: 0";
132    } else {
133      close_parens_.Erase(label);
134    }
135  }
136
137  void ClearOpenParens() {
138    open_parens_.Clear();
139  }
140
141  void ClearCloseParens() {
142    close_parens_.Clear();
143  }
144
145  bool IsOpenParen(Label label) const {
146    return open_parens_.Member(label);
147  }
148
149  bool IsCloseParen(Label label) const {
150    return close_parens_.Member(label);
151  }
152
153 private:
154  // Advances matcher to next open paren if it exists, returning true.
155  // O.w. returns false.
156  bool NextOpenParen();
157
158  // Advances matcher to next open paren if it exists, returning true.
159  // O.w. returns false.
160  bool NextCloseParen();
161
162  M matcher_;
163  MatchType match_type_;          // Type of match to perform
164  uint32 flags_;
165
166  // open paren label set
167  CompactSet<Label, kNoLabel> open_parens_;
168
169  // close paren label set
170  CompactSet<Label, kNoLabel> close_parens_;
171
172
173  bool open_paren_list_;         // Matching open paren list
174  bool close_paren_list_;        // Matching close paren list
175  bool paren_loop_;              // Current arc is the implicit paren loop
176  mutable Arc loop_;             // For non-consuming symbols
177  bool done_;                    // Matching done
178
179  void operator=(const ParenMatcher<F> &);  // Disallow
180};
181
182template <class M> inline
183bool ParenMatcher<M>::Find(Label match_label) {
184  open_paren_list_ = false;
185  close_paren_list_ = false;
186  paren_loop_ = false;
187  done_ = false;
188
189  // Returns all parenthesis arcs
190  if (match_label == kNoLabel && (flags_ & kParenList)) {
191    if (open_parens_.LowerBound() != kNoLabel) {
192      matcher_.LowerBound(open_parens_.LowerBound());
193      open_paren_list_ = NextOpenParen();
194      if (open_paren_list_) return true;
195    }
196    if (close_parens_.LowerBound() != kNoLabel) {
197      matcher_.LowerBound(close_parens_.LowerBound());
198      close_paren_list_ = NextCloseParen();
199      if (close_paren_list_) return true;
200    }
201  }
202
203  // Returns 'implicit' paren loop
204  if (match_label > 0 && (flags_ & kParenLoop) &&
205      (IsOpenParen(match_label) || IsCloseParen(match_label))) {
206    paren_loop_ = true;
207    return true;
208  }
209
210  // Returns all other labels
211  if (matcher_.Find(match_label))
212    return true;
213
214  done_ = true;
215  return false;
216}
217
218template <class F> inline
219void ParenMatcher<F>::Next() {
220  if (paren_loop_) {
221    paren_loop_ = false;
222    done_ = true;
223  } else if (open_paren_list_) {
224    matcher_.Next();
225    open_paren_list_ = NextOpenParen();
226    if (open_paren_list_) return;
227
228    if (close_parens_.LowerBound() != kNoLabel) {
229      matcher_.LowerBound(close_parens_.LowerBound());
230      close_paren_list_ = NextCloseParen();
231      if (close_paren_list_) return;
232    }
233    done_ = !matcher_.Find(kNoLabel);
234  } else if (close_paren_list_) {
235    matcher_.Next();
236    close_paren_list_ = NextCloseParen();
237    if (close_paren_list_) return;
238    done_ = !matcher_.Find(kNoLabel);
239  } else {
240    matcher_.Next();
241    done_ = matcher_.Done();
242  }
243}
244
245// Advances matcher to next open paren if it exists, returning true.
246// O.w. returns false.
247template <class F> inline
248bool ParenMatcher<F>::NextOpenParen() {
249  for (; !matcher_.Done(); matcher_.Next()) {
250    Label label = match_type_ == MATCH_INPUT ?
251        matcher_.Value().ilabel : matcher_.Value().olabel;
252    if (label > open_parens_.UpperBound())
253      return false;
254    if (IsOpenParen(label))
255      return true;
256  }
257  return false;
258}
259
260// Advances matcher to next close paren if it exists, returning true.
261// O.w. returns false.
262template <class F> inline
263bool ParenMatcher<F>::NextCloseParen() {
264  for (; !matcher_.Done(); matcher_.Next()) {
265    Label label = match_type_ == MATCH_INPUT ?
266        matcher_.Value().ilabel : matcher_.Value().olabel;
267    if (label > close_parens_.UpperBound())
268      return false;
269    if (IsCloseParen(label))
270      return true;
271  }
272  return false;
273}
274
275
276template <class F>
277class ParenFilter {
278 public:
279  typedef typename F::FST1 FST1;
280  typedef typename F::FST2 FST2;
281  typedef typename F::Arc Arc;
282  typedef typename Arc::StateId StateId;
283  typedef typename Arc::Label Label;
284  typedef typename Arc::Weight Weight;
285  typedef typename F::Matcher1 Matcher1;
286  typedef typename F::Matcher2 Matcher2;
287  typedef typename F::FilterState FilterState1;
288  typedef StateId StackId;
289  typedef PdtStack<StackId, Label> ParenStack;
290  typedef IntegerFilterState<StackId> FilterState2;
291  typedef PairFilterState<FilterState1, FilterState2> FilterState;
292  typedef ParenFilter<F> Filter;
293
294  ParenFilter(const FST1 &fst1, const FST2 &fst2,
295              Matcher1 *matcher1 = 0,  Matcher2 *matcher2 = 0,
296              const vector<pair<Label, Label> > *parens = 0,
297              bool expand = false, bool keep_parens = true)
298      : filter_(fst1, fst2, matcher1, matcher2),
299        parens_(parens ? *parens : vector<pair<Label, Label> >()),
300        expand_(expand),
301        keep_parens_(keep_parens),
302        f_(FilterState::NoState()),
303        stack_(parens_),
304        paren_id_(-1) {
305    if (parens) {
306      for (size_t i = 0; i < parens->size(); ++i) {
307        const pair<Label, Label>  &p = (*parens)[i];
308        parens_.push_back(p);
309        GetMatcher1()->AddOpenParen(p.first);
310        GetMatcher2()->AddOpenParen(p.first);
311        if (!expand_) {
312          GetMatcher1()->AddCloseParen(p.second);
313          GetMatcher2()->AddCloseParen(p.second);
314        }
315      }
316    }
317  }
318
319  ParenFilter(const Filter &filter, bool safe = false)
320      : filter_(filter.filter_, safe),
321        parens_(filter.parens_),
322        expand_(filter.expand_),
323        keep_parens_(filter.keep_parens_),
324        f_(FilterState::NoState()),
325        stack_(filter.parens_),
326        paren_id_(-1) { }
327
328  FilterState Start() const {
329    return FilterState(filter_.Start(), FilterState2(0));
330  }
331
332  void SetState(StateId s1, StateId s2, const FilterState &f) {
333    f_ = f;
334    filter_.SetState(s1, s2, f_.GetState1());
335    if (!expand_)
336      return;
337
338    ssize_t paren_id = stack_.Top(f.GetState2().GetState());
339    if (paren_id != paren_id_) {
340      if (paren_id_ != -1) {
341        GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second);
342        GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second);
343      }
344      paren_id_ = paren_id;
345      if (paren_id_ != -1) {
346        GetMatcher1()->AddCloseParen(parens_[paren_id_].second);
347        GetMatcher2()->AddCloseParen(parens_[paren_id_].second);
348      }
349    }
350  }
351
352  FilterState FilterArc(Arc *arc1, Arc *arc2) const {
353    FilterState1 f1 = filter_.FilterArc(arc1, arc2);
354    const FilterState2 &f2 = f_.GetState2();
355    if (f1 == FilterState1::NoState())
356      return FilterState::NoState();
357
358    if (arc1->olabel == kNoLabel && arc2->ilabel) {         // arc2 parentheses
359      if (keep_parens_) {
360        arc1->ilabel = arc2->ilabel;
361      } else if (arc2->ilabel) {
362        arc2->olabel = arc1->ilabel;
363      }
364      return FilterParen(arc2->ilabel, f1, f2);
365    } else if (arc2->ilabel == kNoLabel && arc1->olabel) {  // arc1 parentheses
366      if (keep_parens_) {
367        arc2->olabel = arc1->olabel;
368      } else {
369        arc1->ilabel = arc2->olabel;
370      }
371      return FilterParen(arc1->olabel, f1, f2);
372    } else {
373      return FilterState(f1, f2);
374    }
375  }
376
377  void FilterFinal(Weight *w1, Weight *w2) const {
378    if (f_.GetState2().GetState() != 0)
379      *w1 = Weight::Zero();
380    filter_.FilterFinal(w1, w2);
381  }
382
383  // Return resp matchers. Ownership stays with filter.
384  Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); }
385  Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); }
386
387  uint64 Properties(uint64 iprops) const {
388    uint64 oprops = filter_.Properties(iprops);
389    return oprops & kILabelInvariantProperties & kOLabelInvariantProperties;
390  }
391
392 private:
393  const FilterState FilterParen(Label label, const FilterState1 &f1,
394                                const FilterState2 &f2) const {
395    if (!expand_)
396      return FilterState(f1, f2);
397
398    StackId stack_id = stack_.Find(f2.GetState(), label);
399    if (stack_id < 0) {
400      return FilterState::NoState();
401    } else {
402      return FilterState(f1, FilterState2(stack_id));
403    }
404  }
405
406  F filter_;
407  vector<pair<Label, Label> > parens_;
408  bool expand_;                    // Expands to FST
409  bool keep_parens_;               // Retains parentheses in output
410  FilterState f_;                  // Current filter state
411  mutable ParenStack stack_;
412  ssize_t paren_id_;
413};
414
415// Class to setup composition options for PDT composition.
416// Default is for the PDT as the first composition argument.
417template <class Arc, bool left_pdt = true>
418class PdtComposeFstOptions : public
419ComposeFstOptions<Arc,
420                  ParenMatcher< Fst<Arc> >,
421                  ParenFilter<AltSequenceComposeFilter<
422                                ParenMatcher< Fst<Arc> > > > > {
423 public:
424  typedef typename Arc::Label Label;
425  typedef ParenMatcher< Fst<Arc> > PdtMatcher;
426  typedef ParenFilter<AltSequenceComposeFilter<PdtMatcher> > PdtFilter;
427  typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
428  using COptions::matcher1;
429  using COptions::matcher2;
430  using COptions::filter;
431
432  PdtComposeFstOptions(const Fst<Arc> &ifst1,
433                    const vector<pair<Label, Label> > &parens,
434                       const Fst<Arc> &ifst2, bool expand = false,
435                       bool keep_parens = true) {
436    matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenList);
437    matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenLoop);
438
439    filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
440                           expand, keep_parens);
441  }
442};
443
444// Class to setup composition options for PDT with FST composition.
445// Specialization is for the FST as the first composition argument.
446template <class Arc>
447class PdtComposeFstOptions<Arc, false> : public
448ComposeFstOptions<Arc,
449                  ParenMatcher< Fst<Arc> >,
450                  ParenFilter<SequenceComposeFilter<
451                                ParenMatcher< Fst<Arc> > > > > {
452 public:
453  typedef typename Arc::Label Label;
454  typedef ParenMatcher< Fst<Arc> > PdtMatcher;
455  typedef ParenFilter<SequenceComposeFilter<PdtMatcher> > PdtFilter;
456  typedef ComposeFstOptions<Arc, PdtMatcher, PdtFilter> COptions;
457  using COptions::matcher1;
458  using COptions::matcher2;
459  using COptions::filter;
460
461  PdtComposeFstOptions(const Fst<Arc> &ifst1,
462                       const Fst<Arc> &ifst2,
463                       const vector<pair<Label, Label> > &parens,
464                       bool expand = false, bool keep_parens = true) {
465    matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenLoop);
466    matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenList);
467
468    filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens,
469                           expand, keep_parens);
470  }
471};
472
473enum PdtComposeFilter {
474  PAREN_FILTER,          // Bar-Hillel construction; keeps parentheses
475  EXPAND_FILTER,         // Bar-Hillel + expansion; removes parentheses
476  EXPAND_PAREN_FILTER,   // Bar-Hillel + expansion; keeps parentheses
477};
478
479struct PdtComposeOptions {
480  bool connect;  // Connect output
481  PdtComposeFilter filter_type;  // Which pre-defined filter to use
482
483  explicit PdtComposeOptions(bool c, PdtComposeFilter ft = PAREN_FILTER)
484      : connect(c), filter_type(ft) {}
485  PdtComposeOptions() : connect(true), filter_type(PAREN_FILTER) {}
486};
487
488// Composes pushdown transducer (PDT) encoded as an FST (1st arg) and
489// an FST (2nd arg) with the result also a PDT encoded as an Fst. (3rd arg).
490// In the PDTs, some transitions are labeled with open or close
491// parentheses. To be interpreted as a PDT, the parens must balance on
492// a path (see PdtExpand()). The open-close parenthesis label pairs
493// are passed in 'parens'.
494template <class Arc>
495void Compose(const Fst<Arc> &ifst1,
496             const vector<pair<typename Arc::Label,
497                               typename Arc::Label> > &parens,
498             const Fst<Arc> &ifst2,
499             MutableFst<Arc> *ofst,
500             const PdtComposeOptions &opts = PdtComposeOptions()) {
501  bool expand = opts.filter_type != PAREN_FILTER;
502  bool keep_parens = opts.filter_type != EXPAND_FILTER;
503  PdtComposeFstOptions<Arc, true> copts(ifst1, parens, ifst2,
504                                        expand, keep_parens);
505  copts.gc_limit = 0;
506  *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
507  if (opts.connect)
508    Connect(ofst);
509}
510
511// Composes an FST (1st arg) and pushdown transducer (PDT) encoded as
512// an FST (2nd arg) with the result also a PDT encoded as an Fst (3rd arg).
513// In the PDTs, some transitions are labeled with open or close
514// parentheses. To be interpreted as a PDT, the parens must balance on
515// a path (see ExpandFst()). The open-close parenthesis label pairs
516// are passed in 'parens'.
517template <class Arc>
518void Compose(const Fst<Arc> &ifst1,
519             const Fst<Arc> &ifst2,
520             const vector<pair<typename Arc::Label,
521                               typename Arc::Label> > &parens,
522             MutableFst<Arc> *ofst,
523             const PdtComposeOptions &opts = PdtComposeOptions()) {
524  bool expand = opts.filter_type != PAREN_FILTER;
525  bool keep_parens = opts.filter_type != EXPAND_FILTER;
526  PdtComposeFstOptions<Arc, false> copts(ifst1, ifst2, parens,
527                                         expand, keep_parens);
528  copts.gc_limit = 0;
529  *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
530  if (opts.connect)
531    Connect(ofst);
532}
533
534}  // namespace fst
535
536#endif  // FST_EXTENSIONS_PDT_COMPOSE_H__
537