compose.h revision 4a68b3365c8c50aa93505e99ead2565ab73dcdb0
1// compose.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// Class to compute the composition of two FSTs
18
19#ifndef FST_LIB_COMPOSE_H__
20#define FST_LIB_COMPOSE_H__
21
22#include <algorithm>
23
24#include <ext/hash_map>
25using __gnu_cxx::hash_map;
26
27#include "fst/lib/cache.h"
28#include "fst/lib/test-properties.h"
29
30namespace fst {
31
32// Enumeration of uint64 bits used to represent the user-defined
33// properties of FST composition (in the template parameter to
34// ComposeFstOptions<T>). The bits stand for extensions of generic FST
35// composition. ComposeFstOptions<> (all the bits unset) is the "plain"
36// compose without any extra extensions.
37enum ComposeTypes {
38  // RHO: flags dealing with a special "rest" symbol in the FSTs.
39  // NB: at most one of the bits COMPOSE_FST1_RHO, COMPOSE_FST2_RHO
40  // may be set.
41  COMPOSE_FST1_RHO    = 1ULL<<0,  // "Rest" symbol on the output side of fst1.
42  COMPOSE_FST2_RHO    = 1ULL<<1,  // "Rest" symbol on the input side of fst2.
43  COMPOSE_FST1_PHI    = 1ULL<<2,  // "Failure" symbol on the output
44                                  // side of fst1.
45  COMPOSE_FST2_PHI    = 1ULL<<3,  // "Failure" symbol on the input side
46                                  // of fst2.
47  COMPOSE_FST1_SIGMA  = 1ULL<<4,  // "Any" symbol on the output side of
48                                  // fst1.
49  COMPOSE_FST2_SIGMA  = 1ULL<<5,  // "Any" symbol on the input side of
50                                  // fst2.
51  // Optimization related bits.
52  COMPOSE_GENERIC     = 1ULL<<32,  // Disables optimizations, applies
53                                   // the generic version of the
54                                   // composition algorithm. This flag
55                                   // is used for internal testing
56                                   // only.
57
58  // -----------------------------------------------------------------
59  // Auxiliary enum values denoting specific combinations of
60  // bits. Internal use only.
61  COMPOSE_RHO         = COMPOSE_FST1_RHO | COMPOSE_FST2_RHO,
62  COMPOSE_PHI         = COMPOSE_FST1_PHI | COMPOSE_FST2_PHI,
63  COMPOSE_SIGMA       = COMPOSE_FST1_SIGMA | COMPOSE_FST2_SIGMA,
64  COMPOSE_SPECIAL_SYMBOLS = COMPOSE_RHO | COMPOSE_PHI | COMPOSE_SIGMA,
65
66  // -----------------------------------------------------------------
67  // The following bits, denoting specific optimizations, are
68  // typically set *internally* by the composition algorithm.
69  COMPOSE_FST1_STRING = 1ULL<<33,  // fst1 is a string
70  COMPOSE_FST2_STRING = 1ULL<<34,  // fst2 is a string
71  COMPOSE_FST1_DET    = 1ULL<<35,  // fst1 is deterministic
72  COMPOSE_FST2_DET    = 1ULL<<36,  // fst2 is deterministic
73  COMPOSE_INTERNAL_MASK    = 0xffffffff00000000ULL
74};
75
76
77template <uint64 T = 0ULL>
78struct ComposeFstOptions : public CacheOptions {
79  explicit ComposeFstOptions(const CacheOptions &opts) : CacheOptions(opts) {}
80  ComposeFstOptions() { }
81};
82
83
84// Abstract base for the implementation of delayed ComposeFst. The
85// concrete specializations are templated on the (uint64-valued)
86// properties of the FSTs being composed.
87template <class A>
88class ComposeFstImplBase : public CacheImpl<A> {
89 public:
90  using FstImpl<A>::SetType;
91  using FstImpl<A>::SetProperties;
92  using FstImpl<A>::Properties;
93  using FstImpl<A>::SetInputSymbols;
94  using FstImpl<A>::SetOutputSymbols;
95
96  using CacheBaseImpl< CacheState<A> >::HasStart;
97  using CacheBaseImpl< CacheState<A> >::HasFinal;
98  using CacheBaseImpl< CacheState<A> >::HasArcs;
99
100  typedef typename A::Label Label;
101  typedef typename A::Weight Weight;
102  typedef typename A::StateId StateId;
103  typedef CacheState<A> State;
104
105  ComposeFstImplBase(const Fst<A> &fst1,
106                     const Fst<A> &fst2,
107                     const CacheOptions &opts)
108      :CacheImpl<A>(opts), fst1_(fst1.Copy()), fst2_(fst2.Copy()) {
109    SetType("compose");
110    uint64 props1 = fst1.Properties(kFstProperties, false);
111    uint64 props2 = fst2.Properties(kFstProperties, false);
112    SetProperties(ComposeProperties(props1, props2), kCopyProperties);
113
114    if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols()))
115      LOG(FATAL) << "ComposeFst: output symbol table of 1st argument "
116                 << "does not match input symbol table of 2nd argument";
117
118    SetInputSymbols(fst1.InputSymbols());
119    SetOutputSymbols(fst2.OutputSymbols());
120  }
121
122  virtual ~ComposeFstImplBase() {
123    delete fst1_;
124    delete fst2_;
125  }
126
127  StateId Start() {
128    if (!HasStart()) {
129      StateId start = ComputeStart();
130      if (start != kNoStateId) {
131        SetStart(start);
132      }
133    }
134    return CacheImpl<A>::Start();
135  }
136
137  Weight Final(StateId s) {
138    if (!HasFinal(s)) {
139      Weight final = ComputeFinal(s);
140      SetFinal(s, final);
141    }
142    return CacheImpl<A>::Final(s);
143  }
144
145  virtual void Expand(StateId s) = 0;
146
147  size_t NumArcs(StateId s) {
148    if (!HasArcs(s))
149      Expand(s);
150    return CacheImpl<A>::NumArcs(s);
151  }
152
153  size_t NumInputEpsilons(StateId s) {
154    if (!HasArcs(s))
155      Expand(s);
156    return CacheImpl<A>::NumInputEpsilons(s);
157  }
158
159  size_t NumOutputEpsilons(StateId s) {
160    if (!HasArcs(s))
161      Expand(s);
162    return CacheImpl<A>::NumOutputEpsilons(s);
163  }
164
165  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
166    if (!HasArcs(s))
167      Expand(s);
168    CacheImpl<A>::InitArcIterator(s, data);
169  }
170
171  // Access to flags encoding compose options/optimizations etc.  (for
172  // debugging).
173  virtual uint64 ComposeFlags() const = 0;
174
175 protected:
176  virtual StateId ComputeStart() = 0;
177  virtual Weight ComputeFinal(StateId s) = 0;
178
179  const Fst<A> *fst1_;            // first input Fst
180  const Fst<A> *fst2_;            // second input Fst
181};
182
183
184// The following class encapsulates implementation-dependent details
185// of state tuple lookup, i.e. a bijective mapping from triples of two
186// FST states and an epsilon filter state to the corresponding state
187// IDs of the fst resulting from composition. The mapping must
188// implement the [] operator in the style of STL associative
189// containers (map, hash_map), i.e. table[x] must return a reference
190// to the value associated with x. If x is an unassigned tuple, the
191// operator must automatically associate x with value 0.
192//
193// NB: "table[x] == 0" for unassigned tuples x is required by the
194// following off-by-one device used in the implementation of
195// ComposeFstImpl. The value stored in the table is equal to tuple ID
196// plus one, i.e. it is always a strictly positive number. Therefore,
197// table[x] is equal to 0 if and only if x is an unassigned tuple (in
198// which the algorithm assigns a new ID to x, and sets table[x] -
199// stored in a reference - to "new ID + 1"). This form of lookup is
200// more efficient than calling "find(x)" and "insert(make_pair(x, new
201// ID))" if x is an unassigned tuple.
202//
203// The generic implementation is a wrapper around a hash_map.
204template <class A, uint64 T>
205class ComposeStateTable {
206 public:
207  typedef typename A::StateId StateId;
208
209  struct StateTuple {
210    StateTuple() {}
211    StateTuple(StateId s1, StateId s2, int f)
212        : state_id1(s1), state_id2(s2), filt(f) {}
213    StateId state_id1;  // state Id on fst1
214    StateId state_id2;  // state Id on fst2
215    int filt;           // epsilon filter state
216  };
217
218  ComposeStateTable() {
219    StateTuple empty_tuple(kNoStateId, kNoStateId, 0);
220  }
221
222  // NB: if 'tuple' is not in 'table_', the pair (tuple, StateId()) is
223  // inserted into 'table_' (standard STL container semantics). Since
224  // StateId is a built-in type, the explicit default constructor call
225  // StateId() returns 0.
226  StateId &operator[](const StateTuple &tuple) {
227    return table_[tuple];
228  }
229
230 private:
231  // Comparison object for hashing StateTuple(s).
232  class StateTupleEqual {
233   public:
234    bool operator()(const StateTuple& x, const StateTuple& y) const {
235      return x.state_id1 == y.state_id1 &&
236             x.state_id2 == y.state_id2 &&
237             x.filt == y.filt;
238    }
239  };
240
241  static const int kPrime0 = 7853;
242  static const int kPrime1 = 7867;
243
244  // Hash function for StateTuple to Fst states.
245  class StateTupleKey {
246   public:
247    size_t operator()(const StateTuple& x) const {
248      return static_cast<size_t>(x.state_id1 +
249                                 x.state_id2 * kPrime0 +
250                                 x.filt * kPrime1);
251    }
252  };
253
254  // Lookup table mapping state tuples to state IDs.
255  typedef hash_map<StateTuple,
256                         StateId,
257                         StateTupleKey,
258                         StateTupleEqual> StateTable;
259 // Actual table data.
260  StateTable table_;
261
262  DISALLOW_EVIL_CONSTRUCTORS(ComposeStateTable);
263};
264
265
266// State tuple lookup table for the composition of a string FST with a
267// deterministic FST.  The class maps state tuples to their unique IDs
268// (i.e. states of the ComposeFst). Main optimization: due to the
269// 1-to-1 correspondence between the states of the input string FST
270// and those of the resulting (string) FST, a state tuple (s1, s2) is
271// simply mapped to StateId s1. Hence, we use an STL vector as a
272// lookup table. Template argument Fst1IsString specifies which FST is
273// a string (this determines whether or not we index the lookup table
274// by the first or by the second state).
275template <class A, bool Fst1IsString>
276class StringDetComposeStateTable {
277 public:
278  typedef typename A::StateId StateId;
279
280  struct StateTuple {
281    typedef typename A::StateId StateId;
282    StateTuple() {}
283    StateTuple(StateId s1, StateId s2, int /* f */)
284        : state_id1(s1), state_id2(s2) {}
285    StateId state_id1;  // state Id on fst1
286    StateId state_id2;  // state Id on fst2
287    static const int filt = 0;  // 'fake' epsilon filter - only needed
288                                // for API compatibility
289  };
290
291  StringDetComposeStateTable() {}
292
293  // Subscript operator. Behaves in a way similar to its map/hash_map
294  // counterpart, i.e. returns a reference to the value associated
295  // with 'tuple', inserting a 0 value if 'tuple' is unassigned.
296  StateId &operator[](const StateTuple &tuple) {
297    StateId index = Fst1IsString ? tuple.state_id1 : tuple.state_id2;
298    if (index >= (StateId)data_.size()) {
299      // NB: all values in [old_size; index] are initialized to 0.
300      data_.resize(index + 1);
301    }
302    return data_[index];
303  }
304
305 private:
306  vector<StateId> data_;
307
308  DISALLOW_EVIL_CONSTRUCTORS(StringDetComposeStateTable);
309};
310
311
312// Specializations of ComposeStateTable for the string/det case.
313// Both inherit from StringDetComposeStateTable.
314template <class A>
315class ComposeStateTable<A, COMPOSE_FST1_STRING | COMPOSE_FST2_DET>
316    : public StringDetComposeStateTable<A, true> { };
317
318template <class A>
319class ComposeStateTable<A, COMPOSE_FST2_STRING | COMPOSE_FST1_DET>
320    : public StringDetComposeStateTable<A, false> { };
321
322
323// Parameterized implementation of FST composition for a pair of FSTs
324// matching the property bit vector T. If possible,
325// instantiation-specific switches in the code are based on the values
326// of the bits in T, which are known at compile time, so unused code
327// should be optimized away by the compiler.
328template <class A, uint64 T>
329class ComposeFstImpl : public ComposeFstImplBase<A> {
330  typedef typename A::StateId StateId;
331  typedef typename A::Label   Label;
332  typedef typename A::Weight  Weight;
333  using FstImpl<A>::SetType;
334  using FstImpl<A>::SetProperties;
335
336  enum FindType { FIND_INPUT  = 1,          // find input label on fst2
337                  FIND_OUTPUT = 2,          // find output label on fst1
338                  FIND_BOTH   = 3 };        // find choice state dependent
339
340  typedef ComposeStateTable<A, T & COMPOSE_INTERNAL_MASK> StateTupleTable;
341  typedef typename StateTupleTable::StateTuple StateTuple;
342
343 public:
344  ComposeFstImpl(const Fst<A> &fst1,
345                 const Fst<A> &fst2,
346                 const CacheOptions &opts)
347      :ComposeFstImplBase<A>(fst1, fst2, opts) {
348
349    bool osorted = fst1.Properties(kOLabelSorted, false);
350    bool isorted = fst2.Properties(kILabelSorted, false);
351
352    switch (T & COMPOSE_SPECIAL_SYMBOLS) {
353      case COMPOSE_FST1_RHO:
354      case COMPOSE_FST1_PHI:
355      case COMPOSE_FST1_SIGMA:
356        if (!osorted || FLAGS_fst_verify_properties)
357          osorted = fst1.Properties(kOLabelSorted, true);
358        if (!osorted)
359          LOG(FATAL) << "ComposeFst: 1st argument not output label "
360                     << "sorted (special symbols present)";
361        break;
362      case COMPOSE_FST2_RHO:
363      case COMPOSE_FST2_PHI:
364      case COMPOSE_FST2_SIGMA:
365        if (!isorted || FLAGS_fst_verify_properties)
366          isorted = fst2.Properties(kILabelSorted, true);
367        if (!isorted)
368          LOG(FATAL) << "ComposeFst: 2nd argument not input label "
369                     << "sorted (special symbols present)";
370        break;
371      case 0:
372        if (!isorted && !osorted || FLAGS_fst_verify_properties) {
373          osorted = fst1.Properties(kOLabelSorted, true);
374          if (!osorted)
375            isorted = fst2.Properties(kILabelSorted, true);
376        }
377        break;
378      default:
379        LOG(FATAL)
380          << "ComposeFst: More than one special symbol used in composition";
381    }
382
383    if (isorted && (T & COMPOSE_FST2_SIGMA)) {
384      find_type_ = FIND_INPUT;
385    } else if (osorted && (T & COMPOSE_FST1_SIGMA)) {
386      find_type_ = FIND_OUTPUT;
387    } else if (isorted && (T & COMPOSE_FST2_PHI)) {
388      find_type_ = FIND_INPUT;
389    } else if (osorted && (T & COMPOSE_FST1_PHI)) {
390      find_type_ = FIND_OUTPUT;
391    } else if (isorted && (T & COMPOSE_FST2_RHO)) {
392      find_type_ = FIND_INPUT;
393    } else if (osorted && (T & COMPOSE_FST1_RHO)) {
394      find_type_ = FIND_OUTPUT;
395    } else if (isorted && (T & COMPOSE_FST1_STRING)) {
396      find_type_ = FIND_INPUT;
397    } else if(osorted && (T & COMPOSE_FST2_STRING)) {
398      find_type_ = FIND_OUTPUT;
399    } else if (isorted && osorted) {
400      find_type_ = FIND_BOTH;
401    } else if (isorted) {
402      find_type_ = FIND_INPUT;
403    } else if (osorted) {
404      find_type_ = FIND_OUTPUT;
405    } else {
406      LOG(FATAL) << "ComposeFst: 1st argument not output label sorted "
407                 << "and 2nd argument is not input label sorted";
408    }
409  }
410
411  // Finds/creates an Fst state given a StateTuple.  Only creates a new
412  // state if StateTuple is not found in the state hash.
413  //
414  // The method exploits the following device: all pairs stored in the
415  // associative container state_tuple_table_ are of the form (tuple,
416  // id(tuple) + 1), i.e. state_tuple_table_[tuple] > 0 if tuple has
417  // been stored previously. For unassigned tuples, the call to
418  // state_tuple_table_[tuple] creates a new pair (tuple, 0). As a
419  // result, state_tuple_table_[tuple] == 0 iff tuple is new.
420  StateId FindState(const StateTuple& tuple) {
421    StateId &assoc_value = state_tuple_table_[tuple];
422    if (assoc_value == 0) {  // tuple wasn't present in lookup table:
423                             // assign it a new ID.
424      state_tuples_.push_back(tuple);
425      assoc_value = state_tuples_.size();
426    }
427    return assoc_value - 1;  // NB: assoc_value = ID + 1
428  }
429
430  // Generates arc for composition state s from matched input Fst arcs.
431  void AddArc(StateId s, const A &arca, const A &arcb, int f,
432              bool find_input) {
433    A arc;
434    if (find_input) {
435      arc.ilabel = arcb.ilabel;
436      arc.olabel = arca.olabel;
437      arc.weight = Times(arcb.weight, arca.weight);
438      StateTuple tuple(arcb.nextstate, arca.nextstate, f);
439      arc.nextstate = FindState(tuple);
440    } else {
441      arc.ilabel = arca.ilabel;
442      arc.olabel = arcb.olabel;
443      arc.weight = Times(arca.weight, arcb.weight);
444      StateTuple tuple(arca.nextstate, arcb.nextstate, f);
445      arc.nextstate = FindState(tuple);
446    }
447    CacheImpl<A>::AddArc(s, arc);
448  }
449
450  // Arranges it so that the first arg to OrderedExpand is the Fst
451  // that will be passed to FindLabel.
452  void Expand(StateId s) {
453    StateTuple &tuple = state_tuples_[s];
454    StateId s1 = tuple.state_id1;
455    StateId s2 = tuple.state_id2;
456    int f = tuple.filt;
457    if (find_type_ == FIND_INPUT)
458      OrderedExpand(s, ComposeFstImplBase<A>::fst2_, s2,
459                    ComposeFstImplBase<A>::fst1_, s1, f, true);
460    else
461      OrderedExpand(s, ComposeFstImplBase<A>::fst1_, s1,
462                    ComposeFstImplBase<A>::fst2_, s2, f, false);
463  }
464
465  // Access to flags encoding compose options/optimizations etc.  (for
466  // debugging).
467  virtual uint64 ComposeFlags() const { return T; }
468
469 private:
470  // This does that actual matching of labels in the composition. The
471  // arguments are ordered so FindLabel is called with state SA of
472  // FSTA for each arc leaving state SB of FSTB. The FIND_INPUT arg
473  // determines whether the input or output label of arcs at SB is
474  // the one to match on.
475  void OrderedExpand(StateId s, const Fst<A> *fsta, StateId sa,
476                     const Fst<A> *fstb, StateId sb, int f, bool find_input) {
477
478    size_t numarcsa = fsta->NumArcs(sa);
479    size_t numepsa = find_input ? fsta->NumInputEpsilons(sa) :
480                     fsta->NumOutputEpsilons(sa);
481    bool finala = fsta->Final(sa) != Weight::Zero();
482    ArcIterator< Fst<A> > aitera(*fsta, sa);
483    // First handle special epsilons and sigmas on FSTA
484    for (; !aitera.Done(); aitera.Next()) {
485      const A &arca = aitera.Value();
486      Label match_labela = find_input ? arca.ilabel : arca.olabel;
487      if (match_labela > 0) {
488        break;
489      }
490      if ((T & COMPOSE_SIGMA) != 0 &&  match_labela == kSigmaLabel) {
491        // Found a sigma? Match it against all (non-special) symbols
492        // on side b.
493        for (ArcIterator< Fst<A> > aiterb(*fstb, sb);
494             !aiterb.Done();
495             aiterb.Next()) {
496          const A &arcb = aiterb.Value();
497          Label labelb = find_input ? arcb.olabel : arcb.ilabel;
498          if (labelb <= 0) continue;
499          AddArc(s, arca, arcb, 0, find_input);
500        }
501      } else if (f == 0 && match_labela == 0) {
502        A earcb(0, 0, Weight::One(), sb);
503        AddArc(s, arca, earcb, 0, find_input);  // move forward on epsilon
504      }
505    }
506    // Next handle non-epsilon matches, rho labels, and epsilons on FSTB
507    for (ArcIterator< Fst<A> > aiterb(*fstb, sb);
508         !aiterb.Done();
509         aiterb.Next()) {
510      const A &arcb = aiterb.Value();
511      Label match_labelb = find_input ? arcb.olabel : arcb.ilabel;
512      if (match_labelb) {  // Consider non-epsilon match
513        if (FindLabel(&aitera, numarcsa, match_labelb, find_input)) {
514          for (; !aitera.Done(); aitera.Next()) {
515            const A &arca = aitera.Value();
516            Label match_labela = find_input ? arca.ilabel : arca.olabel;
517            if (match_labela != match_labelb)
518              break;
519            AddArc(s, arca, arcb, 0, find_input);  // move forward on match
520          }
521        } else if ((T & COMPOSE_SPECIAL_SYMBOLS) != 0) {
522          // If there is no transition labelled 'match_labelb' in
523          // fsta, try matching 'match_labelb' against special symbols
524          // (Phi, Rho,...).
525          for (aitera.Reset(); !aitera.Done(); aitera.Next()) {
526            A arca = aitera.Value();
527            Label labela = find_input ? arca.ilabel : arca.olabel;
528            if (labela >= 0) {
529              break;
530            } else if (((T & COMPOSE_PHI) != 0) && (labela == kPhiLabel)) {
531              // Case 1: if a failure transition exists, follow its
532              // transitive closure until a) a transition labelled
533              // 'match_labelb' is found, or b) the initial state of
534              // fsta is reached.
535
536              StateId sf = sa;  // Start of current failure transition.
537              while (labela == kPhiLabel && sf != arca.nextstate) {
538                sf = arca.nextstate;
539
540                size_t numarcsf = fsta->NumArcs(sf);
541                ArcIterator< Fst<A> > aiterf(*fsta, sf);
542                if (FindLabel(&aiterf, numarcsf, match_labelb, find_input)) {
543                  // Sub-case 1a: there exists a transition starting
544                  // in sf and consuming symbol 'match_labelb'.
545                  AddArc(s, aiterf.Value(), arcb, 0, find_input);
546                  break;
547                } else {
548                  // No transition labelled 'match_labelb' found: try
549                  // next failure transition (starting at 'sf').
550                  for (aiterf.Reset(); !aiterf.Done(); aiterf.Next()) {
551                    arca = aiterf.Value();
552                    labela = find_input ? arca.ilabel : arca.olabel;
553                    if (labela >= kPhiLabel) break;
554                  }
555                }
556              }
557              if (labela == kPhiLabel && sf == arca.nextstate) {
558                // Sub-case 1b: failure transitions lead to start
559                // state without finding a matching
560                // transition. Therefore, we generate a loop in start
561                // state of fsta.
562                A loop(match_labelb, match_labelb, Weight::One(), sf);
563                AddArc(s, loop, arcb, 0, find_input);
564              }
565            } else if (((T & COMPOSE_RHO) != 0) && (labela == kRhoLabel)) {
566              // Case 2: 'match_labelb' can be matched against a
567              // "rest" (rho) label in fsta.
568              if (find_input) {
569                arca.ilabel = match_labelb;
570                if (arca.olabel == kRhoLabel)
571                  arca.olabel = match_labelb;
572              } else {
573                arca.olabel = match_labelb;
574                if (arca.ilabel == kRhoLabel)
575                  arca.ilabel = match_labelb;
576              }
577              AddArc(s, arca, arcb, 0, find_input);  // move fwd on match
578            }
579          }
580        }
581      } else if (numepsa != numarcsa || finala) {  // Handle FSTB epsilon
582        A earca(0, 0, Weight::One(), sa);
583        AddArc(s, earca, arcb, numepsa > 0, find_input);  // move on epsilon
584      }
585    }
586    SetArcs(s);
587   }
588
589
590  // Finds matches to MATCH_LABEL in arcs given by AITER
591  // using FIND_INPUT to determine whether to look on input or output.
592  bool FindLabel(ArcIterator< Fst<A> > *aiter, size_t numarcs,
593                 Label match_label, bool find_input) {
594    // binary search for match
595    size_t low = 0;
596    size_t high = numarcs;
597    while (low < high) {
598      size_t mid = (low + high) / 2;
599      aiter->Seek(mid);
600      Label label = find_input ?
601                    aiter->Value().ilabel : aiter->Value().olabel;
602      if (label > match_label) {
603        high = mid;
604      } else if (label < match_label) {
605        low = mid + 1;
606      } else {
607        // find first matching label (when non-determinism)
608        for (size_t i = mid; i > low; --i) {
609          aiter->Seek(i - 1);
610          label = find_input ? aiter->Value().ilabel : aiter->Value().olabel;
611          if (label != match_label) {
612            aiter->Seek(i);
613            return true;
614          }
615        }
616        return true;
617      }
618    }
619    return false;
620  }
621
622  StateId ComputeStart() {
623    StateId s1 = ComposeFstImplBase<A>::fst1_->Start();
624    StateId s2 = ComposeFstImplBase<A>::fst2_->Start();
625    if (s1 == kNoStateId || s2 == kNoStateId)
626      return kNoStateId;
627    StateTuple tuple(s1, s2, 0);
628    return FindState(tuple);
629  }
630
631  Weight ComputeFinal(StateId s) {
632    StateTuple &tuple = state_tuples_[s];
633    Weight final = Times(ComposeFstImplBase<A>::fst1_->Final(tuple.state_id1),
634                         ComposeFstImplBase<A>::fst2_->Final(tuple.state_id2));
635    return final;
636  }
637
638
639  FindType find_type_;            // find label on which side?
640
641  // Maps from StateId to StateTuple.
642  vector<StateTuple> state_tuples_;
643
644  // Maps from StateTuple to StateId.
645  StateTupleTable state_tuple_table_;
646
647  DISALLOW_EVIL_CONSTRUCTORS(ComposeFstImpl);
648};
649
650
651// Computes the composition of two transducers. This version is a
652// delayed Fst. If FST1 transduces string x to y with weight a and FST2
653// transduces y to z with weight b, then their composition transduces
654// string x to z with weight Times(x, z).
655//
656// The output labels of the first transducer or the input labels of
657// the second transducer must be sorted.  The weights need to form a
658// commutative semiring (valid for TropicalWeight and LogWeight).
659//
660// Complexity:
661// Assuming the first FST is unsorted and the second is sorted:
662// - Time: O(v1 v2 d1 (log d2 + m2)),
663// - Space: O(v1 v2)
664// where vi = # of states visited, di = maximum out-degree, and mi the
665// maximum multiplicity of the states visited for the ith
666// FST. Constant time and space to visit an input state or arc is
667// assumed and exclusive of caching.
668//
669// Caveats:
670// - ComposeFst does not trim its output (since it is a delayed operation).
671// - The efficiency of composition can be strongly affected by several factors:
672//   - the choice of which tnansducer is sorted - prefer sorting the FST
673//     that has the greater average out-degree.
674//   - the amount of non-determinism
675//   - the presence and location of epsilon transitions - avoid epsilon
676//     transitions on the output side of the first transducer or
677//     the input side of the second transducer or prefer placing
678//     them later in a path since they delay matching and can
679//     introduce non-coaccessible states and transitions.
680template <class A>
681class ComposeFst : public Fst<A> {
682 public:
683  friend class ArcIterator< ComposeFst<A> >;
684  friend class CacheStateIterator< ComposeFst<A> >;
685  friend class CacheArcIterator< ComposeFst<A> >;
686
687  typedef A Arc;
688  typedef typename A::Weight Weight;
689  typedef typename A::StateId StateId;
690  typedef CacheState<A> State;
691
692  ComposeFst(const Fst<A> &fst1, const Fst<A> &fst2)
693      : impl_(Init(fst1, fst2, ComposeFstOptions<>())) { }
694
695  template <uint64 T>
696  ComposeFst(const Fst<A> &fst1,
697             const Fst<A> &fst2,
698             const ComposeFstOptions<T> &opts)
699      : impl_(Init(fst1, fst2, opts)) { }
700
701  ComposeFst(const ComposeFst<A> &fst) : Fst<A>(fst), impl_(fst.impl_) {
702    impl_->IncrRefCount();
703  }
704
705  virtual ~ComposeFst() { if (!impl_->DecrRefCount()) delete impl_;  }
706
707  virtual StateId Start() const { return impl_->Start(); }
708
709  virtual Weight Final(StateId s) const { return impl_->Final(s); }
710
711  virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
712
713  virtual size_t NumInputEpsilons(StateId s) const {
714    return impl_->NumInputEpsilons(s);
715  }
716
717  virtual size_t NumOutputEpsilons(StateId s) const {
718    return impl_->NumOutputEpsilons(s);
719  }
720
721  virtual uint64 Properties(uint64 mask, bool test) const {
722    if (test) {
723      uint64 known, test = TestProperties(*this, mask, &known);
724      impl_->SetProperties(test, known);
725      return test & mask;
726    } else {
727      return impl_->Properties(mask);
728    }
729  }
730
731  virtual const string& Type() const { return impl_->Type(); }
732
733  virtual ComposeFst<A> *Copy() const {
734    return new ComposeFst<A>(*this);
735  }
736
737  virtual const SymbolTable* InputSymbols() const {
738    return impl_->InputSymbols();
739  }
740
741  virtual const SymbolTable* OutputSymbols() const {
742    return impl_->OutputSymbols();
743  }
744
745  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
746
747  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
748    impl_->InitArcIterator(s, data);
749  }
750
751  // Access to flags encoding compose options/optimizations etc.  (for
752  // debugging).
753  uint64 ComposeFlags() const { return impl_->ComposeFlags(); }
754
755 protected:
756  ComposeFstImplBase<A> *Impl() { return impl_; }
757
758 private:
759  ComposeFstImplBase<A> *impl_;
760
761  // Auxiliary method encapsulating the creation of a ComposeFst
762  // implementation that is appropriate for the properties of fst1 and
763  // fst2.
764  template <uint64 T>
765  static ComposeFstImplBase<A> *Init(
766      const Fst<A> &fst1,
767      const Fst<A> &fst2,
768      const ComposeFstOptions<T> &opts) {
769
770    // Filter for sort properties (forces a property check).
771    uint64 sort_props_mask = kILabelSorted | kOLabelSorted;
772    // Filter for optimization-related properties (does not force a
773    // property-check).
774    uint64 opt_props_mask =
775      kString | kIDeterministic | kODeterministic | kNoIEpsilons |
776      kNoOEpsilons;
777
778    uint64 props1 = fst1.Properties(sort_props_mask, true);
779    uint64 props2 = fst2.Properties(sort_props_mask, true);
780
781    props1 |= fst1.Properties(opt_props_mask, false);
782    props2 |= fst2.Properties(opt_props_mask, false);
783
784    if (!(Weight::Properties() & kCommutative)) {
785      props1 |= fst1.Properties(kUnweighted, true);
786      props2 |= fst2.Properties(kUnweighted, true);
787      if (!(props1 & kUnweighted) && !(props2 & kUnweighted))
788        LOG(FATAL) << "ComposeFst: Weight needs to be a commutative semiring: "
789                   << Weight::Type();
790    }
791
792    // Case 1: flag COMPOSE_GENERIC disables optimizations.
793    if (T & COMPOSE_GENERIC) {
794      return new ComposeFstImpl<A, T>(fst1, fst2, opts);
795    }
796
797    const uint64 kStringDetOptProps =
798      kIDeterministic | kILabelSorted | kNoIEpsilons;
799    const uint64 kDetStringOptProps =
800      kODeterministic | kOLabelSorted | kNoOEpsilons;
801
802    // Case 2: fst1 is a string, fst2 is deterministic and epsilon-free.
803    if ((props1 & kString) &&
804        !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) &&
805        ((props2 & kStringDetOptProps) == kStringDetOptProps)) {
806      return new ComposeFstImpl<A, T | COMPOSE_FST1_STRING | COMPOSE_FST2_DET>(
807          fst1, fst2, opts);
808    }
809    // Case 3: fst1 is deterministic and epsilon-free, fst2 is string.
810    if ((props2 & kString) &&
811        !(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) &&
812        ((props1 & kDetStringOptProps) == kDetStringOptProps)) {
813      return new ComposeFstImpl<A, T | COMPOSE_FST2_STRING | COMPOSE_FST1_DET>(
814          fst1, fst2, opts);
815    }
816
817    // Default case: no optimizations.
818    return new ComposeFstImpl<A, T>(fst1, fst2, opts);
819  }
820
821  void operator=(const ComposeFst<A> &fst);  // disallow
822};
823
824
825// Specialization for ComposeFst.
826template<class A>
827class StateIterator< ComposeFst<A> >
828    : public CacheStateIterator< ComposeFst<A> > {
829 public:
830  explicit StateIterator(const ComposeFst<A> &fst)
831      : CacheStateIterator< ComposeFst<A> >(fst) {}
832};
833
834
835// Specialization for ComposeFst.
836template <class A>
837class ArcIterator< ComposeFst<A> >
838    : public CacheArcIterator< ComposeFst<A> > {
839 public:
840  typedef typename A::StateId StateId;
841
842  ArcIterator(const ComposeFst<A> &fst, StateId s)
843      : CacheArcIterator< ComposeFst<A> >(fst, s) {
844    if (!fst.impl_->HasArcs(s))
845      fst.impl_->Expand(s);
846  }
847
848 private:
849  DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
850};
851
852template <class A> inline
853void ComposeFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
854  data->base = new StateIterator< ComposeFst<A> >(*this);
855}
856
857// Useful alias when using StdArc.
858typedef ComposeFst<StdArc> StdComposeFst;
859
860
861struct ComposeOptions {
862  bool connect;  // Connect output
863
864  ComposeOptions(bool c) : connect(c) {}
865  ComposeOptions() : connect(true) { }
866};
867
868
869// Computes the composition of two transducers. This version writes
870// the composed FST into a MurableFst. If FST1 transduces string x to
871// y with weight a and FST2 transduces y to z with weight b, then
872// their composition transduces string x to z with weight
873// Times(x, z).
874//
875// The output labels of the first transducer or the input labels of
876// the second transducer must be sorted.  The weights need to form a
877// commutative semiring (valid for TropicalWeight and LogWeight).
878//
879// Complexity:
880// Assuming the first FST is unsorted and the second is sorted:
881// - Time: O(V1 V2 D1 (log D2 + M2)),
882// - Space: O(V1 V2 D1 M2)
883// where Vi = # of states, Di = maximum out-degree, and Mi is
884// the maximum multiplicity for the ith FST.
885//
886// Caveats:
887// - Compose trims its output.
888// - The efficiency of composition can be strongly affected by several factors:
889//   - the choice of which tnansducer is sorted - prefer sorting the FST
890//     that has the greater average out-degree.
891//   - the amount of non-determinism
892//   - the presence and location of epsilon transitions - avoid epsilon
893//     transitions on the output side of the first transducer or
894//     the input side of the second transducer or prefer placing
895//     them later in a path since they delay matching and can
896//     introduce non-coaccessible states and transitions.
897template<class Arc>
898void Compose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
899             MutableFst<Arc> *ofst,
900             const ComposeOptions &opts = ComposeOptions()) {
901  ComposeFstOptions<> nopts;
902  nopts.gc_limit = 0;  // Cache only the last state for fastest copy.
903  *ofst = ComposeFst<Arc>(ifst1, ifst2, nopts);
904  if (opts.connect)
905    Connect(ofst);
906}
907
908}  // namespace fst
909
910#endif  // FST_LIB_COMPOSE_H__
911