accumulator.h revision dfd8b8327b93660601d016cdc6f29f433b45a8d8
1// accumulator.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Classes to accumulate arc weights. Useful for weight lookahead.
20
21#ifndef FST_LIB_ACCUMULATOR_H__
22#define FST_LIB_ACCUMULATOR_H__
23
24#include <algorithm>
25#include <functional>
26#include <unordered_map>
27using std::tr1::unordered_map;
28using std::tr1::unordered_multimap;
29#include <vector>
30using std::vector;
31
32#include <fst/arcfilter.h>
33#include <fst/arcsort.h>
34#include <fst/dfs-visit.h>
35#include <fst/expanded-fst.h>
36#include <fst/replace.h>
37
38namespace fst {
39
40// This class accumulates arc weights using the semiring Plus().
41template <class A>
42class DefaultAccumulator {
43 public:
44  typedef A Arc;
45  typedef typename A::StateId StateId;
46  typedef typename A::Weight Weight;
47
48  DefaultAccumulator() {}
49
50  DefaultAccumulator(const DefaultAccumulator<A> &acc) {}
51
52  void Init(const Fst<A>& fst, bool copy = false) {}
53
54  void SetState(StateId) {}
55
56  Weight Sum(Weight w, Weight v) {
57    return Plus(w, v);
58  }
59
60  template <class ArcIterator>
61  Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
62             ssize_t end) {
63    Weight sum = w;
64    aiter->Seek(begin);
65    for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
66      sum = Plus(sum, aiter->Value().weight);
67    return sum;
68  }
69
70  bool Error() const { return false; }
71
72 private:
73  void operator=(const DefaultAccumulator<A> &);   // Disallow
74};
75
76
77// This class accumulates arc weights using the log semiring Plus()
78// assuming an arc weight has a WeightConvert specialization to
79// and from log64 weights.
80template <class A>
81class LogAccumulator {
82 public:
83  typedef A Arc;
84  typedef typename A::StateId StateId;
85  typedef typename A::Weight Weight;
86
87  LogAccumulator() {}
88
89  LogAccumulator(const LogAccumulator<A> &acc) {}
90
91  void Init(const Fst<A>& fst, bool copy = false) {}
92
93  void SetState(StateId) {}
94
95  Weight Sum(Weight w, Weight v) {
96    return LogPlus(w, v);
97  }
98
99  template <class ArcIterator>
100  Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
101             ssize_t end) {
102    Weight sum = w;
103    aiter->Seek(begin);
104    for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
105      sum = LogPlus(sum, aiter->Value().weight);
106    return sum;
107  }
108
109  bool Error() const { return false; }
110
111 private:
112  double LogPosExp(double x) { return log(1.0F + exp(-x)); }
113
114  Weight LogPlus(Weight w, Weight v) {
115    double f1 = to_log_weight_(w).Value();
116    double f2 = to_log_weight_(v).Value();
117    if (f1 > f2)
118      return to_weight_(f2 - LogPosExp(f1 - f2));
119    else
120      return to_weight_(f1 - LogPosExp(f2 - f1));
121  }
122
123  WeightConvert<Weight, Log64Weight> to_log_weight_;
124  WeightConvert<Log64Weight, Weight> to_weight_;
125
126  void operator=(const LogAccumulator<A> &);   // Disallow
127};
128
129
130// Stores shareable data for fast log accumulator copies.
131class FastLogAccumulatorData {
132 public:
133  FastLogAccumulatorData() {}
134
135  vector<double> *Weights() { return &weights_; }
136  vector<ssize_t> *WeightPositions() { return &weight_positions_; }
137  double *WeightEnd() { return &(weights_[weights_.size() - 1]); };
138  int RefCount() const { return ref_count_.count(); }
139  int IncrRefCount() { return ref_count_.Incr(); }
140  int DecrRefCount() { return ref_count_.Decr(); }
141
142 private:
143  // Cummulative weight per state for all states s.t. # of arcs >
144  // arc_limit_ with arcs in order. Special first element per state
145  // being Log64Weight::Zero();
146  vector<double> weights_;
147  // Maps from state to corresponding beginning weight position in
148  // weights_. Position -1 means no pre-computed weights for that
149  // state.
150  vector<ssize_t> weight_positions_;
151  RefCounter ref_count_;                  // Reference count.
152
153  DISALLOW_COPY_AND_ASSIGN(FastLogAccumulatorData);
154};
155
156
157// This class accumulates arc weights using the log semiring Plus()
158// assuming an arc weight has a WeightConvert specialization to and
159// from log64 weights. The member function Init(fst) has to be called
160// to setup pre-computed weight information.
161template <class A>
162class FastLogAccumulator {
163 public:
164  typedef A Arc;
165  typedef typename A::StateId StateId;
166  typedef typename A::Weight Weight;
167
168  explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10)
169      : arc_limit_(arc_limit),
170        arc_period_(arc_period),
171        data_(new FastLogAccumulatorData()),
172        error_(false) {}
173
174  FastLogAccumulator(const FastLogAccumulator<A> &acc)
175      : arc_limit_(acc.arc_limit_),
176        arc_period_(acc.arc_period_),
177        data_(acc.data_),
178        error_(acc.error_) {
179    data_->IncrRefCount();
180  }
181
182  ~FastLogAccumulator() {
183    if (!data_->DecrRefCount())
184      delete data_;
185  }
186
187  void SetState(StateId s) {
188    vector<double> &weights = *data_->Weights();
189    vector<ssize_t> &weight_positions = *data_->WeightPositions();
190
191    if (weight_positions.size() <= s) {
192      FSTERROR() << "FastLogAccumulator::SetState: invalid state id.";
193      error_ = true;
194      return;
195    }
196
197    ssize_t pos = weight_positions[s];
198    if (pos >= 0)
199      state_weights_ = &(weights[pos]);
200    else
201      state_weights_ = 0;
202  }
203
204  Weight Sum(Weight w, Weight v) {
205    return LogPlus(w, v);
206  }
207
208  template <class ArcIterator>
209  Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
210             ssize_t end) {
211    if (error_) return Weight::NoWeight();
212    Weight sum = w;
213    // Finds begin and end of pre-stored weights
214    ssize_t index_begin = -1, index_end = -1;
215    ssize_t stored_begin = end, stored_end = end;
216    if (state_weights_ != 0) {
217      index_begin = begin > 0 ? (begin - 1)/ arc_period_ + 1 : 0;
218      index_end = end / arc_period_;
219      stored_begin = index_begin * arc_period_;
220      stored_end = index_end * arc_period_;
221    }
222    // Computes sum before pre-stored weights
223    if (begin < stored_begin) {
224      ssize_t pos_end = min(stored_begin, end);
225      aiter->Seek(begin);
226      for (ssize_t pos = begin; pos < pos_end; aiter->Next(), ++pos)
227        sum = LogPlus(sum, aiter->Value().weight);
228    }
229    // Computes sum between pre-stored weights
230    if (stored_begin < stored_end) {
231      sum = LogPlus(sum, LogMinus(state_weights_[index_end],
232                                  state_weights_[index_begin]));
233    }
234    // Computes sum after pre-stored weights
235    if (stored_end < end) {
236      ssize_t pos_start = max(stored_begin, stored_end);
237      aiter->Seek(pos_start);
238      for (ssize_t pos = pos_start; pos < end; aiter->Next(), ++pos)
239        sum = LogPlus(sum, aiter->Value().weight);
240    }
241    return sum;
242  }
243
244  template <class F>
245  void Init(const F &fst, bool copy = false) {
246    if (copy)
247      return;
248    vector<double> &weights = *data_->Weights();
249    vector<ssize_t> &weight_positions = *data_->WeightPositions();
250    if (!weights.empty() || arc_limit_ < arc_period_) {
251      FSTERROR() << "FastLogAccumulator: initialization error.";
252      error_ = true;
253      return;
254    }
255    weight_positions.reserve(CountStates(fst));
256
257    ssize_t weight_position = 0;
258    for(StateIterator<F> siter(fst); !siter.Done(); siter.Next()) {
259      StateId s = siter.Value();
260      if (fst.NumArcs(s) >= arc_limit_) {
261        double sum = FloatLimits<double>::PosInfinity();
262        weight_positions.push_back(weight_position);
263        weights.push_back(sum);
264        ++weight_position;
265        ssize_t narcs = 0;
266        for(ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) {
267          const A &arc = aiter.Value();
268          sum = LogPlus(sum, arc.weight);
269          // Stores cumulative weight distribution per arc_period_.
270          if (++narcs % arc_period_ == 0) {
271            weights.push_back(sum);
272            ++weight_position;
273          }
274        }
275      } else {
276        weight_positions.push_back(-1);
277      }
278    }
279  }
280
281  bool Error() const { return error_; }
282
283 private:
284  double LogPosExp(double x) {
285    return x == FloatLimits<double>::PosInfinity() ?
286        0.0 : log(1.0F + exp(-x));
287  }
288
289  double LogMinusExp(double x) {
290    return x == FloatLimits<double>::PosInfinity() ?
291        0.0 : log(1.0F - exp(-x));
292  }
293
294  Weight LogPlus(Weight w, Weight v) {
295    double f1 = to_log_weight_(w).Value();
296    double f2 = to_log_weight_(v).Value();
297    if (f1 > f2)
298      return to_weight_(f2 - LogPosExp(f1 - f2));
299    else
300      return to_weight_(f1 - LogPosExp(f2 - f1));
301  }
302
303  double LogPlus(double f1, Weight v) {
304    double f2 = to_log_weight_(v).Value();
305    if (f1 == FloatLimits<double>::PosInfinity())
306      return f2;
307    else if (f1 > f2)
308      return f2 - LogPosExp(f1 - f2);
309    else
310      return f1 - LogPosExp(f2 - f1);
311  }
312
313  Weight LogMinus(double f1, double f2) {
314    if (f1 >= f2) {
315      FSTERROR() << "FastLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1
316                 << " and f2 = " << f2;
317      error_ = true;
318      return Weight::NoWeight();
319    }
320    if (f2 == FloatLimits<double>::PosInfinity())
321      return to_weight_(f1);
322    else
323      return to_weight_(f1 - LogMinusExp(f2 - f1));
324  }
325
326  WeightConvert<Weight, Log64Weight> to_log_weight_;
327  WeightConvert<Log64Weight, Weight> to_weight_;
328
329  ssize_t arc_limit_;     // Minimum # of arcs to pre-compute state
330  ssize_t arc_period_;    // Save cumulative weights per 'arc_period_'.
331  bool init_;             // Cumulative weights initialized?
332  FastLogAccumulatorData *data_;
333  double *state_weights_;
334  bool error_;
335
336  void operator=(const FastLogAccumulator<A> &);   // Disallow
337};
338
339
340// Stores shareable data for cache log accumulator copies.
341// All copies share the same cache.
342template <class A>
343class CacheLogAccumulatorData {
344 public:
345  typedef A Arc;
346  typedef typename A::StateId StateId;
347  typedef typename A::Weight Weight;
348
349  CacheLogAccumulatorData(bool gc, size_t gc_limit)
350      : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {}
351
352  ~CacheLogAccumulatorData() {
353    for(typename unordered_map<StateId, CacheState>::iterator it = cache_.begin();
354        it != cache_.end();
355        ++it)
356      delete it->second.weights;
357  }
358
359  bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; }
360
361  vector<double> *GetWeights(StateId s) {
362    typename unordered_map<StateId, CacheState>::iterator it = cache_.find(s);
363    if (it != cache_.end()) {
364      it->second.recent = true;
365      return it->second.weights;
366    } else {
367      return 0;
368    }
369  }
370
371  void AddWeights(StateId s, vector<double> *weights) {
372    if (cache_gc_ && cache_size_ >= cache_limit_)
373      GC(false);
374    cache_.insert(make_pair(s, CacheState(weights, true)));
375    if (cache_gc_)
376      cache_size_ += weights->capacity() * sizeof(double);
377  }
378
379  int RefCount() const { return ref_count_.count(); }
380  int IncrRefCount() { return ref_count_.Incr(); }
381  int DecrRefCount() { return ref_count_.Decr(); }
382
383 private:
384  // Cached information for a given state.
385  struct CacheState {
386    vector<double>* weights;  // Accumulated weights for this state.
387    bool recent;              // Has this state been accessed since last GC?
388
389    CacheState(vector<double> *w, bool r) : weights(w), recent(r) {}
390  };
391
392  // Garbage collect: Delete from cache states that have not been
393  // accessed since the last GC ('free_recent = false') until
394  // 'cache_size_' is 2/3 of 'cache_limit_'. If it does not free enough
395  // memory, start deleting recently accessed states.
396  void GC(bool free_recent) {
397    size_t cache_target = (2 * cache_limit_)/3 + 1;
398    typename unordered_map<StateId, CacheState>::iterator it = cache_.begin();
399    while (it != cache_.end() && cache_size_ > cache_target) {
400      CacheState &cs = it->second;
401      if (free_recent || !cs.recent) {
402        cache_size_ -= cs.weights->capacity() * sizeof(double);
403        delete cs.weights;
404        cache_.erase(it++);
405      } else {
406        cs.recent = false;
407        ++it;
408      }
409    }
410    if (!free_recent && cache_size_ > cache_target)
411      GC(true);
412  }
413
414  unordered_map<StateId, CacheState> cache_;  // Cache
415  bool cache_gc_;                        // Enable garbage collection
416  size_t cache_limit_;                   // # of bytes cached
417  size_t cache_size_;                    // # of bytes allowed before GC
418  RefCounter ref_count_;
419
420  DISALLOW_COPY_AND_ASSIGN(CacheLogAccumulatorData);
421};
422
423// This class accumulates arc weights using the log semiring Plus()
424//  has a WeightConvert specialization to and from log64 weights.  It
425//  is similar to the FastLogAccumator. However here, the accumulated
426//  weights are pre-computed and stored only for the states that are
427//  visited. The member function Init(fst) has to be called to setup
428//  this accumulator.
429template <class A>
430class CacheLogAccumulator {
431 public:
432  typedef A Arc;
433  typedef typename A::StateId StateId;
434  typedef typename A::Weight Weight;
435
436  explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false,
437                               size_t gc_limit = 10 * 1024 * 1024)
438      : arc_limit_(arc_limit), fst_(0), data_(
439          new CacheLogAccumulatorData<A>(gc, gc_limit)), s_(kNoStateId),
440        error_(false) {}
441
442  CacheLogAccumulator(const CacheLogAccumulator<A> &acc)
443      : arc_limit_(acc.arc_limit_), fst_(acc.fst_ ? acc.fst_->Copy() : 0),
444        data_(acc.data_), s_(kNoStateId), error_(acc.error_) {
445    data_->IncrRefCount();
446  }
447
448  ~CacheLogAccumulator() {
449    if (fst_)
450      delete fst_;
451    if (!data_->DecrRefCount())
452      delete data_;
453  }
454
455  // Arg 'arc_limit' specifies minimum # of arcs to pre-compute state.
456  void Init(const Fst<A> &fst, bool copy = false) {
457    if (copy) {
458      delete fst_;
459    } else if (fst_) {
460      FSTERROR() << "CacheLogAccumulator: initialization error.";
461      error_ = true;
462      return;
463    }
464    fst_ = fst.Copy();
465  }
466
467  void SetState(StateId s, int depth = 0) {
468    if (s == s_)
469      return;
470    s_ = s;
471
472    if (data_->CacheDisabled() || error_) {
473      weights_ = 0;
474      return;
475    }
476
477    if (!fst_) {
478      FSTERROR() << "CacheLogAccumulator::SetState: incorrectly initialized.";
479      error_ = true;
480      weights_ = 0;
481      return;
482    }
483
484    weights_ = data_->GetWeights(s);
485    if ((weights_ == 0) && (fst_->NumArcs(s) >= arc_limit_)) {
486      weights_ = new vector<double>;
487      weights_->reserve(fst_->NumArcs(s) + 1);
488      weights_->push_back(FloatLimits<double>::PosInfinity());
489      data_->AddWeights(s, weights_);
490    }
491  }
492
493  Weight Sum(Weight w, Weight v) {
494    return LogPlus(w, v);
495  }
496
497  template <class Iterator>
498  Weight Sum(Weight w, Iterator *aiter, ssize_t begin,
499             ssize_t end) {
500    if (weights_ == 0) {
501      Weight sum = w;
502      aiter->Seek(begin);
503      for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
504        sum = LogPlus(sum, aiter->Value().weight);
505      return sum;
506    } else {
507      if (weights_->size() <= end)
508        for (aiter->Seek(weights_->size() - 1);
509             weights_->size() <= end;
510             aiter->Next())
511          weights_->push_back(LogPlus(weights_->back(),
512                                      aiter->Value().weight));
513      return LogPlus(w, LogMinus((*weights_)[end], (*weights_)[begin]));
514    }
515  }
516
517  template <class Iterator>
518  size_t LowerBound(double w, Iterator *aiter) {
519    if (weights_ != 0) {
520      return lower_bound(weights_->begin() + 1,
521                         weights_->end(),
522                         w,
523                         std::greater<double>())
524          - weights_->begin() - 1;
525    } else {
526      size_t n = 0;
527      double x =  FloatLimits<double>::PosInfinity();
528      for(aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) {
529        x = LogPlus(x, aiter->Value().weight);
530        if (x < w) break;
531      }
532      return n;
533    }
534  }
535
536  bool Error() const { return error_; }
537
538 private:
539  double LogPosExp(double x) {
540    return x == FloatLimits<double>::PosInfinity() ?
541        0.0 : log(1.0F + exp(-x));
542  }
543
544  double LogMinusExp(double x) {
545    return x == FloatLimits<double>::PosInfinity() ?
546        0.0 : log(1.0F - exp(-x));
547  }
548
549  Weight LogPlus(Weight w, Weight v) {
550    double f1 = to_log_weight_(w).Value();
551    double f2 = to_log_weight_(v).Value();
552    if (f1 > f2)
553      return to_weight_(f2 - LogPosExp(f1 - f2));
554    else
555      return to_weight_(f1 - LogPosExp(f2 - f1));
556  }
557
558  double LogPlus(double f1, Weight v) {
559    double f2 = to_log_weight_(v).Value();
560    if (f1 == FloatLimits<double>::PosInfinity())
561      return f2;
562    else if (f1 > f2)
563      return f2 - LogPosExp(f1 - f2);
564    else
565      return f1 - LogPosExp(f2 - f1);
566  }
567
568  Weight LogMinus(double f1, double f2) {
569    if (f1 >= f2) {
570      FSTERROR() << "CacheLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1
571                 << " and f2 = " << f2;
572      error_ = true;
573      return Weight::NoWeight();
574    }
575    if (f2 == FloatLimits<double>::PosInfinity())
576      return to_weight_(f1);
577    else
578      return to_weight_(f1 - LogMinusExp(f2 - f1));
579  }
580
581  WeightConvert<Weight, Log64Weight> to_log_weight_;
582  WeightConvert<Log64Weight, Weight> to_weight_;
583
584  ssize_t arc_limit_;                    // Minimum # of arcs to cache a state
585  vector<double> *weights_;              // Accumulated weights for cur. state
586  const Fst<A>* fst_;                    // Input fst
587  CacheLogAccumulatorData<A> *data_;     // Cache data
588  StateId s_;                            // Current state
589  bool error_;
590
591  void operator=(const CacheLogAccumulator<A> &);   // Disallow
592};
593
594
595// Stores shareable data for replace accumulator copies.
596template <class Accumulator, class T>
597class ReplaceAccumulatorData {
598 public:
599  typedef typename Accumulator::Arc Arc;
600  typedef typename Arc::StateId StateId;
601  typedef typename Arc::Label Label;
602  typedef T StateTable;
603  typedef typename T::StateTuple StateTuple;
604
605  ReplaceAccumulatorData() : state_table_(0) {}
606
607  ReplaceAccumulatorData(const vector<Accumulator*> &accumulators)
608      : state_table_(0), accumulators_(accumulators) {}
609
610  ~ReplaceAccumulatorData() {
611    for (size_t i = 0; i < fst_array_.size(); ++i)
612      delete fst_array_[i];
613    for (size_t i = 0; i < accumulators_.size(); ++i)
614      delete accumulators_[i];
615  }
616
617  void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples,
618       const StateTable *state_table) {
619    state_table_ = state_table;
620    accumulators_.resize(fst_tuples.size());
621    for (size_t i = 0; i < accumulators_.size(); ++i) {
622      if (!accumulators_[i])
623        accumulators_[i] = new Accumulator;
624      accumulators_[i]->Init(*(fst_tuples[i].second));
625      fst_array_.push_back(fst_tuples[i].second->Copy());
626    }
627  }
628
629  const StateTuple &GetTuple(StateId s) const {
630    return state_table_->Tuple(s);
631  }
632
633  Accumulator *GetAccumulator(size_t i) { return accumulators_[i]; }
634
635  const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i]; }
636
637  int RefCount() const { return ref_count_.count(); }
638  int IncrRefCount() { return ref_count_.Incr(); }
639  int DecrRefCount() { return ref_count_.Decr(); }
640
641 private:
642  const T * state_table_;
643  vector<Accumulator*> accumulators_;
644  vector<const Fst<Arc>*> fst_array_;
645  RefCounter ref_count_;
646
647  DISALLOW_COPY_AND_ASSIGN(ReplaceAccumulatorData);
648};
649
650// This class accumulates weights in a ReplaceFst.  The 'Init' method
651// takes as input the argument used to build the ReplaceFst and the
652// ReplaceFst state table. It uses accumulators of type 'Accumulator'
653// in the underlying FSTs.
654template <class Accumulator,
655          class T = DefaultReplaceStateTable<typename Accumulator::Arc> >
656class ReplaceAccumulator {
657 public:
658  typedef typename Accumulator::Arc Arc;
659  typedef typename Arc::StateId StateId;
660  typedef typename Arc::Label Label;
661  typedef typename Arc::Weight Weight;
662  typedef T StateTable;
663  typedef typename T::StateTuple StateTuple;
664
665  ReplaceAccumulator()
666      : init_(false), data_(new ReplaceAccumulatorData<Accumulator, T>()),
667        error_(false) {}
668
669  ReplaceAccumulator(const vector<Accumulator*> &accumulators)
670      : init_(false),
671        data_(new ReplaceAccumulatorData<Accumulator, T>(accumulators)),
672        error_(false) {}
673
674  ReplaceAccumulator(const ReplaceAccumulator<Accumulator, T> &acc)
675      : init_(acc.init_), data_(acc.data_), error_(acc.error_) {
676    if (!init_)
677      FSTERROR() << "ReplaceAccumulator: can't copy unintialized accumulator";
678    data_->IncrRefCount();
679  }
680
681  ~ReplaceAccumulator() {
682     if (!data_->DecrRefCount())
683      delete data_;
684  }
685
686  // Does not take ownership of the state table, the state table
687  // is own by the ReplaceFst
688  void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples,
689            const StateTable *state_table) {
690    init_ = true;
691    data_->Init(fst_tuples, state_table);
692  }
693
694  void SetState(StateId s) {
695    if (!init_) {
696      FSTERROR() << "ReplaceAccumulator::SetState: incorrectly initialized.";
697      error_ = true;
698      return;
699    }
700    StateTuple tuple = data_->GetTuple(s);
701    fst_id_ = tuple.fst_id - 1;  // Replace FST ID is 1-based
702    data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state);
703    if ((tuple.prefix_id != 0) &&
704        (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) {
705      offset_ = 1;
706      offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state);
707    } else {
708      offset_ = 0;
709      offset_weight_ = Weight::Zero();
710    }
711  }
712
713  Weight Sum(Weight w, Weight v) {
714    if (error_) return Weight::NoWeight();
715    return data_->GetAccumulator(fst_id_)->Sum(w, v);
716  }
717
718  template <class ArcIterator>
719  Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
720             ssize_t end) {
721    if (error_) return Weight::NoWeight();
722    Weight sum = begin == end ? Weight::Zero()
723        : data_->GetAccumulator(fst_id_)->Sum(
724            w, aiter, begin ? begin - offset_ : 0, end - offset_);
725    if (begin == 0 && end != 0 && offset_ > 0)
726      sum = Sum(offset_weight_, sum);
727    return sum;
728  }
729
730  bool Error() const { return error_; }
731
732 private:
733  bool init_;
734  ReplaceAccumulatorData<Accumulator, T> *data_;
735  Label fst_id_;
736  size_t offset_;
737  Weight offset_weight_;
738  bool error_;
739
740  void operator=(const ReplaceAccumulator<Accumulator, T> &);   // Disallow
741};
742
743}  // namespace fst
744
745#endif  // FST_LIB_ACCUMULATOR_H__
746