lookahead-matcher.h revision dfd8b8327b93660601d016cdc6f29f433b45a8d8
1// lookahead-matcher.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Classes to add lookahead to FST matchers, useful e.g. for improving
20// composition efficiency with certain inputs.
21
22#ifndef FST_LIB_LOOKAHEAD_MATCHER_H__
23#define FST_LIB_LOOKAHEAD_MATCHER_H__
24
25#include <fst/add-on.h>
26#include <fst/const-fst.h>
27#include <fst/fst.h>
28#include <fst/label-reachable.h>
29#include <fst/matcher.h>
30
31
32DECLARE_string(save_relabel_ipairs);
33DECLARE_string(save_relabel_opairs);
34
35namespace fst {
36
37// LOOKAHEAD MATCHERS - these have the interface of Matchers (see
38// matcher.h) and these additional methods:
39//
40// template <class F>
41// class LookAheadMatcher {
42//  public:
43//   typedef F FST;
44//   typedef F::Arc Arc;
45//   typedef typename Arc::StateId StateId;
46//   typedef typename Arc::Label Label;
47//   typedef typename Arc::Weight Weight;
48//
49//  // Required constructors.
50//  LookAheadMatcher(const F &fst, MatchType match_type);
51//   // If safe=true, the copy is thread-safe (except the lookahead Fst is
52//   // preserved). See Fst<>::Cop() for further doc.
53//  LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false);
54//
55//  Below are methods for looking ahead for a match to a label and
56//  more generally, to a rational set. Each returns false if there is
57//  definitely not a match and returns true if there possibly is a
58//  match.
59
60//  // LABEL LOOKAHEAD: Can 'label' be read from the current matcher state
61//  // after possibly following epsilon transitions?
62//  bool LookAheadLabel(Label label) const;
63//
64//  // RATIONAL LOOKAHEAD: The next methods allow looking ahead for an
65//  // arbitrary rational set of strings, specified by an FST and a state
66//  // from which to begin the matching. If the lookahead FST is a
67//  // transducer, this looks on the side different from the matcher
68//  // 'match_type' (cf. composition).
69//
70//  // Are there paths P from 's' in the lookahead FST that can be read from
71//  // the cur. matcher state?
72//  bool LookAheadFst(const Fst<Arc>& fst, StateId s);
73//
74//  // Gives an estimate of the combined weight of the paths P in the
75//  // lookahead and matcher FSTs for the last call to LookAheadFst.
76//  // A trivial implementation returns Weight::One(). Non-trivial
77//  // implementations are useful for weight-pushing in composition.
78//  Weight LookAheadWeight() const;
79//
80//  // Is there is a single non-epsilon arc found in the lookahead FST
81//  // that begins P (after possibly following any epsilons) in the last
82//  // call LookAheadFst? If so, return true and copy it to '*arc', o.w.
83//  // return false. A trivial implementation returns false. Non-trivial
84//  // implementations are useful for label-pushing in composition.
85//  bool LookAheadPrefix(Arc *arc);
86//
87//  // Optionally pre-specifies the lookahead FST that will be passed
88//  // to LookAheadFst() for possible precomputation. If copy is true,
89//  // then 'fst' is a copy of the FST used in the previous call to
90//  // this method (useful to avoid unnecessary updates).
91//  void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false);
92//
93// };
94
95//
96// LOOK-AHEAD FLAGS (see also kMatcherFlags in matcher.h):
97//
98// Matcher is a lookahead matcher when 'match_type' is MATCH_INPUT.
99const uint32 kInputLookAheadMatcher =     0x00000010;
100
101// Matcher is a lookahead matcher when 'match_type' is MATCH_OUTPUT.
102const uint32 kOutputLookAheadMatcher =    0x00000020;
103
104// A non-trivial implementation of LookAheadWeight() method defined and
105// should be used?
106const uint32 kLookAheadWeight =           0x00000040;
107
108// A non-trivial implementation of LookAheadPrefix() method defined and
109// should be used?
110const uint32 kLookAheadPrefix =           0x00000080;
111
112// Look-ahead of matcher FST non-epsilon arcs?
113const uint32 kLookAheadNonEpsilons =      0x00000100;
114
115// Look-ahead of matcher FST epsilon arcs?
116const uint32 kLookAheadEpsilons =         0x00000200;
117
118// Ignore epsilon paths for the lookahead prefix? Note this gives
119// correct results in composition only with an appropriate composition
120// filter since it depends on the filter blocking the ignored paths.
121const uint32 kLookAheadNonEpsilonPrefix = 0x00000400;
122
123// For LabelLookAheadMatcher, save relabeling data to file
124const uint32 kLookAheadKeepRelabelData =  0x00000800;
125
126// Flags used for lookahead matchers.
127const uint32 kLookAheadFlags =            0x00000ff0;
128
129// LookAhead Matcher interface, templated on the Arc definition; used
130// for lookahead matcher specializations that are returned by the
131// InitMatcher() Fst method.
132template <class A>
133class LookAheadMatcherBase : public MatcherBase<A> {
134 public:
135  typedef A Arc;
136  typedef typename A::StateId StateId;
137  typedef typename A::Label Label;
138  typedef typename A::Weight Weight;
139
140  LookAheadMatcherBase()
141  : weight_(Weight::One()),
142    prefix_arc_(kNoLabel, kNoLabel, Weight::One(), kNoStateId) {}
143
144  virtual ~LookAheadMatcherBase() {}
145
146  bool LookAheadLabel(Label label) const { return LookAheadLabel_(label); }
147
148  bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
149    return LookAheadFst_(fst, s);
150  }
151
152  Weight LookAheadWeight() const { return weight_; }
153
154  bool LookAheadPrefix(Arc *arc) const {
155    if (prefix_arc_.nextstate != kNoStateId) {
156      *arc = prefix_arc_;
157      return true;
158    } else {
159      return false;
160    }
161  }
162
163  virtual void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) = 0;
164
165 protected:
166  void SetLookAheadWeight(const Weight &w) { weight_ = w; }
167
168  void SetLookAheadPrefix(const Arc &arc) { prefix_arc_ = arc; }
169
170  void ClearLookAheadPrefix() { prefix_arc_.nextstate = kNoStateId; }
171
172 private:
173  virtual bool LookAheadLabel_(Label label) const = 0;
174  virtual bool LookAheadFst_(const Fst<Arc> &fst,
175                             StateId s) = 0;  // This must set l.a. weight and
176                                              // prefix if non-trivial.
177  Weight weight_;                             // Look-ahead weight
178  Arc prefix_arc_;                            // Look-ahead prefix arc
179};
180
181
182// Don't really lookahead, just declare future looks good regardless.
183template <class M>
184class TrivialLookAheadMatcher
185    : public LookAheadMatcherBase<typename M::FST::Arc> {
186 public:
187  typedef typename M::FST FST;
188  typedef typename M::Arc Arc;
189  typedef typename Arc::StateId StateId;
190  typedef typename Arc::Label Label;
191  typedef typename Arc::Weight Weight;
192
193  TrivialLookAheadMatcher(const FST &fst, MatchType match_type)
194      : matcher_(fst, match_type) {}
195
196  TrivialLookAheadMatcher(const TrivialLookAheadMatcher<M> &lmatcher,
197                          bool safe = false)
198      : matcher_(lmatcher.matcher_, safe) {}
199
200  // General matcher methods
201  TrivialLookAheadMatcher<M> *Copy(bool safe = false) const {
202    return new TrivialLookAheadMatcher<M>(*this, safe);
203  }
204
205  MatchType Type(bool test) const { return matcher_.Type(test); }
206  void SetState(StateId s) { return matcher_.SetState(s); }
207  bool Find(Label label) { return matcher_.Find(label); }
208  bool Done() const { return matcher_.Done(); }
209  const Arc& Value() const { return matcher_.Value(); }
210  void Next() { matcher_.Next(); }
211  virtual const FST &GetFst() const { return matcher_.GetFst(); }
212  uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
213  uint32 Flags() const {
214    return matcher_.Flags() | kInputLookAheadMatcher | kOutputLookAheadMatcher;
215  }
216
217  // Look-ahead methods.
218  bool LookAheadLabel(Label label) const { return true;  }
219  bool LookAheadFst(const Fst<Arc> &fst, StateId s) {return true; }
220  Weight LookAheadWeight() const { return Weight::One(); }
221  bool LookAheadPrefix(Arc *arc) const { return false; }
222  void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {}
223
224 private:
225  // This allows base class virtual access to non-virtual derived-
226  // class members of the same name. It makes the derived class more
227  // efficient to use but unsafe to further derive.
228  virtual void SetState_(StateId s) { SetState(s); }
229  virtual bool Find_(Label label) { return Find(label); }
230  virtual bool Done_() const { return Done(); }
231  virtual const Arc& Value_() const { return Value(); }
232  virtual void Next_() { Next(); }
233
234  bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
235
236  bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
237    return LookAheadFst(fst, s);
238  }
239
240  Weight LookAheadWeight_() const { return LookAheadWeight(); }
241  bool LookAheadPrefix_(Arc *arc) const { return LookAheadPrefix(arc); }
242
243  M matcher_;
244};
245
246// Look-ahead of one transition. Template argument F accepts flags to
247// control behavior.
248template <class M, uint32 F = kLookAheadNonEpsilons | kLookAheadEpsilons |
249          kLookAheadWeight | kLookAheadPrefix>
250class ArcLookAheadMatcher
251    : public LookAheadMatcherBase<typename M::FST::Arc> {
252 public:
253  typedef typename M::FST FST;
254  typedef typename M::Arc Arc;
255  typedef typename Arc::StateId StateId;
256  typedef typename Arc::Label Label;
257  typedef typename Arc::Weight Weight;
258  typedef NullAddOn MatcherData;
259
260  using LookAheadMatcherBase<Arc>::LookAheadWeight;
261  using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
262  using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
263  using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
264
265  ArcLookAheadMatcher(const FST &fst, MatchType match_type,
266                      MatcherData *data = 0)
267      : matcher_(fst, match_type),
268        fst_(matcher_.GetFst()),
269        lfst_(0),
270        s_(kNoStateId) {}
271
272  ArcLookAheadMatcher(const ArcLookAheadMatcher<M, F> &lmatcher,
273                      bool safe = false)
274      : matcher_(lmatcher.matcher_, safe),
275        fst_(matcher_.GetFst()),
276        lfst_(lmatcher.lfst_),
277        s_(kNoStateId) {}
278
279  // General matcher methods
280  ArcLookAheadMatcher<M, F> *Copy(bool safe = false) const {
281    return new ArcLookAheadMatcher<M, F>(*this, safe);
282  }
283
284  MatchType Type(bool test) const { return matcher_.Type(test); }
285
286  void SetState(StateId s) {
287    s_ = s;
288    matcher_.SetState(s);
289  }
290
291  bool Find(Label label) { return matcher_.Find(label); }
292  bool Done() const { return matcher_.Done(); }
293  const Arc& Value() const { return matcher_.Value(); }
294  void Next() { matcher_.Next(); }
295  const FST &GetFst() const { return fst_; }
296  uint64 Properties(uint64 props) const { return matcher_.Properties(props); }
297  uint32 Flags() const {
298    return matcher_.Flags() | kInputLookAheadMatcher |
299        kOutputLookAheadMatcher | F;
300  }
301
302  // Writable matcher methods
303  MatcherData *GetData() const { return 0; }
304
305  // Look-ahead methods.
306  bool LookAheadLabel(Label label) const { return matcher_.Find(label); }
307
308  // Checks if there is a matching (possibly super-final) transition
309  // at (s_, s).
310  bool LookAheadFst(const Fst<Arc> &fst, StateId s);
311
312  void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
313    lfst_ = &fst;
314  }
315
316 private:
317  // This allows base class virtual access to non-virtual derived-
318  // class members of the same name. It makes the derived class more
319  // efficient to use but unsafe to further derive.
320  virtual void SetState_(StateId s) { SetState(s); }
321  virtual bool Find_(Label label) { return Find(label); }
322  virtual bool Done_() const { return Done(); }
323  virtual const Arc& Value_() const { return Value(); }
324  virtual void Next_() { Next(); }
325
326  bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
327  bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
328    return LookAheadFst(fst, s);
329  }
330
331  mutable M matcher_;
332  const FST &fst_;         // Matcher FST
333  const Fst<Arc> *lfst_;   // Look-ahead FST
334  StateId s_;              // Matcher state
335};
336
337template <class M, uint32 F>
338bool ArcLookAheadMatcher<M, F>::LookAheadFst(const Fst<Arc> &fst, StateId s) {
339  if (&fst != lfst_)
340    InitLookAheadFst(fst);
341
342  bool ret = false;
343  ssize_t nprefix = 0;
344  if (F & kLookAheadWeight)
345    SetLookAheadWeight(Weight::Zero());
346  if (F & kLookAheadPrefix)
347    ClearLookAheadPrefix();
348  if (fst_.Final(s_) != Weight::Zero() &&
349      lfst_->Final(s) != Weight::Zero()) {
350    if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
351      return true;
352    ++nprefix;
353    if (F & kLookAheadWeight)
354      SetLookAheadWeight(Plus(LookAheadWeight(),
355                              Times(fst_.Final(s_), lfst_->Final(s))));
356    ret = true;
357  }
358  if (matcher_.Find(kNoLabel)) {
359    if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
360      return true;
361    ++nprefix;
362    if (F & kLookAheadWeight)
363      for (; !matcher_.Done(); matcher_.Next())
364        SetLookAheadWeight(Plus(LookAheadWeight(), matcher_.Value().weight));
365    ret = true;
366  }
367  for (ArcIterator< Fst<Arc> > aiter(*lfst_, s);
368       !aiter.Done();
369       aiter.Next()) {
370    const Arc &arc = aiter.Value();
371    Label label = kNoLabel;
372    switch (matcher_.Type(false)) {
373      case MATCH_INPUT:
374        label = arc.olabel;
375        break;
376      case MATCH_OUTPUT:
377        label = arc.ilabel;
378        break;
379      default:
380        FSTERROR() << "ArcLookAheadMatcher::LookAheadFst: bad match type";
381        return true;
382    }
383    if (label == 0) {
384      if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
385        return true;
386      if (!(F & kLookAheadNonEpsilonPrefix))
387        ++nprefix;
388      if (F & kLookAheadWeight)
389        SetLookAheadWeight(Plus(LookAheadWeight(), arc.weight));
390      ret = true;
391    } else if (matcher_.Find(label)) {
392      if (!(F & (kLookAheadWeight | kLookAheadPrefix)))
393        return true;
394      for (; !matcher_.Done(); matcher_.Next()) {
395        ++nprefix;
396        if (F & kLookAheadWeight)
397          SetLookAheadWeight(Plus(LookAheadWeight(),
398                                  Times(arc.weight,
399                                        matcher_.Value().weight)));
400        if ((F & kLookAheadPrefix) && nprefix == 1)
401          SetLookAheadPrefix(arc);
402      }
403      ret = true;
404    }
405  }
406  if (F & kLookAheadPrefix) {
407    if (nprefix == 1)
408      SetLookAheadWeight(Weight::One());  // Avoids double counting.
409    else
410      ClearLookAheadPrefix();
411  }
412  return ret;
413}
414
415
416// Template argument F accepts flags to control behavior.
417// It must include precisely one of KInputLookAheadMatcher or
418// KOutputLookAheadMatcher.
419template <class M, uint32 F = kLookAheadEpsilons | kLookAheadWeight |
420          kLookAheadPrefix | kLookAheadNonEpsilonPrefix |
421          kLookAheadKeepRelabelData,
422          class S = DefaultAccumulator<typename M::Arc> >
423class LabelLookAheadMatcher
424    : public LookAheadMatcherBase<typename M::FST::Arc> {
425 public:
426  typedef typename M::FST FST;
427  typedef typename M::Arc Arc;
428  typedef typename Arc::StateId StateId;
429  typedef typename Arc::Label Label;
430  typedef typename Arc::Weight Weight;
431  typedef LabelReachableData<Label> MatcherData;
432
433  using LookAheadMatcherBase<Arc>::LookAheadWeight;
434  using LookAheadMatcherBase<Arc>::SetLookAheadPrefix;
435  using LookAheadMatcherBase<Arc>::SetLookAheadWeight;
436  using LookAheadMatcherBase<Arc>::ClearLookAheadPrefix;
437
438  LabelLookAheadMatcher(const FST &fst, MatchType match_type,
439                        MatcherData *data = 0, S *s = 0)
440      : matcher_(fst, match_type),
441        lfst_(0),
442        label_reachable_(0),
443        s_(kNoStateId),
444        error_(false) {
445    if (!(F & (kInputLookAheadMatcher | kOutputLookAheadMatcher))) {
446      FSTERROR() << "LabelLookaheadMatcher: bad matcher flags: " << F;
447      error_ = true;
448    }
449    bool reach_input = match_type == MATCH_INPUT;
450    if (data) {
451      if (reach_input == data->ReachInput())
452        label_reachable_ = new LabelReachable<Arc, S>(data, s);
453    } else if ((reach_input && (F & kInputLookAheadMatcher)) ||
454               (!reach_input && (F & kOutputLookAheadMatcher))) {
455      label_reachable_ = new LabelReachable<Arc, S>(
456          fst, reach_input, s, F & kLookAheadKeepRelabelData);
457    }
458  }
459
460  LabelLookAheadMatcher(const LabelLookAheadMatcher<M, F, S> &lmatcher,
461                        bool safe = false)
462      : matcher_(lmatcher.matcher_, safe),
463        lfst_(lmatcher.lfst_),
464        label_reachable_(
465            lmatcher.label_reachable_ ?
466            new LabelReachable<Arc, S>(*lmatcher.label_reachable_) : 0),
467        s_(kNoStateId),
468        error_(lmatcher.error_) {}
469
470  ~LabelLookAheadMatcher() {
471    delete label_reachable_;
472  }
473
474  // General matcher methods
475  LabelLookAheadMatcher<M, F, S> *Copy(bool safe = false) const {
476    return new LabelLookAheadMatcher<M, F, S>(*this, safe);
477  }
478
479  MatchType Type(bool test) const { return matcher_.Type(test); }
480
481  void SetState(StateId s) {
482    if (s_ == s)
483      return;
484    s_ = s;
485    match_set_state_ = false;
486    reach_set_state_ = false;
487  }
488
489  bool Find(Label label) {
490    if (!match_set_state_) {
491      matcher_.SetState(s_);
492      match_set_state_ = true;
493    }
494    return matcher_.Find(label);
495  }
496
497  bool Done() const { return matcher_.Done(); }
498  const Arc& Value() const { return matcher_.Value(); }
499  void Next() { matcher_.Next(); }
500  const FST &GetFst() const { return matcher_.GetFst(); }
501
502  uint64 Properties(uint64 inprops) const {
503    uint64 outprops = matcher_.Properties(inprops);
504    if (error_ || (label_reachable_ && label_reachable_->Error()))
505      outprops |= kError;
506    return outprops;
507  }
508
509  uint32 Flags() const {
510    if (label_reachable_ && label_reachable_->GetData()->ReachInput())
511      return matcher_.Flags() | F | kInputLookAheadMatcher;
512    else if (label_reachable_ && !label_reachable_->GetData()->ReachInput())
513      return matcher_.Flags() | F | kOutputLookAheadMatcher;
514    else
515      return matcher_.Flags();
516  }
517
518  // Writable matcher methods
519  MatcherData *GetData() const {
520    return label_reachable_ ? label_reachable_->GetData() : 0;
521  };
522
523  // Look-ahead methods.
524  bool LookAheadLabel(Label label) const {
525    if (label == 0)
526      return true;
527
528    if (label_reachable_) {
529      if (!reach_set_state_) {
530        label_reachable_->SetState(s_);
531        reach_set_state_ = true;
532      }
533      return label_reachable_->Reach(label);
534    } else {
535      return true;
536    }
537  }
538
539  // Checks if there is a matching (possibly super-final) transition
540  // at (s_, s).
541  template <class L>
542  bool LookAheadFst(const L &fst, StateId s);
543
544  void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
545    lfst_ = &fst;
546    if (label_reachable_)
547      label_reachable_->ReachInit(fst, copy);
548  }
549
550  template <class L>
551  void InitLookAheadFst(const L& fst, bool copy = false) {
552    lfst_ = static_cast<const Fst<Arc> *>(&fst);
553    if (label_reachable_)
554      label_reachable_->ReachInit(fst, copy);
555  }
556
557 private:
558  // This allows base class virtual access to non-virtual derived-
559  // class members of the same name. It makes the derived class more
560  // efficient to use but unsafe to further derive.
561  virtual void SetState_(StateId s) { SetState(s); }
562  virtual bool Find_(Label label) { return Find(label); }
563  virtual bool Done_() const { return Done(); }
564  virtual const Arc& Value_() const { return Value(); }
565  virtual void Next_() { Next(); }
566
567  bool LookAheadLabel_(Label l) const { return LookAheadLabel(l); }
568  bool LookAheadFst_(const Fst<Arc> &fst, StateId s) {
569    return LookAheadFst(fst, s);
570  }
571
572  mutable M matcher_;
573  const Fst<Arc> *lfst_;                     // Look-ahead FST
574  LabelReachable<Arc, S> *label_reachable_;  // Label reachability info
575  StateId s_;                                // Matcher state
576  bool match_set_state_;                     // matcher_.SetState called?
577  mutable bool reach_set_state_;             // reachable_.SetState called?
578  bool error_;
579};
580
581template <class M, uint32 F, class S>
582template <class L> inline
583bool LabelLookAheadMatcher<M, F, S>::LookAheadFst(const L &fst, StateId s) {
584  if (static_cast<const Fst<Arc> *>(&fst) != lfst_)
585    InitLookAheadFst(fst);
586
587  SetLookAheadWeight(Weight::One());
588  ClearLookAheadPrefix();
589
590  if (!label_reachable_)
591    return true;
592
593  label_reachable_->SetState(s_, s);
594  reach_set_state_ = true;
595
596  bool compute_weight = F & kLookAheadWeight;
597  bool compute_prefix = F & kLookAheadPrefix;
598
599  bool reach_input = Type(false) == MATCH_OUTPUT;
600  ArcIterator<L> aiter(fst, s);
601  bool reach_arc = label_reachable_->Reach(&aiter, 0,
602                                           internal::NumArcs(*lfst_, s),
603                                           reach_input, compute_weight);
604  Weight lfinal = internal::Final(*lfst_, s);
605  bool reach_final = lfinal != Weight::Zero() && label_reachable_->ReachFinal();
606  if (reach_arc) {
607    ssize_t begin = label_reachable_->ReachBegin();
608    ssize_t end = label_reachable_->ReachEnd();
609    if (compute_prefix && end - begin == 1 && !reach_final) {
610      aiter.Seek(begin);
611      SetLookAheadPrefix(aiter.Value());
612      compute_weight = false;
613    } else if (compute_weight) {
614      SetLookAheadWeight(label_reachable_->ReachWeight());
615    }
616  }
617  if (reach_final && compute_weight)
618    SetLookAheadWeight(reach_arc ?
619                       Plus(LookAheadWeight(), lfinal) : lfinal);
620
621  return reach_arc || reach_final;
622}
623
624
625// Label-lookahead relabeling class.
626template <class A>
627class LabelLookAheadRelabeler {
628 public:
629  typedef typename A::Label Label;
630  typedef LabelReachableData<Label> MatcherData;
631  typedef AddOnPair<MatcherData, MatcherData> D;
632
633  // Relabels matcher Fst - initialization function object.
634  template <typename I>
635  LabelLookAheadRelabeler(I **impl);
636
637  // Relabels arbitrary Fst. Class L should be a label-lookahead Fst.
638  template <class L>
639  static void Relabel(MutableFst<A> *fst, const L &mfst,
640                      bool relabel_input) {
641    typename L::Impl *impl = mfst.GetImpl();
642    D *data = impl->GetAddOn();
643    LabelReachable<A> reachable(data->First() ?
644                                  data->First() : data->Second());
645    reachable.Relabel(fst, relabel_input);
646  }
647
648  // Returns relabeling pairs (cf. relabel.h::Relabel()).
649  // Class L should be a label-lookahead Fst.
650  // If 'avoid_collisions' is true, extra pairs are added to
651  // ensure no collisions when relabeling automata that have
652  // labels unseen here.
653  template <class L>
654  static void RelabelPairs(const L &mfst, vector<pair<Label, Label> > *pairs,
655                           bool avoid_collisions = false) {
656    typename L::Impl *impl = mfst.GetImpl();
657    D *data = impl->GetAddOn();
658    LabelReachable<A> reachable(data->First() ?
659                                  data->First() : data->Second());
660    reachable.RelabelPairs(pairs, avoid_collisions);
661  }
662};
663
664template <class A>
665template <typename I> inline
666LabelLookAheadRelabeler<A>::LabelLookAheadRelabeler(I **impl) {
667  Fst<A> &fst = (*impl)->GetFst();
668  D *data = (*impl)->GetAddOn();
669  const string name = (*impl)->Type();
670  bool is_mutable = fst.Properties(kMutable, false);
671  MutableFst<A> *mfst = 0;
672  if (is_mutable) {
673    mfst = static_cast<MutableFst<A> *>(&fst);
674  } else {
675    mfst = new VectorFst<A>(fst);
676    data->IncrRefCount();
677    delete *impl;
678  }
679  if (data->First()) {  // reach_input
680    LabelReachable<A> reachable(data->First());
681    reachable.Relabel(mfst, true);
682    if (!FLAGS_save_relabel_ipairs.empty()) {
683      vector<pair<Label, Label> > pairs;
684      reachable.RelabelPairs(&pairs, true);
685      WriteLabelPairs(FLAGS_save_relabel_ipairs, pairs);
686    }
687  } else {
688    LabelReachable<A> reachable(data->Second());
689    reachable.Relabel(mfst, false);
690    if (!FLAGS_save_relabel_opairs.empty()) {
691      vector<pair<Label, Label> > pairs;
692      reachable.RelabelPairs(&pairs, true);
693      WriteLabelPairs(FLAGS_save_relabel_opairs, pairs);
694    }
695  }
696  if (!is_mutable) {
697    *impl = new I(*mfst, name);
698    (*impl)->SetAddOn(data);
699    delete mfst;
700    data->DecrRefCount();
701  }
702}
703
704
705// Generic lookahead matcher, templated on the FST definition
706// - a wrapper around pointer to specific one.
707template <class F>
708class LookAheadMatcher {
709 public:
710  typedef F FST;
711  typedef typename F::Arc Arc;
712  typedef typename Arc::StateId StateId;
713  typedef typename Arc::Label Label;
714  typedef typename Arc::Weight Weight;
715  typedef LookAheadMatcherBase<Arc> LBase;
716
717  LookAheadMatcher(const F &fst, MatchType match_type) {
718    base_ = fst.InitMatcher(match_type);
719    if (!base_)
720      base_ = new SortedMatcher<F>(fst, match_type);
721    lookahead_ = false;
722  }
723
724  LookAheadMatcher(const LookAheadMatcher<F> &matcher, bool safe = false) {
725    base_ = matcher.base_->Copy(safe);
726    lookahead_ = matcher.lookahead_;
727  }
728
729  ~LookAheadMatcher() { delete base_; }
730
731  // General matcher methods
732  LookAheadMatcher<F> *Copy(bool safe = false) const {
733      return new LookAheadMatcher<F>(*this, safe);
734  }
735
736  MatchType Type(bool test) const { return base_->Type(test); }
737  void SetState(StateId s) { base_->SetState(s); }
738  bool Find(Label label) { return base_->Find(label); }
739  bool Done() const { return base_->Done(); }
740  const Arc& Value() const { return base_->Value(); }
741  void Next() { base_->Next(); }
742  const F &GetFst() const { return static_cast<const F &>(base_->GetFst()); }
743
744  uint64 Properties(uint64 props) const { return base_->Properties(props); }
745
746  uint32 Flags() const { return base_->Flags(); }
747
748  // Look-ahead methods
749  bool LookAheadLabel(Label label) const {
750    if (LookAheadCheck()) {
751      LBase *lbase = static_cast<LBase *>(base_);
752      return lbase->LookAheadLabel(label);
753    } else {
754      return true;
755    }
756  }
757
758  bool LookAheadFst(const Fst<Arc> &fst, StateId s) {
759    if (LookAheadCheck()) {
760      LBase *lbase = static_cast<LBase *>(base_);
761      return lbase->LookAheadFst(fst, s);
762    } else {
763      return true;
764    }
765  }
766
767  Weight LookAheadWeight() const {
768    if (LookAheadCheck()) {
769      LBase *lbase = static_cast<LBase *>(base_);
770      return lbase->LookAheadWeight();
771    } else {
772      return Weight::One();
773    }
774  }
775
776  bool LookAheadPrefix(Arc *arc) const {
777    if (LookAheadCheck()) {
778      LBase *lbase = static_cast<LBase *>(base_);
779      return lbase->LookAheadPrefix(arc);
780    } else {
781      return false;
782    }
783  }
784
785  void InitLookAheadFst(const Fst<Arc>& fst, bool copy = false) {
786    if (LookAheadCheck()) {
787      LBase *lbase = static_cast<LBase *>(base_);
788      lbase->InitLookAheadFst(fst, copy);
789    }
790  }
791
792 private:
793  bool LookAheadCheck() const {
794    if (!lookahead_) {
795      lookahead_ = base_->Flags() &
796          (kInputLookAheadMatcher | kOutputLookAheadMatcher);
797      if (!lookahead_) {
798        FSTERROR() << "LookAheadMatcher: No look-ahead matcher defined";
799      }
800    }
801    return lookahead_;
802  }
803
804  MatcherBase<Arc> *base_;
805  mutable bool lookahead_;
806
807  void operator=(const LookAheadMatcher<Arc> &);  // disallow
808};
809
810}  // namespace fst
811
812#endif  // FST_LIB_LOOKAHEAD_MATCHER_H__
813