1// cache.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// An Fst implementation that caches FST elements of a delayed
20// computation.
21
22#ifndef FST_LIB_CACHE_H__
23#define FST_LIB_CACHE_H__
24
25#include <vector>
26using std::vector;
27#include <list>
28
29#include <fst/vector-fst.h>
30
31
32DECLARE_bool(fst_default_cache_gc);
33DECLARE_int64(fst_default_cache_gc_limit);
34
35namespace fst {
36
37struct CacheOptions {
38  bool gc;          // enable GC
39  size_t gc_limit;  // # of bytes allowed before GC
40
41  CacheOptions(bool g, size_t l) : gc(g), gc_limit(l) {}
42  CacheOptions()
43      : gc(FLAGS_fst_default_cache_gc),
44        gc_limit(FLAGS_fst_default_cache_gc_limit) {}
45};
46
47// A CacheStateAllocator allocates and frees CacheStates
48// template <class S>
49// struct CacheStateAllocator {
50//   S *Allocate(StateId s);
51//   void Free(S *state, StateId s);
52// };
53//
54
55// A simple allocator class, can be overridden as needed,
56// maintains a single entry cache.
57template <class S>
58struct DefaultCacheStateAllocator {
59  typedef typename S::Arc::StateId StateId;
60
61  DefaultCacheStateAllocator() : mru_(NULL) { }
62
63  ~DefaultCacheStateAllocator() {
64    delete mru_;
65  }
66
67  S *Allocate(StateId s) {
68    if (mru_) {
69      S *state = mru_;
70      mru_ = NULL;
71      state->Reset();
72      return state;
73    }
74    return new S();
75  }
76
77  void Free(S *state, StateId s) {
78    if (mru_) {
79      delete mru_;
80    }
81    mru_ = state;
82  }
83
84 private:
85  S *mru_;
86};
87
88// VectorState but additionally has a flags data member (see
89// CacheState below). This class is used to cache FST elements with
90// the flags used to indicate what has been cached. Use HasStart()
91// HasFinal(), and HasArcs() to determine if cached and SetStart(),
92// SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note
93// you must set the final weight even if the state is non-final to
94// mark it as cached. If the 'gc' option is 'false', cached items have
95// the extent of the FST - minimizing computation. If the 'gc' option
96// is 'true', garbage collection of states (not in use in an arc
97// iterator and not 'protected') is performed, in a rough
98// approximation of LRU order, when 'gc_limit' bytes is reached -
99// controlling memory use. When 'gc_limit' is 0, special optimizations
100// apply - minimizing memory use.
101
102template <class S, class C = DefaultCacheStateAllocator<S> >
103class CacheBaseImpl : public VectorFstBaseImpl<S> {
104 public:
105  typedef S State;
106  typedef C Allocator;
107  typedef typename State::Arc Arc;
108  typedef typename Arc::Weight Weight;
109  typedef typename Arc::StateId StateId;
110
111  using FstImpl<Arc>::Type;
112  using FstImpl<Arc>::Properties;
113  using FstImpl<Arc>::SetProperties;
114  using VectorFstBaseImpl<State>::NumStates;
115  using VectorFstBaseImpl<State>::Start;
116  using VectorFstBaseImpl<State>::AddState;
117  using VectorFstBaseImpl<State>::SetState;
118  using VectorFstBaseImpl<State>::ReserveStates;
119
120  explicit CacheBaseImpl(C *allocator = 0)
121      : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0),
122        cache_first_state_id_(kNoStateId), cache_first_state_(0),
123        cache_gc_(FLAGS_fst_default_cache_gc),  cache_size_(0),
124        cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit ||
125                     FLAGS_fst_default_cache_gc_limit == 0 ?
126                     FLAGS_fst_default_cache_gc_limit : kMinCacheLimit),
127        protect_(false) {
128    allocator_ = allocator ? allocator : new C();
129  }
130
131  explicit CacheBaseImpl(const CacheOptions &opts, C *allocator = 0)
132      : cache_start_(false), nknown_states_(0),
133        min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
134        cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0),
135        cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ?
136                     opts.gc_limit : kMinCacheLimit),
137        protect_(false) {
138    allocator_ = allocator ? allocator : new C();
139  }
140
141  // Preserve gc parameters. If preserve_cache true, also preserves
142  // cache data.
143  CacheBaseImpl(const CacheBaseImpl<S, C> &impl, bool preserve_cache = false)
144      : VectorFstBaseImpl<S>(), cache_start_(false), nknown_states_(0),
145        min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
146        cache_first_state_(0), cache_gc_(impl.cache_gc_), cache_size_(0),
147        cache_limit_(impl.cache_limit_),
148        protect_(impl.protect_) {
149    allocator_ = new C();
150    if (preserve_cache) {
151      cache_start_ = impl.cache_start_;
152      nknown_states_ = impl.nknown_states_;
153      expanded_states_ = impl.expanded_states_;
154      min_unexpanded_state_id_ = impl.min_unexpanded_state_id_;
155      if (impl.cache_first_state_id_ != kNoStateId) {
156        cache_first_state_id_ = impl.cache_first_state_id_;
157        cache_first_state_ = allocator_->Allocate(cache_first_state_id_);
158        *cache_first_state_ = *impl.cache_first_state_;
159      }
160      cache_states_ = impl.cache_states_;
161      cache_size_ = impl.cache_size_;
162      ReserveStates(impl.NumStates());
163      for (StateId s = 0; s < impl.NumStates(); ++s) {
164        const S *state =
165            static_cast<const VectorFstBaseImpl<S> &>(impl).GetState(s);
166        if (state) {
167          S *copied_state = allocator_->Allocate(s);
168          *copied_state = *state;
169          AddState(copied_state);
170        } else {
171          AddState(0);
172        }
173      }
174      VectorFstBaseImpl<S>::SetStart(impl.Start());
175    }
176  }
177
178  ~CacheBaseImpl() {
179    allocator_->Free(cache_first_state_, cache_first_state_id_);
180    delete allocator_;
181  }
182
183  // Gets a state from its ID; state must exist.
184  const S *GetState(StateId s) const {
185    if (s == cache_first_state_id_)
186      return cache_first_state_;
187    else
188      return VectorFstBaseImpl<S>::GetState(s);
189  }
190
191  // Gets a state from its ID; state must exist.
192  S *GetState(StateId s) {
193    if (s == cache_first_state_id_)
194      return cache_first_state_;
195    else
196      return VectorFstBaseImpl<S>::GetState(s);
197  }
198
199  // Gets a state from its ID; return 0 if it doesn't exist.
200  const S *CheckState(StateId s) const {
201    if (s == cache_first_state_id_)
202      return cache_first_state_;
203    else if (s < NumStates())
204      return VectorFstBaseImpl<S>::GetState(s);
205    else
206      return 0;
207  }
208
209  // Gets a state from its ID; add it if necessary.
210  S *ExtendState(StateId s);
211
212  void SetStart(StateId s) {
213    VectorFstBaseImpl<S>::SetStart(s);
214    cache_start_ = true;
215    if (s >= nknown_states_)
216      nknown_states_ = s + 1;
217  }
218
219  void SetFinal(StateId s, Weight w) {
220    S *state = ExtendState(s);
221    state->final = w;
222    state->flags |= kCacheFinal | kCacheRecent | kCacheModified;
223  }
224
225  // AddArc adds a single arc to state s and does incremental cache
226  // book-keeping.  For efficiency, prefer PushArc and SetArcs below
227  // when possible.
228  void AddArc(StateId s, const Arc &arc) {
229    S *state = ExtendState(s);
230    state->arcs.push_back(arc);
231    if (arc.ilabel == 0) {
232      ++state->niepsilons;
233    }
234    if (arc.olabel == 0) {
235      ++state->noepsilons;
236    }
237    const Arc *parc = state->arcs.empty() ? 0 : &(state->arcs.back());
238    SetProperties(AddArcProperties(Properties(), s, arc, parc));
239    state->flags |= kCacheModified;
240    if (cache_gc_ && s != cache_first_state_id_ &&
241        !(state->flags & kCacheProtect)) {
242      cache_size_ += sizeof(Arc);
243      if (cache_size_ > cache_limit_)
244        GC(s, false);
245    }
246  }
247
248  // Adds a single arc to state s but delays cache book-keeping.
249  // SetArcs must be called when all PushArc calls at a state are
250  // complete.  Do not mix with calls to AddArc.
251  void PushArc(StateId s, const Arc &arc) {
252    S *state = ExtendState(s);
253    state->arcs.push_back(arc);
254  }
255
256  // Marks arcs of state s as cached and does cache book-keeping after all
257  // calls to PushArc have been completed.  Do not mix with calls to AddArc.
258  void SetArcs(StateId s) {
259    S *state = ExtendState(s);
260    vector<Arc> &arcs = state->arcs;
261    state->niepsilons = state->noepsilons = 0;
262    for (size_t a = 0; a < arcs.size(); ++a) {
263      const Arc &arc = arcs[a];
264      if (arc.nextstate >= nknown_states_)
265        nknown_states_ = arc.nextstate + 1;
266      if (arc.ilabel == 0)
267        ++state->niepsilons;
268      if (arc.olabel == 0)
269        ++state->noepsilons;
270    }
271    ExpandedState(s);
272    state->flags |= kCacheArcs | kCacheRecent | kCacheModified;
273    if (cache_gc_ && s != cache_first_state_id_ &&
274        !(state->flags & kCacheProtect)) {
275      cache_size_ += arcs.capacity() * sizeof(Arc);
276      if (cache_size_ > cache_limit_)
277        GC(s, false);
278    }
279  };
280
281  void ReserveArcs(StateId s, size_t n) {
282    S *state = ExtendState(s);
283    state->arcs.reserve(n);
284  }
285
286  void DeleteArcs(StateId s, size_t n) {
287    S *state = ExtendState(s);
288    const vector<Arc> &arcs = state->arcs;
289    for (size_t i = 0; i < n; ++i) {
290      size_t j = arcs.size() - i - 1;
291      if (arcs[j].ilabel == 0)
292        --state->niepsilons;
293      if (arcs[j].olabel == 0)
294        --state->noepsilons;
295    }
296
297    state->arcs.resize(arcs.size() - n);
298    SetProperties(DeleteArcsProperties(Properties()));
299    state->flags |= kCacheModified;
300    if (cache_gc_ && s != cache_first_state_id_ &&
301        !(state->flags & kCacheProtect)) {
302      cache_size_ -= n * sizeof(Arc);
303    }
304  }
305
306  void DeleteArcs(StateId s) {
307    S *state = ExtendState(s);
308    size_t n = state->arcs.size();
309    state->niepsilons = 0;
310    state->noepsilons = 0;
311    state->arcs.clear();
312    SetProperties(DeleteArcsProperties(Properties()));
313    state->flags |= kCacheModified;
314    if (cache_gc_ && s != cache_first_state_id_ &&
315        !(state->flags & kCacheProtect)) {
316      cache_size_ -= n * sizeof(Arc);
317    }
318  }
319
320  void DeleteStates(const vector<StateId> &dstates) {
321    size_t old_num_states = NumStates();
322    vector<StateId> newid(old_num_states, 0);
323    for (size_t i = 0; i < dstates.size(); ++i)
324      newid[dstates[i]] = kNoStateId;
325    StateId nstates = 0;
326    for (StateId s = 0; s < old_num_states; ++s) {
327      if (newid[s] != kNoStateId) {
328        newid[s] = nstates;
329        ++nstates;
330      }
331    }
332    // just for states_.resize(), does unnecessary walk.
333    VectorFstBaseImpl<S>::DeleteStates(dstates);
334    SetProperties(DeleteStatesProperties(Properties()));
335    // Update list of cached states.
336    typename list<StateId>::iterator siter = cache_states_.begin();
337    while (siter != cache_states_.end()) {
338      if (newid[*siter] != kNoStateId) {
339        *siter = newid[*siter];
340        ++siter;
341      } else {
342        cache_states_.erase(siter++);
343      }
344    }
345  }
346
347  void DeleteStates() {
348    cache_states_.clear();
349    allocator_->Free(cache_first_state_, cache_first_state_id_);
350    for (int s = 0; s < NumStates(); ++s) {
351      allocator_->Free(VectorFstBaseImpl<S>::GetState(s), s);
352      SetState(s, 0);
353    }
354    nknown_states_ = 0;
355    min_unexpanded_state_id_ = 0;
356    cache_first_state_id_ = kNoStateId;
357    cache_first_state_ = 0;
358    cache_size_ = 0;
359    cache_start_ = false;
360    VectorFstBaseImpl<State>::DeleteStates();
361    SetProperties(DeleteAllStatesProperties(Properties(),
362                                            kExpanded | kMutable));
363  }
364
365  // Is the start state cached?
366  bool HasStart() const {
367    if (!cache_start_ && Properties(kError))
368      cache_start_ = true;
369    return cache_start_;
370  }
371
372  // Is the final weight of state s cached?
373  bool HasFinal(StateId s) const {
374    const S *state = CheckState(s);
375    if (state && state->flags & kCacheFinal) {
376      state->flags |= kCacheRecent;
377      return true;
378    } else {
379      return false;
380    }
381  }
382
383  // Are arcs of state s cached?
384  bool HasArcs(StateId s) const {
385    const S *state = CheckState(s);
386    if (state && state->flags & kCacheArcs) {
387      state->flags |= kCacheRecent;
388      return true;
389    } else {
390      return false;
391    }
392  }
393
394  Weight Final(StateId s) const {
395    const S *state = GetState(s);
396    return state->final;
397  }
398
399  size_t NumArcs(StateId s) const {
400    const S *state = GetState(s);
401    return state->arcs.size();
402  }
403
404  size_t NumInputEpsilons(StateId s) const {
405    const S *state = GetState(s);
406    return state->niepsilons;
407  }
408
409  size_t NumOutputEpsilons(StateId s) const {
410    const S *state = GetState(s);
411    return state->noepsilons;
412  }
413
414  // Provides information needed for generic arc iterator.
415  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
416    const S *state = GetState(s);
417    data->base = 0;
418    data->narcs = state->arcs.size();
419    data->arcs = data->narcs > 0 ? &(state->arcs[0]) : 0;
420    data->ref_count = &(state->ref_count);
421    ++(*data->ref_count);
422  }
423
424  // Number of known states.
425  StateId NumKnownStates() const { return nknown_states_; }
426
427  // Update number of known states taking in account the existence of state s.
428  void UpdateNumKnownStates(StateId s) {
429    if (s >= nknown_states_)
430      nknown_states_ = s + 1;
431  }
432
433  // Find the mininum never-expanded state Id
434  StateId MinUnexpandedState() const {
435    while (min_unexpanded_state_id_ < expanded_states_.size() &&
436          expanded_states_[min_unexpanded_state_id_])
437      ++min_unexpanded_state_id_;
438    return min_unexpanded_state_id_;
439  }
440
441  // Removes from cache_states_ and uncaches (not referenced-counted
442  // or protected) states that have not been accessed since the last
443  // GC until at most cache_fraction * cache_limit_ bytes are cached.
444  // If that fails to free enough, recurs uncaching recently visited
445  // states as well. If still unable to free enough memory, then
446  // widens cache_limit_ to fulfill condition.
447  void GC(StateId current, bool free_recent,  float cache_fraction = 0.666);
448
449  // Setc/clears GC protection: if true, new states are protected
450  // from garbage collection.
451  void GCProtect(bool on) { protect_ = on; }
452
453  void ExpandedState(StateId s) {
454    if (s < min_unexpanded_state_id_)
455      return;
456    while (expanded_states_.size() <= s)
457      expanded_states_.push_back(false);
458    expanded_states_[s] = true;
459  }
460
461  C *GetAllocator() const {
462    return allocator_;
463  }
464
465  // Caching on/off switch, limit and size accessors.
466  bool GetCacheGc() const { return cache_gc_; }
467  size_t GetCacheLimit() const { return cache_limit_; }
468  size_t GetCacheSize() const { return cache_size_; }
469
470 private:
471  static const size_t kMinCacheLimit = 8096;   // Minimum (non-zero) cache limit
472
473  static const uint32 kCacheFinal =    0x0001;  // Final weight has been cached
474  static const uint32 kCacheArcs =     0x0002;  // Arcs have been cached
475  static const uint32 kCacheRecent =   0x0004;  // Mark as visited since GC
476  static const uint32 kCacheProtect =  0x0008;  // Mark state as GC protected
477
478 public:
479  static const uint32 kCacheModified = 0x0010;  // Mark state as modified
480  static const uint32 kCacheFlags = kCacheFinal | kCacheArcs | kCacheRecent
481      | kCacheProtect | kCacheModified;
482
483 private:
484  C *allocator_;                             // used to allocate new states
485  mutable bool cache_start_;                 // Is the start state cached?
486  StateId nknown_states_;                    // # of known states
487  vector<bool> expanded_states_;             // states that have been expanded
488  mutable StateId min_unexpanded_state_id_;  // minimum never-expanded state Id
489  StateId cache_first_state_id_;             // First cached state id
490  S *cache_first_state_;                     // First cached state
491  list<StateId> cache_states_;               // list of currently cached states
492  bool cache_gc_;                            // enable GC
493  size_t cache_size_;                        // # of bytes cached
494  size_t cache_limit_;                       // # of bytes allowed before GC
495  bool protect_;                             // Protect new states from GC
496
497  void operator=(const CacheBaseImpl<S, C> &impl);    // disallow
498};
499
500// Gets a state from its ID; add it if necessary.
501template <class S, class C>
502S *CacheBaseImpl<S, C>::ExtendState(typename S::Arc::StateId s) {
503  // If 'protect_' true and a new state, protects from garbage collection.
504  if (s == cache_first_state_id_) {
505    return cache_first_state_;                   // Return 1st cached state
506  } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) {
507    cache_first_state_id_ = s;                   // Remember 1st cached state
508    cache_first_state_ = allocator_->Allocate(s);
509    if (protect_) cache_first_state_->flags |= kCacheProtect;
510    return cache_first_state_;
511  } else if (cache_first_state_id_ != kNoStateId &&
512             cache_first_state_->ref_count == 0 &&
513             !(cache_first_state_->flags & kCacheProtect)) {
514    // With Default allocator, the Free and Allocate will reuse the same S*.
515    allocator_->Free(cache_first_state_, cache_first_state_id_);
516    cache_first_state_id_ = s;
517    cache_first_state_ = allocator_->Allocate(s);
518    if (protect_) cache_first_state_->flags |= kCacheProtect;
519    return cache_first_state_;                   // Return 1st cached state
520  } else {
521    while (NumStates() <= s)                     // Add state to main cache
522      AddState(0);
523    S *state = VectorFstBaseImpl<S>::GetState(s);
524    if (!state) {
525      state = allocator_->Allocate(s);
526      if (protect_) state->flags |= kCacheProtect;
527      SetState(s, state);
528      if (cache_first_state_id_ != kNoStateId) {  // Forget 1st cached state
529        while (NumStates() <= cache_first_state_id_)
530          AddState(0);
531        SetState(cache_first_state_id_, cache_first_state_);
532        if (cache_gc_ && !(cache_first_state_->flags & kCacheProtect)) {
533          cache_states_.push_back(cache_first_state_id_);
534          cache_size_ += sizeof(S) +
535                         cache_first_state_->arcs.capacity() * sizeof(Arc);
536        }
537        cache_limit_ = kMinCacheLimit;
538        cache_first_state_id_ = kNoStateId;
539        cache_first_state_ = 0;
540      }
541      if (cache_gc_ && !protect_) {
542        cache_states_.push_back(s);
543        cache_size_ += sizeof(S);
544        if (cache_size_ > cache_limit_)
545          GC(s, false);
546      }
547    }
548    return state;
549  }
550}
551
552// Removes from cache_states_ and uncaches (not referenced-counted or
553// protected) states that have not been accessed since the last GC
554// until at most cache_fraction * cache_limit_ bytes are cached.  If
555// that fails to free enough, recurs uncaching recently visited states
556// as well. If still unable to free enough memory, then widens cache_limit_
557// to fulfill condition.
558template <class S, class C>
559void CacheBaseImpl<S, C>::GC(typename S::Arc::StateId current,
560                             bool free_recent, float cache_fraction) {
561  if (!cache_gc_)
562    return;
563  VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this
564          << "), free recently cached = " << free_recent
565          << ", cache size = " << cache_size_
566          << ", cache frac = " << cache_fraction
567          << ", cache limit = " << cache_limit_ << "\n";
568  typename list<StateId>::iterator siter = cache_states_.begin();
569
570  size_t cache_target = cache_fraction * cache_limit_;
571  while (siter != cache_states_.end()) {
572    StateId s = *siter;
573    S* state = VectorFstBaseImpl<S>::GetState(s);
574    if (cache_size_ > cache_target && state->ref_count == 0 &&
575        (free_recent || !(state->flags & kCacheRecent)) && s != current) {
576      cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc);
577      allocator_->Free(state, s);
578      SetState(s, 0);
579      cache_states_.erase(siter++);
580    } else {
581      state->flags &= ~kCacheRecent;
582      ++siter;
583    }
584  }
585  if (!free_recent && cache_size_ > cache_target) {   // recurses on recent
586    GC(current, true);
587  } else if (cache_target > 0) {                      // widens cache limit
588    while (cache_size_ > cache_target) {
589      cache_limit_ *= 2;
590      cache_target *= 2;
591    }
592  } else if (cache_size_ > 0) {
593    FSTERROR() << "CacheImpl:GC: Unable to free all cached states";
594  }
595  VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this
596          << "), free recently cached = " << free_recent
597          << ", cache size = " << cache_size_
598          << ", cache frac = " << cache_fraction
599          << ", cache limit = " << cache_limit_ << "\n";
600}
601
602template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheFinal;
603template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheArcs;
604template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheRecent;
605template <class S, class C> const uint32 CacheBaseImpl<S, C>::kCacheModified;
606template <class S, class C> const size_t CacheBaseImpl<S, C>::kMinCacheLimit;
607
608// Arcs implemented by an STL vector per state. Similar to VectorState
609// but adds flags and ref count to keep track of what has been cached.
610template <class A>
611struct CacheState {
612  typedef A Arc;
613  typedef typename A::Weight Weight;
614  typedef typename A::StateId StateId;
615
616  CacheState() :  final(Weight::Zero()), flags(0), ref_count(0) {}
617
618  void Reset() {
619    flags = 0;
620    ref_count = 0;
621    arcs.resize(0);
622  }
623
624  Weight final;              // Final weight
625  vector<A> arcs;            // Arcs represenation
626  size_t niepsilons;         // # of input epsilons
627  size_t noepsilons;         // # of output epsilons
628  mutable uint32 flags;
629  mutable int ref_count;
630};
631
632// A CacheBaseImpl with a commonly used CacheState.
633template <class A>
634class CacheImpl : public CacheBaseImpl< CacheState<A> > {
635 public:
636  typedef CacheState<A> State;
637
638  CacheImpl() {}
639
640  explicit CacheImpl(const CacheOptions &opts)
641      : CacheBaseImpl< CacheState<A> >(opts) {}
642
643  CacheImpl(const CacheImpl<A> &impl, bool preserve_cache = false)
644      : CacheBaseImpl<State>(impl, preserve_cache) {}
645
646 private:
647  void operator=(const CacheImpl<State> &impl);    // disallow
648};
649
650
651// Use this to make a state iterator for a CacheBaseImpl-derived Fst,
652// which must have type 'State' defined.  Note this iterator only
653// returns those states reachable from the initial state, so consider
654// implementing a class-specific one.
655template <class F>
656class CacheStateIterator : public StateIteratorBase<typename F::Arc> {
657 public:
658  typedef typename F::Arc Arc;
659  typedef typename Arc::StateId StateId;
660  typedef typename F::State State;
661  typedef CacheBaseImpl<State> Impl;
662
663  CacheStateIterator(const F &fst, Impl *impl)
664      : fst_(fst), impl_(impl), s_(0) {
665        fst_.Start();  // force start state
666      }
667
668  bool Done() const {
669    if (s_ < impl_->NumKnownStates())
670      return false;
671    if (s_ < impl_->NumKnownStates())
672      return false;
673    for (StateId u = impl_->MinUnexpandedState();
674         u < impl_->NumKnownStates();
675         u = impl_->MinUnexpandedState()) {
676      // force state expansion
677      ArcIterator<F> aiter(fst_, u);
678      aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache);
679      for (; !aiter.Done(); aiter.Next())
680        impl_->UpdateNumKnownStates(aiter.Value().nextstate);
681      impl_->ExpandedState(u);
682      if (s_ < impl_->NumKnownStates())
683        return false;
684    }
685    return true;
686  }
687
688  StateId Value() const { return s_; }
689
690  void Next() { ++s_; }
691
692  void Reset() { s_ = 0; }
693
694 private:
695  // This allows base class virtual access to non-virtual derived-
696  // class members of the same name. It makes the derived class more
697  // efficient to use but unsafe to further derive.
698  virtual bool Done_() const { return Done(); }
699  virtual StateId Value_() const { return Value(); }
700  virtual void Next_() { Next(); }
701  virtual void Reset_() { Reset(); }
702
703  const F &fst_;
704  Impl *impl_;
705  StateId s_;
706};
707
708
709// Use this to make an arc iterator for a CacheBaseImpl-derived Fst,
710// which must have types 'Arc' and 'State' defined.
711template <class F,
712          class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > >
713class CacheArcIterator {
714 public:
715  typedef typename F::Arc Arc;
716  typedef typename F::State State;
717  typedef typename Arc::StateId StateId;
718  typedef CacheBaseImpl<State, C> Impl;
719
720  CacheArcIterator(Impl *impl, StateId s) : i_(0) {
721    state_ = impl->ExtendState(s);
722    ++state_->ref_count;
723  }
724
725  ~CacheArcIterator() { --state_->ref_count;  }
726
727  bool Done() const { return i_ >= state_->arcs.size(); }
728
729  const Arc& Value() const { return state_->arcs[i_]; }
730
731  void Next() { ++i_; }
732
733  size_t Position() const { return i_; }
734
735  void Reset() { i_ = 0; }
736
737  void Seek(size_t a) { i_ = a; }
738
739  uint32 Flags() const {
740    return kArcValueFlags;
741  }
742
743  void SetFlags(uint32 flags, uint32 mask) {}
744
745 private:
746  const State *state_;
747  size_t i_;
748
749  DISALLOW_COPY_AND_ASSIGN(CacheArcIterator);
750};
751
752// Use this to make a mutable arc iterator for a CacheBaseImpl-derived Fst,
753// which must have types 'Arc' and 'State' defined.
754template <class F,
755          class C = DefaultCacheStateAllocator<CacheState<typename F::Arc> > >
756class CacheMutableArcIterator
757    : public MutableArcIteratorBase<typename F::Arc> {
758 public:
759  typedef typename F::State State;
760  typedef typename F::Arc Arc;
761  typedef typename Arc::StateId StateId;
762  typedef typename Arc::Weight Weight;
763  typedef CacheBaseImpl<State, C> Impl;
764
765  // You will need to call MutateCheck() in the constructor.
766  CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) {
767    state_ = impl_->ExtendState(s_);
768    ++state_->ref_count;
769  };
770
771  ~CacheMutableArcIterator() {
772    --state_->ref_count;
773  }
774
775  bool Done() const { return i_ >= state_->arcs.size(); }
776
777  const Arc& Value() const { return state_->arcs[i_]; }
778
779  void Next() { ++i_; }
780
781  size_t Position() const { return i_; }
782
783  void Reset() { i_ = 0; }
784
785  void Seek(size_t a) { i_ = a; }
786
787  void SetValue(const Arc& arc) {
788    state_->flags |= CacheBaseImpl<State, C>::kCacheModified;
789    uint64 properties = impl_->Properties();
790    Arc& oarc = state_->arcs[i_];
791    if (oarc.ilabel != oarc.olabel)
792      properties &= ~kNotAcceptor;
793    if (oarc.ilabel == 0) {
794      --state_->niepsilons;
795      properties &= ~kIEpsilons;
796      if (oarc.olabel == 0)
797        properties &= ~kEpsilons;
798    }
799    if (oarc.olabel == 0) {
800      --state_->noepsilons;
801      properties &= ~kOEpsilons;
802    }
803    if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One())
804      properties &= ~kWeighted;
805    oarc = arc;
806    if (arc.ilabel != arc.olabel) {
807      properties |= kNotAcceptor;
808      properties &= ~kAcceptor;
809    }
810    if (arc.ilabel == 0) {
811      ++state_->niepsilons;
812      properties |= kIEpsilons;
813      properties &= ~kNoIEpsilons;
814      if (arc.olabel == 0) {
815        properties |= kEpsilons;
816        properties &= ~kNoEpsilons;
817      }
818    }
819    if (arc.olabel == 0) {
820      ++state_->noepsilons;
821      properties |= kOEpsilons;
822      properties &= ~kNoOEpsilons;
823    }
824    if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) {
825      properties |= kWeighted;
826      properties &= ~kUnweighted;
827    }
828    properties &= kSetArcProperties | kAcceptor | kNotAcceptor |
829        kEpsilons | kNoEpsilons | kIEpsilons | kNoIEpsilons |
830        kOEpsilons | kNoOEpsilons | kWeighted | kUnweighted;
831    impl_->SetProperties(properties);
832  }
833
834  uint32 Flags() const {
835    return kArcValueFlags;
836  }
837
838  void SetFlags(uint32 f, uint32 m) {}
839
840 private:
841  virtual bool Done_() const { return Done(); }
842  virtual const Arc& Value_() const { return Value(); }
843  virtual void Next_() { Next(); }
844  virtual size_t Position_() const { return Position(); }
845  virtual void Reset_() { Reset(); }
846  virtual void Seek_(size_t a) { Seek(a); }
847  virtual void SetValue_(const Arc &a) { SetValue(a); }
848  uint32 Flags_() const { return Flags(); }
849  void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); }
850
851  size_t i_;
852  StateId s_;
853  Impl *impl_;
854  State *state_;
855
856  DISALLOW_COPY_AND_ASSIGN(CacheMutableArcIterator);
857};
858
859}  // namespace fst
860
861#endif  // FST_LIB_CACHE_H__
862