1f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// push.h 2f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 3f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Licensed under the Apache License, Version 2.0 (the "License"); 4f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// you may not use this file except in compliance with the License. 5f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// You may obtain a copy of the License at 6f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// 7f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// http://www.apache.org/licenses/LICENSE-2.0 8f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// 9f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Unless required by applicable law or agreed to in writing, software 10f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// distributed under the License is distributed on an "AS IS" BASIS, 11f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// See the License for the specific language governing permissions and 13f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// limitations under the License. 14f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// 15f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Copyright 2005-2010 Google, Inc. 16f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Author: allauzen@google.com (Cyril Allauzen) 17f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// 18f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// \file 19f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Class to reweight/push an FST. 20f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 21f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#ifndef FST_LIB_PUSH_H__ 22f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#define FST_LIB_PUSH_H__ 23f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 24f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <vector> 25f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonusing std::vector; 26f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 27f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/factor-weight.h> 28f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/fst.h> 29f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/arc-map.h> 30f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/reweight.h> 31f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/shortest-distance.h> 32f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 33f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 34f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonnamespace fst { 35f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 36f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Private helper functions for Push 37f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonnamespace internal { 38f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 39f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Compute the total weight (sum of the weights of all accepting paths) from 40f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// the output of ShortestDistance. 'distance' is the shortest distance from the 41f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// initial state when 'reverse == false' and to the final states when 42f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// 'reverse == true'. 43f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class Arc> 44f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontypename Arc::Weight ComputeTotalWeight( 45f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson const Fst<Arc> &fst, 46f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson const vector<typename Arc::Weight> &distance, 47f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson bool reverse) { 48f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (reverse) 49f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return fst.Start() < distance.size() ? 50f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson distance[fst.Start()] : Arc::Weight::Zero(); 51f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 52f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typename Arc::Weight sum = Arc::Weight::Zero(); 53f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for (typename Arc::StateId s = 0; s < distance.size(); ++s) 54f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson sum = Plus(sum, Times(distance[s], fst.Final(s))); 55f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return sum; 56f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson} 57f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 58f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Divide the weight of every accepting path by 'w'. The weight 'w' is 59f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// divided at the final states if 'at_final == true' and at the 60f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// initial state otherwise. 61f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class Arc> 62f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonvoid RemoveWeight(MutableFst<Arc> *fst, typename Arc::Weight w, bool at_final) { 63f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if ((w == Arc::Weight::One()) || (w == Arc::Weight::Zero())) 64f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson return; 65f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 66f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (at_final) { 67f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // Remove 'w' from the final states 68f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for (StateIterator< MutableFst<Arc> > sit(*fst); 69f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson !sit.Done(); 70f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson sit.Next()) 71f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson fst->SetFinal(sit.Value(), 72f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Divide(fst->Final(sit.Value()), w, DIVIDE_RIGHT)); 73f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } else { // at_final == false 74f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson // Remove 'w' from the initial state 75f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typename Arc::StateId start = fst->Start(); 76f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson for (MutableArcIterator<MutableFst<Arc> > ait(fst, start); 77f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson !ait.Done(); 78f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ait.Next()) { 79f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Arc arc = ait.Value(); 80f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson arc.weight = Divide(arc.weight, w, DIVIDE_LEFT); 81f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ait.SetValue(arc); 82f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 83f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson fst->SetFinal(start, Divide(fst->Final(start), w, DIVIDE_LEFT)); 84f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 85f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson} 86f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson} // namespace internal 87f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 88f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Pushes the weights in FST in the direction defined by TYPE. If 89f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// pushing towards the initial state, the sum of the weight of the 90f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// outgoing transitions and final weight at a non-initial state is 91f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// equal to One() in the resulting machine. If pushing towards the 92f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// final state, the same property holds on the reverse machine. 93f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// 94f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Weight needs to be left distributive when pushing towards the 95f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// initial state and right distributive when pushing towards the final 96f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// states. 97f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class Arc> 98f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonvoid Push(MutableFst<Arc> *fst, 99f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ReweightType type, 100f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson float delta = kDelta, 101f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson bool remove_total_weight = false) { 102f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson vector<typename Arc::Weight> distance; 103f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL, delta); 104f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typename Arc::Weight total_weight = Arc::Weight::One(); 105f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (remove_total_weight) 106f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson total_weight = internal::ComputeTotalWeight(*fst, distance, 107f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson type == REWEIGHT_TO_INITIAL); 108f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Reweight(fst, distance, type); 109f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (remove_total_weight) 110f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson internal::RemoveWeight(fst, total_weight, type == REWEIGHT_TO_FINAL); 111f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson} 112f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 113f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonconst uint32 kPushWeights = 0x0001; 114f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonconst uint32 kPushLabels = 0x0002; 115f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonconst uint32 kPushRemoveTotalWeight = 0x0004; 116f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonconst uint32 kPushRemoveCommonAffix = 0x0008; 117f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 118f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// OFST obtained from IFST by pushing weights and/or labels according 119f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// to PTYPE in the direction defined by RTYPE. Weight needs to be 120f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// left distributive when pushing weights towards the initial state 121f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// and right distributive when pushing weights towards the final 122f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// states. 123f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class Arc, ReweightType rtype> 124f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonvoid Push(const Fst<Arc> &ifst, 125f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson MutableFst<Arc> *ofst, 126f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson uint32 ptype, 127f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson float delta = kDelta) { 128f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 129f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) { 130f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson *ofst = ifst; 131f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight); 132f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } else if (ptype & kPushLabels) { 133f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson const StringType stype = rtype == REWEIGHT_TO_INITIAL 134f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ? STRING_LEFT 135f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson : STRING_RIGHT; 136f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson vector<typename GallicArc<Arc, stype>::Weight> gdistance; 137f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson VectorFst<GallicArc<Arc, stype> > gfst; 138f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ArcMap(ifst, &gfst, ToGallicMapper<Arc, stype>()); 139f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (ptype & kPushWeights ) { 140f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); 141f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } else { 142f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ArcMapFst<Arc, Arc, RmWeightMapper<Arc> > 143f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson uwfst(ifst, RmWeightMapper<Arc>()); 144f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ArcMapFst<Arc, GallicArc<Arc, stype>, ToGallicMapper<Arc, stype> > 145f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson guwfst(uwfst, ToGallicMapper<Arc, stype>()); 146f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); 147f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 148f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typename GallicArc<Arc, stype>::Weight total_weight = 149f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson GallicArc<Arc, stype>::Weight::One(); 150f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) { 151f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson total_weight = internal::ComputeTotalWeight( 152f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson gfst, gdistance, rtype == REWEIGHT_TO_INITIAL); 153f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson total_weight = typename GallicArc<Arc, stype>::Weight( 154f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ptype & kPushRemoveCommonAffix ? total_weight.Value1() 155f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson : StringWeight<typename Arc::Label, stype>::One(), 156f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ptype & kPushRemoveTotalWeight ? total_weight.Value2() 157f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson : Arc::Weight::One()); 158f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 159f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson Reweight(&gfst, gdistance, rtype); 160f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) 161f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson internal::RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL); 162f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson FactorWeightFst< GallicArc<Arc, stype>, GallicFactor<typename Arc::Label, 163f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson typename Arc::Weight, stype> > fwfst(gfst); 164f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ArcMap(fwfst, ofst, FromGallicMapper<Arc, stype>()); 165f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson ofst->SetOutputSymbols(ifst.OutputSymbols()); 166f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } else { 167f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson LOG(WARNING) << "Push: pushing type is set to 0: " 168f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson << "pushing neither labels nor weights."; 169f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson *ofst = ifst; 170f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson } 171f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson} 172f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 173f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson} // namespace fst 174f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson 175f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#endif /* FST_LIB_PUSH_H_ */ 176