compose.h revision 73018b4a1d088cdda0e7bd059fddf1f308a8195a
1// compose.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// Class to compute the composition of two FSTs
18
19#ifndef FST_LIB_COMPOSE_H__
20#define FST_LIB_COMPOSE_H__
21
22#include <algorithm>
23
24#include <unordered_map>
25
26#include "fst/lib/cache.h"
27#include "fst/lib/test-properties.h"
28
29namespace fst {
30
31// Enumeration of uint64 bits used to represent the user-defined
32// properties of FST composition (in the template parameter to
33// ComposeFstOptions<T>). The bits stand for extensions of generic FST
34// composition. ComposeFstOptions<> (all the bits unset) is the "plain"
35// compose without any extra extensions.
36enum ComposeTypes {
37  // RHO: flags dealing with a special "rest" symbol in the FSTs.
38  // NB: at most one of the bits COMPOSE_FST1_RHO, COMPOSE_FST2_RHO
39  // may be set.
40  COMPOSE_FST1_RHO    = 1ULL<<0,  // "Rest" symbol on the output side of fst1.
41  COMPOSE_FST2_RHO    = 1ULL<<1,  // "Rest" symbol on the input side of fst2.
42  COMPOSE_FST1_PHI    = 1ULL<<2,  // "Failure" symbol on the output
43                                  // side of fst1.
44  COMPOSE_FST2_PHI    = 1ULL<<3,  // "Failure" symbol on the input side
45                                  // of fst2.
46  COMPOSE_FST1_SIGMA  = 1ULL<<4,  // "Any" symbol on the output side of
47                                  // fst1.
48  COMPOSE_FST2_SIGMA  = 1ULL<<5,  // "Any" symbol on the input side of
49                                  // fst2.
50  // Optimization related bits.
51  COMPOSE_GENERIC     = 1ULL<<32,  // Disables optimizations, applies
52                                   // the generic version of the
53                                   // composition algorithm. This flag
54                                   // is used for internal testing
55                                   // only.
56
57  // -----------------------------------------------------------------
58  // Auxiliary enum values denoting specific combinations of
59  // bits. Internal use only.
60  COMPOSE_RHO         = COMPOSE_FST1_RHO | COMPOSE_FST2_RHO,
61  COMPOSE_PHI         = COMPOSE_FST1_PHI | COMPOSE_FST2_PHI,
62  COMPOSE_SIGMA       = COMPOSE_FST1_SIGMA | COMPOSE_FST2_SIGMA,
63  COMPOSE_SPECIAL_SYMBOLS = COMPOSE_RHO | COMPOSE_PHI | COMPOSE_SIGMA,
64
65  // -----------------------------------------------------------------
66  // The following bits, denoting specific optimizations, are
67  // typically set *internally* by the composition algorithm.
68  COMPOSE_FST1_STRING = 1ULL<<33,  // fst1 is a string
69  COMPOSE_FST2_STRING = 1ULL<<34,  // fst2 is a string
70  COMPOSE_FST1_DET    = 1ULL<<35,  // fst1 is deterministic
71  COMPOSE_FST2_DET    = 1ULL<<36,  // fst2 is deterministic
72  COMPOSE_INTERNAL_MASK    = 0xffffffff00000000ULL
73};
74
75
76template <uint64 T = 0ULL>
77struct ComposeFstOptions : public CacheOptions {
78  explicit ComposeFstOptions(const CacheOptions &opts) : CacheOptions(opts) {}
79  ComposeFstOptions() { }
80};
81
82
83// Abstract base for the implementation of delayed ComposeFst. The
84// concrete specializations are templated on the (uint64-valued)
85// properties of the FSTs being composed.
86template <class A>
87class ComposeFstImplBase : public CacheImpl<A> {
88 public:
89  using FstImpl<A>::SetType;
90  using FstImpl<A>::SetProperties;
91  using FstImpl<A>::Properties;
92  using FstImpl<A>::SetInputSymbols;
93  using FstImpl<A>::SetOutputSymbols;
94
95  using CacheBaseImpl< CacheState<A> >::HasStart;
96  using CacheBaseImpl< CacheState<A> >::HasFinal;
97  using CacheBaseImpl< CacheState<A> >::HasArcs;
98
99  typedef typename A::Label Label;
100  typedef typename A::Weight Weight;
101  typedef typename A::StateId StateId;
102  typedef CacheState<A> State;
103
104  ComposeFstImplBase(const Fst<A> &fst1,
105                     const Fst<A> &fst2,
106                     const CacheOptions &opts)
107      :CacheImpl<A>(opts), fst1_(fst1.Copy()), fst2_(fst2.Copy()) {
108    SetType("compose");
109    uint64 props1 = fst1.Properties(kFstProperties, false);
110    uint64 props2 = fst2.Properties(kFstProperties, false);
111    SetProperties(ComposeProperties(props1, props2), kCopyProperties);
112
113    if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols()))
114      LOG(FATAL) << "ComposeFst: output symbol table of 1st argument "
115                 << "does not match input symbol table of 2nd argument";
116
117    SetInputSymbols(fst1.InputSymbols());
118    SetOutputSymbols(fst2.OutputSymbols());
119  }
120
121  virtual ~ComposeFstImplBase() {
122    delete fst1_;
123    delete fst2_;
124  }
125
126  StateId Start() {
127    if (!HasStart()) {
128      StateId start = ComputeStart();
129      if (start != kNoStateId) {
130        this->SetStart(start);
131      }
132    }
133    return CacheImpl<A>::Start();
134  }
135
136  Weight Final(StateId s) {
137    if (!HasFinal(s)) {
138      Weight final = ComputeFinal(s);
139      this->SetFinal(s, final);
140    }
141    return CacheImpl<A>::Final(s);
142  }
143
144  virtual void Expand(StateId s) = 0;
145
146  size_t NumArcs(StateId s) {
147    if (!HasArcs(s))
148      Expand(s);
149    return CacheImpl<A>::NumArcs(s);
150  }
151
152  size_t NumInputEpsilons(StateId s) {
153    if (!HasArcs(s))
154      Expand(s);
155    return CacheImpl<A>::NumInputEpsilons(s);
156  }
157
158  size_t NumOutputEpsilons(StateId s) {
159    if (!HasArcs(s))
160      Expand(s);
161    return CacheImpl<A>::NumOutputEpsilons(s);
162  }
163
164  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
165    if (!HasArcs(s))
166      Expand(s);
167    CacheImpl<A>::InitArcIterator(s, data);
168  }
169
170  // Access to flags encoding compose options/optimizations etc.  (for
171  // debugging).
172  virtual uint64 ComposeFlags() const = 0;
173
174 protected:
175  virtual StateId ComputeStart() = 0;
176  virtual Weight ComputeFinal(StateId s) = 0;
177
178  const Fst<A> *fst1_;            // first input Fst
179  const Fst<A> *fst2_;            // second input Fst
180};
181
182
183// The following class encapsulates implementation-dependent details
184// of state tuple lookup, i.e. a bijective mapping from triples of two
185// FST states and an epsilon filter state to the corresponding state
186// IDs of the fst resulting from composition. The mapping must
187// implement the [] operator in the style of STL associative
188// containers (map, hash_map), i.e. table[x] must return a reference
189// to the value associated with x. If x is an unassigned tuple, the
190// operator must automatically associate x with value 0.
191//
192// NB: "table[x] == 0" for unassigned tuples x is required by the
193// following off-by-one device used in the implementation of
194// ComposeFstImpl. The value stored in the table is equal to tuple ID
195// plus one, i.e. it is always a strictly positive number. Therefore,
196// table[x] is equal to 0 if and only if x is an unassigned tuple (in
197// which the algorithm assigns a new ID to x, and sets table[x] -
198// stored in a reference - to "new ID + 1"). This form of lookup is
199// more efficient than calling "find(x)" and "insert(make_pair(x, new
200// ID))" if x is an unassigned tuple.
201//
202// The generic implementation is a wrapper around a hash_map.
203template <class A, uint64 T>
204class ComposeStateTable {
205 public:
206  typedef typename A::StateId StateId;
207
208  struct StateTuple {
209    StateTuple() {}
210    StateTuple(StateId s1, StateId s2, int f)
211        : state_id1(s1), state_id2(s2), filt(f) {}
212    StateId state_id1;  // state Id on fst1
213    StateId state_id2;  // state Id on fst2
214    int filt;           // epsilon filter state
215  };
216
217  ComposeStateTable() {
218    StateTuple empty_tuple(kNoStateId, kNoStateId, 0);
219  }
220
221  // NB: if 'tuple' is not in 'table_', the pair (tuple, StateId()) is
222  // inserted into 'table_' (standard STL container semantics). Since
223  // StateId is a built-in type, the explicit default constructor call
224  // StateId() returns 0.
225  StateId &operator[](const StateTuple &tuple) {
226    return table_[tuple];
227  }
228
229 private:
230  // Comparison object for hashing StateTuple(s).
231  class StateTupleEqual {
232   public:
233    bool operator()(const StateTuple& x, const StateTuple& y) const {
234      return x.state_id1 == y.state_id1 &&
235             x.state_id2 == y.state_id2 &&
236             x.filt == y.filt;
237    }
238  };
239
240  static const int kPrime0 = 7853;
241  static const int kPrime1 = 7867;
242
243  // Hash function for StateTuple to Fst states.
244  class StateTupleKey {
245   public:
246    size_t operator()(const StateTuple& x) const {
247      return static_cast<size_t>(x.state_id1 +
248                                 x.state_id2 * kPrime0 +
249                                 x.filt * kPrime1);
250    }
251  };
252
253  // Lookup table mapping state tuples to state IDs.
254  typedef std::unordered_map<StateTuple, StateId, StateTupleKey,
255                             StateTupleEqual> StateTable;
256 // Actual table data.
257  StateTable table_;
258
259  DISALLOW_EVIL_CONSTRUCTORS(ComposeStateTable);
260};
261
262
263// State tuple lookup table for the composition of a string FST with a
264// deterministic FST.  The class maps state tuples to their unique IDs
265// (i.e. states of the ComposeFst). Main optimization: due to the
266// 1-to-1 correspondence between the states of the input string FST
267// and those of the resulting (string) FST, a state tuple (s1, s2) is
268// simply mapped to StateId s1. Hence, we use an STL vector as a
269// lookup table. Template argument Fst1IsString specifies which FST is
270// a string (this determines whether or not we index the lookup table
271// by the first or by the second state).
272template <class A, bool Fst1IsString>
273class StringDetComposeStateTable {
274 public:
275  typedef typename A::StateId StateId;
276
277  struct StateTuple {
278    typedef typename A::StateId StateId;
279    StateTuple() {}
280    StateTuple(StateId s1, StateId s2, int /* f */)
281        : state_id1(s1), state_id2(s2) {}
282    StateId state_id1;  // state Id on fst1
283    StateId state_id2;  // state Id on fst2
284    static const int filt = 0;  // 'fake' epsilon filter - only needed
285                                // for API compatibility
286  };
287
288  StringDetComposeStateTable() {}
289
290  // Subscript operator. Behaves in a way similar to its map/hash_map
291  // counterpart, i.e. returns a reference to the value associated
292  // with 'tuple', inserting a 0 value if 'tuple' is unassigned.
293  StateId &operator[](const StateTuple &tuple) {
294    StateId index = Fst1IsString ? tuple.state_id1 : tuple.state_id2;
295    if (index >= (StateId)data_.size()) {
296      // NB: all values in [old_size; index] are initialized to 0.
297      data_.resize(index + 1);
298    }
299    return data_[index];
300  }
301
302 private:
303  vector<StateId> data_;
304
305  DISALLOW_EVIL_CONSTRUCTORS(StringDetComposeStateTable);
306};
307
308
309// Specializations of ComposeStateTable for the string/det case.
310// Both inherit from StringDetComposeStateTable.
311template <class A>
312class ComposeStateTable<A, COMPOSE_FST1_STRING | COMPOSE_FST2_DET>
313    : public StringDetComposeStateTable<A, true> { };
314
315template <class A>
316class ComposeStateTable<A, COMPOSE_FST2_STRING | COMPOSE_FST1_DET>
317    : public StringDetComposeStateTable<A, false> { };
318
319
320// Parameterized implementation of FST composition for a pair of FSTs
321// matching the property bit vector T. If possible,
322// instantiation-specific switches in the code are based on the values
323// of the bits in T, which are known at compile time, so unused code
324// should be optimized away by the compiler.
325template <class A, uint64 T>
326class ComposeFstImpl : public ComposeFstImplBase<A> {
327  typedef typename A::StateId StateId;
328  typedef typename A::Label   Label;
329  typedef typename A::Weight  Weight;
330  using FstImpl<A>::SetType;
331  using FstImpl<A>::SetProperties;
332
333  enum FindType { FIND_INPUT  = 1,          // find input label on fst2
334                  FIND_OUTPUT = 2,          // find output label on fst1
335                  FIND_BOTH   = 3 };        // find choice state dependent
336
337  typedef ComposeStateTable<A, T & COMPOSE_INTERNAL_MASK> StateTupleTable;
338  typedef typename StateTupleTable::StateTuple StateTuple;
339
340 public:
341  ComposeFstImpl(const Fst<A> &fst1,
342                 const Fst<A> &fst2,
343                 const CacheOptions &opts)
344      :ComposeFstImplBase<A>(fst1, fst2, opts) {
345
346    bool osorted = fst1.Properties(kOLabelSorted, false);
347    bool isorted = fst2.Properties(kILabelSorted, false);
348
349    switch (T & COMPOSE_SPECIAL_SYMBOLS) {
350      case COMPOSE_FST1_RHO:
351      case COMPOSE_FST1_PHI:
352      case COMPOSE_FST1_SIGMA:
353        if (!osorted || FLAGS_fst_verify_properties)
354          osorted = fst1.Properties(kOLabelSorted, true);
355        if (!osorted)
356          LOG(FATAL) << "ComposeFst: 1st argument not output label "
357                     << "sorted (special symbols present)";
358        break;
359      case COMPOSE_FST2_RHO:
360      case COMPOSE_FST2_PHI:
361      case COMPOSE_FST2_SIGMA:
362        if (!isorted || FLAGS_fst_verify_properties)
363          isorted = fst2.Properties(kILabelSorted, true);
364        if (!isorted)
365          LOG(FATAL) << "ComposeFst: 2nd argument not input label "
366                     << "sorted (special symbols present)";
367        break;
368      case 0:
369        if ((!isorted && !osorted) || FLAGS_fst_verify_properties) {
370          osorted = fst1.Properties(kOLabelSorted, true);
371          if (!osorted)
372            isorted = fst2.Properties(kILabelSorted, true);
373        }
374        break;
375      default:
376        LOG(FATAL)
377          << "ComposeFst: More than one special symbol used in composition";
378    }
379
380    if (isorted && (T & COMPOSE_FST2_SIGMA)) {
381      find_type_ = FIND_INPUT;
382    } else if (osorted && (T & COMPOSE_FST1_SIGMA)) {
383      find_type_ = FIND_OUTPUT;
384    } else if (isorted && (T & COMPOSE_FST2_PHI)) {
385      find_type_ = FIND_INPUT;
386    } else if (osorted && (T & COMPOSE_FST1_PHI)) {
387      find_type_ = FIND_OUTPUT;
388    } else if (isorted && (T & COMPOSE_FST2_RHO)) {
389      find_type_ = FIND_INPUT;
390    } else if (osorted && (T & COMPOSE_FST1_RHO)) {
391      find_type_ = FIND_OUTPUT;
392    } else if (isorted && (T & COMPOSE_FST1_STRING)) {
393      find_type_ = FIND_INPUT;
394    } else if(osorted && (T & COMPOSE_FST2_STRING)) {
395      find_type_ = FIND_OUTPUT;
396    } else if (isorted && osorted) {
397      find_type_ = FIND_BOTH;
398    } else if (isorted) {
399      find_type_ = FIND_INPUT;
400    } else if (osorted) {
401      find_type_ = FIND_OUTPUT;
402    } else {
403      LOG(FATAL) << "ComposeFst: 1st argument not output label sorted "
404                 << "and 2nd argument is not input label sorted";
405    }
406  }
407
408  // Finds/creates an Fst state given a StateTuple.  Only creates a new
409  // state if StateTuple is not found in the state hash.
410  //
411  // The method exploits the following device: all pairs stored in the
412  // associative container state_tuple_table_ are of the form (tuple,
413  // id(tuple) + 1), i.e. state_tuple_table_[tuple] > 0 if tuple has
414  // been stored previously. For unassigned tuples, the call to
415  // state_tuple_table_[tuple] creates a new pair (tuple, 0). As a
416  // result, state_tuple_table_[tuple] == 0 iff tuple is new.
417  StateId FindState(const StateTuple& tuple) {
418    StateId &assoc_value = state_tuple_table_[tuple];
419    if (assoc_value == 0) {  // tuple wasn't present in lookup table:
420                             // assign it a new ID.
421      state_tuples_.push_back(tuple);
422      assoc_value = state_tuples_.size();
423    }
424    return assoc_value - 1;  // NB: assoc_value = ID + 1
425  }
426
427  // Generates arc for composition state s from matched input Fst arcs.
428  void AddArc(StateId s, const A &arca, const A &arcb, int f,
429              bool find_input) {
430    A arc;
431    if (find_input) {
432      arc.ilabel = arcb.ilabel;
433      arc.olabel = arca.olabel;
434      arc.weight = Times(arcb.weight, arca.weight);
435      StateTuple tuple(arcb.nextstate, arca.nextstate, f);
436      arc.nextstate = FindState(tuple);
437    } else {
438      arc.ilabel = arca.ilabel;
439      arc.olabel = arcb.olabel;
440      arc.weight = Times(arca.weight, arcb.weight);
441      StateTuple tuple(arca.nextstate, arcb.nextstate, f);
442      arc.nextstate = FindState(tuple);
443    }
444    CacheImpl<A>::AddArc(s, arc);
445  }
446
447  // Arranges it so that the first arg to OrderedExpand is the Fst
448  // that will be passed to FindLabel.
449  void Expand(StateId s) {
450    StateTuple &tuple = state_tuples_[s];
451    StateId s1 = tuple.state_id1;
452    StateId s2 = tuple.state_id2;
453    int f = tuple.filt;
454    if (find_type_ == FIND_INPUT)
455      OrderedExpand(s, ComposeFstImplBase<A>::fst2_, s2,
456                    ComposeFstImplBase<A>::fst1_, s1, f, true);
457    else
458      OrderedExpand(s, ComposeFstImplBase<A>::fst1_, s1,
459                    ComposeFstImplBase<A>::fst2_, s2, f, false);
460  }
461
462  // Access to flags encoding compose options/optimizations etc.  (for
463  // debugging).
464  virtual uint64 ComposeFlags() const { return T; }
465
466 private:
467  // This does that actual matching of labels in the composition. The
468  // arguments are ordered so FindLabel is called with state SA of
469  // FSTA for each arc leaving state SB of FSTB. The FIND_INPUT arg
470  // determines whether the input or output label of arcs at SB is
471  // the one to match on.
472  void OrderedExpand(StateId s, const Fst<A> *fsta, StateId sa,
473                     const Fst<A> *fstb, StateId sb, int f, bool find_input) {
474
475    size_t numarcsa = fsta->NumArcs(sa);
476    size_t numepsa = find_input ? fsta->NumInputEpsilons(sa) :
477                     fsta->NumOutputEpsilons(sa);
478    bool finala = fsta->Final(sa) != Weight::Zero();
479    ArcIterator< Fst<A> > aitera(*fsta, sa);
480    // First handle special epsilons and sigmas on FSTA
481    for (; !aitera.Done(); aitera.Next()) {
482      const A &arca = aitera.Value();
483      Label match_labela = find_input ? arca.ilabel : arca.olabel;
484      if (match_labela > 0) {
485        break;
486      }
487      if ((T & COMPOSE_SIGMA) != 0 &&  match_labela == kSigmaLabel) {
488        // Found a sigma? Match it against all (non-special) symbols
489        // on side b.
490        for (ArcIterator< Fst<A> > aiterb(*fstb, sb);
491             !aiterb.Done();
492             aiterb.Next()) {
493          const A &arcb = aiterb.Value();
494          Label labelb = find_input ? arcb.olabel : arcb.ilabel;
495          if (labelb <= 0) continue;
496          AddArc(s, arca, arcb, 0, find_input);
497        }
498      } else if (f == 0 && match_labela == 0) {
499        A earcb(0, 0, Weight::One(), sb);
500        AddArc(s, arca, earcb, 0, find_input);  // move forward on epsilon
501      }
502    }
503    // Next handle non-epsilon matches, rho labels, and epsilons on FSTB
504    for (ArcIterator< Fst<A> > aiterb(*fstb, sb);
505         !aiterb.Done();
506         aiterb.Next()) {
507      const A &arcb = aiterb.Value();
508      Label match_labelb = find_input ? arcb.olabel : arcb.ilabel;
509      if (match_labelb) {  // Consider non-epsilon match
510        if (FindLabel(&aitera, numarcsa, match_labelb, find_input)) {
511          for (; !aitera.Done(); aitera.Next()) {
512            const A &arca = aitera.Value();
513            Label match_labela = find_input ? arca.ilabel : arca.olabel;
514            if (match_labela != match_labelb)
515              break;
516            AddArc(s, arca, arcb, 0, find_input);  // move forward on match
517          }
518        } else if ((T & COMPOSE_SPECIAL_SYMBOLS) != 0) {
519          // If there is no transition labelled 'match_labelb' in
520          // fsta, try matching 'match_labelb' against special symbols
521          // (Phi, Rho,...).
522          for (aitera.Reset(); !aitera.Done(); aitera.Next()) {
523            A arca = aitera.Value();
524            Label labela = find_input ? arca.ilabel : arca.olabel;
525            if (labela >= 0) {
526              break;
527            } else if (((T & COMPOSE_PHI) != 0) && (labela == kPhiLabel)) {
528              // Case 1: if a failure transition exists, follow its
529              // transitive closure until a) a transition labelled
530              // 'match_labelb' is found, or b) the initial state of
531              // fsta is reached.
532
533              StateId sf = sa;  // Start of current failure transition.
534              while (labela == kPhiLabel && sf != arca.nextstate) {
535                sf = arca.nextstate;
536
537                size_t numarcsf = fsta->NumArcs(sf);
538                ArcIterator< Fst<A> > aiterf(*fsta, sf);
539                if (FindLabel(&aiterf, numarcsf, match_labelb, find_input)) {
540                  // Sub-case 1a: there exists a transition starting
541                  // in sf and consuming symbol 'match_labelb'.
542                  AddArc(s, aiterf.Value(), arcb, 0, find_input);
543                  break;
544                } else {
545                  // No transition labelled 'match_labelb' found: try
546                  // next failure transition (starting at 'sf').
547                  for (aiterf.Reset(); !aiterf.Done(); aiterf.Next()) {
548                    arca = aiterf.Value();
549                    labela = find_input ? arca.ilabel : arca.olabel;
550                    if (labela >= kPhiLabel) break;
551                  }
552                }
553              }
554              if (labela == kPhiLabel && sf == arca.nextstate) {
555                // Sub-case 1b: failure transitions lead to start
556                // state without finding a matching
557                // transition. Therefore, we generate a loop in start
558                // state of fsta.
559                A loop(match_labelb, match_labelb, Weight::One(), sf);
560                AddArc(s, loop, arcb, 0, find_input);
561              }
562            } else if (((T & COMPOSE_RHO) != 0) && (labela == kRhoLabel)) {
563              // Case 2: 'match_labelb' can be matched against a
564              // "rest" (rho) label in fsta.
565              if (find_input) {
566                arca.ilabel = match_labelb;
567                if (arca.olabel == kRhoLabel)
568                  arca.olabel = match_labelb;
569              } else {
570                arca.olabel = match_labelb;
571                if (arca.ilabel == kRhoLabel)
572                  arca.ilabel = match_labelb;
573              }
574              AddArc(s, arca, arcb, 0, find_input);  // move fwd on match
575            }
576          }
577        }
578      } else if (numepsa != numarcsa || finala) {  // Handle FSTB epsilon
579        A earca(0, 0, Weight::One(), sa);
580        AddArc(s, earca, arcb, numepsa > 0, find_input);  // move on epsilon
581      }
582    }
583    this->SetArcs(s);
584   }
585
586
587  // Finds matches to MATCH_LABEL in arcs given by AITER
588  // using FIND_INPUT to determine whether to look on input or output.
589  bool FindLabel(ArcIterator< Fst<A> > *aiter, size_t numarcs,
590                 Label match_label, bool find_input) {
591    // binary search for match
592    size_t low = 0;
593    size_t high = numarcs;
594    while (low < high) {
595      size_t mid = (low + high) / 2;
596      aiter->Seek(mid);
597      Label label = find_input ?
598                    aiter->Value().ilabel : aiter->Value().olabel;
599      if (label > match_label) {
600        high = mid;
601      } else if (label < match_label) {
602        low = mid + 1;
603      } else {
604        // find first matching label (when non-determinism)
605        for (size_t i = mid; i > low; --i) {
606          aiter->Seek(i - 1);
607          label = find_input ? aiter->Value().ilabel : aiter->Value().olabel;
608          if (label != match_label) {
609            aiter->Seek(i);
610            return true;
611          }
612        }
613        return true;
614      }
615    }
616    return false;
617  }
618
619  StateId ComputeStart() {
620    StateId s1 = ComposeFstImplBase<A>::fst1_->Start();
621    StateId s2 = ComposeFstImplBase<A>::fst2_->Start();
622    if (s1 == kNoStateId || s2 == kNoStateId)
623      return kNoStateId;
624    StateTuple tuple(s1, s2, 0);
625    return FindState(tuple);
626  }
627
628  Weight ComputeFinal(StateId s) {
629    StateTuple &tuple = state_tuples_[s];
630    Weight final = Times(ComposeFstImplBase<A>::fst1_->Final(tuple.state_id1),
631                         ComposeFstImplBase<A>::fst2_->Final(tuple.state_id2));
632    return final;
633  }
634
635
636  FindType find_type_;            // find label on which side?
637
638  // Maps from StateId to StateTuple.
639  vector<StateTuple> state_tuples_;
640
641  // Maps from StateTuple to StateId.
642  StateTupleTable state_tuple_table_;
643
644  DISALLOW_EVIL_CONSTRUCTORS(ComposeFstImpl);
645};
646
647
648// Computes the composition of two transducers. This version is a
649// delayed Fst. If FST1 transduces string x to y with weight a and FST2
650// transduces y to z with weight b, then their composition transduces
651// string x to z with weight Times(x, z).
652//
653// The output labels of the first transducer or the input labels of
654// the second transducer must be sorted.  The weights need to form a
655// commutative semiring (valid for TropicalWeight and LogWeight).
656//
657// Complexity:
658// Assuming the first FST is unsorted and the second is sorted:
659// - Time: O(v1 v2 d1 (log d2 + m2)),
660// - Space: O(v1 v2)
661// where vi = # of states visited, di = maximum out-degree, and mi the
662// maximum multiplicity of the states visited for the ith
663// FST. Constant time and space to visit an input state or arc is
664// assumed and exclusive of caching.
665//
666// Caveats:
667// - ComposeFst does not trim its output (since it is a delayed operation).
668// - The efficiency of composition can be strongly affected by several factors:
669//   - the choice of which tnansducer is sorted - prefer sorting the FST
670//     that has the greater average out-degree.
671//   - the amount of non-determinism
672//   - the presence and location of epsilon transitions - avoid epsilon
673//     transitions on the output side of the first transducer or
674//     the input side of the second transducer or prefer placing
675//     them later in a path since they delay matching and can
676//     introduce non-coaccessible states and transitions.
677template <class A>
678class ComposeFst : public Fst<A> {
679 public:
680  friend class ArcIterator< ComposeFst<A> >;
681  friend class CacheStateIterator< ComposeFst<A> >;
682  friend class CacheArcIterator< ComposeFst<A> >;
683
684  typedef A Arc;
685  typedef typename A::Weight Weight;
686  typedef typename A::StateId StateId;
687  typedef CacheState<A> State;
688
689  ComposeFst(const Fst<A> &fst1, const Fst<A> &fst2)
690      : impl_(Init(fst1, fst2, ComposeFstOptions<>())) { }
691
692  template <uint64 T>
693  ComposeFst(const Fst<A> &fst1,
694             const Fst<A> &fst2,
695             const ComposeFstOptions<T> &opts)
696      : impl_(Init(fst1, fst2, opts)) { }
697
698  ComposeFst(const ComposeFst<A> &fst) : Fst<A>(fst), impl_(fst.impl_) {
699    impl_->IncrRefCount();
700  }
701
702  virtual ~ComposeFst() { if (!impl_->DecrRefCount()) delete impl_;  }
703
704  virtual StateId Start() const { return impl_->Start(); }
705
706  virtual Weight Final(StateId s) const { return impl_->Final(s); }
707
708  virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
709
710  virtual size_t NumInputEpsilons(StateId s) const {
711    return impl_->NumInputEpsilons(s);
712  }
713
714  virtual size_t NumOutputEpsilons(StateId s) const {
715    return impl_->NumOutputEpsilons(s);
716  }
717
718  virtual uint64 Properties(uint64 mask, bool test) const {
719    if (test) {
720      uint64 known, test = TestProperties(*this, mask, &known);
721      impl_->SetProperties(test, known);
722      return test & mask;
723    } else {
724      return impl_->Properties(mask);
725    }
726  }
727
728  virtual const string& Type() const { return impl_->Type(); }
729
730  virtual ComposeFst<A> *Copy() const {
731    return new ComposeFst<A>(*this);
732  }
733
734  virtual const SymbolTable* InputSymbols() const {
735    return impl_->InputSymbols();
736  }
737
738  virtual const SymbolTable* OutputSymbols() const {
739    return impl_->OutputSymbols();
740  }
741
742  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
743
744  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
745    impl_->InitArcIterator(s, data);
746  }
747
748  // Access to flags encoding compose options/optimizations etc.  (for
749  // debugging).
750  uint64 ComposeFlags() const { return impl_->ComposeFlags(); }
751
752 protected:
753  ComposeFstImplBase<A> *Impl() { return impl_; }
754
755 private:
756  ComposeFstImplBase<A> *impl_;
757
758  // Auxiliary method encapsulating the creation of a ComposeFst
759  // implementation that is appropriate for the properties of fst1 and
760  // fst2.
761  template <uint64 T>
762  static ComposeFstImplBase<A> *Init(
763      const Fst<A> &fst1,
764      const Fst<A> &fst2,
765      const ComposeFstOptions<T> &opts) {
766
767    // Filter for sort properties (forces a property check).
768    uint64 sort_props_mask = kILabelSorted | kOLabelSorted;
769    // Filter for optimization-related properties (does not force a
770    // property-check).
771    uint64 opt_props_mask =
772      kString | kIDeterministic | kODeterministic | kNoIEpsilons |
773      kNoOEpsilons;
774
775    uint64 props1 = fst1.Properties(sort_props_mask, true);
776    uint64 props2 = fst2.Properties(sort_props_mask, true);
777
778    props1 |= fst1.Properties(opt_props_mask, false);
779    props2 |= fst2.Properties(opt_props_mask, false);
780
781    if (!(Weight::Properties() & kCommutative)) {
782      props1 |= fst1.Properties(kUnweighted, true);
783      props2 |= fst2.Properties(kUnweighted, true);
784      if (!(props1 & kUnweighted) && !(props2 & kUnweighted))
785        LOG(FATAL) << "ComposeFst: Weight needs to be a commutative semiring: "
786                   << Weight::Type();
787    }
788
789    // Case 1: flag COMPOSE_GENERIC disables optimizations.
790    if (T & COMPOSE_GENERIC) {
791      return new ComposeFstImpl<A, T>(fst1, fst2, opts);
792    }
793
794    const uint64 kStringDetOptProps =
795      kIDeterministic | kILabelSorted | kNoIEpsilons;
796    const uint64 kDetStringOptProps =
797      kODeterministic | kOLabelSorted | kNoOEpsilons;
798
799    // Case 2: fst1 is a string, fst2 is deterministic and epsilon-free.
800    if ((props1 & kString) &&
801        !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) &&
802        ((props2 & kStringDetOptProps) == kStringDetOptProps)) {
803      return new ComposeFstImpl<A, T | COMPOSE_FST1_STRING | COMPOSE_FST2_DET>(
804          fst1, fst2, opts);
805    }
806    // Case 3: fst1 is deterministic and epsilon-free, fst2 is string.
807    if ((props2 & kString) &&
808        !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) &&
809        ((props1 & kDetStringOptProps) == kDetStringOptProps)) {
810      return new ComposeFstImpl<A, T | COMPOSE_FST2_STRING | COMPOSE_FST1_DET>(
811          fst1, fst2, opts);
812    }
813
814    // Default case: no optimizations.
815    return new ComposeFstImpl<A, T>(fst1, fst2, opts);
816  }
817
818  void operator=(const ComposeFst<A> &fst);  // disallow
819};
820
821
822// Specialization for ComposeFst.
823template<class A>
824class StateIterator< ComposeFst<A> >
825    : public CacheStateIterator< ComposeFst<A> > {
826 public:
827  explicit StateIterator(const ComposeFst<A> &fst)
828      : CacheStateIterator< ComposeFst<A> >(fst) {}
829};
830
831
832// Specialization for ComposeFst.
833template <class A>
834class ArcIterator< ComposeFst<A> >
835    : public CacheArcIterator< ComposeFst<A> > {
836 public:
837  typedef typename A::StateId StateId;
838
839  ArcIterator(const ComposeFst<A> &fst, StateId s)
840      : CacheArcIterator< ComposeFst<A> >(fst, s) {
841    if (!fst.impl_->HasArcs(s))
842      fst.impl_->Expand(s);
843  }
844
845 private:
846  DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
847};
848
849template <class A> inline
850void ComposeFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
851  data->base = new StateIterator< ComposeFst<A> >(*this);
852}
853
854// Useful alias when using StdArc.
855typedef ComposeFst<StdArc> StdComposeFst;
856
857
858struct ComposeOptions {
859  bool connect;  // Connect output
860
861  ComposeOptions(bool c) : connect(c) {}
862  ComposeOptions() : connect(true) { }
863};
864
865
866// Computes the composition of two transducers. This version writes
867// the composed FST into a MurableFst. If FST1 transduces string x to
868// y with weight a and FST2 transduces y to z with weight b, then
869// their composition transduces string x to z with weight
870// Times(x, z).
871//
872// The output labels of the first transducer or the input labels of
873// the second transducer must be sorted.  The weights need to form a
874// commutative semiring (valid for TropicalWeight and LogWeight).
875//
876// Complexity:
877// Assuming the first FST is unsorted and the second is sorted:
878// - Time: O(V1 V2 D1 (log D2 + M2)),
879// - Space: O(V1 V2 D1 M2)
880// where Vi = # of states, Di = maximum out-degree, and Mi is
881// the maximum multiplicity for the ith FST.
882//
883// Caveats:
884// - Compose trims its output.
885// - The efficiency of composition can be strongly affected by several factors:
886//   - the choice of which tnansducer is sorted - prefer sorting the FST
887//     that has the greater average out-degree.
888//   - the amount of non-determinism
889//   - the presence and location of epsilon transitions - avoid epsilon
890//     transitions on the output side of the first transducer or
891//     the input side of the second transducer or prefer placing
892//     them later in a path since they delay matching and can
893//     introduce non-coaccessible states and transitions.
894template<class Arc>
895void Compose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
896             MutableFst<Arc> *ofst,
897             const ComposeOptions &opts = ComposeOptions()) {
898  ComposeFstOptions<> nopts;
899  nopts.gc_limit = 0;  // Cache only the last state for fastest copy.
900  *ofst = ComposeFst<Arc>(ifst1, ifst2, nopts);
901  if (opts.connect)
902    Connect(ofst);
903}
904
905}  // namespace fst
906
907#endif  // FST_LIB_COMPOSE_H__
908