1// cache.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// An Fst implementation that caches FST elements of a delayed
18// computation.
19
20#ifndef FST_LIB_CACHE_H__
21#define FST_LIB_CACHE_H__
22
23#include <list>
24
25#include "fst/lib/vector-fst.h"
26
27DECLARE_bool(fst_default_cache_gc);
28DECLARE_int64(fst_default_cache_gc_limit);
29
30namespace fst {
31
32struct CacheOptions {
33  bool gc;          // enable GC
34  size_t gc_limit;  // # of bytes allowed before GC
35
36
37  CacheOptions(bool g, size_t l) : gc(g), gc_limit(l) {}
38  CacheOptions()
39      : gc(FLAGS_fst_default_cache_gc),
40        gc_limit(FLAGS_fst_default_cache_gc_limit) {}
41};
42
43
44// This is a VectorFstBaseImpl container that holds a State similar to
45// VectorState but additionally has a flags data member (see
46// CacheState below). This class is used to cache FST elements with
47// the flags used to indicate what has been cached. Use HasStart()
48// HasFinal(), and HasArcs() to determine if cached and SetStart(),
49// SetFinal(), AddArc(), and SetArcs() to cache. Note you must set the
50// final weight even if the state is non-final to mark it as
51// cached. If the 'gc' option is 'false', cached items have the extent
52// of the FST - minimizing computation. If the 'gc' option is 'true',
53// garbage collection of states (not in use in an arc iterator) is
54// performed, in a rough approximation of LRU order, when 'gc_limit'
55// bytes is reached - controlling memory use. When 'gc_limit' is 0,
56// special optimizations apply - minimizing memory use.
57
58template <class S>
59class CacheBaseImpl : public VectorFstBaseImpl<S> {
60 public:
61  using FstImpl<typename S::Arc>::Type;
62  using VectorFstBaseImpl<S>::NumStates;
63  using VectorFstBaseImpl<S>::AddState;
64
65  typedef S State;
66  typedef typename S::Arc Arc;
67  typedef typename Arc::Weight Weight;
68  typedef typename Arc::StateId StateId;
69
70  CacheBaseImpl()
71      : cache_start_(false), nknown_states_(0), min_unexpanded_state_id_(0),
72        cache_first_state_id_(kNoStateId), cache_first_state_(0),
73        cache_gc_(FLAGS_fst_default_cache_gc),  cache_size_(0),
74        cache_limit_(FLAGS_fst_default_cache_gc_limit > kMinCacheLimit ||
75                     FLAGS_fst_default_cache_gc_limit == 0 ?
76                     FLAGS_fst_default_cache_gc_limit : kMinCacheLimit) {}
77
78  explicit CacheBaseImpl(const CacheOptions &opts)
79      : cache_start_(false), nknown_states_(0),
80        min_unexpanded_state_id_(0), cache_first_state_id_(kNoStateId),
81        cache_first_state_(0), cache_gc_(opts.gc), cache_size_(0),
82        cache_limit_(opts.gc_limit > kMinCacheLimit || opts.gc_limit == 0 ?
83                     opts.gc_limit : kMinCacheLimit) {}
84
85  ~CacheBaseImpl() {
86    delete cache_first_state_;
87  }
88
89  // Gets a state from its ID; state must exist.
90  const S *GetState(StateId s) const {
91    if (s == cache_first_state_id_)
92      return cache_first_state_;
93    else
94      return VectorFstBaseImpl<S>::GetState(s);
95  }
96
97  // Gets a state from its ID; state must exist.
98  S *GetState(StateId s) {
99    if (s == cache_first_state_id_)
100      return cache_first_state_;
101    else
102      return VectorFstBaseImpl<S>::GetState(s);
103  }
104
105  // Gets a state from its ID; return 0 if it doesn't exist.
106  const S *CheckState(StateId s) const {
107    if (s == cache_first_state_id_)
108      return cache_first_state_;
109    else if (s < NumStates())
110      return VectorFstBaseImpl<S>::GetState(s);
111    else
112      return 0;
113  }
114
115  // Gets a state from its ID; add it if necessary.
116  S *ExtendState(StateId s) {
117    if (s == cache_first_state_id_) {
118      return cache_first_state_;                   // Return 1st cached state
119    } else if (cache_limit_ == 0 && cache_first_state_id_ == kNoStateId) {
120      cache_first_state_id_ = s;                   // Remember 1st cached state
121      cache_first_state_ = new S;
122      return cache_first_state_;
123    } else if (cache_first_state_id_ != kNoStateId &&
124               cache_first_state_->ref_count == 0) {
125      cache_first_state_id_ = s;                   // Reuse 1st cached state
126      cache_first_state_->Reset();
127      return cache_first_state_;                   // Return 1st cached state
128    } else {
129      while (NumStates() <= s)                     // Add state to main cache
130        AddState(0);
131      if (!VectorFstBaseImpl<S>::GetState(s)) {
132        this->SetState(s, new S);
133        if (cache_first_state_id_ != kNoStateId) {  // Forget 1st cached state
134          while (NumStates() <= cache_first_state_id_)
135            AddState(0);
136          this->SetState(cache_first_state_id_, cache_first_state_);
137          if (cache_gc_) {
138            cache_states_.push_back(cache_first_state_id_);
139            cache_size_ += sizeof(S) +
140                           cache_first_state_->arcs.capacity() * sizeof(Arc);
141            cache_limit_ = kMinCacheLimit;
142          }
143          cache_first_state_id_ = kNoStateId;
144          cache_first_state_ = 0;
145        }
146        if (cache_gc_) {
147          cache_states_.push_back(s);
148          cache_size_ += sizeof(S);
149          if (cache_size_ > cache_limit_)
150            GC(s, false);
151        }
152      }
153      return VectorFstBaseImpl<S>::GetState(s);
154    }
155  }
156
157  void SetStart(StateId s) {
158    VectorFstBaseImpl<S>::SetStart(s);
159    cache_start_ = true;
160    if (s >= nknown_states_)
161      nknown_states_ = s + 1;
162  }
163
164  void SetFinal(StateId s, Weight w) {
165    S *state = ExtendState(s);
166    state->final = w;
167    state->flags |= kCacheFinal | kCacheRecent;
168  }
169
170  void AddArc(StateId s, const Arc &arc) {
171    S *state = ExtendState(s);
172    state->arcs.push_back(arc);
173  }
174
175  // Marks arcs of state s as cached.
176  void SetArcs(StateId s) {
177    S *state = ExtendState(s);
178    vector<Arc> &arcs = state->arcs;
179    state->niepsilons = state->noepsilons = 0;
180    for (unsigned int a = 0; a < arcs.size(); ++a) {
181      const Arc &arc = arcs[a];
182      if (arc.nextstate >= nknown_states_)
183        nknown_states_ = arc.nextstate + 1;
184      if (arc.ilabel == 0)
185        ++state->niepsilons;
186      if (arc.olabel == 0)
187        ++state->noepsilons;
188    }
189    ExpandedState(s);
190    state->flags |= kCacheArcs | kCacheRecent;
191    if (cache_gc_ && s != cache_first_state_id_) {
192      cache_size_ += arcs.capacity() * sizeof(Arc);
193      if (cache_size_ > cache_limit_)
194        GC(s, false);
195    }
196  };
197
198  void ReserveArcs(StateId s, size_t n) {
199    S *state = ExtendState(s);
200    state->arcs.reserve(n);
201  }
202
203  // Is the start state cached?
204  bool HasStart() const { return cache_start_; }
205  // Is the final weight of state s cached?
206
207  bool HasFinal(StateId s) const {
208    const S *state = CheckState(s);
209    if (state && state->flags & kCacheFinal) {
210      state->flags |= kCacheRecent;
211      return true;
212    } else {
213      return false;
214    }
215  }
216
217  // Are arcs of state s cached?
218  bool HasArcs(StateId s) const {
219    const S *state = CheckState(s);
220    if (state && state->flags & kCacheArcs) {
221      state->flags |= kCacheRecent;
222      return true;
223    } else {
224      return false;
225    }
226  }
227
228  Weight Final(StateId s) const {
229    const S *state = GetState(s);
230    return state->final;
231  }
232
233  size_t NumArcs(StateId s) const {
234    const S *state = GetState(s);
235    return state->arcs.size();
236  }
237
238  size_t NumInputEpsilons(StateId s) const {
239    const S *state = GetState(s);
240    return state->niepsilons;
241  }
242
243  size_t NumOutputEpsilons(StateId s) const {
244    const S *state = GetState(s);
245    return state->noepsilons;
246  }
247
248  // Provides information needed for generic arc iterator.
249  void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
250    const S *state = GetState(s);
251    data->base = 0;
252    data->narcs = state->arcs.size();
253    data->arcs = data->narcs > 0 ? &(state->arcs[0]) : 0;
254    data->ref_count = &(state->ref_count);
255    ++(*data->ref_count);
256  }
257
258  // Number of known states.
259  StateId NumKnownStates() const { return nknown_states_; }
260  // Find the mininum never-expanded state Id
261  StateId MinUnexpandedState() const {
262    while (min_unexpanded_state_id_ < (StateId)expanded_states_.size() &&
263          expanded_states_[min_unexpanded_state_id_])
264      ++min_unexpanded_state_id_;
265    return min_unexpanded_state_id_;
266  }
267
268  // Removes from cache_states_ and uncaches (not referenced-counted)
269  // states that have not been accessed since the last GC until
270  // cache_limit_/3 bytes are uncached.  If that fails to free enough,
271  // recurs uncaching recently visited states as well. If still
272  // unable to free enough memory, then widens cache_limit_.
273  void GC(StateId current, bool free_recent) {
274    if (!cache_gc_)
275      return;
276    VLOG(2) << "CacheImpl: Enter GC: object = " << Type() << "(" << this
277            << "), free recently cached = " << free_recent
278            << ", cache size = " << cache_size_
279            << ", cache limit = " << cache_limit_ << "\n";
280    typename list<StateId>::iterator siter = cache_states_.begin();
281
282    size_t cache_target = (2 * cache_limit_)/3 + 1;
283    while (siter != cache_states_.end()) {
284      StateId s = *siter;
285      S* state = VectorFstBaseImpl<S>::GetState(s);
286      if (cache_size_ > cache_target && state->ref_count == 0 &&
287          (free_recent || !(state->flags & kCacheRecent)) && s != current) {
288        cache_size_ -= sizeof(S) + state->arcs.capacity() * sizeof(Arc);
289        delete state;
290        this->SetState(s, 0);
291        cache_states_.erase(siter++);
292      } else {
293        state->flags &= ~kCacheRecent;
294        ++siter;
295      }
296    }
297    if (!free_recent && cache_size_ > cache_target) {
298      GC(current, true);
299    } else {
300      while (cache_size_ > cache_target) {
301        cache_limit_ *= 2;
302        cache_target *= 2;
303      }
304    }
305    VLOG(2) << "CacheImpl: Exit GC: object = " << Type() << "(" << this
306            << "), free recently cached = " << free_recent
307            << ", cache size = " << cache_size_
308            << ", cache limit = " << cache_limit_ << "\n";
309  }
310
311 private:
312  static const uint32 kCacheFinal =  0x0001;  // Final weight has been cached
313  static const uint32 kCacheArcs =   0x0002;  // Arcs have been cached
314  static const uint32 kCacheRecent = 0x0004;  // Mark as visited since GC
315
316  static const size_t kMinCacheLimit;         // Minimum (non-zero) cache limit
317
318  void ExpandedState(StateId s) {
319    if (s < min_unexpanded_state_id_)
320      return;
321    while ((StateId)expanded_states_.size() <= s)
322      expanded_states_.push_back(false);
323    expanded_states_[s] = true;
324  }
325
326  bool cache_start_;                         // Is the start state cached?
327  StateId nknown_states_;                    // # of known states
328  vector<bool> expanded_states_;             // states that have been expanded
329  mutable StateId min_unexpanded_state_id_;  // minimum never-expanded state Id
330  StateId cache_first_state_id_;             // First cached state id
331  S *cache_first_state_;                     // First cached state
332  list<StateId> cache_states_;               // list of currently cached states
333  bool cache_gc_;                            // enable GC
334  size_t cache_size_;                        // # of bytes cached
335  size_t cache_limit_;                       // # of bytes allowed before GC
336
337  void InitStateIterator(StateIteratorData<Arc> *);  // disallow
338  DISALLOW_EVIL_CONSTRUCTORS(CacheBaseImpl);
339};
340
341template <class S>
342const size_t CacheBaseImpl<S>::kMinCacheLimit = 8096;
343
344
345// Arcs implemented by an STL vector per state. Similar to VectorState
346// but adds flags and ref count to keep track of what has been cached.
347template <class A>
348struct CacheState {
349  typedef A Arc;
350  typedef typename A::Weight Weight;
351  typedef typename A::StateId StateId;
352
353  CacheState() :  final(Weight::Zero()), flags(0), ref_count(0) {}
354
355  void Reset() {
356    flags = 0;
357    ref_count = 0;
358    arcs.resize(0);
359  }
360
361  Weight final;              // Final weight
362  vector<A> arcs;            // Arcs represenation
363  size_t niepsilons;         // # of input epsilons
364  size_t noepsilons;         // # of output epsilons
365  mutable uint32 flags;
366  mutable int ref_count;
367};
368
369// A CacheBaseImpl with a commonly used CacheState.
370template <class A>
371class CacheImpl : public CacheBaseImpl< CacheState<A> > {
372 public:
373  typedef CacheState<A> State;
374
375  CacheImpl() {}
376
377  explicit CacheImpl(const CacheOptions &opts)
378      : CacheBaseImpl< CacheState<A> >(opts) {}
379
380 private:
381  DISALLOW_EVIL_CONSTRUCTORS(CacheImpl);
382};
383
384
385// Use this to make a state iterator for a CacheBaseImpl-derived Fst.
386// You'll need to make this class a friend of your derived Fst.
387// Note this iterator only returns those states reachable from
388// the initial state, so consider implementing a class-specific one.
389template <class F>
390class CacheStateIterator : public StateIteratorBase<typename F::Arc> {
391 public:
392  typedef typename F::Arc Arc;
393  typedef typename Arc::StateId StateId;
394
395  explicit CacheStateIterator(const F &fst) : fst_(fst), s_(0) {}
396
397  virtual bool Done() const {
398    if (s_ < fst_.impl_->NumKnownStates())
399      return false;
400    fst_.Start();  // force start state
401    if (s_ < fst_.impl_->NumKnownStates())
402      return false;
403    for (int u = fst_.impl_->MinUnexpandedState();
404         u < fst_.impl_->NumKnownStates();
405         u = fst_.impl_->MinUnexpandedState()) {
406      ArcIterator<F>(fst_, u);  // force state expansion
407      if (s_ < fst_.impl_->NumKnownStates())
408        return false;
409    }
410    return true;
411  }
412
413  virtual StateId Value() const { return s_; }
414
415  virtual void Next() { ++s_; }
416
417  virtual void Reset() { s_ = 0; }
418
419 private:
420  const F &fst_;
421  StateId s_;
422};
423
424
425// Use this to make an arc iterator for a CacheBaseImpl-derived Fst.
426// You'll need to make this class a friend of your derived Fst and
427// define types Arc and State.
428template <class F>
429class CacheArcIterator {
430 public:
431  typedef typename F::Arc Arc;
432  typedef typename F::State State;
433  typedef typename Arc::StateId StateId;
434
435  CacheArcIterator(const F &fst, StateId s) : i_(0) {
436    state_ = fst.impl_->ExtendState(s);
437    ++state_->ref_count;
438  }
439
440  ~CacheArcIterator() { --state_->ref_count;  }
441
442  bool Done() const { return i_ >= state_->arcs.size(); }
443
444  const Arc& Value() const { return state_->arcs[i_]; }
445
446  void Next() { ++i_; }
447
448  void Reset() { i_ = 0; }
449
450  void Seek(size_t a) { i_ = a; }
451
452 private:
453  const State *state_;
454  size_t i_;
455
456  DISALLOW_EVIL_CONSTRUCTORS(CacheArcIterator);
457};
458
459}  // namespace fst
460
461#endif  // FST_LIB_CACHE_H__
462