matcher.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
1// matcher.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Classes to allow matching labels leaving FST states.
20
21#ifndef FST_LIB_MATCHER_H__
22#define FST_LIB_MATCHER_H__
23
24#include <algorithm>
25#include <set>
26
27#include <fst/mutable-fst.h>  // for all internal FST accessors
28
29
30namespace fst {
31
32// MATCHERS - these can find and iterate through requested labels at
33// FST states. In the simplest form, these are just some associative
34// map or search keyed on labels. More generally, they may
35// implement matching special labels that represent sets of labels
36// such as 'sigma' (all), 'rho' (rest), or 'phi' (fail).
37// The Matcher interface is:
38//
39// template <class F>
40// class Matcher {
41//  public:
42//   typedef F FST;
43//   typedef F::Arc Arc;
44//   typedef typename Arc::StateId StateId;
45//   typedef typename Arc::Label Label;
46//   typedef typename Arc::Weight Weight;
47//
48//   // Required constructors.
49//   Matcher(const F &fst, MatchType type);
50//   // If safe=true, the copy is thread-safe. See Fst<>::Copy()
51//   // for further doc.
52//   Matcher(const Matcher &matcher, bool safe = false);
53//
54//   // If safe=true, the copy is thread-safe. See Fst<>::Copy()
55//   // for further doc.
56//   Matcher<F> *Copy(bool safe = false) const;
57//
58//   // Returns the match type that can be provided (depending on
59//   // compatibility of the input FST). It is either
60//   // the requested match type, MATCH_NONE, or MATCH_UNKNOWN.
61//   // If 'test' is false, a constant time test is performed, but
62//   // MATCH_UNKNOWN may be returned. If 'test' is true,
63//   // a definite answer is returned, but may involve more costly
64//   // computation (e.g., visiting the Fst).
65//   MatchType Type(bool test) const;
66//   // Specifies the current state.
67//   void SetState(StateId s);
68//
69//   // This finds matches to a label at the current state.
70//   // Returns true if a match found. kNoLabel matches any
71//   // 'non-consuming' transitions, e.g., epsilon transitions,
72//   // which do not require a matching symbol.
73//   bool Find(Label label);
74//   // These iterate through any matches found:
75//   bool Done() const;         // No more matches.
76//   const A& Value() const;    // Current arc (when !Done)
77//   void Next();               // Advance to next arc (when !Done)
78//
79//   // Return matcher FST.
80//   const F& GetFst() const;
81//   // This specifies the known Fst properties as viewed from this
82//   // matcher. It takes as argument the input Fst's known properties.
83//   uint64 Properties(uint64 props) const;
84// };
85
86// Flags used for basic matchers (see also lookahead.h).
87const uint32 kMatcherFlags = 0x00000000;
88
89// Matcher interface, templated on the Arc definition; used
90// for matcher specializations that are returned by the
91// InitMatcher Fst method.
92template <class A>
93class MatcherBase {
94 public:
95  typedef A Arc;
96  typedef typename A::StateId StateId;
97  typedef typename A::Label Label;
98  typedef typename A::Weight Weight;
99
100  virtual ~MatcherBase() {}
101
102  virtual MatcherBase<A> *Copy(bool safe = false) const = 0;
103  virtual MatchType Type(bool test) const = 0;
104  void SetState(StateId s) { SetState_(s); }
105  bool Find(Label label) { return Find_(label); }
106  bool Done() const { return Done_(); }
107  const A& Value() const { return Value_(); }
108  void Next() { Next_(); }
109  virtual const Fst<A> &GetFst() const = 0;
110  virtual uint64 Properties(uint64 props) const = 0;
111  virtual uint32 Flags() const { return 0; }
112 private:
113  virtual void SetState_(StateId s) = 0;
114  virtual bool Find_(Label label) = 0;
115  virtual bool Done_() const = 0;
116  virtual const A& Value_() const  = 0;
117  virtual void Next_()  = 0;
118};
119
120
121// A matcher that expects sorted labels on the side to be matched.
122// If match_type == MATCH_INPUT, epsilons match the implicit self loop
123// Arc(kNoLabel, 0, Weight::One(), current_state) as well as any
124// actual epsilon transitions. If match_type == MATCH_OUTPUT, then
125// Arc(0, kNoLabel, Weight::One(), current_state) is instead matched.
126template <class F>
127class SortedMatcher : public MatcherBase<typename F::Arc> {
128 public:
129  typedef F FST;
130  typedef typename F::Arc Arc;
131  typedef typename Arc::StateId StateId;
132  typedef typename Arc::Label Label;
133  typedef typename Arc::Weight Weight;
134
135  // Labels >= binary_label will be searched for by binary search,
136  // o.w. linear search is used.
137  SortedMatcher(const F &fst, MatchType match_type,
138                Label binary_label = 1)
139      : fst_(fst.Copy()),
140        s_(kNoStateId),
141        aiter_(0),
142        match_type_(match_type),
143        binary_label_(binary_label),
144        match_label_(kNoLabel),
145        narcs_(0),
146        loop_(kNoLabel, 0, Weight::One(), kNoStateId),
147        error_(false) {
148    switch(match_type_) {
149      case MATCH_INPUT:
150      case MATCH_NONE:
151        break;
152      case MATCH_OUTPUT:
153        swap(loop_.ilabel, loop_.olabel);
154        break;
155      default:
156        FSTERROR() << "SortedMatcher: bad match type";
157        match_type_ = MATCH_NONE;
158        error_ = true;
159    }
160  }
161
162  SortedMatcher(const SortedMatcher<F> &matcher, bool safe = false)
163      : fst_(matcher.fst_->Copy(safe)),
164        s_(kNoStateId),
165        aiter_(0),
166        match_type_(matcher.match_type_),
167        binary_label_(matcher.binary_label_),
168        match_label_(kNoLabel),
169        narcs_(0),
170        loop_(matcher.loop_),
171        error_(matcher.error_) {}
172
173  virtual ~SortedMatcher() {
174    if (aiter_)
175      delete aiter_;
176    delete fst_;
177  }
178
179  virtual SortedMatcher<F> *Copy(bool safe = false) const {
180    return new SortedMatcher<F>(*this, safe);
181  }
182
183  virtual MatchType Type(bool test) const {
184    if (match_type_ == MATCH_NONE)
185      return match_type_;
186
187    uint64 true_prop =  match_type_ == MATCH_INPUT ?
188        kILabelSorted : kOLabelSorted;
189    uint64 false_prop = match_type_ == MATCH_INPUT ?
190        kNotILabelSorted : kNotOLabelSorted;
191    uint64 props = fst_->Properties(true_prop | false_prop, test);
192
193    if (props & true_prop)
194      return match_type_;
195    else if (props & false_prop)
196      return MATCH_NONE;
197    else
198      return MATCH_UNKNOWN;
199  }
200
201  void SetState(StateId s) {
202    if (s_ == s)
203      return;
204    s_ = s;
205    if (match_type_ == MATCH_NONE) {
206      FSTERROR() << "SortedMatcher: bad match type";
207      error_ = true;
208    }
209    if (aiter_)
210      delete aiter_;
211    aiter_ = new ArcIterator<F>(*fst_, s);
212    aiter_->SetFlags(kArcNoCache, kArcNoCache);
213    narcs_ = internal::NumArcs(*fst_, s);
214    loop_.nextstate = s;
215  }
216
217  bool Find(Label match_label);
218
219  bool Done() const {
220    if (current_loop_)
221      return false;
222    if (aiter_->Done())
223      return true;
224    aiter_->SetFlags(
225      match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
226      kArcValueFlags);
227    Label label = match_type_ == MATCH_INPUT ?
228        aiter_->Value().ilabel : aiter_->Value().olabel;
229    return label != match_label_;
230  }
231
232  const Arc& Value() const {
233    if (current_loop_) {
234      return loop_;
235    }
236    aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
237    return aiter_->Value();
238  }
239
240  void Next() {
241    if (current_loop_)
242      current_loop_ = false;
243    else
244      aiter_->Next();
245  }
246
247  virtual const F &GetFst() const { return *fst_; }
248
249  virtual uint64 Properties(uint64 inprops) const {
250    uint64 outprops = inprops;
251    if (error_) outprops |= kError;
252    return outprops;
253  }
254
255 private:
256  virtual void SetState_(StateId s) { SetState(s); }
257  virtual bool Find_(Label label) { return Find(label); }
258  virtual bool Done_() const { return Done(); }
259  virtual const Arc& Value_() const { return Value(); }
260  virtual void Next_() { Next(); }
261
262  const F *fst_;
263  StateId s_;                     // Current state
264  ArcIterator<F> *aiter_;         // Iterator for current state
265  MatchType match_type_;          // Type of match to perform
266  Label binary_label_;            // Least label for binary search
267  Label match_label_;             // Current label to be matched
268  size_t narcs_;                  // Current state arc count
269  Arc loop_;                      // For non-consuming symbols
270  bool current_loop_;             // Current arc is the implicit loop
271  bool error_;                    // Error encountered
272
273  void operator=(const SortedMatcher<F> &);  // Disallow
274};
275
276template <class F> inline
277bool SortedMatcher<F>::Find(Label match_label) {
278  if (error_) {
279    current_loop_ = false;
280    match_label_ = kNoLabel;
281    return false;
282  }
283  current_loop_ = match_label == 0;
284  match_label_ = match_label == kNoLabel ? 0 : match_label;
285  aiter_->SetFlags(
286      match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
287      kArcValueFlags);
288  if (match_label_ >= binary_label_) {
289    // Binary search for match.
290    size_t low = 0;
291    size_t high = narcs_;
292    while (low < high) {
293      size_t mid = (low + high) / 2;
294      aiter_->Seek(mid);
295      Label label = match_type_ == MATCH_INPUT ?
296          aiter_->Value().ilabel : aiter_->Value().olabel;
297      if (label > match_label_) {
298        high = mid;
299      } else if (label < match_label_) {
300        low = mid + 1;
301      } else {
302        // find first matching label (when non-determinism)
303        for (size_t i = mid; i > low; --i) {
304          aiter_->Seek(i - 1);
305          label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel :
306              aiter_->Value().olabel;
307          if (label != match_label_) {
308            aiter_->Seek(i);
309            return true;
310          }
311        }
312        return true;
313      }
314    }
315    return current_loop_;
316  } else {
317    // Linear search for match.
318    for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) {
319      Label label = match_type_ == MATCH_INPUT ?
320          aiter_->Value().ilabel : aiter_->Value().olabel;
321      if (label == match_label_) {
322        return true;
323      }
324      if (label > match_label_)
325        break;
326    }
327    return current_loop_;
328  }
329}
330
331
332// Specifies whether during matching we rewrite both the input and output sides.
333enum MatcherRewriteMode {
334  MATCHER_REWRITE_AUTO = 0,    // Rewrites both sides iff acceptor.
335  MATCHER_REWRITE_ALWAYS,
336  MATCHER_REWRITE_NEVER
337};
338
339
340// For any requested label that doesn't match at a state, this matcher
341// considers all transitions that match the label 'rho_label' (rho =
342// 'rest').  Each such rho transition found is returned with the
343// rho_label rewritten as the requested label (both sides if an
344// acceptor, or if 'rewrite_both' is true and both input and output
345// labels of the found transition are 'rho_label').  If 'rho_label' is
346// kNoLabel, this special matching is not done.  RhoMatcher is
347// templated itself on a matcher, which is used to perform the
348// underlying matching.  By default, the underlying matcher is
349// constructed by RhoMatcher.  The user can instead pass in this
350// object; in that case, RhoMatcher takes its ownership.
351template <class M>
352class RhoMatcher : public MatcherBase<typename M::Arc> {
353 public:
354  typedef typename M::FST FST;
355  typedef typename M::Arc Arc;
356  typedef typename Arc::StateId StateId;
357  typedef typename Arc::Label Label;
358  typedef typename Arc::Weight Weight;
359
360  RhoMatcher(const FST &fst,
361             MatchType match_type,
362             Label rho_label = kNoLabel,
363             MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
364             M *matcher = 0)
365      : matcher_(matcher ? matcher : new M(fst, match_type)),
366        match_type_(match_type),
367        rho_label_(rho_label),
368        error_(false) {
369    if (match_type == MATCH_BOTH) {
370      FSTERROR() << "RhoMatcher: bad match type";
371      match_type_ = MATCH_NONE;
372      error_ = true;
373    }
374    if (rho_label == 0) {
375      FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label";
376      rho_label_ = kNoLabel;
377      error_ = true;
378    }
379
380    if (rewrite_mode == MATCHER_REWRITE_AUTO)
381      rewrite_both_ = fst.Properties(kAcceptor, true);
382    else if (rewrite_mode == MATCHER_REWRITE_ALWAYS)
383      rewrite_both_ = true;
384    else
385      rewrite_both_ = false;
386  }
387
388  RhoMatcher(const RhoMatcher<M> &matcher, bool safe = false)
389      : matcher_(new M(*matcher.matcher_, safe)),
390        match_type_(matcher.match_type_),
391        rho_label_(matcher.rho_label_),
392        rewrite_both_(matcher.rewrite_both_),
393        error_(matcher.error_) {}
394
395  virtual ~RhoMatcher() {
396    delete matcher_;
397  }
398
399  virtual RhoMatcher<M> *Copy(bool safe = false) const {
400    return new RhoMatcher<M>(*this, safe);
401  }
402
403  virtual MatchType Type(bool test) const { return matcher_->Type(test); }
404
405  void SetState(StateId s) {
406    matcher_->SetState(s);
407    has_rho_ = rho_label_ != kNoLabel;
408  }
409
410  bool Find(Label match_label) {
411    if (match_label == rho_label_ && rho_label_ != kNoLabel) {
412      FSTERROR() << "RhoMatcher::Find: bad label (rho)";
413      error_ = true;
414      return false;
415    }
416    if (matcher_->Find(match_label)) {
417      rho_match_ = kNoLabel;
418      return true;
419    } else if (has_rho_ && match_label != 0 && match_label != kNoLabel &&
420               (has_rho_ = matcher_->Find(rho_label_))) {
421      rho_match_ = match_label;
422      return true;
423    } else {
424      return false;
425    }
426  }
427
428  bool Done() const { return matcher_->Done(); }
429
430  const Arc& Value() const {
431    if (rho_match_ == kNoLabel) {
432      return matcher_->Value();
433    } else {
434      rho_arc_ = matcher_->Value();
435      if (rewrite_both_) {
436        if (rho_arc_.ilabel == rho_label_)
437          rho_arc_.ilabel = rho_match_;
438        if (rho_arc_.olabel == rho_label_)
439          rho_arc_.olabel = rho_match_;
440      } else if (match_type_ == MATCH_INPUT) {
441        rho_arc_.ilabel = rho_match_;
442      } else {
443        rho_arc_.olabel = rho_match_;
444      }
445      return rho_arc_;
446    }
447  }
448
449  void Next() { matcher_->Next(); }
450
451  virtual const FST &GetFst() const { return matcher_->GetFst(); }
452
453  virtual uint64 Properties(uint64 props) const;
454
455 private:
456  virtual void SetState_(StateId s) { SetState(s); }
457  virtual bool Find_(Label label) { return Find(label); }
458  virtual bool Done_() const { return Done(); }
459  virtual const Arc& Value_() const { return Value(); }
460  virtual void Next_() { Next(); }
461
462  M *matcher_;
463  MatchType match_type_;  // Type of match requested
464  Label rho_label_;       // Label that represents the rho transition
465  bool rewrite_both_;     // Rewrite both sides when both are 'rho_label_'
466  bool has_rho_;          // Are there possibly rhos at the current state?
467  Label rho_match_;       // Current label that matches rho transition
468  mutable Arc rho_arc_;   // Arc to return when rho match
469  bool error_;            // Error encountered
470
471  void operator=(const RhoMatcher<M> &);  // Disallow
472};
473
474template <class M> inline
475uint64 RhoMatcher<M>::Properties(uint64 inprops) const {
476  uint64 outprops = matcher_->Properties(inprops);
477  if (error_) outprops |= kError;
478
479  if (match_type_ == MATCH_NONE) {
480    return outprops;
481  } else if (match_type_ == MATCH_INPUT) {
482    if (rewrite_both_) {
483      return outprops & ~(kODeterministic | kNonODeterministic | kString |
484                       kILabelSorted | kNotILabelSorted |
485                       kOLabelSorted | kNotOLabelSorted);
486    } else {
487      return outprops & ~(kODeterministic | kAcceptor | kString |
488                       kILabelSorted | kNotILabelSorted);
489    }
490  } else if (match_type_ == MATCH_OUTPUT) {
491    if (rewrite_both_) {
492      return outprops & ~(kIDeterministic | kNonIDeterministic | kString |
493                       kILabelSorted | kNotILabelSorted |
494                       kOLabelSorted | kNotOLabelSorted);
495    } else {
496      return outprops & ~(kIDeterministic | kAcceptor | kString |
497                       kOLabelSorted | kNotOLabelSorted);
498    }
499  } else {
500    // Shouldn't ever get here.
501    FSTERROR() << "RhoMatcher:: bad match type: " << match_type_;
502    return 0;
503  }
504}
505
506
507// For any requested label, this matcher considers all transitions
508// that match the label 'sigma_label' (sigma = "any"), and this in
509// additions to transitions with the requested label.  Each such sigma
510// transition found is returned with the sigma_label rewritten as the
511// requested label (both sides if an acceptor, or if 'rewrite_both' is
512// true and both input and output labels of the found transition are
513// 'sigma_label').  If 'sigma_label' is kNoLabel, this special
514// matching is not done.  SigmaMatcher is templated itself on a
515// matcher, which is used to perform the underlying matching.  By
516// default, the underlying matcher is constructed by SigmaMatcher.
517// The user can instead pass in this object; in that case,
518// SigmaMatcher takes its ownership.
519template <class M>
520class SigmaMatcher : public MatcherBase<typename M::Arc> {
521 public:
522  typedef typename M::FST FST;
523  typedef typename M::Arc Arc;
524  typedef typename Arc::StateId StateId;
525  typedef typename Arc::Label Label;
526  typedef typename Arc::Weight Weight;
527
528  SigmaMatcher(const FST &fst,
529               MatchType match_type,
530               Label sigma_label = kNoLabel,
531               MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
532               M *matcher = 0)
533      : matcher_(matcher ? matcher : new M(fst, match_type)),
534        match_type_(match_type),
535        sigma_label_(sigma_label),
536        error_(false) {
537    if (match_type == MATCH_BOTH) {
538      FSTERROR() << "SigmaMatcher: bad match type";
539      match_type_ = MATCH_NONE;
540      error_ = true;
541    }
542    if (sigma_label == 0) {
543      FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label";
544      sigma_label_ = kNoLabel;
545      error_ = true;
546    }
547
548    if (rewrite_mode == MATCHER_REWRITE_AUTO)
549      rewrite_both_ = fst.Properties(kAcceptor, true);
550    else if (rewrite_mode == MATCHER_REWRITE_ALWAYS)
551      rewrite_both_ = true;
552    else
553      rewrite_both_ = false;
554  }
555
556  SigmaMatcher(const SigmaMatcher<M> &matcher, bool safe = false)
557      : matcher_(new M(*matcher.matcher_, safe)),
558        match_type_(matcher.match_type_),
559        sigma_label_(matcher.sigma_label_),
560        rewrite_both_(matcher.rewrite_both_),
561        error_(matcher.error_) {}
562
563  virtual ~SigmaMatcher() {
564    delete matcher_;
565  }
566
567  virtual SigmaMatcher<M> *Copy(bool safe = false) const {
568    return new SigmaMatcher<M>(*this, safe);
569  }
570
571  virtual MatchType Type(bool test) const { return matcher_->Type(test); }
572
573  void SetState(StateId s) {
574    matcher_->SetState(s);
575    has_sigma_ =
576        sigma_label_ != kNoLabel ? matcher_->Find(sigma_label_) : false;
577  }
578
579  bool Find(Label match_label) {
580    match_label_ = match_label;
581    if (match_label == sigma_label_ && sigma_label_ != kNoLabel) {
582      FSTERROR() << "SigmaMatcher::Find: bad label (sigma)";
583      error_ = true;
584      return false;
585    }
586    if (matcher_->Find(match_label)) {
587      sigma_match_ = kNoLabel;
588      return true;
589    } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel &&
590               matcher_->Find(sigma_label_)) {
591      sigma_match_ = match_label;
592      return true;
593    } else {
594      return false;
595    }
596  }
597
598  bool Done() const {
599    return matcher_->Done();
600  }
601
602  const Arc& Value() const {
603    if (sigma_match_ == kNoLabel) {
604      return matcher_->Value();
605    } else {
606      sigma_arc_ = matcher_->Value();
607      if (rewrite_both_) {
608        if (sigma_arc_.ilabel == sigma_label_)
609          sigma_arc_.ilabel = sigma_match_;
610        if (sigma_arc_.olabel == sigma_label_)
611          sigma_arc_.olabel = sigma_match_;
612      } else if (match_type_ == MATCH_INPUT) {
613        sigma_arc_.ilabel = sigma_match_;
614      } else {
615        sigma_arc_.olabel = sigma_match_;
616      }
617      return sigma_arc_;
618    }
619  }
620
621  void Next() {
622    matcher_->Next();
623    if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) &&
624        (match_label_ > 0)) {
625      matcher_->Find(sigma_label_);
626      sigma_match_ = match_label_;
627    }
628  }
629
630  virtual const FST &GetFst() const { return matcher_->GetFst(); }
631
632  virtual uint64 Properties(uint64 props) const;
633
634private:
635  virtual void SetState_(StateId s) { SetState(s); }
636  virtual bool Find_(Label label) { return Find(label); }
637  virtual bool Done_() const { return Done(); }
638  virtual const Arc& Value_() const { return Value(); }
639  virtual void Next_() { Next(); }
640
641  M *matcher_;
642  MatchType match_type_;   // Type of match requested
643  Label sigma_label_;      // Label that represents the sigma transition
644  bool rewrite_both_;      // Rewrite both sides when both are 'sigma_label_'
645  bool has_sigma_;         // Are there sigmas at the current state?
646  Label sigma_match_;      // Current label that matches sigma transition
647  mutable Arc sigma_arc_;  // Arc to return when sigma match
648  Label match_label_;      // Label being matched
649  bool error_;             // Error encountered
650
651  void operator=(const SigmaMatcher<M> &);  // disallow
652};
653
654template <class M> inline
655uint64 SigmaMatcher<M>::Properties(uint64 inprops) const {
656  uint64 outprops = matcher_->Properties(inprops);
657  if (error_) outprops |= kError;
658
659  if (match_type_ == MATCH_NONE) {
660    return outprops;
661  } else if (rewrite_both_) {
662    return outprops & ~(kIDeterministic | kNonIDeterministic |
663                     kODeterministic | kNonODeterministic |
664                     kILabelSorted | kNotILabelSorted |
665                     kOLabelSorted | kNotOLabelSorted |
666                     kString);
667  } else if (match_type_ == MATCH_INPUT) {
668    return outprops & ~(kIDeterministic | kNonIDeterministic |
669                     kODeterministic | kNonODeterministic |
670                     kILabelSorted | kNotILabelSorted |
671                     kString | kAcceptor);
672  } else if (match_type_ == MATCH_OUTPUT) {
673    return outprops & ~(kIDeterministic | kNonIDeterministic |
674                     kODeterministic | kNonODeterministic |
675                     kOLabelSorted | kNotOLabelSorted |
676                     kString | kAcceptor);
677  } else {
678    // Shouldn't ever get here.
679    FSTERROR() << "SigmaMatcher:: bad match type: " << match_type_;
680    return 0;
681  }
682}
683
684
685// For any requested label that doesn't match at a state, this matcher
686// considers the *unique* transition that matches the label 'phi_label'
687// (phi = 'fail'), and recursively looks for a match at its
688// destination.  When 'phi_loop' is true, if no match is found but a
689// phi self-loop is found, then the phi transition found is returned
690// with the phi_label rewritten as the requested label (both sides if
691// an acceptor, or if 'rewrite_both' is true and both input and output
692// labels of the found transition are 'phi_label').  If 'phi_label' is
693// kNoLabel, this special matching is not done.  PhiMatcher is
694// templated itself on a matcher, which is used to perform the
695// underlying matching.  By default, the underlying matcher is
696// constructed by PhiMatcher. The user can instead pass in this
697// object; in that case, PhiMatcher takes its ownership.
698// Warning: phi non-determinism not supported (for simplicity).
699template <class M>
700class PhiMatcher : public MatcherBase<typename M::Arc> {
701 public:
702  typedef typename M::FST FST;
703  typedef typename M::Arc Arc;
704  typedef typename Arc::StateId StateId;
705  typedef typename Arc::Label Label;
706  typedef typename Arc::Weight Weight;
707
708  PhiMatcher(const FST &fst,
709             MatchType match_type,
710             Label phi_label = kNoLabel,
711             bool phi_loop = true,
712             MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
713             M *matcher = 0)
714      : matcher_(matcher ? matcher : new M(fst, match_type)),
715        match_type_(match_type),
716        phi_label_(phi_label),
717        state_(kNoStateId),
718        phi_loop_(phi_loop),
719        error_(false) {
720    if (match_type == MATCH_BOTH) {
721      FSTERROR() << "PhiMatcher: bad match type";
722      match_type_ = MATCH_NONE;
723      error_ = true;
724    }
725    if (phi_label == 0) {
726      FSTERROR() << "PhiMatcher: 0 cannot be used as phi_label";
727      phi_label_ = kNoLabel;
728      error_ = true;
729    }
730
731    if (rewrite_mode == MATCHER_REWRITE_AUTO)
732      rewrite_both_ = fst.Properties(kAcceptor, true);
733    else if (rewrite_mode == MATCHER_REWRITE_ALWAYS)
734      rewrite_both_ = true;
735    else
736      rewrite_both_ = false;
737   }
738
739  PhiMatcher(const PhiMatcher<M> &matcher, bool safe = false)
740      : matcher_(new M(*matcher.matcher_, safe)),
741        match_type_(matcher.match_type_),
742        phi_label_(matcher.phi_label_),
743        rewrite_both_(matcher.rewrite_both_),
744        state_(kNoStateId),
745        phi_loop_(matcher.phi_loop_),
746        error_(matcher.error_) {}
747
748  virtual ~PhiMatcher() {
749    delete matcher_;
750  }
751
752  virtual PhiMatcher<M> *Copy(bool safe = false) const {
753    return new PhiMatcher<M>(*this, safe);
754  }
755
756  virtual MatchType Type(bool test) const { return matcher_->Type(test); }
757
758  void SetState(StateId s) {
759    matcher_->SetState(s);
760    state_ = s;
761    has_phi_ = phi_label_ != kNoLabel;
762  }
763
764  bool Find(Label match_label);
765
766  bool Done() const { return matcher_->Done(); }
767
768  const Arc& Value() const {
769    if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) {
770      return matcher_->Value();
771    } else {
772      phi_arc_ = matcher_->Value();
773      phi_arc_.weight = Times(phi_weight_, phi_arc_.weight);
774      if (phi_match_ != kNoLabel) {
775        if (rewrite_both_) {
776          if (phi_arc_.ilabel == phi_label_)
777            phi_arc_.ilabel = phi_match_;
778          if (phi_arc_.olabel == phi_label_)
779            phi_arc_.olabel = phi_match_;
780        } else if (match_type_ == MATCH_INPUT) {
781          phi_arc_.ilabel = phi_match_;
782        } else {
783          phi_arc_.olabel = phi_match_;
784        }
785      }
786      return phi_arc_;
787    }
788  }
789
790  void Next() { matcher_->Next(); }
791
792  virtual const FST &GetFst() const { return matcher_->GetFst(); }
793
794  virtual uint64 Properties(uint64 props) const;
795
796private:
797  virtual void SetState_(StateId s) { SetState(s); }
798  virtual bool Find_(Label label) { return Find(label); }
799  virtual bool Done_() const { return Done(); }
800  virtual const Arc& Value_() const { return Value(); }
801  virtual void Next_() { Next(); }
802
803  M *matcher_;
804  MatchType match_type_;  // Type of match requested
805  Label phi_label_;       // Label that represents the phi transition
806  bool rewrite_both_;     // Rewrite both sides when both are 'phi_label_'
807  bool has_phi_;          // Are there possibly phis at the current state?
808  Label phi_match_;       // Current label that matches phi loop
809  mutable Arc phi_arc_;   // Arc to return
810  StateId state_;         // State where looking for matches
811  Weight phi_weight_;     // Product of the weights of phi transitions taken
812  bool phi_loop_;         // When true, phi self-loop are allowed and treated
813                          // as rho (required for Aho-Corasick)
814  bool error_;             // Error encountered
815
816  void operator=(const PhiMatcher<M> &);  // disallow
817};
818
819template <class M> inline
820bool PhiMatcher<M>::Find(Label match_label) {
821  if (match_label == phi_label_ && phi_label_ != kNoLabel) {
822    FSTERROR() << "PhiMatcher::Find: bad label (phi)";
823    error_ = true;
824    return false;
825  }
826  matcher_->SetState(state_);
827  phi_match_ = kNoLabel;
828  phi_weight_ = Weight::One();
829  if (!has_phi_ || match_label == 0 || match_label == kNoLabel)
830    return matcher_->Find(match_label);
831  StateId state = state_;
832  while (!matcher_->Find(match_label)) {
833    if (!matcher_->Find(phi_label_))
834      return false;
835    if (phi_loop_ && matcher_->Value().nextstate == state) {
836      phi_match_ = match_label;
837      return true;
838    }
839    phi_weight_ = Times(phi_weight_, matcher_->Value().weight);
840    state = matcher_->Value().nextstate;
841    matcher_->Next();
842    if (!matcher_->Done()) {
843      FSTERROR() << "PhiMatcher: phi non-determinism not supported";
844      error_ = true;
845    }
846    matcher_->SetState(state);
847  }
848  return true;
849}
850
851template <class M> inline
852uint64 PhiMatcher<M>::Properties(uint64 inprops) const {
853  uint64 outprops = matcher_->Properties(inprops);
854  if (error_) outprops |= kError;
855
856  if (match_type_ == MATCH_NONE) {
857    return outprops;
858  } else if (match_type_ == MATCH_INPUT) {
859    if (rewrite_both_) {
860      return outprops & ~(kODeterministic | kNonODeterministic | kString |
861                       kILabelSorted | kNotILabelSorted |
862                       kOLabelSorted | kNotOLabelSorted);
863    } else {
864      return outprops & ~(kODeterministic | kAcceptor | kString |
865                       kILabelSorted | kNotILabelSorted |
866                       kOLabelSorted | kNotOLabelSorted);
867    }
868  } else if (match_type_ == MATCH_OUTPUT) {
869    if (rewrite_both_) {
870      return outprops & ~(kIDeterministic | kNonIDeterministic | kString |
871                       kILabelSorted | kNotILabelSorted |
872                       kOLabelSorted | kNotOLabelSorted);
873    } else {
874      return outprops & ~(kIDeterministic | kAcceptor | kString |
875                       kILabelSorted | kNotILabelSorted |
876                       kOLabelSorted | kNotOLabelSorted);
877    }
878  } else {
879    // Shouldn't ever get here.
880    FSTERROR() << "PhiMatcher:: bad match type: " << match_type_;
881    return 0;
882  }
883}
884
885
886//
887// MULTI-EPS MATCHER FLAGS
888//
889
890// Return multi-epsilon arcs for Find(kNoLabel).
891const uint32 kMultiEpsList =  0x00000001;
892
893// Return a kNolabel loop for Find(multi_eps).
894const uint32 kMultiEpsLoop =  0x00000002;
895
896// MultiEpsMatcher: allows treating multiple non-0 labels as
897// non-consuming labels in addition to 0 that is always
898// non-consuming. Precise behavior controlled by 'flags' argument. By
899// default, the underlying matcher is constructed by
900// MultiEpsMatcher. The user can instead pass in this object; in that
901// case, MultiEpsMatcher takes its ownership iff 'own_matcher' is
902// true.
903template <class M>
904class MultiEpsMatcher {
905 public:
906  typedef typename M::FST FST;
907  typedef typename M::Arc Arc;
908  typedef typename Arc::StateId StateId;
909  typedef typename Arc::Label Label;
910  typedef typename Arc::Weight Weight;
911
912  MultiEpsMatcher(const FST &fst, MatchType match_type,
913                  uint32 flags = (kMultiEpsLoop | kMultiEpsList),
914                  M *matcher = 0, bool own_matcher = true)
915      : matcher_(matcher ? matcher : new M(fst, match_type)),
916        flags_(flags),
917        own_matcher_(matcher ?  own_matcher : true) {
918    if (match_type == MATCH_INPUT) {
919      loop_.ilabel = kNoLabel;
920      loop_.olabel = 0;
921    } else {
922      loop_.ilabel = 0;
923      loop_.olabel = kNoLabel;
924    }
925    loop_.weight = Weight::One();
926    loop_.nextstate = kNoStateId;
927  }
928
929  MultiEpsMatcher(const MultiEpsMatcher<M> &matcher, bool safe = false)
930      : matcher_(new M(*matcher.matcher_, safe)),
931        flags_(matcher.flags_),
932        own_matcher_(true),
933        multi_eps_labels_(matcher.multi_eps_labels_),
934        loop_(matcher.loop_) {
935    loop_.nextstate = kNoStateId;
936  }
937
938  ~MultiEpsMatcher() {
939    if (own_matcher_)
940      delete matcher_;
941  }
942
943  MultiEpsMatcher<M> *Copy(bool safe = false) const {
944    return new MultiEpsMatcher<M>(*this, safe);
945  }
946
947  MatchType Type(bool test) const { return matcher_->Type(test); }
948
949  void SetState(StateId s) {
950    matcher_->SetState(s);
951    loop_.nextstate = s;
952  }
953
954  bool Find(Label match_label);
955
956  bool Done() const {
957    return done_;
958  }
959
960  const Arc& Value() const {
961    return current_loop_ ? loop_ : matcher_->Value();
962  }
963
964  void Next() {
965    if (!current_loop_) {
966      matcher_->Next();
967      done_ = matcher_->Done();
968      if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) {
969        ++multi_eps_iter_;
970        while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
971               !matcher_->Find(*multi_eps_iter_))
972          ++multi_eps_iter_;
973        if (multi_eps_iter_ != multi_eps_labels_.End())
974          done_ = false;
975        else
976          done_ = !matcher_->Find(kNoLabel);
977
978      }
979    } else {
980      done_ = true;
981    }
982  }
983
984  const FST &GetFst() const { return matcher_->GetFst(); }
985
986  uint64 Properties(uint64 props) const { return matcher_->Properties(props); }
987
988  uint32 Flags() const { return matcher_->Flags(); }
989
990  void AddMultiEpsLabel(Label label) {
991    if (label == 0) {
992      FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
993    } else {
994      multi_eps_labels_.Insert(label);
995    }
996  }
997
998  void ClearMultiEpsLabels() {
999    multi_eps_labels_.Clear();
1000  }
1001
1002private:
1003  // Specialized for 'set' - log lookup
1004  bool IsMultiEps(const set<Label> &multi_eps_labels, Label label) const {
1005    return multi_eps_labels.Find(label) != multi_eps_labels.end();
1006  }
1007
1008  M *matcher_;
1009  uint32 flags_;
1010  bool own_matcher_;             // Does this class delete the matcher?
1011
1012  // Multi-eps label set
1013  CompactSet<Label, kNoLabel> multi_eps_labels_;
1014  typename CompactSet<Label, kNoLabel>::const_iterator multi_eps_iter_;
1015
1016  bool current_loop_;            // Current arc is the implicit loop
1017  mutable Arc loop_;             // For non-consuming symbols
1018  bool done_;                    // Matching done
1019
1020  void operator=(const MultiEpsMatcher<M> &);  // Disallow
1021};
1022
1023template <class M> inline
1024bool MultiEpsMatcher<M>::Find(Label match_label) {
1025  multi_eps_iter_ = multi_eps_labels_.End();
1026  current_loop_ = false;
1027  bool ret;
1028  if (match_label == 0) {
1029    ret = matcher_->Find(0);
1030  } else if (match_label == kNoLabel) {
1031    if (flags_ & kMultiEpsList) {
1032      // return all non-consuming arcs (incl. epsilon)
1033      multi_eps_iter_ = multi_eps_labels_.Begin();
1034      while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1035             !matcher_->Find(*multi_eps_iter_))
1036        ++multi_eps_iter_;
1037      if (multi_eps_iter_ != multi_eps_labels_.End())
1038        ret = true;
1039      else
1040        ret = matcher_->Find(kNoLabel);
1041    } else {
1042      // return all epsilon arcs
1043      ret = matcher_->Find(kNoLabel);
1044    }
1045  } else if ((flags_ & kMultiEpsLoop) &&
1046             multi_eps_labels_.Find(match_label) != multi_eps_labels_.End()) {
1047    // return 'implicit' loop
1048    current_loop_ = true;
1049    ret = true;
1050  } else {
1051    ret = matcher_->Find(match_label);
1052  }
1053  done_ = !ret;
1054  return ret;
1055}
1056
1057
1058// Generic matcher, templated on the FST definition
1059// - a wrapper around pointer to specific one.
1060// Here is a typical use: \code
1061//   Matcher<StdFst> matcher(fst, MATCH_INPUT);
1062//   matcher.SetState(state);
1063//   if (matcher.Find(label))
1064//     for (; !matcher.Done(); matcher.Next()) {
1065//       StdArc &arc = matcher.Value();
1066//       ...
1067//     } \endcode
1068template <class F>
1069class Matcher {
1070 public:
1071  typedef F FST;
1072  typedef typename F::Arc Arc;
1073  typedef typename Arc::StateId StateId;
1074  typedef typename Arc::Label Label;
1075  typedef typename Arc::Weight Weight;
1076
1077  Matcher(const F &fst, MatchType match_type) {
1078    base_ = fst.InitMatcher(match_type);
1079    if (!base_)
1080      base_ = new SortedMatcher<F>(fst, match_type);
1081  }
1082
1083  Matcher(const Matcher<F> &matcher, bool safe = false) {
1084    base_ = matcher.base_->Copy(safe);
1085  }
1086
1087  // Takes ownership of the provided matcher
1088  Matcher(MatcherBase<Arc>* base_matcher) { base_ = base_matcher; }
1089
1090  ~Matcher() { delete base_; }
1091
1092  Matcher<F> *Copy(bool safe = false) const {
1093    return new Matcher<F>(*this, safe);
1094  }
1095
1096  MatchType Type(bool test) const { return base_->Type(test); }
1097  void SetState(StateId s) { base_->SetState(s); }
1098  bool Find(Label label) { return base_->Find(label); }
1099  bool Done() const { return base_->Done(); }
1100  const Arc& Value() const { return base_->Value(); }
1101  void Next() { base_->Next(); }
1102  const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); }
1103  uint64 Properties(uint64 props) const { return base_->Properties(props); }
1104  uint32 Flags() const { return base_->Flags() & kMatcherFlags; }
1105
1106 private:
1107  MatcherBase<Arc> *base_;
1108
1109  void operator=(const Matcher<Arc> &);  // disallow
1110};
1111
1112}  // namespace fst
1113
1114
1115
1116#endif  // FST_LIB_MATCHER_H__
1117