1// matcher.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Classes to allow matching labels leaving FST states.
20
21#ifndef FST_LIB_MATCHER_H__
22#define FST_LIB_MATCHER_H__
23
24#include <algorithm>
25#include <set>
26
27#include <fst/mutable-fst.h>  // for all internal FST accessors
28
29
30namespace fst {
31
32// MATCHERS - these can find and iterate through requested labels at
33// FST states. In the simplest form, these are just some associative
34// map or search keyed on labels. More generally, they may
35// implement matching special labels that represent sets of labels
36// such as 'sigma' (all), 'rho' (rest), or 'phi' (fail).
37// The Matcher interface is:
38//
39// template <class F>
40// class Matcher {
41//  public:
42//   typedef F FST;
43//   typedef F::Arc Arc;
44//   typedef typename Arc::StateId StateId;
45//   typedef typename Arc::Label Label;
46//   typedef typename Arc::Weight Weight;
47//
48//   // Required constructors.
49//   Matcher(const F &fst, MatchType type);
50//   // If safe=true, the copy is thread-safe. See Fst<>::Copy()
51//   // for further doc.
52//   Matcher(const Matcher &matcher, bool safe = false);
53//
54//   // If safe=true, the copy is thread-safe. See Fst<>::Copy()
55//   // for further doc.
56//   Matcher<F> *Copy(bool safe = false) const;
57//
58//   // Returns the match type that can be provided (depending on
59//   // compatibility of the input FST). It is either
60//   // the requested match type, MATCH_NONE, or MATCH_UNKNOWN.
61//   // If 'test' is false, a constant time test is performed, but
62//   // MATCH_UNKNOWN may be returned. If 'test' is true,
63//   // a definite answer is returned, but may involve more costly
64//   // computation (e.g., visiting the Fst).
65//   MatchType Type(bool test) const;
66//   // Specifies the current state.
67//   void SetState(StateId s);
68//
69//   // This finds matches to a label at the current state.
70//   // Returns true if a match found. kNoLabel matches any
71//   // 'non-consuming' transitions, e.g., epsilon transitions,
72//   // which do not require a matching symbol.
73//   bool Find(Label label);
74//   // These iterate through any matches found:
75//   bool Done() const;         // No more matches.
76//   const A& Value() const;    // Current arc (when !Done)
77//   void Next();               // Advance to next arc (when !Done)
78//   // Initially and after SetState() the iterator methods
79//   // have undefined behavior until Find() is called.
80//
81//   // Return matcher FST.
82//   const F& GetFst() const;
83//   // This specifies the known Fst properties as viewed from this
84//   // matcher. It takes as argument the input Fst's known properties.
85//   uint64 Properties(uint64 props) const;
86// };
87
88//
89// MATCHER FLAGS (see also kLookAheadFlags in lookahead-matcher.h)
90//
91// Matcher prefers being used as the matching side in composition.
92const uint32 kPreferMatch  = 0x00000001;
93
94// Matcher needs to be used as the matching side in composition.
95const uint32 kRequireMatch = 0x00000002;
96
97// Flags used for basic matchers (see also lookahead.h).
98const uint32 kMatcherFlags = kPreferMatch | kRequireMatch;
99
100// Matcher interface, templated on the Arc definition; used
101// for matcher specializations that are returned by the
102// InitMatcher Fst method.
103template <class A>
104class MatcherBase {
105 public:
106  typedef A Arc;
107  typedef typename A::StateId StateId;
108  typedef typename A::Label Label;
109  typedef typename A::Weight Weight;
110
111  virtual ~MatcherBase() {}
112
113  virtual MatcherBase<A> *Copy(bool safe = false) const = 0;
114  virtual MatchType Type(bool test) const = 0;
115  void SetState(StateId s) { SetState_(s); }
116  bool Find(Label label) { return Find_(label); }
117  bool Done() const { return Done_(); }
118  const A& Value() const { return Value_(); }
119  void Next() { Next_(); }
120  virtual const Fst<A> &GetFst() const = 0;
121  virtual uint64 Properties(uint64 props) const = 0;
122  virtual uint32 Flags() const { return 0; }
123 private:
124  virtual void SetState_(StateId s) = 0;
125  virtual bool Find_(Label label) = 0;
126  virtual bool Done_() const = 0;
127  virtual const A& Value_() const  = 0;
128  virtual void Next_()  = 0;
129};
130
131
132// A matcher that expects sorted labels on the side to be matched.
133// If match_type == MATCH_INPUT, epsilons match the implicit self loop
134// Arc(kNoLabel, 0, Weight::One(), current_state) as well as any
135// actual epsilon transitions. If match_type == MATCH_OUTPUT, then
136// Arc(0, kNoLabel, Weight::One(), current_state) is instead matched.
137template <class F>
138class SortedMatcher : public MatcherBase<typename F::Arc> {
139 public:
140  typedef F FST;
141  typedef typename F::Arc Arc;
142  typedef typename Arc::StateId StateId;
143  typedef typename Arc::Label Label;
144  typedef typename Arc::Weight Weight;
145
146  // Labels >= binary_label will be searched for by binary search,
147  // o.w. linear search is used.
148  SortedMatcher(const F &fst, MatchType match_type,
149                Label binary_label = 1)
150      : fst_(fst.Copy()),
151        s_(kNoStateId),
152        aiter_(0),
153        match_type_(match_type),
154        binary_label_(binary_label),
155        match_label_(kNoLabel),
156        narcs_(0),
157        loop_(kNoLabel, 0, Weight::One(), kNoStateId),
158        error_(false) {
159    switch(match_type_) {
160      case MATCH_INPUT:
161      case MATCH_NONE:
162        break;
163      case MATCH_OUTPUT:
164        swap(loop_.ilabel, loop_.olabel);
165        break;
166      default:
167        FSTERROR() << "SortedMatcher: bad match type";
168        match_type_ = MATCH_NONE;
169        error_ = true;
170    }
171  }
172
173  SortedMatcher(const SortedMatcher<F> &matcher, bool safe = false)
174      : fst_(matcher.fst_->Copy(safe)),
175        s_(kNoStateId),
176        aiter_(0),
177        match_type_(matcher.match_type_),
178        binary_label_(matcher.binary_label_),
179        match_label_(kNoLabel),
180        narcs_(0),
181        loop_(matcher.loop_),
182        error_(matcher.error_) {}
183
184  virtual ~SortedMatcher() {
185    if (aiter_)
186      delete aiter_;
187    delete fst_;
188  }
189
190  virtual SortedMatcher<F> *Copy(bool safe = false) const {
191    return new SortedMatcher<F>(*this, safe);
192  }
193
194  virtual MatchType Type(bool test) const {
195    if (match_type_ == MATCH_NONE)
196      return match_type_;
197
198    uint64 true_prop =  match_type_ == MATCH_INPUT ?
199        kILabelSorted : kOLabelSorted;
200    uint64 false_prop = match_type_ == MATCH_INPUT ?
201        kNotILabelSorted : kNotOLabelSorted;
202    uint64 props = fst_->Properties(true_prop | false_prop, test);
203
204    if (props & true_prop)
205      return match_type_;
206    else if (props & false_prop)
207      return MATCH_NONE;
208    else
209      return MATCH_UNKNOWN;
210  }
211
212  void SetState(StateId s) {
213    if (s_ == s)
214      return;
215    s_ = s;
216    if (match_type_ == MATCH_NONE) {
217      FSTERROR() << "SortedMatcher: bad match type";
218      error_ = true;
219    }
220    if (aiter_)
221      delete aiter_;
222    aiter_ = new ArcIterator<F>(*fst_, s);
223    aiter_->SetFlags(kArcNoCache, kArcNoCache);
224    narcs_ = internal::NumArcs(*fst_, s);
225    loop_.nextstate = s;
226  }
227
228  bool Find(Label match_label) {
229    exact_match_ = true;
230    if (error_) {
231      current_loop_ = false;
232      match_label_ = kNoLabel;
233      return false;
234    }
235    current_loop_ = match_label == 0;
236    match_label_ = match_label == kNoLabel ? 0 : match_label;
237    if (Search()) {
238      return true;
239    } else {
240      return current_loop_;
241    }
242  }
243
244  // Positions matcher to the first position where inserting
245  // match_label would maintain the sort order.
246  void LowerBound(Label match_label) {
247    exact_match_ = false;
248    current_loop_ = false;
249    if (error_) {
250      match_label_ = kNoLabel;
251      return;
252    }
253    match_label_ = match_label;
254    Search();
255  }
256
257  // After Find(), returns false if no more exact matches.
258  // After LowerBound(), returns false if no more arcs.
259  bool Done() const {
260    if (current_loop_)
261      return false;
262    if (aiter_->Done())
263      return true;
264    if (!exact_match_)
265      return false;
266    aiter_->SetFlags(
267      match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
268      kArcValueFlags);
269    Label label = match_type_ == MATCH_INPUT ?
270        aiter_->Value().ilabel : aiter_->Value().olabel;
271    return label != match_label_;
272  }
273
274  const Arc& Value() const {
275    if (current_loop_) {
276      return loop_;
277    }
278    aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
279    return aiter_->Value();
280  }
281
282  void Next() {
283    if (current_loop_)
284      current_loop_ = false;
285    else
286      aiter_->Next();
287  }
288
289  virtual const F &GetFst() const { return *fst_; }
290
291  virtual uint64 Properties(uint64 inprops) const {
292    uint64 outprops = inprops;
293    if (error_) outprops |= kError;
294    return outprops;
295  }
296
297  size_t Position() const { return aiter_ ? aiter_->Position() : 0; }
298
299 private:
300  virtual void SetState_(StateId s) { SetState(s); }
301  virtual bool Find_(Label label) { return Find(label); }
302  virtual bool Done_() const { return Done(); }
303  virtual const Arc& Value_() const { return Value(); }
304  virtual void Next_() { Next(); }
305
306  bool Search();
307
308  const F *fst_;
309  StateId s_;                     // Current state
310  ArcIterator<F> *aiter_;         // Iterator for current state
311  MatchType match_type_;          // Type of match to perform
312  Label binary_label_;            // Least label for binary search
313  Label match_label_;             // Current label to be matched
314  size_t narcs_;                  // Current state arc count
315  Arc loop_;                      // For non-consuming symbols
316  bool current_loop_;             // Current arc is the implicit loop
317  bool exact_match_;              // Exact match or lower bound?
318  bool error_;                    // Error encountered
319
320  void operator=(const SortedMatcher<F> &);  // Disallow
321};
322
323// Returns true iff match to match_label_. Positions arc iterator at
324// lower bound regardless.
325template <class F> inline
326bool SortedMatcher<F>::Search() {
327  aiter_->SetFlags(
328      match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
329      kArcValueFlags);
330  if (match_label_ >= binary_label_) {
331    // Binary search for match.
332    size_t low = 0;
333    size_t high = narcs_;
334    while (low < high) {
335      size_t mid = (low + high) / 2;
336      aiter_->Seek(mid);
337      Label label = match_type_ == MATCH_INPUT ?
338          aiter_->Value().ilabel : aiter_->Value().olabel;
339      if (label > match_label_) {
340        high = mid;
341      } else if (label < match_label_) {
342        low = mid + 1;
343      } else {
344        // find first matching label (when non-determinism)
345        for (size_t i = mid; i > low; --i) {
346          aiter_->Seek(i - 1);
347          label = match_type_ == MATCH_INPUT ? aiter_->Value().ilabel :
348              aiter_->Value().olabel;
349          if (label != match_label_) {
350            aiter_->Seek(i);
351            return true;
352          }
353        }
354        return true;
355      }
356    }
357    aiter_->Seek(low);
358    return false;
359  } else {
360    // Linear search for match.
361    for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) {
362      Label label = match_type_ == MATCH_INPUT ?
363          aiter_->Value().ilabel : aiter_->Value().olabel;
364      if (label == match_label_) {
365        return true;
366      }
367      if (label > match_label_)
368        break;
369    }
370    return false;
371  }
372}
373
374
375// Specifies whether during matching we rewrite both the input and output sides.
376enum MatcherRewriteMode {
377  MATCHER_REWRITE_AUTO = 0,    // Rewrites both sides iff acceptor.
378  MATCHER_REWRITE_ALWAYS,
379  MATCHER_REWRITE_NEVER
380};
381
382
383// For any requested label that doesn't match at a state, this matcher
384// considers all transitions that match the label 'rho_label' (rho =
385// 'rest').  Each such rho transition found is returned with the
386// rho_label rewritten as the requested label (both sides if an
387// acceptor, or if 'rewrite_both' is true and both input and output
388// labels of the found transition are 'rho_label').  If 'rho_label' is
389// kNoLabel, this special matching is not done.  RhoMatcher is
390// templated itself on a matcher, which is used to perform the
391// underlying matching.  By default, the underlying matcher is
392// constructed by RhoMatcher.  The user can instead pass in this
393// object; in that case, RhoMatcher takes its ownership.
394template <class M>
395class RhoMatcher : public MatcherBase<typename M::Arc> {
396 public:
397  typedef typename M::FST FST;
398  typedef typename M::Arc Arc;
399  typedef typename Arc::StateId StateId;
400  typedef typename Arc::Label Label;
401  typedef typename Arc::Weight Weight;
402
403  RhoMatcher(const FST &fst,
404             MatchType match_type,
405             Label rho_label = kNoLabel,
406             MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
407             M *matcher = 0)
408      : matcher_(matcher ? matcher : new M(fst, match_type)),
409        match_type_(match_type),
410        rho_label_(rho_label),
411        error_(false) {
412    if (match_type == MATCH_BOTH) {
413      FSTERROR() << "RhoMatcher: bad match type";
414      match_type_ = MATCH_NONE;
415      error_ = true;
416    }
417    if (rho_label == 0) {
418      FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label";
419      rho_label_ = kNoLabel;
420      error_ = true;
421    }
422
423    if (rewrite_mode == MATCHER_REWRITE_AUTO)
424      rewrite_both_ = fst.Properties(kAcceptor, true);
425    else if (rewrite_mode == MATCHER_REWRITE_ALWAYS)
426      rewrite_both_ = true;
427    else
428      rewrite_both_ = false;
429  }
430
431  RhoMatcher(const RhoMatcher<M> &matcher, bool safe = false)
432      : matcher_(new M(*matcher.matcher_, safe)),
433        match_type_(matcher.match_type_),
434        rho_label_(matcher.rho_label_),
435        rewrite_both_(matcher.rewrite_both_),
436        error_(matcher.error_) {}
437
438  virtual ~RhoMatcher() {
439    delete matcher_;
440  }
441
442  virtual RhoMatcher<M> *Copy(bool safe = false) const {
443    return new RhoMatcher<M>(*this, safe);
444  }
445
446  virtual MatchType Type(bool test) const { return matcher_->Type(test); }
447
448  void SetState(StateId s) {
449    matcher_->SetState(s);
450    has_rho_ = rho_label_ != kNoLabel;
451  }
452
453  bool Find(Label match_label) {
454    if (match_label == rho_label_ && rho_label_ != kNoLabel) {
455      FSTERROR() << "RhoMatcher::Find: bad label (rho)";
456      error_ = true;
457      return false;
458    }
459    if (matcher_->Find(match_label)) {
460      rho_match_ = kNoLabel;
461      return true;
462    } else if (has_rho_ && match_label != 0 && match_label != kNoLabel &&
463               (has_rho_ = matcher_->Find(rho_label_))) {
464      rho_match_ = match_label;
465      return true;
466    } else {
467      return false;
468    }
469  }
470
471  bool Done() const { return matcher_->Done(); }
472
473  const Arc& Value() const {
474    if (rho_match_ == kNoLabel) {
475      return matcher_->Value();
476    } else {
477      rho_arc_ = matcher_->Value();
478      if (rewrite_both_) {
479        if (rho_arc_.ilabel == rho_label_)
480          rho_arc_.ilabel = rho_match_;
481        if (rho_arc_.olabel == rho_label_)
482          rho_arc_.olabel = rho_match_;
483      } else if (match_type_ == MATCH_INPUT) {
484        rho_arc_.ilabel = rho_match_;
485      } else {
486        rho_arc_.olabel = rho_match_;
487      }
488      return rho_arc_;
489    }
490  }
491
492  void Next() { matcher_->Next(); }
493
494  virtual const FST &GetFst() const { return matcher_->GetFst(); }
495
496  virtual uint64 Properties(uint64 props) const;
497
498  virtual uint32 Flags() const {
499    if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE)
500      return matcher_->Flags();
501    return matcher_->Flags() | kRequireMatch;
502  }
503
504 private:
505  virtual void SetState_(StateId s) { SetState(s); }
506  virtual bool Find_(Label label) { return Find(label); }
507  virtual bool Done_() const { return Done(); }
508  virtual const Arc& Value_() const { return Value(); }
509  virtual void Next_() { Next(); }
510
511  M *matcher_;
512  MatchType match_type_;  // Type of match requested
513  Label rho_label_;       // Label that represents the rho transition
514  bool rewrite_both_;     // Rewrite both sides when both are 'rho_label_'
515  bool has_rho_;          // Are there possibly rhos at the current state?
516  Label rho_match_;       // Current label that matches rho transition
517  mutable Arc rho_arc_;   // Arc to return when rho match
518  bool error_;            // Error encountered
519
520  void operator=(const RhoMatcher<M> &);  // Disallow
521};
522
523template <class M> inline
524uint64 RhoMatcher<M>::Properties(uint64 inprops) const {
525  uint64 outprops = matcher_->Properties(inprops);
526  if (error_) outprops |= kError;
527
528  if (match_type_ == MATCH_NONE) {
529    return outprops;
530  } else if (match_type_ == MATCH_INPUT) {
531    if (rewrite_both_) {
532      return outprops & ~(kODeterministic | kNonODeterministic | kString |
533                       kILabelSorted | kNotILabelSorted |
534                       kOLabelSorted | kNotOLabelSorted);
535    } else {
536      return outprops & ~(kODeterministic | kAcceptor | kString |
537                       kILabelSorted | kNotILabelSorted);
538    }
539  } else if (match_type_ == MATCH_OUTPUT) {
540    if (rewrite_both_) {
541      return outprops & ~(kIDeterministic | kNonIDeterministic | kString |
542                       kILabelSorted | kNotILabelSorted |
543                       kOLabelSorted | kNotOLabelSorted);
544    } else {
545      return outprops & ~(kIDeterministic | kAcceptor | kString |
546                       kOLabelSorted | kNotOLabelSorted);
547    }
548  } else {
549    // Shouldn't ever get here.
550    FSTERROR() << "RhoMatcher:: bad match type: " << match_type_;
551    return 0;
552  }
553}
554
555
556// For any requested label, this matcher considers all transitions
557// that match the label 'sigma_label' (sigma = "any"), and this in
558// additions to transitions with the requested label.  Each such sigma
559// transition found is returned with the sigma_label rewritten as the
560// requested label (both sides if an acceptor, or if 'rewrite_both' is
561// true and both input and output labels of the found transition are
562// 'sigma_label').  If 'sigma_label' is kNoLabel, this special
563// matching is not done.  SigmaMatcher is templated itself on a
564// matcher, which is used to perform the underlying matching.  By
565// default, the underlying matcher is constructed by SigmaMatcher.
566// The user can instead pass in this object; in that case,
567// SigmaMatcher takes its ownership.
568template <class M>
569class SigmaMatcher : public MatcherBase<typename M::Arc> {
570 public:
571  typedef typename M::FST FST;
572  typedef typename M::Arc Arc;
573  typedef typename Arc::StateId StateId;
574  typedef typename Arc::Label Label;
575  typedef typename Arc::Weight Weight;
576
577  SigmaMatcher(const FST &fst,
578               MatchType match_type,
579               Label sigma_label = kNoLabel,
580               MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
581               M *matcher = 0)
582      : matcher_(matcher ? matcher : new M(fst, match_type)),
583        match_type_(match_type),
584        sigma_label_(sigma_label),
585        error_(false) {
586    if (match_type == MATCH_BOTH) {
587      FSTERROR() << "SigmaMatcher: bad match type";
588      match_type_ = MATCH_NONE;
589      error_ = true;
590    }
591    if (sigma_label == 0) {
592      FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label";
593      sigma_label_ = kNoLabel;
594      error_ = true;
595    }
596
597    if (rewrite_mode == MATCHER_REWRITE_AUTO)
598      rewrite_both_ = fst.Properties(kAcceptor, true);
599    else if (rewrite_mode == MATCHER_REWRITE_ALWAYS)
600      rewrite_both_ = true;
601    else
602      rewrite_both_ = false;
603  }
604
605  SigmaMatcher(const SigmaMatcher<M> &matcher, bool safe = false)
606      : matcher_(new M(*matcher.matcher_, safe)),
607        match_type_(matcher.match_type_),
608        sigma_label_(matcher.sigma_label_),
609        rewrite_both_(matcher.rewrite_both_),
610        error_(matcher.error_) {}
611
612  virtual ~SigmaMatcher() {
613    delete matcher_;
614  }
615
616  virtual SigmaMatcher<M> *Copy(bool safe = false) const {
617    return new SigmaMatcher<M>(*this, safe);
618  }
619
620  virtual MatchType Type(bool test) const { return matcher_->Type(test); }
621
622  void SetState(StateId s) {
623    matcher_->SetState(s);
624    has_sigma_ =
625        sigma_label_ != kNoLabel ? matcher_->Find(sigma_label_) : false;
626  }
627
628  bool Find(Label match_label) {
629    match_label_ = match_label;
630    if (match_label == sigma_label_ && sigma_label_ != kNoLabel) {
631      FSTERROR() << "SigmaMatcher::Find: bad label (sigma)";
632      error_ = true;
633      return false;
634    }
635    if (matcher_->Find(match_label)) {
636      sigma_match_ = kNoLabel;
637      return true;
638    } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel &&
639               matcher_->Find(sigma_label_)) {
640      sigma_match_ = match_label;
641      return true;
642    } else {
643      return false;
644    }
645  }
646
647  bool Done() const {
648    return matcher_->Done();
649  }
650
651  const Arc& Value() const {
652    if (sigma_match_ == kNoLabel) {
653      return matcher_->Value();
654    } else {
655      sigma_arc_ = matcher_->Value();
656      if (rewrite_both_) {
657        if (sigma_arc_.ilabel == sigma_label_)
658          sigma_arc_.ilabel = sigma_match_;
659        if (sigma_arc_.olabel == sigma_label_)
660          sigma_arc_.olabel = sigma_match_;
661      } else if (match_type_ == MATCH_INPUT) {
662        sigma_arc_.ilabel = sigma_match_;
663      } else {
664        sigma_arc_.olabel = sigma_match_;
665      }
666      return sigma_arc_;
667    }
668  }
669
670  void Next() {
671    matcher_->Next();
672    if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) &&
673        (match_label_ > 0)) {
674      matcher_->Find(sigma_label_);
675      sigma_match_ = match_label_;
676    }
677  }
678
679  virtual const FST &GetFst() const { return matcher_->GetFst(); }
680
681  virtual uint64 Properties(uint64 props) const;
682
683  virtual uint32 Flags() const {
684    if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE)
685      return matcher_->Flags();
686    // kRequireMatch temporarily disabled until issues
687    // in //speech/gaudi/annotation/util/denorm are resolved.
688    // return matcher_->Flags() | kRequireMatch;
689    return matcher_->Flags();
690  }
691
692private:
693  virtual void SetState_(StateId s) { SetState(s); }
694  virtual bool Find_(Label label) { return Find(label); }
695  virtual bool Done_() const { return Done(); }
696  virtual const Arc& Value_() const { return Value(); }
697  virtual void Next_() { Next(); }
698
699  M *matcher_;
700  MatchType match_type_;   // Type of match requested
701  Label sigma_label_;      // Label that represents the sigma transition
702  bool rewrite_both_;      // Rewrite both sides when both are 'sigma_label_'
703  bool has_sigma_;         // Are there sigmas at the current state?
704  Label sigma_match_;      // Current label that matches sigma transition
705  mutable Arc sigma_arc_;  // Arc to return when sigma match
706  Label match_label_;      // Label being matched
707  bool error_;             // Error encountered
708
709  void operator=(const SigmaMatcher<M> &);  // disallow
710};
711
712template <class M> inline
713uint64 SigmaMatcher<M>::Properties(uint64 inprops) const {
714  uint64 outprops = matcher_->Properties(inprops);
715  if (error_) outprops |= kError;
716
717  if (match_type_ == MATCH_NONE) {
718    return outprops;
719  } else if (rewrite_both_) {
720    return outprops & ~(kIDeterministic | kNonIDeterministic |
721                     kODeterministic | kNonODeterministic |
722                     kILabelSorted | kNotILabelSorted |
723                     kOLabelSorted | kNotOLabelSorted |
724                     kString);
725  } else if (match_type_ == MATCH_INPUT) {
726    return outprops & ~(kIDeterministic | kNonIDeterministic |
727                     kODeterministic | kNonODeterministic |
728                     kILabelSorted | kNotILabelSorted |
729                     kString | kAcceptor);
730  } else if (match_type_ == MATCH_OUTPUT) {
731    return outprops & ~(kIDeterministic | kNonIDeterministic |
732                     kODeterministic | kNonODeterministic |
733                     kOLabelSorted | kNotOLabelSorted |
734                     kString | kAcceptor);
735  } else {
736    // Shouldn't ever get here.
737    FSTERROR() << "SigmaMatcher:: bad match type: " << match_type_;
738    return 0;
739  }
740}
741
742
743// For any requested label that doesn't match at a state, this matcher
744// considers the *unique* transition that matches the label 'phi_label'
745// (phi = 'fail'), and recursively looks for a match at its
746// destination.  When 'phi_loop' is true, if no match is found but a
747// phi self-loop is found, then the phi transition found is returned
748// with the phi_label rewritten as the requested label (both sides if
749// an acceptor, or if 'rewrite_both' is true and both input and output
750// labels of the found transition are 'phi_label').  If 'phi_label' is
751// kNoLabel, this special matching is not done.  PhiMatcher is
752// templated itself on a matcher, which is used to perform the
753// underlying matching.  By default, the underlying matcher is
754// constructed by PhiMatcher. The user can instead pass in this
755// object; in that case, PhiMatcher takes its ownership.
756// Warning: phi non-determinism not supported (for simplicity).
757template <class M>
758class PhiMatcher : public MatcherBase<typename M::Arc> {
759 public:
760  typedef typename M::FST FST;
761  typedef typename M::Arc Arc;
762  typedef typename Arc::StateId StateId;
763  typedef typename Arc::Label Label;
764  typedef typename Arc::Weight Weight;
765
766  PhiMatcher(const FST &fst,
767             MatchType match_type,
768             Label phi_label = kNoLabel,
769             bool phi_loop = true,
770             MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
771             M *matcher = 0)
772      : matcher_(matcher ? matcher : new M(fst, match_type)),
773        match_type_(match_type),
774        phi_label_(phi_label),
775        state_(kNoStateId),
776        phi_loop_(phi_loop),
777        error_(false) {
778    if (match_type == MATCH_BOTH) {
779      FSTERROR() << "PhiMatcher: bad match type";
780      match_type_ = MATCH_NONE;
781      error_ = true;
782    }
783
784    if (rewrite_mode == MATCHER_REWRITE_AUTO)
785      rewrite_both_ = fst.Properties(kAcceptor, true);
786    else if (rewrite_mode == MATCHER_REWRITE_ALWAYS)
787      rewrite_both_ = true;
788    else
789      rewrite_both_ = false;
790   }
791
792  PhiMatcher(const PhiMatcher<M> &matcher, bool safe = false)
793      : matcher_(new M(*matcher.matcher_, safe)),
794        match_type_(matcher.match_type_),
795        phi_label_(matcher.phi_label_),
796        rewrite_both_(matcher.rewrite_both_),
797        state_(kNoStateId),
798        phi_loop_(matcher.phi_loop_),
799        error_(matcher.error_) {}
800
801  virtual ~PhiMatcher() {
802    delete matcher_;
803  }
804
805  virtual PhiMatcher<M> *Copy(bool safe = false) const {
806    return new PhiMatcher<M>(*this, safe);
807  }
808
809  virtual MatchType Type(bool test) const { return matcher_->Type(test); }
810
811  void SetState(StateId s) {
812    matcher_->SetState(s);
813    state_ = s;
814    has_phi_ = phi_label_ != kNoLabel;
815  }
816
817  bool Find(Label match_label);
818
819  bool Done() const { return matcher_->Done(); }
820
821  const Arc& Value() const {
822    if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) {
823      return matcher_->Value();
824    } else if (phi_match_ == 0) {  // Virtual epsilon loop
825      phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_);
826      if (match_type_ == MATCH_OUTPUT)
827        swap(phi_arc_.ilabel, phi_arc_.olabel);
828      return phi_arc_;
829    } else {
830      phi_arc_ = matcher_->Value();
831      phi_arc_.weight = Times(phi_weight_, phi_arc_.weight);
832      if (phi_match_ != kNoLabel) {  // Phi loop match
833        if (rewrite_both_) {
834          if (phi_arc_.ilabel == phi_label_)
835            phi_arc_.ilabel = phi_match_;
836          if (phi_arc_.olabel == phi_label_)
837            phi_arc_.olabel = phi_match_;
838        } else if (match_type_ == MATCH_INPUT) {
839          phi_arc_.ilabel = phi_match_;
840        } else {
841          phi_arc_.olabel = phi_match_;
842        }
843      }
844      return phi_arc_;
845    }
846  }
847
848  void Next() { matcher_->Next(); }
849
850  virtual const FST &GetFst() const { return matcher_->GetFst(); }
851
852  virtual uint64 Properties(uint64 props) const;
853
854  virtual uint32 Flags() const {
855    if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE)
856      return matcher_->Flags();
857    return matcher_->Flags() | kRequireMatch;
858  }
859
860private:
861  virtual void SetState_(StateId s) { SetState(s); }
862  virtual bool Find_(Label label) { return Find(label); }
863  virtual bool Done_() const { return Done(); }
864  virtual const Arc& Value_() const { return Value(); }
865  virtual void Next_() { Next(); }
866
867  M *matcher_;
868  MatchType match_type_;  // Type of match requested
869  Label phi_label_;       // Label that represents the phi transition
870  bool rewrite_both_;     // Rewrite both sides when both are 'phi_label_'
871  bool has_phi_;          // Are there possibly phis at the current state?
872  Label phi_match_;       // Current label that matches phi loop
873  mutable Arc phi_arc_;   // Arc to return
874  StateId state_;         // State where looking for matches
875  Weight phi_weight_;     // Product of the weights of phi transitions taken
876  bool phi_loop_;         // When true, phi self-loop are allowed and treated
877                          // as rho (required for Aho-Corasick)
878  bool error_;             // Error encountered
879
880  void operator=(const PhiMatcher<M> &);  // disallow
881};
882
883template <class M> inline
884bool PhiMatcher<M>::Find(Label match_label) {
885  if (match_label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) {
886    FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_;
887    error_ = true;
888    return false;
889  }
890  matcher_->SetState(state_);
891  phi_match_ = kNoLabel;
892  phi_weight_ = Weight::One();
893  if (phi_label_ == 0) {          // When 'phi_label_ == 0',
894    if (match_label == kNoLabel)  // there are no more true epsilon arcs,
895      return false;
896    if (match_label == 0) {       // but virtual eps loop need to be returned
897      if (!matcher_->Find(kNoLabel)) {
898        return matcher_->Find(0);
899      } else {
900        phi_match_ = 0;
901        return true;
902      }
903    }
904  }
905  if (!has_phi_ || match_label == 0 || match_label == kNoLabel)
906    return matcher_->Find(match_label);
907  StateId state = state_;
908  while (!matcher_->Find(match_label)) {
909    // Look for phi transition (if phi_label_ == 0, we need to look
910    // for -1 to avoid getting the virtual self-loop)
911    if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_))
912      return false;
913    if (phi_loop_ && matcher_->Value().nextstate == state) {
914      phi_match_ = match_label;
915      return true;
916    }
917    phi_weight_ = Times(phi_weight_, matcher_->Value().weight);
918    state = matcher_->Value().nextstate;
919    matcher_->Next();
920    if (!matcher_->Done()) {
921      FSTERROR() << "PhiMatcher: phi non-determinism not supported";
922      error_ = true;
923    }
924    matcher_->SetState(state);
925  }
926  return true;
927}
928
929template <class M> inline
930uint64 PhiMatcher<M>::Properties(uint64 inprops) const {
931  uint64 outprops = matcher_->Properties(inprops);
932  if (error_) outprops |= kError;
933
934  if (match_type_ == MATCH_NONE) {
935    return outprops;
936  } else if (match_type_ == MATCH_INPUT) {
937    if (phi_label_ == 0) {
938      outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons;
939      outprops |= kNoEpsilons | kNoIEpsilons;
940    }
941    if (rewrite_both_) {
942      return outprops & ~(kODeterministic | kNonODeterministic | kString |
943                       kILabelSorted | kNotILabelSorted |
944                       kOLabelSorted | kNotOLabelSorted);
945    } else {
946      return outprops & ~(kODeterministic | kAcceptor | kString |
947                       kILabelSorted | kNotILabelSorted |
948                       kOLabelSorted | kNotOLabelSorted);
949    }
950  } else if (match_type_ == MATCH_OUTPUT) {
951    if (phi_label_ == 0) {
952      outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons;
953      outprops |= kNoEpsilons | kNoOEpsilons;
954    }
955    if (rewrite_both_) {
956      return outprops & ~(kIDeterministic | kNonIDeterministic | kString |
957                       kILabelSorted | kNotILabelSorted |
958                       kOLabelSorted | kNotOLabelSorted);
959    } else {
960      return outprops & ~(kIDeterministic | kAcceptor | kString |
961                       kILabelSorted | kNotILabelSorted |
962                       kOLabelSorted | kNotOLabelSorted);
963    }
964  } else {
965    // Shouldn't ever get here.
966    FSTERROR() << "PhiMatcher:: bad match type: " << match_type_;
967    return 0;
968  }
969}
970
971
972//
973// MULTI-EPS MATCHER FLAGS
974//
975
976// Return multi-epsilon arcs for Find(kNoLabel).
977const uint32 kMultiEpsList =  0x00000001;
978
979// Return a kNolabel loop for Find(multi_eps).
980const uint32 kMultiEpsLoop =  0x00000002;
981
982// MultiEpsMatcher: allows treating multiple non-0 labels as
983// non-consuming labels in addition to 0 that is always
984// non-consuming. Precise behavior controlled by 'flags' argument. By
985// default, the underlying matcher is constructed by
986// MultiEpsMatcher. The user can instead pass in this object; in that
987// case, MultiEpsMatcher takes its ownership iff 'own_matcher' is
988// true.
989template <class M>
990class MultiEpsMatcher {
991 public:
992  typedef typename M::FST FST;
993  typedef typename M::Arc Arc;
994  typedef typename Arc::StateId StateId;
995  typedef typename Arc::Label Label;
996  typedef typename Arc::Weight Weight;
997
998  MultiEpsMatcher(const FST &fst, MatchType match_type,
999                  uint32 flags = (kMultiEpsLoop | kMultiEpsList),
1000                  M *matcher = 0, bool own_matcher = true)
1001      : matcher_(matcher ? matcher : new M(fst, match_type)),
1002        flags_(flags),
1003        own_matcher_(matcher ?  own_matcher : true) {
1004    if (match_type == MATCH_INPUT) {
1005      loop_.ilabel = kNoLabel;
1006      loop_.olabel = 0;
1007    } else {
1008      loop_.ilabel = 0;
1009      loop_.olabel = kNoLabel;
1010    }
1011    loop_.weight = Weight::One();
1012    loop_.nextstate = kNoStateId;
1013  }
1014
1015  MultiEpsMatcher(const MultiEpsMatcher<M> &matcher, bool safe = false)
1016      : matcher_(new M(*matcher.matcher_, safe)),
1017        flags_(matcher.flags_),
1018        own_matcher_(true),
1019        multi_eps_labels_(matcher.multi_eps_labels_),
1020        loop_(matcher.loop_) {
1021    loop_.nextstate = kNoStateId;
1022  }
1023
1024  ~MultiEpsMatcher() {
1025    if (own_matcher_)
1026      delete matcher_;
1027  }
1028
1029  MultiEpsMatcher<M> *Copy(bool safe = false) const {
1030    return new MultiEpsMatcher<M>(*this, safe);
1031  }
1032
1033  MatchType Type(bool test) const { return matcher_->Type(test); }
1034
1035  void SetState(StateId s) {
1036    matcher_->SetState(s);
1037    loop_.nextstate = s;
1038  }
1039
1040  bool Find(Label match_label);
1041
1042  bool Done() const {
1043    return done_;
1044  }
1045
1046  const Arc& Value() const {
1047    return current_loop_ ? loop_ : matcher_->Value();
1048  }
1049
1050  void Next() {
1051    if (!current_loop_) {
1052      matcher_->Next();
1053      done_ = matcher_->Done();
1054      if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) {
1055        ++multi_eps_iter_;
1056        while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1057               !matcher_->Find(*multi_eps_iter_))
1058          ++multi_eps_iter_;
1059        if (multi_eps_iter_ != multi_eps_labels_.End())
1060          done_ = false;
1061        else
1062          done_ = !matcher_->Find(kNoLabel);
1063
1064      }
1065    } else {
1066      done_ = true;
1067    }
1068  }
1069
1070  const FST &GetFst() const { return matcher_->GetFst(); }
1071
1072  uint64 Properties(uint64 props) const { return matcher_->Properties(props); }
1073
1074  uint32 Flags() const { return matcher_->Flags(); }
1075
1076  void AddMultiEpsLabel(Label label) {
1077    if (label == 0) {
1078      FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1079    } else {
1080      multi_eps_labels_.Insert(label);
1081    }
1082  }
1083
1084  void RemoveMultiEpsLabel(Label label) {
1085    if (label == 0) {
1086      FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1087    } else {
1088      multi_eps_labels_.Erase(label);
1089    }
1090  }
1091
1092  void ClearMultiEpsLabels() {
1093    multi_eps_labels_.Clear();
1094  }
1095
1096private:
1097  M *matcher_;
1098  uint32 flags_;
1099  bool own_matcher_;             // Does this class delete the matcher?
1100
1101  // Multi-eps label set
1102  CompactSet<Label, kNoLabel> multi_eps_labels_;
1103  typename CompactSet<Label, kNoLabel>::const_iterator multi_eps_iter_;
1104
1105  bool current_loop_;            // Current arc is the implicit loop
1106  mutable Arc loop_;             // For non-consuming symbols
1107  bool done_;                    // Matching done
1108
1109  void operator=(const MultiEpsMatcher<M> &);  // Disallow
1110};
1111
1112template <class M> inline
1113bool MultiEpsMatcher<M>::Find(Label match_label) {
1114  multi_eps_iter_ = multi_eps_labels_.End();
1115  current_loop_ = false;
1116  bool ret;
1117  if (match_label == 0) {
1118    ret = matcher_->Find(0);
1119  } else if (match_label == kNoLabel) {
1120    if (flags_ & kMultiEpsList) {
1121      // return all non-consuming arcs (incl. epsilon)
1122      multi_eps_iter_ = multi_eps_labels_.Begin();
1123      while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1124             !matcher_->Find(*multi_eps_iter_))
1125        ++multi_eps_iter_;
1126      if (multi_eps_iter_ != multi_eps_labels_.End())
1127        ret = true;
1128      else
1129        ret = matcher_->Find(kNoLabel);
1130    } else {
1131      // return all epsilon arcs
1132      ret = matcher_->Find(kNoLabel);
1133    }
1134  } else if ((flags_ & kMultiEpsLoop) &&
1135             multi_eps_labels_.Find(match_label) != multi_eps_labels_.End()) {
1136    // return 'implicit' loop
1137    current_loop_ = true;
1138    ret = true;
1139  } else {
1140    ret = matcher_->Find(match_label);
1141  }
1142  done_ = !ret;
1143  return ret;
1144}
1145
1146
1147// Generic matcher, templated on the FST definition
1148// - a wrapper around pointer to specific one.
1149// Here is a typical use: \code
1150//   Matcher<StdFst> matcher(fst, MATCH_INPUT);
1151//   matcher.SetState(state);
1152//   if (matcher.Find(label))
1153//     for (; !matcher.Done(); matcher.Next()) {
1154//       StdArc &arc = matcher.Value();
1155//       ...
1156//     } \endcode
1157template <class F>
1158class Matcher {
1159 public:
1160  typedef F FST;
1161  typedef typename F::Arc Arc;
1162  typedef typename Arc::StateId StateId;
1163  typedef typename Arc::Label Label;
1164  typedef typename Arc::Weight Weight;
1165
1166  Matcher(const F &fst, MatchType match_type) {
1167    base_ = fst.InitMatcher(match_type);
1168    if (!base_)
1169      base_ = new SortedMatcher<F>(fst, match_type);
1170  }
1171
1172  Matcher(const Matcher<F> &matcher, bool safe = false) {
1173    base_ = matcher.base_->Copy(safe);
1174  }
1175
1176  // Takes ownership of the provided matcher
1177  Matcher(MatcherBase<Arc>* base_matcher) { base_ = base_matcher; }
1178
1179  ~Matcher() { delete base_; }
1180
1181  Matcher<F> *Copy(bool safe = false) const {
1182    return new Matcher<F>(*this, safe);
1183  }
1184
1185  MatchType Type(bool test) const { return base_->Type(test); }
1186  void SetState(StateId s) { base_->SetState(s); }
1187  bool Find(Label label) { return base_->Find(label); }
1188  bool Done() const { return base_->Done(); }
1189  const Arc& Value() const { return base_->Value(); }
1190  void Next() { base_->Next(); }
1191  const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); }
1192  uint64 Properties(uint64 props) const { return base_->Properties(props); }
1193  uint32 Flags() const { return base_->Flags() & kMatcherFlags; }
1194
1195 private:
1196  MatcherBase<Arc> *base_;
1197
1198  void operator=(const Matcher<Arc> &);  // disallow
1199};
1200
1201}  // namespace fst
1202
1203
1204
1205#endif  // FST_LIB_MATCHER_H__
1206