1f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// accumulator.h 2f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 3f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Licensed under the Apache License, Version 2.0 (the "License"); 4f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// you may not use this file except in compliance with the License. 5f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// You may obtain a copy of the License at 6f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// 7f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// http://www.apache.org/licenses/LICENSE-2.0 8f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// 9f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Unless required by applicable law or agreed to in writing, software 10f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// distributed under the License is distributed on an "AS IS" BASIS, 11f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// See the License for the specific language governing permissions and 13f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// limitations under the License. 14f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// 15f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Copyright 2005-2010 Google, Inc. 16f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Author: riley@google.com (Michael Riley) 17f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// 18f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// \file 19f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Classes to accumulate arc weights. Useful for weight lookahead. 20f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 21f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#ifndef FST_LIB_ACCUMULATOR_H__ 22f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#define FST_LIB_ACCUMULATOR_H__ 23f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 24f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <algorithm> 25f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <functional> 263da1eb108d36da35333b2d655202791af854996bPrzemyslaw Szczepaniak#include <tr1/unordered_map> 27f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonusing std::tr1::unordered_map; 28f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonusing std::tr1::unordered_multimap; 29f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <vector> 30f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonusing std::vector; 31f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 32f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/arcfilter.h> 33f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/arcsort.h> 34f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/dfs-visit.h> 35f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/expanded-fst.h> 36f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/replace.h> 37f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 38f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonnamespace fst { 39f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 40f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// This class accumulates arc weights using the semiring Plus(). 41f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class A> 42f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass DefaultAccumulator { 43f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public: 44f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef A Arc; 45f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename A::StateId StateId; 46f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename A::Weight Weight; 47f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 48f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson DefaultAccumulator() {} 49f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 50f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson DefaultAccumulator(const DefaultAccumulator<A> &acc) {} 51f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 52f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void Init(const Fst<A>& fst, bool copy = false) {} 53f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 54f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void SetState(StateId) {} 55f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 56f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight Sum(Weight w, Weight v) { 57f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return Plus(w, v); 58f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 59f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 60f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson template <class ArcIterator> 61f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, 62f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ssize_t end) { 63f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight sum = w; 64f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson aiter->Seek(begin); 65f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos) 66f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson sum = Plus(sum, aiter->Value().weight); 67f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return sum; 68f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 69f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 70f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson bool Error() const { return false; } 71f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 72f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private: 73f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void operator=(const DefaultAccumulator<A> &); // Disallow 74f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson}; 75f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 76f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 77f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// This class accumulates arc weights using the log semiring Plus() 78f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// assuming an arc weight has a WeightConvert specialization to 79f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// and from log64 weights. 80f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class A> 81f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass LogAccumulator { 82f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public: 83f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef A Arc; 84f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename A::StateId StateId; 85f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename A::Weight Weight; 86f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 87f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson LogAccumulator() {} 88f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 89f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson LogAccumulator(const LogAccumulator<A> &acc) {} 90f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 91f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void Init(const Fst<A>& fst, bool copy = false) {} 92f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 93f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void SetState(StateId) {} 94f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 95f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight Sum(Weight w, Weight v) { 96f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return LogPlus(w, v); 97f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 98f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 99f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson template <class ArcIterator> 100f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, 101f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ssize_t end) { 102f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight sum = w; 103f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson aiter->Seek(begin); 104f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos) 105f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson sum = LogPlus(sum, aiter->Value().weight); 106f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return sum; 107f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 108f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 109f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson bool Error() const { return false; } 110f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 111f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private: 112f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double LogPosExp(double x) { return log(1.0F + exp(-x)); } 113f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 114f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight LogPlus(Weight w, Weight v) { 115f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double f1 = to_log_weight_(w).Value(); 116f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double f2 = to_log_weight_(v).Value(); 117f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (f1 > f2) 118f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return to_weight_(f2 - LogPosExp(f1 - f2)); 119f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson else 120f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return to_weight_(f1 - LogPosExp(f2 - f1)); 121f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 122f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 123f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson WeightConvert<Weight, Log64Weight> to_log_weight_; 124f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson WeightConvert<Log64Weight, Weight> to_weight_; 125f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 126f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void operator=(const LogAccumulator<A> &); // Disallow 127f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson}; 128f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 129f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 130f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Stores shareable data for fast log accumulator copies. 131f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass FastLogAccumulatorData { 132f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public: 133f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson FastLogAccumulatorData() {} 134f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 135f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson vector<double> *Weights() { return &weights_; } 136f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson vector<ssize_t> *WeightPositions() { return &weight_positions_; } 137f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double *WeightEnd() { return &(weights_[weights_.size() - 1]); }; 138f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson int RefCount() const { return ref_count_.count(); } 139f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson int IncrRefCount() { return ref_count_.Incr(); } 140f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson int DecrRefCount() { return ref_count_.Decr(); } 141f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 142f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private: 143f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // Cummulative weight per state for all states s.t. # of arcs > 144f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // arc_limit_ with arcs in order. Special first element per state 145f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // being Log64Weight::Zero(); 146f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson vector<double> weights_; 147f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // Maps from state to corresponding beginning weight position in 148f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // weights_. Position -1 means no pre-computed weights for that 149f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // state. 150f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson vector<ssize_t> weight_positions_; 151f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson RefCounter ref_count_; // Reference count. 152f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 153f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson DISALLOW_COPY_AND_ASSIGN(FastLogAccumulatorData); 154f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson}; 155f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 156f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 157f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// This class accumulates arc weights using the log semiring Plus() 158f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// assuming an arc weight has a WeightConvert specialization to and 159f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// from log64 weights. The member function Init(fst) has to be called 160f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// to setup pre-computed weight information. 161f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class A> 162f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass FastLogAccumulator { 163f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public: 164f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef A Arc; 165f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename A::StateId StateId; 166f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename A::Weight Weight; 167f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 168f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10) 169f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson : arc_limit_(arc_limit), 170f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson arc_period_(arc_period), 171f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson data_(new FastLogAccumulatorData()), 172f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson error_(false) {} 173f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 174f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson FastLogAccumulator(const FastLogAccumulator<A> &acc) 175f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson : arc_limit_(acc.arc_limit_), 176f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson arc_period_(acc.arc_period_), 177f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson data_(acc.data_), 178f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson error_(acc.error_) { 179f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson data_->IncrRefCount(); 180f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 181f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 182f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ~FastLogAccumulator() { 183f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (!data_->DecrRefCount()) 184f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson delete data_; 185f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 186f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 187f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void SetState(StateId s) { 188f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson vector<double> &weights = *data_->Weights(); 189f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson vector<ssize_t> &weight_positions = *data_->WeightPositions(); 190f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 191f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (weight_positions.size() <= s) { 192f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson FSTERROR() << "FastLogAccumulator::SetState: invalid state id."; 193f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson error_ = true; 194f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return; 195f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 196f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 197f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ssize_t pos = weight_positions[s]; 198f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (pos >= 0) 199f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson state_weights_ = &(weights[pos]); 200f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson else 201f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson state_weights_ = 0; 202f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 203f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 204f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight Sum(Weight w, Weight v) { 205f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return LogPlus(w, v); 206f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 207f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 208f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson template <class ArcIterator> 209f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, 210f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ssize_t end) { 211f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (error_) return Weight::NoWeight(); 212f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight sum = w; 213f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // Finds begin and end of pre-stored weights 214f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ssize_t index_begin = -1, index_end = -1; 215f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ssize_t stored_begin = end, stored_end = end; 216f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (state_weights_ != 0) { 217f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson index_begin = begin > 0 ? (begin - 1)/ arc_period_ + 1 : 0; 218f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson index_end = end / arc_period_; 219f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson stored_begin = index_begin * arc_period_; 220f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson stored_end = index_end * arc_period_; 221f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 222f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // Computes sum before pre-stored weights 223f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (begin < stored_begin) { 224f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ssize_t pos_end = min(stored_begin, end); 225f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson aiter->Seek(begin); 226f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for (ssize_t pos = begin; pos < pos_end; aiter->Next(), ++pos) 227f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson sum = LogPlus(sum, aiter->Value().weight); 228f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 229f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // Computes sum between pre-stored weights 230f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (stored_begin < stored_end) { 231f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson sum = LogPlus(sum, LogMinus(state_weights_[index_end], 232f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson state_weights_[index_begin])); 233f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 234f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // Computes sum after pre-stored weights 235f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (stored_end < end) { 236f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ssize_t pos_start = max(stored_begin, stored_end); 237f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson aiter->Seek(pos_start); 238f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for (ssize_t pos = pos_start; pos < end; aiter->Next(), ++pos) 239f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson sum = LogPlus(sum, aiter->Value().weight); 240f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 241f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return sum; 242f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 243f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 244f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson template <class F> 245f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void Init(const F &fst, bool copy = false) { 246f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (copy) 247f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return; 248f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson vector<double> &weights = *data_->Weights(); 249f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson vector<ssize_t> &weight_positions = *data_->WeightPositions(); 250f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (!weights.empty() || arc_limit_ < arc_period_) { 251f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson FSTERROR() << "FastLogAccumulator: initialization error."; 252f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson error_ = true; 253f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return; 254f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 255f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson weight_positions.reserve(CountStates(fst)); 256f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 257f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ssize_t weight_position = 0; 258f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for(StateIterator<F> siter(fst); !siter.Done(); siter.Next()) { 259f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson StateId s = siter.Value(); 260f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (fst.NumArcs(s) >= arc_limit_) { 261dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin double sum = FloatLimits<double>::PosInfinity(); 262f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson weight_positions.push_back(weight_position); 263f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson weights.push_back(sum); 264f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ++weight_position; 265f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ssize_t narcs = 0; 266f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for(ArcIterator<F> aiter(fst, s); !aiter.Done(); aiter.Next()) { 267f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson const A &arc = aiter.Value(); 268f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson sum = LogPlus(sum, arc.weight); 269f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // Stores cumulative weight distribution per arc_period_. 270f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (++narcs % arc_period_ == 0) { 271f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson weights.push_back(sum); 272f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ++weight_position; 273f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 274f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 275f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } else { 276f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson weight_positions.push_back(-1); 277f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 278f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 279f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 280f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 281f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson bool Error() const { return error_; } 282f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 283f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private: 284f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double LogPosExp(double x) { 285dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin return x == FloatLimits<double>::PosInfinity() ? 286f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 0.0 : log(1.0F + exp(-x)); 287f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 288f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 289f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double LogMinusExp(double x) { 290dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin return x == FloatLimits<double>::PosInfinity() ? 291f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 0.0 : log(1.0F - exp(-x)); 292f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 293f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 294f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight LogPlus(Weight w, Weight v) { 295f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double f1 = to_log_weight_(w).Value(); 296f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double f2 = to_log_weight_(v).Value(); 297f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (f1 > f2) 298f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return to_weight_(f2 - LogPosExp(f1 - f2)); 299f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson else 300f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return to_weight_(f1 - LogPosExp(f2 - f1)); 301f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 302f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 303f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double LogPlus(double f1, Weight v) { 304f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double f2 = to_log_weight_(v).Value(); 305dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin if (f1 == FloatLimits<double>::PosInfinity()) 306f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return f2; 307f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson else if (f1 > f2) 308f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return f2 - LogPosExp(f1 - f2); 309f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson else 310f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return f1 - LogPosExp(f2 - f1); 311f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 312f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 313f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight LogMinus(double f1, double f2) { 314f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (f1 >= f2) { 315f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson FSTERROR() << "FastLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1 316f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson << " and f2 = " << f2; 317f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson error_ = true; 318f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return Weight::NoWeight(); 319f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 320dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin if (f2 == FloatLimits<double>::PosInfinity()) 321f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return to_weight_(f1); 322f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson else 323f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return to_weight_(f1 - LogMinusExp(f2 - f1)); 324f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 325f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 326f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson WeightConvert<Weight, Log64Weight> to_log_weight_; 327f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson WeightConvert<Log64Weight, Weight> to_weight_; 328f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 329f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ssize_t arc_limit_; // Minimum # of arcs to pre-compute state 330f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ssize_t arc_period_; // Save cumulative weights per 'arc_period_'. 331f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson bool init_; // Cumulative weights initialized? 332f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson FastLogAccumulatorData *data_; 333f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double *state_weights_; 334f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson bool error_; 335f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 336f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void operator=(const FastLogAccumulator<A> &); // Disallow 337f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson}; 338f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 339f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 340f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Stores shareable data for cache log accumulator copies. 341f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// All copies share the same cache. 342f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class A> 343f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass CacheLogAccumulatorData { 344f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public: 345f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef A Arc; 346f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename A::StateId StateId; 347f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename A::Weight Weight; 348f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 349f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson CacheLogAccumulatorData(bool gc, size_t gc_limit) 350f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {} 351f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 352f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ~CacheLogAccumulatorData() { 353f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for(typename unordered_map<StateId, CacheState>::iterator it = cache_.begin(); 354f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson it != cache_.end(); 355f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ++it) 356f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson delete it->second.weights; 357f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 358f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 359f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; } 360f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 361f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson vector<double> *GetWeights(StateId s) { 362f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typename unordered_map<StateId, CacheState>::iterator it = cache_.find(s); 363f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (it != cache_.end()) { 364f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson it->second.recent = true; 365f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return it->second.weights; 366f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } else { 367f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return 0; 368f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 369f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 370f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 371f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void AddWeights(StateId s, vector<double> *weights) { 372f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (cache_gc_ && cache_size_ >= cache_limit_) 373f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson GC(false); 374f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson cache_.insert(make_pair(s, CacheState(weights, true))); 375f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (cache_gc_) 376f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson cache_size_ += weights->capacity() * sizeof(double); 377f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 378f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 379f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson int RefCount() const { return ref_count_.count(); } 380f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson int IncrRefCount() { return ref_count_.Incr(); } 381f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson int DecrRefCount() { return ref_count_.Decr(); } 382f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 383f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private: 384f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // Cached information for a given state. 385f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson struct CacheState { 386f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson vector<double>* weights; // Accumulated weights for this state. 387f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson bool recent; // Has this state been accessed since last GC? 388f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 389f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson CacheState(vector<double> *w, bool r) : weights(w), recent(r) {} 390f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson }; 391f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 392f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // Garbage collect: Delete from cache states that have not been 393f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // accessed since the last GC ('free_recent = false') until 394f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // 'cache_size_' is 2/3 of 'cache_limit_'. If it does not free enough 395f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // memory, start deleting recently accessed states. 396f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void GC(bool free_recent) { 397f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson size_t cache_target = (2 * cache_limit_)/3 + 1; 398f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typename unordered_map<StateId, CacheState>::iterator it = cache_.begin(); 399f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson while (it != cache_.end() && cache_size_ > cache_target) { 400f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson CacheState &cs = it->second; 401f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (free_recent || !cs.recent) { 402f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson cache_size_ -= cs.weights->capacity() * sizeof(double); 403f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson delete cs.weights; 404f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson cache_.erase(it++); 405f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } else { 406f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson cs.recent = false; 407f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ++it; 408f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 409f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 410f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (!free_recent && cache_size_ > cache_target) 411f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson GC(true); 412f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 413f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 414f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson unordered_map<StateId, CacheState> cache_; // Cache 415f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson bool cache_gc_; // Enable garbage collection 416f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson size_t cache_limit_; // # of bytes cached 417f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson size_t cache_size_; // # of bytes allowed before GC 418f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson RefCounter ref_count_; 419f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 420f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson DISALLOW_COPY_AND_ASSIGN(CacheLogAccumulatorData); 421f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson}; 422f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 423f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// This class accumulates arc weights using the log semiring Plus() 424f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// has a WeightConvert specialization to and from log64 weights. It 425f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// is similar to the FastLogAccumator. However here, the accumulated 426f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// weights are pre-computed and stored only for the states that are 427f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// visited. The member function Init(fst) has to be called to setup 428f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// this accumulator. 429f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class A> 430f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass CacheLogAccumulator { 431f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public: 432f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef A Arc; 433f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename A::StateId StateId; 434f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename A::Weight Weight; 435f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 436f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false, 437f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson size_t gc_limit = 10 * 1024 * 1024) 438f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson : arc_limit_(arc_limit), fst_(0), data_( 439f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson new CacheLogAccumulatorData<A>(gc, gc_limit)), s_(kNoStateId), 440f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson error_(false) {} 441f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 442f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson CacheLogAccumulator(const CacheLogAccumulator<A> &acc) 443f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson : arc_limit_(acc.arc_limit_), fst_(acc.fst_ ? acc.fst_->Copy() : 0), 444f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson data_(acc.data_), s_(kNoStateId), error_(acc.error_) { 445f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson data_->IncrRefCount(); 446f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 447f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 448f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ~CacheLogAccumulator() { 449f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (fst_) 450f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson delete fst_; 451f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (!data_->DecrRefCount()) 452f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson delete data_; 453f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 454f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 455f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // Arg 'arc_limit' specifies minimum # of arcs to pre-compute state. 456f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void Init(const Fst<A> &fst, bool copy = false) { 457f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (copy) { 458f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson delete fst_; 459f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } else if (fst_) { 460f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson FSTERROR() << "CacheLogAccumulator: initialization error."; 461f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson error_ = true; 462f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return; 463f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 464f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson fst_ = fst.Copy(); 465f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 466f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 467f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void SetState(StateId s, int depth = 0) { 468f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (s == s_) 469f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return; 470f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson s_ = s; 471f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 472f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (data_->CacheDisabled() || error_) { 473f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson weights_ = 0; 474f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return; 475f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 476f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 477f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (!fst_) { 478f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson FSTERROR() << "CacheLogAccumulator::SetState: incorrectly initialized."; 479f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson error_ = true; 480f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson weights_ = 0; 481f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return; 482f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 483f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 484f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson weights_ = data_->GetWeights(s); 485f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if ((weights_ == 0) && (fst_->NumArcs(s) >= arc_limit_)) { 486f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson weights_ = new vector<double>; 487f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson weights_->reserve(fst_->NumArcs(s) + 1); 488dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin weights_->push_back(FloatLimits<double>::PosInfinity()); 489f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson data_->AddWeights(s, weights_); 490f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 491f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 492f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 493f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight Sum(Weight w, Weight v) { 494f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return LogPlus(w, v); 495f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 496f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 497f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson template <class Iterator> 498f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight Sum(Weight w, Iterator *aiter, ssize_t begin, 499f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ssize_t end) { 500f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (weights_ == 0) { 501f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight sum = w; 502f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson aiter->Seek(begin); 503f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for (ssize_t pos = begin; pos < end; aiter->Next(), ++pos) 504f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson sum = LogPlus(sum, aiter->Value().weight); 505f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return sum; 506f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } else { 507f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (weights_->size() <= end) 508f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for (aiter->Seek(weights_->size() - 1); 509f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson weights_->size() <= end; 510f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson aiter->Next()) 511f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson weights_->push_back(LogPlus(weights_->back(), 512f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson aiter->Value().weight)); 513f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return LogPlus(w, LogMinus((*weights_)[end], (*weights_)[begin])); 514f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 515f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 516f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 517f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson template <class Iterator> 518f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson size_t LowerBound(double w, Iterator *aiter) { 519f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (weights_ != 0) { 520f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return lower_bound(weights_->begin() + 1, 521f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson weights_->end(), 522f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson w, 523f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson std::greater<double>()) 524f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson - weights_->begin() - 1; 525f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } else { 526f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson size_t n = 0; 527dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin double x = FloatLimits<double>::PosInfinity(); 528f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for(aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) { 529f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson x = LogPlus(x, aiter->Value().weight); 530f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (x < w) break; 531f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 532f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return n; 533f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 534f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 535f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 536f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson bool Error() const { return error_; } 537f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 538f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private: 539f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double LogPosExp(double x) { 540dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin return x == FloatLimits<double>::PosInfinity() ? 541f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 0.0 : log(1.0F + exp(-x)); 542f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 543f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 544f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double LogMinusExp(double x) { 545dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin return x == FloatLimits<double>::PosInfinity() ? 546f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 0.0 : log(1.0F - exp(-x)); 547f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 548f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 549f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight LogPlus(Weight w, Weight v) { 550f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double f1 = to_log_weight_(w).Value(); 551f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double f2 = to_log_weight_(v).Value(); 552f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (f1 > f2) 553f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return to_weight_(f2 - LogPosExp(f1 - f2)); 554f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson else 555f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return to_weight_(f1 - LogPosExp(f2 - f1)); 556f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 557f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 558f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double LogPlus(double f1, Weight v) { 559f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson double f2 = to_log_weight_(v).Value(); 560dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin if (f1 == FloatLimits<double>::PosInfinity()) 561f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return f2; 562f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson else if (f1 > f2) 563f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return f2 - LogPosExp(f1 - f2); 564f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson else 565f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return f1 - LogPosExp(f2 - f1); 566f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 567f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 568f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight LogMinus(double f1, double f2) { 569f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (f1 >= f2) { 570f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson FSTERROR() << "CacheLogAcumulator::LogMinus: f1 >= f2 with f1 = " << f1 571f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson << " and f2 = " << f2; 572f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson error_ = true; 573f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return Weight::NoWeight(); 574f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 575dfd8b8327b93660601d016cdc6f29f433b45a8d8Alexander Gutkin if (f2 == FloatLimits<double>::PosInfinity()) 576f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return to_weight_(f1); 577f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson else 578f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return to_weight_(f1 - LogMinusExp(f2 - f1)); 579f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 580f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 581f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson WeightConvert<Weight, Log64Weight> to_log_weight_; 582f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson WeightConvert<Log64Weight, Weight> to_weight_; 583f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 584f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ssize_t arc_limit_; // Minimum # of arcs to cache a state 585f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson vector<double> *weights_; // Accumulated weights for cur. state 586f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson const Fst<A>* fst_; // Input fst 587f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson CacheLogAccumulatorData<A> *data_; // Cache data 588f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson StateId s_; // Current state 589f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson bool error_; 590f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 591f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void operator=(const CacheLogAccumulator<A> &); // Disallow 592f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson}; 593f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 594f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 595f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Stores shareable data for replace accumulator copies. 596f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class Accumulator, class T> 597f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass ReplaceAccumulatorData { 598f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public: 599f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename Accumulator::Arc Arc; 600f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename Arc::StateId StateId; 601f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename Arc::Label Label; 602f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef T StateTable; 603f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename T::StateTuple StateTuple; 604f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 605f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ReplaceAccumulatorData() : state_table_(0) {} 606f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 607f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ReplaceAccumulatorData(const vector<Accumulator*> &accumulators) 608f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson : state_table_(0), accumulators_(accumulators) {} 609f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 610f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ~ReplaceAccumulatorData() { 611f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for (size_t i = 0; i < fst_array_.size(); ++i) 612f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson delete fst_array_[i]; 613f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for (size_t i = 0; i < accumulators_.size(); ++i) 614f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson delete accumulators_[i]; 615f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 616f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 617f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples, 618f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson const StateTable *state_table) { 619f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson state_table_ = state_table; 620f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson accumulators_.resize(fst_tuples.size()); 621f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for (size_t i = 0; i < accumulators_.size(); ++i) { 622f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (!accumulators_[i]) 623f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson accumulators_[i] = new Accumulator; 624f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson accumulators_[i]->Init(*(fst_tuples[i].second)); 625f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson fst_array_.push_back(fst_tuples[i].second->Copy()); 626f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 627f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 628f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 629f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson const StateTuple &GetTuple(StateId s) const { 630f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return state_table_->Tuple(s); 631f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 632f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 633f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Accumulator *GetAccumulator(size_t i) { return accumulators_[i]; } 634f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 635f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson const Fst<Arc> *GetFst(size_t i) const { return fst_array_[i]; } 636f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 637f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson int RefCount() const { return ref_count_.count(); } 638f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson int IncrRefCount() { return ref_count_.Incr(); } 639f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson int DecrRefCount() { return ref_count_.Decr(); } 640f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 641f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private: 642f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson const T * state_table_; 643f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson vector<Accumulator*> accumulators_; 644f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson vector<const Fst<Arc>*> fst_array_; 645f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson RefCounter ref_count_; 646f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 647f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson DISALLOW_COPY_AND_ASSIGN(ReplaceAccumulatorData); 648f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson}; 649f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 650f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// This class accumulates weights in a ReplaceFst. The 'Init' method 651f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// takes as input the argument used to build the ReplaceFst and the 652f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// ReplaceFst state table. It uses accumulators of type 'Accumulator' 653f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// in the underlying FSTs. 654f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class Accumulator, 655f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson class T = DefaultReplaceStateTable<typename Accumulator::Arc> > 656f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonclass ReplaceAccumulator { 657f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson public: 658f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename Accumulator::Arc Arc; 659f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename Arc::StateId StateId; 660f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename Arc::Label Label; 661f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename Arc::Weight Weight; 662f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef T StateTable; 663f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typedef typename T::StateTuple StateTuple; 664f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 665f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ReplaceAccumulator() 666f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson : init_(false), data_(new ReplaceAccumulatorData<Accumulator, T>()), 667f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson error_(false) {} 668f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 669f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ReplaceAccumulator(const vector<Accumulator*> &accumulators) 670f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson : init_(false), 671f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson data_(new ReplaceAccumulatorData<Accumulator, T>(accumulators)), 672f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson error_(false) {} 673f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 674f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ReplaceAccumulator(const ReplaceAccumulator<Accumulator, T> &acc) 675f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson : init_(acc.init_), data_(acc.data_), error_(acc.error_) { 676f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (!init_) 677f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson FSTERROR() << "ReplaceAccumulator: can't copy unintialized accumulator"; 678f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson data_->IncrRefCount(); 679f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 680f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 681f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ~ReplaceAccumulator() { 682f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (!data_->DecrRefCount()) 683f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson delete data_; 684f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 685f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 686f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // Does not take ownership of the state table, the state table 687f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // is own by the ReplaceFst 688f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void Init(const vector<pair<Label, const Fst<Arc>*> > &fst_tuples, 689f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson const StateTable *state_table) { 690f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson init_ = true; 691f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson data_->Init(fst_tuples, state_table); 692f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 693f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 694f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void SetState(StateId s) { 695f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (!init_) { 696f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson FSTERROR() << "ReplaceAccumulator::SetState: incorrectly initialized."; 697f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson error_ = true; 698f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return; 699f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 700f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson StateTuple tuple = data_->GetTuple(s); 701f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based 702f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state); 703f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if ((tuple.prefix_id != 0) && 704f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) { 705f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson offset_ = 1; 706f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state); 707f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } else { 708f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson offset_ = 0; 709f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson offset_weight_ = Weight::Zero(); 710f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 711f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 712f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 713f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight Sum(Weight w, Weight v) { 714f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (error_) return Weight::NoWeight(); 715f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return data_->GetAccumulator(fst_id_)->Sum(w, v); 716f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 717f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 718f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson template <class ArcIterator> 719f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight Sum(Weight w, ArcIterator *aiter, ssize_t begin, 720f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ssize_t end) { 721f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (error_) return Weight::NoWeight(); 722f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight sum = begin == end ? Weight::Zero() 723f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson : data_->GetAccumulator(fst_id_)->Sum( 724f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson w, aiter, begin ? begin - offset_ : 0, end - offset_); 725f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (begin == 0 && end != 0 && offset_ > 0) 726f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson sum = Sum(offset_weight_, sum); 727f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return sum; 728f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 729f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 730f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson bool Error() const { return error_; } 731f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 732f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson private: 733f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson bool init_; 734f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ReplaceAccumulatorData<Accumulator, T> *data_; 735f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Label fst_id_; 736f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson size_t offset_; 737f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Weight offset_weight_; 738f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson bool error_; 739f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 740f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson void operator=(const ReplaceAccumulator<Accumulator, T> &); // Disallow 741f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson}; 742f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 743f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson} // namespace fst 744f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 745f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#endif // FST_LIB_ACCUMULATOR_H__ 746