replace.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
147be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// replace.h
247be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org
347be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// Licensed under the Apache License, Version 2.0 (the "License");
447be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// you may not use this file except in compliance with the License.
547be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// You may obtain a copy of the License at
647be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//
747be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//     http://www.apache.org/licenses/LICENSE-2.0
847be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//
947be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// Unless required by applicable law or agreed to in writing, software
1047be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// distributed under the License is distributed on an "AS IS" BASIS,
1147be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1247be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// See the License for the specific language governing permissions and
1347be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// limitations under the License.
1447be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//
1547be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// Copyright 2005-2010 Google, Inc.
1647be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// Author: johans@google.com (Johan Schalkwyk)
1747be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//
1847be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// \file
1947be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// Functions and classes for the recursive replacement of Fsts.
2047be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//
2147be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org
2247be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org#ifndef FST_LIB_REPLACE_H__
2347be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org#define FST_LIB_REPLACE_H__
2447be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org
2547be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org#include <unordered_map>
2647be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.orgusing std::tr1::unordered_map;
2747be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.orgusing std::tr1::unordered_multimap;
2847be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org#include <set>
2947be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org#include <string>
3047be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org#include <utility>
3147be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.orgusing std::pair; using std::make_pair;
3247be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org#include <vector>
3347be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.orgusing std::vector;
3447be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org
3547be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org#include <fst/cache.h>
3647be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org#include <fst/expanded-fst.h>
3747be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org#include <fst/fst.h>
3847be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org#include <fst/matcher.h>
3947be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org#include <fst/replace-util.h>
4047be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org#include <fst/state-table.h>
4147be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org#include <fst/test-properties.h>
4247be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org
4347be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.orgnamespace fst {
4447be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org
4547be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//
4647be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// REPLACE STATE TUPLES AND TABLES
4747be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//
4847be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// The replace state table has the form
4947be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//
5047be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// template <class A, class P>
5147be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// class ReplaceStateTable {
5247be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//  public:
5347be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//   typedef A Arc;
5447be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//   typedef P PrefixId;
5547be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//   typedef typename A::StateId StateId;
5647be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//   typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
5747be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//   typedef typename A::Label Label;
5847be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//
5947be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//   // Required constuctor
6047be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//   ReplaceStateTable(const vector<pair<Label, const Fst<A>*> > &fst_tuples,
6147be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//                     Label root);
6247be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//
6347be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//   // Required copy constructor that does not copy state
6447be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//   ReplaceStateTable(const ReplaceStateTable<A,P> &table);
6547be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//
6647be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//   // Lookup state ID by tuple. If it doesn't exist, then add it.
6747be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//   StateId FindState(const StateTuple &tuple);
6847be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//
6947be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//   // Lookup state tuple by ID.
7047be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org//   const StateTuple &Tuple(StateId id) const;
7147be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// };
7247be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org
7347be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org
7447be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// \struct ReplaceStateTuple
7547be73b8629244d6bb63a28198f97f040ce53d21henrike@webrtc.org// \brief Tuple of information that uniquely defines a state in replace
76template <class S, class P>
77struct ReplaceStateTuple {
78  typedef S StateId;
79  typedef P PrefixId;
80
81  ReplaceStateTuple()
82      : prefix_id(-1), fst_id(kNoStateId), fst_state(kNoStateId) {}
83
84  ReplaceStateTuple(PrefixId p, StateId f, StateId s)
85      : prefix_id(p), fst_id(f), fst_state(s) {}
86
87  PrefixId prefix_id;  // index in prefix table
88  StateId fst_id;      // current fst being walked
89  StateId fst_state;   // current state in fst being walked, not to be
90                       // confused with the state_id of the combined fst
91};
92
93
94// Equality of replace state tuples.
95template <class S, class P>
96inline bool operator==(const ReplaceStateTuple<S, P>& x,
97                       const ReplaceStateTuple<S, P>& y) {
98  return x.prefix_id == y.prefix_id &&
99      x.fst_id == y.fst_id &&
100      x.fst_state == y.fst_state;
101}
102
103
104// \class ReplaceRootSelector
105// Functor returning true for tuples corresponding to states in the root FST
106template <class S, class P>
107class ReplaceRootSelector {
108 public:
109  bool operator()(const ReplaceStateTuple<S, P> &tuple) const {
110    return tuple.prefix_id == 0;
111  }
112};
113
114
115// \class ReplaceFingerprint
116// Fingerprint for general replace state tuples.
117template <class S, class P>
118class ReplaceFingerprint {
119 public:
120  ReplaceFingerprint(const vector<uint64> *size_array)
121      : cumulative_size_array_(size_array) {}
122
123  uint64 operator()(const ReplaceStateTuple<S, P> &tuple) const {
124    return tuple.prefix_id * (cumulative_size_array_->back()) +
125        cumulative_size_array_->at(tuple.fst_id - 1) +
126        tuple.fst_state;
127  }
128
129 private:
130  const vector<uint64> *cumulative_size_array_;
131};
132
133
134// \class ReplaceFstStateFingerprint
135// Useful when the fst_state uniquely define the tuple.
136template <class S, class P>
137class ReplaceFstStateFingerprint {
138 public:
139  uint64 operator()(const ReplaceStateTuple<S, P>& tuple) const {
140    return tuple.fst_state;
141  }
142};
143
144
145// \class ReplaceHash
146// A generic hash function for replace state tuples.
147template <typename S, typename P>
148class ReplaceHash {
149 public:
150  size_t operator()(const ReplaceStateTuple<S, P>& t) const {
151    return t.prefix_id + t.fst_id * kPrime0 + t.fst_state * kPrime1;
152  }
153 private:
154  static const size_t kPrime0;
155  static const size_t kPrime1;
156};
157
158template <typename S, typename P>
159const size_t ReplaceHash<S, P>::kPrime0 = 7853;
160
161template <typename S, typename P>
162const size_t ReplaceHash<S, P>::kPrime1 = 7867;
163
164template <class A, class T> class ReplaceFstMatcher;
165
166
167// \class VectorHashReplaceStateTable
168// A two-level state table for replace.
169// Warning: calls CountStates to compute the number of states of each
170// component Fst.
171template <class A, class P = ssize_t>
172class VectorHashReplaceStateTable {
173 public:
174  typedef A Arc;
175  typedef typename A::StateId StateId;
176  typedef typename A::Label Label;
177  typedef P PrefixId;
178  typedef ReplaceStateTuple<StateId, P> StateTuple;
179  typedef VectorHashStateTable<ReplaceStateTuple<StateId, P>,
180                               ReplaceRootSelector<StateId, P>,
181                               ReplaceFstStateFingerprint<StateId, P>,
182                               ReplaceFingerprint<StateId, P> > StateTable;
183
184  VectorHashReplaceStateTable(
185      const vector<pair<Label, const Fst<A>*> > &fst_tuples,
186      Label root) : root_size_(0) {
187    cumulative_size_array_.push_back(0);
188    for (size_t i = 0; i < fst_tuples.size(); ++i) {
189      if (fst_tuples[i].first == root) {
190        root_size_ = CountStates(*(fst_tuples[i].second));
191        cumulative_size_array_.push_back(cumulative_size_array_.back());
192      } else {
193        cumulative_size_array_.push_back(cumulative_size_array_.back() +
194                                         CountStates(*(fst_tuples[i].second)));
195      }
196    }
197    state_table_ = new StateTable(
198        new ReplaceRootSelector<StateId, P>,
199        new ReplaceFstStateFingerprint<StateId, P>,
200        new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
201        root_size_,
202        root_size_ + cumulative_size_array_.back());
203  }
204
205  VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A, P> &table)
206      : root_size_(table.root_size_),
207        cumulative_size_array_(table.cumulative_size_array_) {
208    state_table_ = new StateTable(
209        new ReplaceRootSelector<StateId, P>,
210        new ReplaceFstStateFingerprint<StateId, P>,
211        new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
212        root_size_,
213        root_size_ + cumulative_size_array_.back());
214  }
215
216  ~VectorHashReplaceStateTable() {
217    delete state_table_;
218  }
219
220  StateId FindState(const StateTuple &tuple) {
221    return state_table_->FindState(tuple);
222  }
223
224  const StateTuple &Tuple(StateId id) const {
225    return state_table_->Tuple(id);
226  }
227
228 private:
229  StateId root_size_;
230  vector<uint64> cumulative_size_array_;
231  StateTable *state_table_;
232};
233
234
235// \class DefaultReplaceStateTable
236// Default replace state table
237template <class A, class P = ssize_t>
238class DefaultReplaceStateTable : public CompactHashStateTable<
239  ReplaceStateTuple<typename A::StateId, P>,
240  ReplaceHash<typename A::StateId, P> > {
241 public:
242  typedef A Arc;
243  typedef typename A::StateId StateId;
244  typedef typename A::Label Label;
245  typedef P PrefixId;
246  typedef ReplaceStateTuple<StateId, P> StateTuple;
247  typedef CompactHashStateTable<StateTuple,
248                                ReplaceHash<StateId, PrefixId> > StateTable;
249
250  using StateTable::FindState;
251  using StateTable::Tuple;
252
253  DefaultReplaceStateTable(
254      const vector<pair<Label, const Fst<A>*> > &fst_tuples,
255      Label root) {}
256
257  DefaultReplaceStateTable(const DefaultReplaceStateTable<A, P> &table)
258      : StateTable() {}
259};
260
261//
262// REPLACE FST CLASS
263//
264
265// By default ReplaceFst will copy the input label of the 'replace arc'.
266// For acceptors we do not want this behaviour. Instead we need to
267// create an epsilon arc when recursing into the appropriate Fst.
268// The 'epsilon_on_replace' option can be used to toggle this behaviour.
269template <class A, class T = DefaultReplaceStateTable<A> >
270struct ReplaceFstOptions : CacheOptions {
271  int64 root;    // root rule for expansion
272  bool  epsilon_on_replace;
273  bool  take_ownership;  // take ownership of input Fst(s)
274  T*    state_table;
275
276  ReplaceFstOptions(const CacheOptions &opts, int64 r)
277      : CacheOptions(opts),
278        root(r),
279        epsilon_on_replace(false),
280        take_ownership(false),
281        state_table(0) {}
282  explicit ReplaceFstOptions(int64 r)
283      : root(r),
284        epsilon_on_replace(false),
285        take_ownership(false),
286        state_table(0) {}
287  ReplaceFstOptions(int64 r, bool epsilon_replace_arc)
288      : root(r),
289        epsilon_on_replace(epsilon_replace_arc),
290        take_ownership(false),
291        state_table(0) {}
292  ReplaceFstOptions()
293      : root(kNoLabel),
294        epsilon_on_replace(false),
295        take_ownership(false),
296        state_table(0) {}
297};
298
299
300// \class ReplaceFstImpl
301// \brief Implementation class for replace class Fst
302//
303// The replace implementation class supports a dynamic
304// expansion of a recursive transition network represented as Fst
305// with dynamic replacable arcs.
306//
307template <class A, class T>
308class ReplaceFstImpl : public CacheImpl<A> {
309  friend class ReplaceFstMatcher<A, T>;
310
311 public:
312  using FstImpl<A>::SetType;
313  using FstImpl<A>::SetProperties;
314  using FstImpl<A>::WriteHeader;
315  using FstImpl<A>::SetInputSymbols;
316  using FstImpl<A>::SetOutputSymbols;
317  using FstImpl<A>::InputSymbols;
318  using FstImpl<A>::OutputSymbols;
319
320  using CacheImpl<A>::PushArc;
321  using CacheImpl<A>::HasArcs;
322  using CacheImpl<A>::HasFinal;
323  using CacheImpl<A>::HasStart;
324  using CacheImpl<A>::SetArcs;
325  using CacheImpl<A>::SetFinal;
326  using CacheImpl<A>::SetStart;
327
328  typedef typename A::Label   Label;
329  typedef typename A::Weight  Weight;
330  typedef typename A::StateId StateId;
331  typedef CacheState<A> State;
332  typedef A Arc;
333  typedef unordered_map<Label, Label> NonTerminalHash;
334
335  typedef T StateTable;
336  typedef typename T::PrefixId PrefixId;
337  typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
338
339  // constructor for replace class implementation.
340  // \param fst_tuples array of label/fst tuples, one for each non-terminal
341  ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples,
342                 const ReplaceFstOptions<A, T> &opts)
343      : CacheImpl<A>(opts),
344        epsilon_on_replace_(opts.epsilon_on_replace),
345        state_table_(opts.state_table ? opts.state_table :
346                     new StateTable(fst_tuples, opts.root)) {
347
348    SetType("replace");
349
350    if (fst_tuples.size() > 0) {
351      SetInputSymbols(fst_tuples[0].second->InputSymbols());
352      SetOutputSymbols(fst_tuples[0].second->OutputSymbols());
353    }
354
355    bool all_negative = true;  // all nonterminals are negative?
356    bool dense_range = true;   // all nonterminals are positive
357                               // and form a dense range containing 1?
358    for (size_t i = 0; i < fst_tuples.size(); ++i) {
359      Label nonterminal = fst_tuples[i].first;
360      if (nonterminal >= 0)
361        all_negative = false;
362      if (nonterminal > fst_tuples.size() || nonterminal <= 0)
363        dense_range = false;
364    }
365
366    vector<uint64> inprops;
367    bool all_ilabel_sorted = true;
368    bool all_olabel_sorted = true;
369    bool all_non_empty = true;
370    fst_array_.push_back(0);
371    for (size_t i = 0; i < fst_tuples.size(); ++i) {
372      Label label = fst_tuples[i].first;
373      const Fst<A> *fst = fst_tuples[i].second;
374      nonterminal_hash_[label] = fst_array_.size();
375      nonterminal_set_.insert(label);
376      fst_array_.push_back(opts.take_ownership ? fst : fst->Copy());
377      if (fst->Start() == kNoStateId)
378        all_non_empty = false;
379      if(!fst->Properties(kILabelSorted, false))
380        all_ilabel_sorted = false;
381      if(!fst->Properties(kOLabelSorted, false))
382        all_olabel_sorted = false;
383      inprops.push_back(fst->Properties(kCopyProperties, false));
384      if (i) {
385        if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
386          FSTERROR() << "ReplaceFstImpl: input symbols of Fst " << i
387                     << " does not match input symbols of base Fst (0'th fst)";
388          SetProperties(kError, kError);
389        }
390        if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
391          FSTERROR() << "ReplaceFstImpl: output symbols of Fst " << i
392                     << " does not match output symbols of base Fst "
393                     << "(0'th fst)";
394          SetProperties(kError, kError);
395        }
396      }
397    }
398    Label nonterminal = nonterminal_hash_[opts.root];
399    if ((nonterminal == 0) && (fst_array_.size() > 1)) {
400      FSTERROR() << "ReplaceFstImpl: no Fst corresponding to root label '"
401                 << opts.root << "' in the input tuple vector";
402      SetProperties(kError, kError);
403    }
404    root_ = (nonterminal > 0) ? nonterminal : 1;
405
406    SetProperties(ReplaceProperties(inprops, root_ - 1, epsilon_on_replace_,
407                                    all_non_empty));
408    // We assume that all terminals are positive.  The resulting
409    // ReplaceFst is known to be kILabelSorted when all sub-FSTs are
410    // kILabelSorted and one of the 3 following conditions is satisfied:
411    //  1. 'epsilon_on_replace' is false, or
412    //  2. all non-terminals are negative, or
413    //  3. all non-terninals are positive and form a dense range containing 1.
414    if (all_ilabel_sorted &&
415        (!epsilon_on_replace_ || all_negative || dense_range))
416      SetProperties(kILabelSorted, kILabelSorted);
417    // Similarly, the resulting ReplaceFst is known to be
418    // kOLabelSorted when all sub-FSTs are kOLabelSorted and one of
419    // the 2 following conditions is satisfied:
420    //  1. all non-terminals are negative, or
421    //  2. all non-terninals are positive and form a dense range containing 1.
422    if (all_olabel_sorted && (all_negative || dense_range))
423      SetProperties(kOLabelSorted, kOLabelSorted);
424
425    // Enable optional caching as long as sorted and all non empty.
426    if (Properties(kILabelSorted | kOLabelSorted) && all_non_empty)
427      always_cache_ = false;
428    else
429      always_cache_ = true;
430    VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
431            << (always_cache_ ? "true" : "false");
432  }
433
434  ReplaceFstImpl(const ReplaceFstImpl& impl)
435      : CacheImpl<A>(impl),
436        epsilon_on_replace_(impl.epsilon_on_replace_),
437        always_cache_(impl.always_cache_),
438        state_table_(new StateTable(*(impl.state_table_))),
439        nonterminal_set_(impl.nonterminal_set_),
440        nonterminal_hash_(impl.nonterminal_hash_),
441        root_(impl.root_) {
442    SetType("replace");
443    SetProperties(impl.Properties(), kCopyProperties);
444    SetInputSymbols(impl.InputSymbols());
445    SetOutputSymbols(impl.OutputSymbols());
446    fst_array_.reserve(impl.fst_array_.size());
447    fst_array_.push_back(0);
448    for (size_t i = 1; i < impl.fst_array_.size(); ++i) {
449      fst_array_.push_back(impl.fst_array_[i]->Copy(true));
450    }
451  }
452
453  ~ReplaceFstImpl() {
454    VLOG(2) << "~ReplaceFstImpl: gc = "
455            << (CacheImpl<A>::GetCacheGc() ? "true" : "false")
456            << ", gc_size = " << CacheImpl<A>::GetCacheSize()
457            << ", gc_limit = " << CacheImpl<A>::GetCacheLimit();
458
459    delete state_table_;
460    for (size_t i = 1; i < fst_array_.size(); ++i) {
461      delete fst_array_[i];
462    }
463  }
464
465  // Computes the dependency graph of the replace class and returns
466  // true if the dependencies are cyclic. Cyclic dependencies will result
467  // in an un-expandable replace fst.
468  bool CyclicDependencies() const {
469    ReplaceUtil<A> replace_util(fst_array_, nonterminal_hash_, root_);
470    return replace_util.CyclicDependencies();
471  }
472
473  // Return or compute start state of replace fst
474  StateId Start() {
475    if (!HasStart()) {
476      if (fst_array_.size() == 1) {      // no fsts defined for replace
477        SetStart(kNoStateId);
478        return kNoStateId;
479      } else {
480        const Fst<A>* fst = fst_array_[root_];
481        StateId fst_start = fst->Start();
482        if (fst_start == kNoStateId)  // root Fst is empty
483          return kNoStateId;
484
485        PrefixId prefix = GetPrefixId(StackPrefix());
486        StateId start = state_table_->FindState(
487            StateTuple(prefix, root_, fst_start));
488        SetStart(start);
489        return start;
490      }
491    } else {
492      return CacheImpl<A>::Start();
493    }
494  }
495
496  // return final weight of state (kInfWeight means state is not final)
497  Weight Final(StateId s) {
498    if (!HasFinal(s)) {
499      const StateTuple& tuple  = state_table_->Tuple(s);
500      const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
501      const Fst<A>* fst = fst_array_[tuple.fst_id];
502      StateId fst_state = tuple.fst_state;
503
504      if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0)
505        SetFinal(s, fst->Final(fst_state));
506      else
507        SetFinal(s, Weight::Zero());
508    }
509    return CacheImpl<A>::Final(s);
510  }
511
512  size_t NumArcs(StateId s) {
513    if (HasArcs(s)) {  // If state cached, use the cached value.
514      return CacheImpl<A>::NumArcs(s);
515    } else if (always_cache_) {  // If always caching, expand and cache state.
516      Expand(s);
517      return CacheImpl<A>::NumArcs(s);
518    } else {  // Otherwise compute the number of arcs without expanding.
519      StateTuple tuple  = state_table_->Tuple(s);
520      if (tuple.fst_state == kNoStateId)
521        return 0;
522
523      const Fst<A>* fst = fst_array_[tuple.fst_id];
524      size_t num_arcs = fst->NumArcs(tuple.fst_state);
525      if (ComputeFinalArc(tuple, 0))
526        num_arcs++;
527
528      return num_arcs;
529    }
530  }
531
532  // Returns whether a given label is a non terminal
533  bool IsNonTerminal(Label l) const {
534    // TODO(allauzen): be smarter and take advantage of
535    // all_dense or all_negative.
536    // Use also in ComputeArc, this would require changes to replace
537    // so that recursing into an empty fst lead to a non co-accessible
538    // state instead of deleting the arc as done currently.
539    // Current use correct, since i/olabel sorted iff all_non_empty.
540    typename NonTerminalHash::const_iterator it =
541        nonterminal_hash_.find(l);
542    return it != nonterminal_hash_.end();
543  }
544
545  size_t NumInputEpsilons(StateId s) {
546    if (HasArcs(s)) {
547      // If state cached, use the cached value.
548      return CacheImpl<A>::NumInputEpsilons(s);
549    } else if (always_cache_ || !Properties(kILabelSorted)) {
550      // If always caching or if the number of input epsilons is too expensive
551      // to compute without caching (i.e. not ilabel sorted),
552      // then expand and cache state.
553      Expand(s);
554      return CacheImpl<A>::NumInputEpsilons(s);
555    } else {
556      // Otherwise, compute the number of input epsilons without caching.
557      StateTuple tuple  = state_table_->Tuple(s);
558      if (tuple.fst_state == kNoStateId)
559        return 0;
560      const Fst<A>* fst = fst_array_[tuple.fst_id];
561      size_t num  = 0;
562      if (!epsilon_on_replace_) {
563        // If epsilon_on_replace is false, all input epsilon arcs
564        // are also input epsilons arcs in the underlying machine.
565        fst->NumInputEpsilons(tuple.fst_state);
566      } else {
567        // Otherwise, one need to consider that all non-terminal arcs
568        // in the underlying machine also become input epsilon arc.
569        ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
570        for (; !aiter.Done() &&
571                 ((aiter.Value().ilabel == 0) ||
572                  IsNonTerminal(aiter.Value().olabel));
573             aiter.Next())
574          ++num;
575      }
576      if (ComputeFinalArc(tuple, 0))
577        num++;
578      return num;
579    }
580  }
581
582  size_t NumOutputEpsilons(StateId s) {
583    if (HasArcs(s)) {
584      // If state cached, use the cached value.
585      return CacheImpl<A>::NumOutputEpsilons(s);
586    } else if(always_cache_ || !Properties(kOLabelSorted)) {
587      // If always caching or if the number of output epsilons is too expensive
588      // to compute without caching (i.e. not olabel sorted),
589      // then expand and cache state.
590      Expand(s);
591      return CacheImpl<A>::NumOutputEpsilons(s);
592    } else {
593      // Otherwise, compute the number of output epsilons without caching.
594      StateTuple tuple  = state_table_->Tuple(s);
595      if (tuple.fst_state == kNoStateId)
596        return 0;
597      const Fst<A>* fst = fst_array_[tuple.fst_id];
598      size_t num  = 0;
599      ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
600      for (; !aiter.Done() &&
601               ((aiter.Value().olabel == 0) ||
602                IsNonTerminal(aiter.Value().olabel));
603           aiter.Next())
604        ++num;
605      if (ComputeFinalArc(tuple, 0))
606        num++;
607      return num;
608    }
609  }
610
611  uint64 Properties() const { return Properties(kFstProperties); }
612
613  // Set error if found; return FST impl properties.
614  uint64 Properties(uint64 mask) const {
615    if (mask & kError) {
616      for (size_t i = 1; i < fst_array_.size(); ++i) {
617        if (fst_array_[i]->Properties(kError, false))
618          SetProperties(kError, kError);
619      }
620    }
621    return FstImpl<Arc>::Properties(mask);
622  }
623
624  // return the base arc iterator, if arcs have not been computed yet,
625  // extend/recurse for new arcs.
626  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
627    if (!HasArcs(s))
628      Expand(s);
629    CacheImpl<A>::InitArcIterator(s, data);
630    // TODO(allauzen): Set behaviour of generic iterator
631    // Warning: ArcIterator<ReplaceFst<A> >::InitCache()
632    // relies on current behaviour.
633  }
634
635
636  // Extend current state (walk arcs one level deep)
637  void Expand(StateId s) {
638    StateTuple tuple = state_table_->Tuple(s);
639
640    // If local fst is empty
641    if (tuple.fst_state == kNoStateId) {
642      SetArcs(s);
643      return;
644    }
645
646    ArcIterator< Fst<A> > aiter(
647        *(fst_array_[tuple.fst_id]), tuple.fst_state);
648    Arc arc;
649
650    // Create a final arc when needed
651    if (ComputeFinalArc(tuple, &arc))
652      PushArc(s, arc);
653
654    // Expand all arcs leaving the state
655    for (;!aiter.Done(); aiter.Next()) {
656      if (ComputeArc(tuple, aiter.Value(), &arc))
657        PushArc(s, arc);
658    }
659
660    SetArcs(s);
661  }
662
663  void Expand(StateId s, const StateTuple &tuple,
664              const ArcIteratorData<A> &data) {
665     // If local fst is empty
666    if (tuple.fst_state == kNoStateId) {
667      SetArcs(s);
668      return;
669    }
670
671    ArcIterator< Fst<A> > aiter(data);
672    Arc arc;
673
674    // Create a final arc when needed
675    if (ComputeFinalArc(tuple, &arc))
676      AddArc(s, arc);
677
678    // Expand all arcs leaving the state
679    for (; !aiter.Done(); aiter.Next()) {
680      if (ComputeArc(tuple, aiter.Value(), &arc))
681        AddArc(s, arc);
682    }
683
684    SetArcs(s);
685  }
686
687  // If arcp == 0, only returns if a final arc is required, does not
688  // actually compute it.
689  bool ComputeFinalArc(const StateTuple &tuple, A* arcp,
690                       uint32 flags = kArcValueFlags) {
691    const Fst<A>* fst = fst_array_[tuple.fst_id];
692    StateId fst_state = tuple.fst_state;
693    if (fst_state == kNoStateId)
694      return false;
695
696   // if state is final, pop up stack
697    const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
698    if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) {
699      if (arcp) {
700        arcp->ilabel = 0;
701        arcp->olabel = 0;
702        if (flags & kArcNextStateValue) {
703          PrefixId prefix_id = PopPrefix(stack);
704          const PrefixTuple& top = stack.Top();
705          arcp->nextstate = state_table_->FindState(
706              StateTuple(prefix_id, top.fst_id, top.nextstate));
707        }
708        if (flags & kArcWeightValue)
709          arcp->weight = fst->Final(fst_state);
710      }
711      return true;
712    } else {
713      return false;
714    }
715  }
716
717  // Compute the arc in the replace fst corresponding to a given
718  // in the underlying machine. Returns false if the underlying arc
719  // corresponds to no arc in the replace.
720  bool ComputeArc(const StateTuple &tuple, const A &arc, A* arcp,
721                  uint32 flags = kArcValueFlags) {
722    if (!epsilon_on_replace_ &&
723        (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
724      *arcp = arc;
725      return true;
726    }
727
728    if (arc.olabel == 0) {  // expand local fst
729      StateId nextstate = flags & kArcNextStateValue
730          ? state_table_->FindState(
731              StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
732          : kNoStateId;
733      *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
734    } else {
735      // check for non terminal
736      typename NonTerminalHash::const_iterator it =
737          nonterminal_hash_.find(arc.olabel);
738      if (it != nonterminal_hash_.end()) {  // recurse into non terminal
739        Label nonterminal = it->second;
740        const Fst<A>* nt_fst = fst_array_[nonterminal];
741        PrefixId nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id],
742                                        tuple.fst_id, arc.nextstate);
743
744        // if start state is valid replace, else arc is implicitly
745        // deleted
746        StateId nt_start = nt_fst->Start();
747        if (nt_start != kNoStateId) {
748          StateId nt_nextstate =  flags & kArcNextStateValue
749              ? state_table_->FindState(
750                  StateTuple(nt_prefix, nonterminal, nt_start))
751              : kNoStateId;
752          Label ilabel = (epsilon_on_replace_) ? 0 : arc.ilabel;
753          *arcp = A(ilabel, 0, arc.weight, nt_nextstate);
754        } else {
755          return false;
756        }
757      } else {
758        StateId nextstate = flags & kArcNextStateValue
759            ? state_table_->FindState(
760                StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
761            : kNoStateId;
762        *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
763      }
764    }
765    return true;
766  }
767
768  // Returns the arc iterator flags supported by this Fst.
769  uint32 ArcIteratorFlags() const {
770    uint32 flags = kArcValueFlags;
771    if (!always_cache_)
772      flags |= kArcNoCache;
773    return flags;
774  }
775
776  T* GetStateTable() const {
777    return state_table_;
778  }
779
780  const Fst<A>* GetFst(Label fst_id) const {
781    return fst_array_[fst_id];
782  }
783
784  bool EpsilonOnReplace() const { return epsilon_on_replace_; }
785
786  // private helper classes
787 private:
788  static const size_t kPrime0;
789
790  // \class PrefixTuple
791  // \brief Tuple of fst_id and destination state (entry in stack prefix)
792  struct PrefixTuple {
793    PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {}
794
795    Label   fst_id;
796    StateId nextstate;
797  };
798
799  // \class StackPrefix
800  // \brief Container for stack prefix.
801  class StackPrefix {
802   public:
803    StackPrefix() {}
804
805    // copy constructor
806    StackPrefix(const StackPrefix& x) :
807        prefix_(x.prefix_) {
808    }
809
810    void Push(StateId fst_id, StateId nextstate) {
811      prefix_.push_back(PrefixTuple(fst_id, nextstate));
812    }
813
814    void Pop() {
815      prefix_.pop_back();
816    }
817
818    const PrefixTuple& Top() const {
819      return prefix_[prefix_.size()-1];
820    }
821
822    size_t Depth() const {
823      return prefix_.size();
824    }
825
826   public:
827    vector<PrefixTuple> prefix_;
828  };
829
830
831  // \class StackPrefixEqual
832  // \brief Compare two stack prefix classes for equality
833  class StackPrefixEqual {
834   public:
835    bool operator()(const StackPrefix& x, const StackPrefix& y) const {
836      if (x.prefix_.size() != y.prefix_.size()) return false;
837      for (size_t i = 0; i < x.prefix_.size(); ++i) {
838        if (x.prefix_[i].fst_id    != y.prefix_[i].fst_id ||
839           x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false;
840      }
841      return true;
842    }
843  };
844
845  //
846  // \class StackPrefixKey
847  // \brief Hash function for stack prefix to prefix id
848  class StackPrefixKey {
849   public:
850    size_t operator()(const StackPrefix& x) const {
851      size_t sum = 0;
852      for (size_t i = 0; i < x.prefix_.size(); ++i) {
853        sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0;
854      }
855      return sum;
856    }
857  };
858
859  typedef unordered_map<StackPrefix, PrefixId, StackPrefixKey, StackPrefixEqual>
860  StackPrefixHash;
861
862  // private methods
863 private:
864  // hash stack prefix (return unique index into stackprefix array)
865  PrefixId GetPrefixId(const StackPrefix& prefix) {
866    typename StackPrefixHash::iterator it = prefix_hash_.find(prefix);
867    if (it == prefix_hash_.end()) {
868      PrefixId prefix_id = stackprefix_array_.size();
869      stackprefix_array_.push_back(prefix);
870      prefix_hash_[prefix] = prefix_id;
871      return prefix_id;
872    } else {
873      return it->second;
874    }
875  }
876
877  // prefix id after a stack pop
878  PrefixId PopPrefix(StackPrefix prefix) {
879    prefix.Pop();
880    return GetPrefixId(prefix);
881  }
882
883  // prefix id after a stack push
884  PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
885    prefix.Push(fst_id, nextstate);
886    return GetPrefixId(prefix);
887  }
888
889
890  // private data
891 private:
892  // runtime options
893  bool epsilon_on_replace_;
894  bool always_cache_;  // Optionally caching arc iterator disabled when true
895
896  // state table
897  StateTable *state_table_;
898
899  // cross index of unique stack prefix
900  // could potentially have one copy of prefix array
901  StackPrefixHash prefix_hash_;
902  vector<StackPrefix> stackprefix_array_;
903
904  set<Label> nonterminal_set_;
905  NonTerminalHash nonterminal_hash_;
906  vector<const Fst<A>*> fst_array_;
907  Label root_;
908
909  void operator=(const ReplaceFstImpl<A, T> &);  // disallow
910};
911
912
913template <class A, class T>
914const size_t ReplaceFstImpl<A, T>::kPrime0 = 7853;
915
916//
917// \class ReplaceFst
918// \brief Recursivively replaces arcs in the root Fst with other Fsts.
919// This version is a delayed Fst.
920//
921// ReplaceFst supports dynamic replacement of arcs in one Fst with
922// another Fst. This replacement is recursive.  ReplaceFst can be used
923// to support a variety of delayed constructions such as recursive
924// transition networks, union, or closure.  It is constructed with an
925// array of Fst(s). One Fst represents the root (or topology)
926// machine. The root Fst refers to other Fsts by recursively replacing
927// arcs labeled as non-terminals with the matching non-terminal
928// Fst. Currently the ReplaceFst uses the output symbols of the arcs
929// to determine whether the arc is a non-terminal arc or not. A
930// non-terminal can be any label that is not a non-zero terminal label
931// in the output alphabet.
932//
933// Note that the constructor uses a vector of pair<>. These correspond
934// to the tuple of non-terminal Label and corresponding Fst. For example
935// to implement the closure operation we need 2 Fsts. The first root
936// Fst is a single Arc on the start State that self loops, it references
937// the particular machine for which we are performing the closure operation.
938//
939// The ReplaceFst class supports an optionally caching arc iterator:
940//    ArcIterator< ReplaceFst<A> >
941// The ReplaceFst need to be built such that it is known to be ilabel
942// or olabel sorted (see usage below).
943//
944// Observe that Matcher<Fst<A> > will use the optionally caching arc
945// iterator when available (Fst is ilabel sorted and matching on the
946// input, or Fst is olabel sorted and matching on the output).
947// In order to obtain the most efficient behaviour, it is recommended
948// to set 'epsilon_on_replace' to false (this means constructing acceptors
949// as transducers with epsilons on the input side of nonterminal arcs)
950// and matching on the input side.
951//
952// This class attaches interface to implementation and handles
953// reference counting, delegating most methods to ImplToFst.
954template <class A, class T = DefaultReplaceStateTable<A> >
955class ReplaceFst : public ImplToFst< ReplaceFstImpl<A, T> > {
956 public:
957  friend class ArcIterator< ReplaceFst<A, T> >;
958  friend class StateIterator< ReplaceFst<A, T> >;
959  friend class ReplaceFstMatcher<A, T>;
960
961  typedef A Arc;
962  typedef typename A::Label   Label;
963  typedef typename A::Weight  Weight;
964  typedef typename A::StateId StateId;
965  typedef CacheState<A> State;
966  typedef ReplaceFstImpl<A, T> Impl;
967
968  using ImplToFst<Impl>::Properties;
969
970  ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
971             Label root)
972      : ImplToFst<Impl>(new Impl(fst_array, ReplaceFstOptions<A, T>(root))) {}
973
974  ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
975             const ReplaceFstOptions<A, T> &opts)
976      : ImplToFst<Impl>(new Impl(fst_array, opts)) {}
977
978  // See Fst<>::Copy() for doc.
979  ReplaceFst(const ReplaceFst<A, T>& fst, bool safe = false)
980      : ImplToFst<Impl>(fst, safe) {}
981
982  // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
983  virtual ReplaceFst<A, T> *Copy(bool safe = false) const {
984    return new ReplaceFst<A, T>(*this, safe);
985  }
986
987  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
988
989  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
990    GetImpl()->InitArcIterator(s, data);
991  }
992
993  virtual MatcherBase<A> *InitMatcher(MatchType match_type) const {
994    if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
995        ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
996         (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
997      return new ReplaceFstMatcher<A, T>(*this, match_type);
998    }
999    else {
1000      VLOG(2) << "Not using replace matcher";
1001      return 0;
1002    }
1003  }
1004
1005  bool CyclicDependencies() const {
1006    return GetImpl()->CyclicDependencies();
1007  }
1008
1009 private:
1010  // Makes visible to friends.
1011  Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
1012
1013  void operator=(const ReplaceFst<A> &fst);  // disallow
1014};
1015
1016
1017// Specialization for ReplaceFst.
1018template<class A, class T>
1019class StateIterator< ReplaceFst<A, T> >
1020    : public CacheStateIterator< ReplaceFst<A, T> > {
1021 public:
1022  explicit StateIterator(const ReplaceFst<A, T> &fst)
1023      : CacheStateIterator< ReplaceFst<A, T> >(fst, fst.GetImpl()) {}
1024
1025 private:
1026  DISALLOW_COPY_AND_ASSIGN(StateIterator);
1027};
1028
1029
1030// Specialization for ReplaceFst.
1031// Implements optional caching. It can be used as follows:
1032//
1033//   ReplaceFst<A> replace;
1034//   ArcIterator< ReplaceFst<A> > aiter(replace, s);
1035//   // Note: ArcIterator< Fst<A> > is always a caching arc iterator.
1036//   aiter.SetFlags(kArcNoCache, kArcNoCache);
1037//   // Use the arc iterator, no arc will be cached, no state will be expanded.
1038//   // The varied 'kArcValueFlags' can be used to decide which part
1039//   // of arc values needs to be computed.
1040//   aiter.SetFlags(kArcILabelValue, kArcValueFlags);
1041//   // Only want the ilabel for this arc
1042//   aiter.Value();  // Does not compute the destination state.
1043//   aiter.Next();
1044//   aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
1045//   // Want both ilabel and nextstate for that arc
1046//   aiter.Value();  // Does compute the destination state and inserts it
1047//                   // in the replace state table.
1048//   // No Arc has been cached at that point.
1049//
1050template <class A, class T>
1051class ArcIterator< ReplaceFst<A, T> > {
1052 public:
1053  typedef A Arc;
1054  typedef typename A::StateId StateId;
1055
1056  ArcIterator(const ReplaceFst<A, T> &fst, StateId s)
1057      : fst_(fst), state_(s), pos_(0), offset_(0), flags_(0), arcs_(0),
1058        data_flags_(0), final_flags_(0) {
1059    cache_data_.ref_count = 0;
1060    local_data_.ref_count = 0;
1061
1062    // If FST does not support optional caching, force caching.
1063    if(!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1064       !(fst_.GetImpl()->HasArcs(state_)))
1065       fst_.GetImpl()->Expand(state_);
1066
1067    // If state is already cached, use cached arcs array.
1068    if (fst_.GetImpl()->HasArcs(state_)) {
1069      (fst_.GetImpl())->template CacheImpl<A>::InitArcIterator(state_,
1070                                                               &cache_data_);
1071      num_arcs_ = cache_data_.narcs;
1072      arcs_ = cache_data_.arcs;      // 'arcs_' is a ptr to the cached arcs.
1073      data_flags_ = kArcValueFlags;  // All the arc member values are valid.
1074    } else {  // Otherwise delay decision until Value() is called.
1075      tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(state_);
1076      if (tuple_.fst_state == kNoStateId) {
1077        num_arcs_ = 0;
1078      } else {
1079        // The decision to cache or not to cache has been defered
1080        // until Value() or SetFlags() is called. However, the arc
1081        // iterator is set up now to be ready for non-caching in order
1082        // to keep the Value() method simple and efficient.
1083        const Fst<A>* fst = fst_.GetImpl()->GetFst(tuple_.fst_id);
1084        fst->InitArcIterator(tuple_.fst_state, &local_data_);
1085        // 'arcs_' is a pointer to the arcs in the underlying machine.
1086        arcs_ = local_data_.arcs;
1087        // Compute the final arc (but not its destination state)
1088        // if a final arc is required.
1089        bool has_final_arc = fst_.GetImpl()->ComputeFinalArc(
1090            tuple_,
1091            &final_arc_,
1092            kArcValueFlags & ~kArcNextStateValue);
1093        // Set the arc value flags that hold for 'final_arc_'.
1094        final_flags_ = kArcValueFlags & ~kArcNextStateValue;
1095        // Compute the number of arcs.
1096        num_arcs_ = local_data_.narcs;
1097        if (has_final_arc)
1098          ++num_arcs_;
1099        // Set the offset between the underlying arc positions and
1100        // the positions in the arc iterator.
1101        offset_ = num_arcs_ - local_data_.narcs;
1102        // Defers the decision to cache or not until Value() or
1103        // SetFlags() is called.
1104        data_flags_ = 0;
1105      }
1106    }
1107  }
1108
1109  ~ArcIterator() {
1110    if (cache_data_.ref_count)
1111      --(*cache_data_.ref_count);
1112    if (local_data_.ref_count)
1113      --(*local_data_.ref_count);
1114  }
1115
1116  void ExpandAndCache() const   {
1117    // TODO(allauzen): revisit this
1118    // fst_.GetImpl()->Expand(state_, tuple_, local_data_);
1119    // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(state_,
1120    //                                               &cache_data_);
1121    //
1122    fst_.InitArcIterator(state_, &cache_data_);  // Expand and cache state.
1123    arcs_ = cache_data_.arcs;  // 'arcs_' is a pointer to the cached arcs.
1124    data_flags_ = kArcValueFlags;  // All the arc member values are valid.
1125    offset_ = 0;  // No offset
1126
1127  }
1128
1129  void Init() {
1130    if (flags_ & kArcNoCache) {  // If caching is disabled
1131      // 'arcs_' is a pointer to the arcs in the underlying machine.
1132      arcs_ = local_data_.arcs;
1133      // Set the arcs value flags that hold for 'arcs_'.
1134      data_flags_ = kArcWeightValue;
1135      if (!fst_.GetImpl()->EpsilonOnReplace())
1136          data_flags_ |= kArcILabelValue;
1137      // Set the offset between the underlying arc positions and
1138      // the positions in the arc iterator.
1139      offset_ = num_arcs_ - local_data_.narcs;
1140    } else {  // Otherwise, expand and cache
1141      ExpandAndCache();
1142    }
1143  }
1144
1145  bool Done() const { return pos_ >= num_arcs_; }
1146
1147  const A& Value() const {
1148    // If 'data_flags_' was set to 0, non-caching was not requested
1149    if (!data_flags_) {
1150      // TODO(allauzen): revisit this.
1151      if (flags_ & kArcNoCache) {
1152        // Should never happen.
1153        FSTERROR() << "ReplaceFst: inconsistent arc iterator flags";
1154      }
1155      ExpandAndCache();  // Expand and cache.
1156    }
1157
1158    if (pos_ - offset_ >= 0) {  // The requested arc is not the 'final' arc.
1159      const A& arc = arcs_[pos_ - offset_];
1160      if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
1161        // If the value flags for 'arc' match the recquired value flags
1162        // then return 'arc'.
1163        return arc;
1164      } else {
1165        // Otherwise, compute the corresponding arc on-the-fly.
1166        fst_.GetImpl()->ComputeArc(tuple_, arc, &arc_, flags_ & kArcValueFlags);
1167        return arc_;
1168      }
1169    } else {  // The requested arc is the 'final' arc.
1170      if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
1171        // If the arc value flags that hold for the final arc
1172        // do not match the requested value flags, then
1173        // 'final_arc_' needs to be updated.
1174        fst_.GetImpl()->ComputeFinalArc(tuple_, &final_arc_,
1175                                    flags_ & kArcValueFlags);
1176        final_flags_ = flags_ & kArcValueFlags;
1177      }
1178      return final_arc_;
1179    }
1180  }
1181
1182  void Next() { ++pos_; }
1183
1184  size_t Position() const { return pos_; }
1185
1186  void Reset() { pos_ = 0;  }
1187
1188  void Seek(size_t pos) { pos_ = pos; }
1189
1190  uint32 Flags() const { return flags_; }
1191
1192  void SetFlags(uint32 f, uint32 mask) {
1193    // Update the flags taking into account what flags are supported
1194    // by the Fst.
1195    flags_ &= ~mask;
1196    flags_ |= (f & fst_.GetImpl()->ArcIteratorFlags());
1197    // If non-caching is not requested (and caching has not already
1198    // been performed), then flush 'data_flags_' to request caching
1199    // during the next call to Value().
1200    if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
1201      if (!fst_.GetImpl()->HasArcs(state_))
1202         data_flags_ = 0;
1203    }
1204    // If 'data_flags_' has been flushed but non-caching is requested
1205    // before calling Value(), then set up the iterator for non-caching.
1206    if ((f & kArcNoCache) && (!data_flags_))
1207      Init();
1208  }
1209
1210 private:
1211  const ReplaceFst<A, T> &fst_;           // Reference to the FST
1212  StateId state_;                         // State in the FST
1213  mutable typename T::StateTuple tuple_;  // Tuple corresponding to state_
1214
1215  ssize_t pos_;             // Current position
1216  mutable ssize_t offset_;  // Offset between position in iterator and in arcs_
1217  ssize_t num_arcs_;        // Number of arcs at state_
1218  uint32 flags_;            // Behavorial flags for the arc iterator
1219  mutable Arc arc_;         // Memory to temporarily store computed arcs
1220
1221  mutable ArcIteratorData<Arc> cache_data_;  // Arc iterator data in cache
1222  mutable ArcIteratorData<Arc> local_data_;  // Arc iterator data in local fst
1223
1224  mutable const A* arcs_;       // Array of arcs
1225  mutable uint32 data_flags_;   // Arc value flags valid for data in arcs_
1226  mutable Arc final_arc_;       // Final arc (when required)
1227  mutable uint32 final_flags_;  // Arc value flags valid for final_arc_
1228
1229  DISALLOW_COPY_AND_ASSIGN(ArcIterator);
1230};
1231
1232
1233template <class A, class T>
1234class ReplaceFstMatcher : public MatcherBase<A> {
1235 public:
1236  typedef A Arc;
1237  typedef typename A::StateId StateId;
1238  typedef typename A::Label Label;
1239  typedef MultiEpsMatcher<Matcher<Fst<A> > > LocalMatcher;
1240
1241  ReplaceFstMatcher(const ReplaceFst<A, T> &fst, fst::MatchType match_type)
1242      : fst_(fst),
1243        impl_(fst_.GetImpl()),
1244        s_(fst::kNoStateId),
1245        match_type_(match_type),
1246        current_loop_(false),
1247        final_arc_(false),
1248        loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
1249    if (match_type_ == fst::MATCH_OUTPUT)
1250      swap(loop_.ilabel, loop_.olabel);
1251    InitMatchers();
1252  }
1253
1254  ReplaceFstMatcher(const ReplaceFstMatcher<A, T> &matcher, bool safe = false)
1255      : fst_(matcher.fst_),
1256        impl_(fst_.GetImpl()),
1257        s_(fst::kNoStateId),
1258        match_type_(matcher.match_type_),
1259        current_loop_(false),
1260        loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
1261    if (match_type_ == fst::MATCH_OUTPUT)
1262      swap(loop_.ilabel, loop_.olabel);
1263    InitMatchers();
1264  }
1265
1266  // Create a local matcher for each component Fst of replace.
1267  // LocalMatcher is a multi epsilon wrapper matcher. MultiEpsilonMatcher
1268  // is used to match each non-terminal arc, since these non-terminal
1269  // turn into epsilons on recursion.
1270  void InitMatchers() {
1271    const vector<const Fst<A>*>& fst_array = impl_->fst_array_;
1272    matcher_.resize(fst_array.size(), 0);
1273    for (size_t i = 0; i < fst_array.size(); ++i) {
1274      if (fst_array[i]) {
1275        matcher_[i] =
1276            new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList);
1277
1278        typename set<Label>::iterator it = impl_->nonterminal_set_.begin();
1279        for (; it != impl_->nonterminal_set_.end(); ++it) {
1280          matcher_[i]->AddMultiEpsLabel(*it);
1281        }
1282      }
1283    }
1284  }
1285
1286  virtual ReplaceFstMatcher<A, T> *Copy(bool safe = false) const {
1287    return new ReplaceFstMatcher<A, T>(*this, safe);
1288  }
1289
1290  virtual ~ReplaceFstMatcher() {
1291    for (size_t i = 0; i < matcher_.size(); ++i)
1292      delete matcher_[i];
1293  }
1294
1295  virtual MatchType Type(bool test) const {
1296    if (match_type_ == MATCH_NONE)
1297      return match_type_;
1298
1299    uint64 true_prop =  match_type_ == MATCH_INPUT ?
1300        kILabelSorted : kOLabelSorted;
1301    uint64 false_prop = match_type_ == MATCH_INPUT ?
1302        kNotILabelSorted : kNotOLabelSorted;
1303    uint64 props = fst_.Properties(true_prop | false_prop, test);
1304
1305    if (props & true_prop)
1306      return match_type_;
1307    else if (props & false_prop)
1308      return MATCH_NONE;
1309    else
1310      return MATCH_UNKNOWN;
1311  }
1312
1313  virtual const Fst<A> &GetFst() const {
1314    return fst_;
1315  }
1316
1317  virtual uint64 Properties(uint64 props) const {
1318    return props;
1319  }
1320
1321 private:
1322  // Set the sate from which our matching happens.
1323  virtual void SetState_(StateId s) {
1324    if (s_ == s) return;
1325
1326    s_ = s;
1327    tuple_ = impl_->GetStateTable()->Tuple(s_);
1328    if (tuple_.fst_state == kNoStateId) {
1329      done_ = true;
1330      return;
1331    }
1332    // Get current matcher. Used for non epsilon matching
1333    current_matcher_ = matcher_[tuple_.fst_id];
1334    current_matcher_->SetState(tuple_.fst_state);
1335    loop_.nextstate = s_;
1336
1337    final_arc_ = false;
1338  }
1339
1340  // Search for label, from previous set state. If label == 0, first
1341  // hallucinate and epsilon loop, else use the underlying matcher to
1342  // search for the label or epsilons.
1343  // - Note since the ReplaceFST recursion on non-terminal arcs causes
1344  //   epsilon transitions to be created we use the MultiEpsilonMatcher
1345  //   to search for possible matches of non terminals.
1346  // - If the component Fst reaches a final state we also need to add
1347  //   the exiting final arc.
1348  virtual bool Find_(Label label) {
1349    bool found = false;
1350    label_ = label;
1351    if (label_ == 0 || label_ == kNoLabel) {
1352      // Compute loop directly, saving Replace::ComputeArc
1353      if (label_ == 0) {
1354        current_loop_ = true;
1355        found = true;
1356      }
1357      // Search for matching multi epsilons
1358      final_arc_ = impl_->ComputeFinalArc(tuple_, 0);
1359      found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
1360    } else {
1361      // Search on sub machine directly using sub machine matcher.
1362      found = current_matcher_->Find(label_);
1363    }
1364    return found;
1365  }
1366
1367  virtual bool Done_() const {
1368    return !current_loop_ && !final_arc_ && current_matcher_->Done();
1369  }
1370
1371  virtual const Arc& Value_() const {
1372    if (current_loop_) {
1373      return loop_;
1374    }
1375    if (final_arc_) {
1376      impl_->ComputeFinalArc(tuple_, &arc_);
1377      return arc_;
1378    }
1379    const Arc& component_arc = current_matcher_->Value();
1380    impl_->ComputeArc(tuple_, component_arc, &arc_);
1381    return arc_;
1382  }
1383
1384  virtual void Next_() {
1385    if (current_loop_) {
1386      current_loop_ = false;
1387      return;
1388    }
1389    if (final_arc_) {
1390      final_arc_ = false;
1391      return;
1392    }
1393    current_matcher_->Next();
1394  }
1395
1396  const ReplaceFst<A, T>& fst_;
1397  ReplaceFstImpl<A, T> *impl_;
1398  LocalMatcher* current_matcher_;
1399  vector<LocalMatcher*> matcher_;
1400
1401  StateId s_;                        // Current state
1402  Label label_;                      // Current label
1403
1404  MatchType match_type_;             // Supplied by caller
1405  mutable bool done_;
1406  mutable bool current_loop_;        // Current arc is the implicit loop
1407  mutable bool final_arc_;           // Current arc for exiting recursion
1408  mutable typename T::StateTuple tuple_;  // Tuple corresponding to state_
1409  mutable Arc arc_;
1410  Arc loop_;
1411};
1412
1413template <class A, class T> inline
1414void ReplaceFst<A, T>::InitStateIterator(StateIteratorData<A> *data) const {
1415  data->base = new StateIterator< ReplaceFst<A, T> >(*this);
1416}
1417
1418typedef ReplaceFst<StdArc> StdReplaceFst;
1419
1420
1421// // Recursivively replaces arcs in the root Fst with other Fsts.
1422// This version writes the result of replacement to an output MutableFst.
1423//
1424// Replace supports replacement of arcs in one Fst with another
1425// Fst. This replacement is recursive.  Replace takes an array of
1426// Fst(s). One Fst represents the root (or topology) machine. The root
1427// Fst refers to other Fsts by recursively replacing arcs labeled as
1428// non-terminals with the matching non-terminal Fst. Currently Replace
1429// uses the output symbols of the arcs to determine whether the arc is
1430// a non-terminal arc or not. A non-terminal can be any label that is
1431// not a non-zero terminal label in the output alphabet.  Note that
1432// input argument is a vector of pair<>. These correspond to the tuple
1433// of non-terminal Label and corresponding Fst.
1434template<class Arc>
1435void Replace(const vector<pair<typename Arc::Label,
1436             const Fst<Arc>* > >& ifst_array,
1437             MutableFst<Arc> *ofst, typename Arc::Label root,
1438             bool epsilon_on_replace) {
1439  ReplaceFstOptions<Arc> opts(root, epsilon_on_replace);
1440  opts.gc_limit = 0;  // Cache only the last state for fastest copy.
1441  *ofst = ReplaceFst<Arc>(ifst_array, opts);
1442}
1443
1444template<class Arc>
1445void Replace(const vector<pair<typename Arc::Label,
1446             const Fst<Arc>* > >& ifst_array,
1447             MutableFst<Arc> *ofst, typename Arc::Label root) {
1448  Replace(ifst_array, ofst, root, false);
1449}
1450
1451}  // namespace fst
1452
1453#endif  // FST_LIB_REPLACE_H__
1454