1f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// accumulator.h
2f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
3f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Licensed under the Apache License, Version 2.0 (the "License");
4f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// you may not use this file except in compliance with the License.
5f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// You may obtain a copy of the License at
6f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//
7f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//     http://www.apache.org/licenses/LICENSE-2.0
8f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//
9f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Unless required by applicable law or agreed to in writing, software
10f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// distributed under the License is distributed on an "AS IS" BASIS,
11f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// See the License for the specific language governing permissions and
13f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// limitations under the License.
14f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//
15f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Copyright 2005-2010 Google, Inc.
16f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Author: riley@google.com (Michael Riley)
17f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//
18f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// \file
19f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Classes to accumulate arc weights. Useful for weight lookahead.
20f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
21f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#ifndef FST_LIB_ACCUMULATOR_H__
22f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#define FST_LIB_ACCUMULATOR_H__
23f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
24f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <algorithm>
25f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <functional>
263da1eb108d36da35333b2d655202791af854996bPrzemyslaw Szczepaniak#include <tr1/unordered_map>
27f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonusing std::tr1::unordered_map;
28f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonusing std::tr1::unordered_multimap;
29f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <vector>
30f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonusing std::vector;
31f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
32f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/arcfilter.h>
33f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/arcsort.h>
34f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/dfs-visit.h>
35f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/expanded-fst.h>
36f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/replace.h>
37f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
38f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonnamespace fst {
39f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
40f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// This class accumulates arc weights using the semiring Plus().
41f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class A>
42f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass DefaultAccumulator {
43f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public:
44f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef A Arc;
45f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename A::StateId StateId;
46f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename A::Weight Weight;
47f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
48f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  DefaultAccumulator() {}
49f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
50f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  DefaultAccumulator(const DefaultAccumulator<A> &acc) {}
51f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
52f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void Init(const Fst<A>& fst, bool copy = false) {}
53f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
54f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void SetState(StateId) {}
55f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
56f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight Sum(Weight w, Weight v) {
57f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    return Plus(w, v);
58f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
59f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
60f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  template <class ArcIterator>
61f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
62f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson             ssize_t end) {
63f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    Weight sum = w;
64f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    aiter->Seek(begin);
65f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
66f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      sum = Plus(sum, aiter->Value().weight);
67f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    return sum;
68f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
69f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
70f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  bool Error() const { return false; }
71f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
72f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private:
73f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void operator=(const DefaultAccumulator<A> &);   // Disallow
74f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson};
75f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
76f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
77f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// This class accumulates arc weights using the log semiring Plus()
78f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// assuming an arc weight has a WeightConvert specialization to
79f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// and from log64 weights.
80f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class A>
81f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass LogAccumulator {
82f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public:
83f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef A Arc;
84f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename A::StateId StateId;
85f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename A::Weight Weight;
86f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
87f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  LogAccumulator() {}
88f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
89f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  LogAccumulator(const LogAccumulator<A> &acc) {}
90f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
91f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void Init(const Fst<A>& fst, bool copy = false) {}
92f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
93f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void SetState(StateId) {}
94f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
95f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight Sum(Weight w, Weight v) {
96f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    return LogPlus(w, v);
97f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
98f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
99f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  template <class ArcIterator>
100f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
101f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson             ssize_t end) {
102f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    Weight sum = w;
103f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    aiter->Seek(begin);
104f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
105f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      sum = LogPlus(sum, aiter->Value().weight);
106f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    return sum;
107f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
108f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
109f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  bool Error() const { return false; }
110f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
111f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private:
112f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  double LogPosExp(double x) { return log(1.0F + exp(-x)); }
113f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
114f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight LogPlus(Weight w, Weight v) {
115f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    double f1 = to_log_weight_(w).Value();
116f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    double f2 = to_log_weight_(v).Value();
117f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (f1 > f2)
118f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return to_weight_(f2 - LogPosExp(f1 - f2));
119f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    else
120f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return to_weight_(f1 - LogPosExp(f2 - f1));
121f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
122f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
123f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  WeightConvert<Weight, Log64Weight> to_log_weight_;
124f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  WeightConvert<Log64Weight, Weight> to_weight_;
125f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
126f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void operator=(const LogAccumulator<A> &);   // Disallow
127f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson};
128f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
129f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
130f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Stores shareable data for fast log accumulator copies.
131f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass FastLogAccumulatorData {
132f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public:
133f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  FastLogAccumulatorData() {}
134f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
135f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  vector<double> *Weights() { return &weights_; }
136f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  vector<ssize_t> *WeightPositions() { return &weight_positions_; }
137f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  double *WeightEnd() { return &(weights_[weights_.size() - 1]); };
138f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  int RefCount() const { return ref_count_.count(); }
139f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  int IncrRefCount() { return ref_count_.Incr(); }
140f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  int DecrRefCount() { return ref_count_.Decr(); }
141f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
142f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private:
143f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  // Cummulative weight per state for all states s.t. # of arcs >
144f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  // arc_limit_ with arcs in order. Special first element per state
145f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  // being Log64Weight::Zero();
146f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  vector<double> weights_;
147f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  // Maps from state to corresponding beginning weight position in
148f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  // weights_. Position -1 means no pre-computed weights for that
149f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  // state.
150f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  vector<ssize_t> weight_positions_;
151f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  RefCounter ref_count_;                  // Reference count.
152f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
153f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  DISALLOW_COPY_AND_ASSIGN(FastLogAccumulatorData);
154f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson};
155f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
156f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
157f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// This class accumulates arc weights using the log semiring Plus()
158f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// assuming an arc weight has a WeightConvert specialization to and
159f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// from log64 weights. The member function Init(fst) has to be called
160f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// to setup pre-computed weight information.
161f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class A>
162f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass FastLogAccumulator {
163f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public:
164f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef A Arc;
165f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename A::StateId StateId;
166f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename A::Weight Weight;
167f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
168f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10)
169f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      : arc_limit_(arc_limit),
170f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        arc_period_(arc_period),
171f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        data_(new FastLogAccumulatorData()),
172f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        error_(false) {}
173f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
174f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  FastLogAccumulator(const FastLogAccumulator<A> &acc)
175f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      : arc_limit_(acc.arc_limit_),
176f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        arc_period_(acc.arc_period_),
177f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        data_(acc.data_),
178f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        error_(acc.error_) {
179f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    data_->IncrRefCount();
180f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
181f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
182f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  ~FastLogAccumulator() {
183f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (!data_->DecrRefCount())
184f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      delete data_;
185f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
186f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
187f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void SetState(StateId s) {
188f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    vector<double> &weights = *data_->Weights();
189f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    vector<ssize_t> &weight_positions = *data_->WeightPositions();
190f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
191f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (weight_positions.size() <= s) {
192f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      FSTERROR() << "FastLogAccumulator::SetState: invalid state id.";
193f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      error_ = true;
194f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return;
195f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
196f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
197f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    ssize_t pos = weight_positions[s];
198f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (pos >= 0)
199f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      state_weights_ = &(weights[pos]);
200f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    else
201f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      state_weights_ = 0;
202f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
203f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
204f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight Sum(Weight w, Weight v) {
205f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    return LogPlus(w, v);
206f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
207f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
208f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  template <class ArcIterator>
209f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
210f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson             ssize_t end) {
211f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (error_) return Weight::NoWeight();
212f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    Weight sum = w;
213f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    // Finds begin and end of pre-stored weights
214f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    ssize_t index_begin = -1, index_end = -1;
215f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    ssize_t stored_begin = end, stored_end = end;
216f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (state_weights_ != 0) {
217f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      index_begin = begin > 0 ? (begin - 1)/ arc_period_ + 1 : 0;
218f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      index_end = end / arc_period_;
219f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      stored_begin = index_begin * arc_period_;
220f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      stored_end = index_end * arc_period_;
221f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
222f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    // Computes sum before pre-stored weights
223f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (begin < stored_begin) {
224f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      ssize_t pos_end = min(stored_begin, end);
225f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      aiter->Seek(begin);
226f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      for (ssize_t pos = begin; pos < pos_end; aiter->Next(), ++pos)
227f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        sum = LogPlus(sum, aiter->Value().weight);
228f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
229f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    // Computes sum between pre-stored weights
230f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (stored_begin < stored_end) {
231f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      sum = LogPlus(sum, LogMinus(state_weights_[index_end],
232f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson                                  state_weights_[index_begin]));
233f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
234f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    // Computes sum after pre-stored weights
235f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (stored_end < end) {
236f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      ssize_t pos_start = max(stored_begin, stored_end);
237f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      aiter->Seek(pos_start);
238f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      for (ssize_t pos = pos_start; pos < end; aiter->Next(), ++pos)
239f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        sum = LogPlus(sum, aiter->Value().weight);
240f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
241f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    return sum;
242f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
243f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
244f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  template <class F>
245f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void Init(const F &fst, bool copy = false) {
246f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (copy)
247f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return;
248f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    vector<double> &weights = *data_->Weights();
249f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    vector<ssize_t> &weight_positions = *data_->WeightPositions();
250f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (!weights.empty() || arc_limit_ < arc_period_) {
251f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      FSTERROR() << "FastLogAccumulator: initialization error.";
252f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      error_ = true;
253f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return;
254f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
255f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    weight_positions.reserve(CountStates(fst));
256f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
257f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    ssize_t weight_position = 0;
258f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    for(StateIterator<F> siter(fst); !siter.Done(); siter.Next()) {
259f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      StateId s = siter.Value();
260f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      if (fst.NumArcs(s) >= arc_limit_) {
261dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin        double sum = FloatLimits<double>::PosInfinity();
262f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        weight_positions.push_back(weight_position);
263f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        weights.push_back(sum);
264f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        ++weight_position;
265f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        ssize_t narcs = 0;
266f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        for(ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) {
267f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          const A &arc = aiter.Value();
268f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          sum = LogPlus(sum, arc.weight);
269f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          // Stores cumulative weight distribution per arc_period_.
270f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          if (++narcs % arc_period_ == 0) {
271f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson            weights.push_back(sum);
272f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson            ++weight_position;
273f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          }
274f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        }
275f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      } else {
276f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        weight_positions.push_back(-1);
277f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      }
278f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
279f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
280f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
281f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  bool Error() const { return error_; }
282f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
283f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private:
284f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  double LogPosExp(double x) {
285dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin    return x == FloatLimits<double>::PosInfinity() ?
286f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        0.0 : log(1.0F + exp(-x));
287f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
288f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
289f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  double LogMinusExp(double x) {
290dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin    return x == FloatLimits<double>::PosInfinity() ?
291f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        0.0 : log(1.0F - exp(-x));
292f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
293f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
294f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight LogPlus(Weight w, Weight v) {
295f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    double f1 = to_log_weight_(w).Value();
296f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    double f2 = to_log_weight_(v).Value();
297f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (f1 > f2)
298f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return to_weight_(f2 - LogPosExp(f1 - f2));
299f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    else
300f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return to_weight_(f1 - LogPosExp(f2 - f1));
301f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
302f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
303f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  double LogPlus(double f1, Weight v) {
304f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    double f2 = to_log_weight_(v).Value();
305dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin    if (f1 == FloatLimits<double>::PosInfinity())
306f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return f2;
307f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    else if (f1 > f2)
308f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return f2 - LogPosExp(f1 - f2);
309f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    else
310f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return f1 - LogPosExp(f2 - f1);
311f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
312f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
313f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight LogMinus(double f1, double f2) {
314f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (f1 >= f2) {
315f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      FSTERROR() << "FastLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1
316f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson                 << " and f2 = " << f2;
317f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      error_ = true;
318f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return Weight::NoWeight();
319f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
320dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin    if (f2 == FloatLimits<double>::PosInfinity())
321f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return to_weight_(f1);
322f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    else
323f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return to_weight_(f1 - LogMinusExp(f2 - f1));
324f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
325f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
326f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  WeightConvert<Weight, Log64Weight> to_log_weight_;
327f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  WeightConvert<Log64Weight, Weight> to_weight_;
328f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
329f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  ssize_t arc_limit_;     // Minimum # of arcs to pre-compute state
330f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  ssize_t arc_period_;    // Save cumulative weights per 'arc_period_'.
331f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  bool init_;             // Cumulative weights initialized?
332f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  FastLogAccumulatorData *data_;
333f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  double *state_weights_;
334f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  bool error_;
335f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
336f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void operator=(const FastLogAccumulator<A> &);   // Disallow
337f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson};
338f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
339f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
340f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Stores shareable data for cache log accumulator copies.
341f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// All copies share the same cache.
342f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class A>
343f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass CacheLogAccumulatorData {
344f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public:
345f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef A Arc;
346f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename A::StateId StateId;
347f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename A::Weight Weight;
348f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
349f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  CacheLogAccumulatorData(bool gc, size_t gc_limit)
350f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {}
351f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
352f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  ~CacheLogAccumulatorData() {
353f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    for(typename unordered_map<StateId, CacheState>::iterator it = cache_.begin();
354f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        it != cache_.end();
355f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        ++it)
356f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      delete it->second.weights;
357f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
358f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
359f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; }
360f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
361f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  vector<double> *GetWeights(StateId s) {
362f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    typename unordered_map<StateId, CacheState>::iterator it = cache_.find(s);
363f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (it != cache_.end()) {
364f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      it->second.recent = true;
365f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return it->second.weights;
366f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    } else {
367f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return 0;
368f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
369f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
370f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
371f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void AddWeights(StateId s, vector<double> *weights) {
372f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (cache_gc_ && cache_size_ >= cache_limit_)
373f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      GC(false);
374f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    cache_.insert(make_pair(s, CacheState(weights, true)));
375f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (cache_gc_)
376f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      cache_size_ += weights->capacity() * sizeof(double);
377f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
378f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
379f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  int RefCount() const { return ref_count_.count(); }
380f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  int IncrRefCount() { return ref_count_.Incr(); }
381f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  int DecrRefCount() { return ref_count_.Decr(); }
382f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
383f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private:
384f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  // Cached information for a given state.
385f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  struct CacheState {
386f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    vector<double>* weights;  // Accumulated weights for this state.
387f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    bool recent;              // Has this state been accessed since last GC?
388f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
389f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    CacheState(vector<double> *w, bool r) : weights(w), recent(r) {}
390f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  };
391f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
392f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  // Garbage collect: Delete from cache states that have not been
393f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  // accessed since the last GC ('free_recent = false') until
394f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  // 'cache_size_' is 2/3 of 'cache_limit_'. If it does not free enough
395f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  // memory, start deleting recently accessed states.
396f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void GC(bool free_recent) {
397f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    size_t cache_target = (2 * cache_limit_)/3 + 1;
398f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    typename unordered_map<StateId, CacheState>::iterator it = cache_.begin();
399f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    while (it != cache_.end() && cache_size_ > cache_target) {
400f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      CacheState &cs = it->second;
401f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      if (free_recent || !cs.recent) {
402f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        cache_size_ -= cs.weights->capacity() * sizeof(double);
403f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        delete cs.weights;
404f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        cache_.erase(it++);
405f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      } else {
406f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        cs.recent = false;
407f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        ++it;
408f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      }
409f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
410f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (!free_recent && cache_size_ > cache_target)
411f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      GC(true);
412f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
413f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
414f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  unordered_map<StateId, CacheState> cache_;  // Cache
415f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  bool cache_gc_;                        // Enable garbage collection
416f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  size_t cache_limit_;                   // # of bytes cached
417f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  size_t cache_size_;                    // # of bytes allowed before GC
418f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  RefCounter ref_count_;
419f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
420f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  DISALLOW_COPY_AND_ASSIGN(CacheLogAccumulatorData);
421f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson};
422f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
423f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// This class accumulates arc weights using the log semiring Plus()
424f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//  has a WeightConvert specialization to and from log64 weights.  It
425f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//  is similar to the FastLogAccumator. However here, the accumulated
426f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//  weights are pre-computed and stored only for the states that are
427f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//  visited. The member function Init(fst) has to be called to setup
428f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//  this accumulator.
429f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class A>
430f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass CacheLogAccumulator {
431f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public:
432f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef A Arc;
433f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename A::StateId StateId;
434f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename A::Weight Weight;
435f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
436f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false,
437f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson                               size_t gc_limit = 10 * 1024 * 1024)
438f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      : arc_limit_(arc_limit), fst_(0), data_(
439f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          new CacheLogAccumulatorData<A>(gc, gc_limit)), s_(kNoStateId),
440f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        error_(false) {}
441f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
442f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  CacheLogAccumulator(const CacheLogAccumulator<A> &acc)
443f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      : arc_limit_(acc.arc_limit_), fst_(acc.fst_ ? acc.fst_->Copy() : 0),
444f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        data_(acc.data_), s_(kNoStateId), error_(acc.error_) {
445f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    data_->IncrRefCount();
446f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
447f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
448f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  ~CacheLogAccumulator() {
449f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (fst_)
450f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      delete fst_;
451f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (!data_->DecrRefCount())
452f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      delete data_;
453f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
454f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
455f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  // Arg 'arc_limit' specifies minimum # of arcs to pre-compute state.
456f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void Init(const Fst<A> &fst, bool copy = false) {
457f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (copy) {
458f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      delete fst_;
459f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    } else if (fst_) {
460f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      FSTERROR() << "CacheLogAccumulator: initialization error.";
461f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      error_ = true;
462f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return;
463f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
464f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    fst_ = fst.Copy();
465f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
466f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
467f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void SetState(StateId s, int depth = 0) {
468f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (s == s_)
469f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return;
470f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    s_ = s;
471f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
472f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (data_->CacheDisabled() || error_) {
473f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      weights_ = 0;
474f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return;
475f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
476f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
477f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (!fst_) {
478f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      FSTERROR() << "CacheLogAccumulator::SetState: incorrectly initialized.";
479f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      error_ = true;
480f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      weights_ = 0;
481f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return;
482f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
483f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
484f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    weights_ = data_->GetWeights(s);
485f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if ((weights_ == 0) && (fst_->NumArcs(s) >= arc_limit_)) {
486f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      weights_ = new vector<double>;
487f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      weights_->reserve(fst_->NumArcs(s) + 1);
488dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin      weights_->push_back(FloatLimits<double>::PosInfinity());
489f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      data_->AddWeights(s, weights_);
490f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
491f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
492f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
493f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight Sum(Weight w, Weight v) {
494f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    return LogPlus(w, v);
495f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
496f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
497f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  template <class Iterator>
498f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight Sum(Weight w, Iterator *aiter, ssize_t begin,
499f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson             ssize_t end) {
500f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (weights_ == 0) {
501f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      Weight sum = w;
502f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      aiter->Seek(begin);
503f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos)
504f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        sum = LogPlus(sum, aiter->Value().weight);
505f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return sum;
506f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    } else {
507f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      if (weights_->size() <= end)
508f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        for (aiter->Seek(weights_->size() - 1);
509f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson             weights_->size() <= end;
510f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson             aiter->Next())
511f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          weights_->push_back(LogPlus(weights_->back(),
512f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson                                      aiter->Value().weight));
513f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return LogPlus(w, LogMinus((*weights_)[end], (*weights_)[begin]));
514f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
515f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
516f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
517f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  template <class Iterator>
518f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  size_t LowerBound(double w, Iterator *aiter) {
519f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (weights_ != 0) {
520f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return lower_bound(weights_->begin() + 1,
521f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson                         weights_->end(),
522f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson                         w,
523f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson                         std::greater<double>())
524f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          - weights_->begin() - 1;
525f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    } else {
526f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      size_t n = 0;
527dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin      double x =  FloatLimits<double>::PosInfinity();
528f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      for(aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) {
529f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        x = LogPlus(x, aiter->Value().weight);
530f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        if (x < w) break;
531f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      }
532f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return n;
533f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
534f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
535f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
536f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  bool Error() const { return error_; }
537f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
538f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private:
539f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  double LogPosExp(double x) {
540dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin    return x == FloatLimits<double>::PosInfinity() ?
541f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        0.0 : log(1.0F + exp(-x));
542f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
543f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
544f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  double LogMinusExp(double x) {
545dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin    return x == FloatLimits<double>::PosInfinity() ?
546f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        0.0 : log(1.0F - exp(-x));
547f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
548f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
549f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight LogPlus(Weight w, Weight v) {
550f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    double f1 = to_log_weight_(w).Value();
551f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    double f2 = to_log_weight_(v).Value();
552f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (f1 > f2)
553f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return to_weight_(f2 - LogPosExp(f1 - f2));
554f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    else
555f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return to_weight_(f1 - LogPosExp(f2 - f1));
556f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
557f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
558f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  double LogPlus(double f1, Weight v) {
559f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    double f2 = to_log_weight_(v).Value();
560dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin    if (f1 == FloatLimits<double>::PosInfinity())
561f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return f2;
562f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    else if (f1 > f2)
563f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return f2 - LogPosExp(f1 - f2);
564f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    else
565f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return f1 - LogPosExp(f2 - f1);
566f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
567f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
568f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight LogMinus(double f1, double f2) {
569f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (f1 >= f2) {
570f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      FSTERROR() << "CacheLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1
571f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson                 << " and f2 = " << f2;
572f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      error_ = true;
573f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return Weight::NoWeight();
574f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
575dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin    if (f2 == FloatLimits<double>::PosInfinity())
576f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return to_weight_(f1);
577f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    else
578f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return to_weight_(f1 - LogMinusExp(f2 - f1));
579f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
580f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
581f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  WeightConvert<Weight, Log64Weight> to_log_weight_;
582f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  WeightConvert<Log64Weight, Weight> to_weight_;
583f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
584f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  ssize_t arc_limit_;                    // Minimum # of arcs to cache a state
585f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  vector<double> *weights_;              // Accumulated weights for cur. state
586f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  const Fst<A>* fst_;                    // Input fst
587f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  CacheLogAccumulatorData<A> *data_;     // Cache data
588f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  StateId s_;                            // Current state
589f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  bool error_;
590f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
591f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void operator=(const CacheLogAccumulator<A> &);   // Disallow
592f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson};
593f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
594f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
595f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Stores shareable data for replace accumulator copies.
596f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class Accumulator, class T>
597f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass ReplaceAccumulatorData {
598f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public:
599f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename Accumulator::Arc Arc;
600f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename Arc::StateId StateId;
601f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename Arc::Label Label;
602f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef T StateTable;
603f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename T::StateTuple StateTuple;
604f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
605f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  ReplaceAccumulatorData() : state_table_(0) {}
606f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
607f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  ReplaceAccumulatorData(const vector<Accumulator*> &accumulators)
608f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      : state_table_(0), accumulators_(accumulators) {}
609f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
610f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  ~ReplaceAccumulatorData() {
611f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    for (size_t i = 0; i < fst_array_.size(); ++i)
612f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      delete fst_array_[i];
613f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    for (size_t i = 0; i < accumulators_.size(); ++i)
614f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      delete accumulators_[i];
615f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
616f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
617f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples,
618f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson       const StateTable *state_table) {
619f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    state_table_ = state_table;
620f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    accumulators_.resize(fst_tuples.size());
621f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    for (size_t i = 0; i < accumulators_.size(); ++i) {
622f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      if (!accumulators_[i])
623f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        accumulators_[i] = new Accumulator;
624f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      accumulators_[i]->Init(*(fst_tuples[i].second));
625f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      fst_array_.push_back(fst_tuples[i].second->Copy());
626f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
627f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
628f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
629f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  const StateTuple &GetTuple(StateId s) const {
630f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    return state_table_->Tuple(s);
631f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
632f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
633f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Accumulator *GetAccumulator(size_t i) { return accumulators_[i]; }
634f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
635f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i]; }
636f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
637f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  int RefCount() const { return ref_count_.count(); }
638f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  int IncrRefCount() { return ref_count_.Incr(); }
639f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  int DecrRefCount() { return ref_count_.Decr(); }
640f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
641f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private:
642f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  const T * state_table_;
643f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  vector<Accumulator*> accumulators_;
644f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  vector<const Fst<Arc>*> fst_array_;
645f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  RefCounter ref_count_;
646f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
647f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  DISALLOW_COPY_AND_ASSIGN(ReplaceAccumulatorData);
648f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson};
649f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
650f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// This class accumulates weights in a ReplaceFst.  The 'Init' method
651f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// takes as input the argument used to build the ReplaceFst and the
652f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// ReplaceFst state table. It uses accumulators of type 'Accumulator'
653f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// in the underlying FSTs.
654f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class Accumulator,
655f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          class T = DefaultReplaceStateTable<typename Accumulator::Arc> >
656f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass ReplaceAccumulator {
657f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public:
658f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename Accumulator::Arc Arc;
659f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename Arc::StateId StateId;
660f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename Arc::Label Label;
661f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename Arc::Weight Weight;
662f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef T StateTable;
663f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typedef typename T::StateTuple StateTuple;
664f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
665f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  ReplaceAccumulator()
666f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      : init_(false), data_(new ReplaceAccumulatorData<Accumulator, T>()),
667f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        error_(false) {}
668f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
669f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  ReplaceAccumulator(const vector<Accumulator*> &accumulators)
670f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      : init_(false),
671f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        data_(new ReplaceAccumulatorData<Accumulator, T>(accumulators)),
672f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        error_(false) {}
673f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
674f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  ReplaceAccumulator(const ReplaceAccumulator<Accumulator, T> &acc)
675f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      : init_(acc.init_), data_(acc.data_), error_(acc.error_) {
676f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (!init_)
677f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      FSTERROR() << "ReplaceAccumulator: can't copy unintialized accumulator";
678f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    data_->IncrRefCount();
679f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
680f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
681f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  ~ReplaceAccumulator() {
682f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson     if (!data_->DecrRefCount())
683f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      delete data_;
684f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
685f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
686f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  // Does not take ownership of the state table, the state table
687f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  // is own by the ReplaceFst
688f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples,
689f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson            const StateTable *state_table) {
690f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    init_ = true;
691f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    data_->Init(fst_tuples, state_table);
692f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
693f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
694f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void SetState(StateId s) {
695f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (!init_) {
696f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      FSTERROR() << "ReplaceAccumulator::SetState: incorrectly initialized.";
697f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      error_ = true;
698f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return;
699f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
700f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    StateTuple tuple = data_->GetTuple(s);
701f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    fst_id_ = tuple.fst_id - 1;  // Replace FST ID is 1-based
702f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state);
703f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if ((tuple.prefix_id != 0) &&
704f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) {
705f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      offset_ = 1;
706f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state);
707f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    } else {
708f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      offset_ = 0;
709f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      offset_weight_ = Weight::Zero();
710f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
711f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
712f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
713f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight Sum(Weight w, Weight v) {
714f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (error_) return Weight::NoWeight();
715f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    return data_->GetAccumulator(fst_id_)->Sum(w, v);
716f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
717f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
718f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  template <class ArcIterator>
719f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin,
720f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson             ssize_t end) {
721f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (error_) return Weight::NoWeight();
722f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    Weight sum = begin == end ? Weight::Zero()
723f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        : data_->GetAccumulator(fst_id_)->Sum(
724f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson            w, aiter, begin ? begin - offset_ : 0, end - offset_);
725f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (begin == 0 && end != 0 && offset_ > 0)
726f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      sum = Sum(offset_weight_, sum);
727f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    return sum;
728f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
729f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
730f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  bool Error() const { return error_; }
731f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
732f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private:
733f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  bool init_;
734f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  ReplaceAccumulatorData<Accumulator, T> *data_;
735f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Label fst_id_;
736f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  size_t offset_;
737f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Weight offset_weight_;
738f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  bool error_;
739f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
740f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  void operator=(const ReplaceAccumulator<Accumulator, T> &);   // Disallow
741f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson};
742f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
743f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson}  // namespace fst
744f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
745f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#endif  // FST_LIB_ACCUMULATOR_H__
746