1// push.h 2 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// Copyright 2005-2010 Google, Inc. 16// Author: allauzen@google.com (Cyril Allauzen) 17// 18// \file 19// Class to reweight/push an FST. 20 21#ifndef FST_LIB_PUSH_H__ 22#define FST_LIB_PUSH_H__ 23 24#include <vector> 25using std::vector; 26 27#include <fst/factor-weight.h> 28#include <fst/fst.h> 29#include <fst/arc-map.h> 30#include <fst/reweight.h> 31#include <fst/shortest-distance.h> 32 33 34namespace fst { 35 36// Private helper functions for Push 37namespace internal { 38 39// Compute the total weight (sum of the weights of all accepting paths) from 40// the output of ShortestDistance. 'distance' is the shortest distance from the 41// initial state when 'reverse == false' and to the final states when 42// 'reverse == true'. 43template <class Arc> 44typename Arc::Weight ComputeTotalWeight( 45 const Fst<Arc> &fst, 46 const vector<typename Arc::Weight> &distance, 47 bool reverse) { 48 if (reverse) 49 return fst.Start() < distance.size() ? 50 distance[fst.Start()] : Arc::Weight::Zero(); 51 52 typename Arc::Weight sum = Arc::Weight::Zero(); 53 for (typename Arc::StateId s = 0; s < distance.size(); ++s) 54 sum = Plus(sum, Times(distance[s], fst.Final(s))); 55 return sum; 56} 57 58// Divide the weight of every accepting path by 'w'. The weight 'w' is 59// divided at the final states if 'at_final == true' and at the 60// initial state otherwise. 61template <class Arc> 62void RemoveWeight(MutableFst<Arc> *fst, typename Arc::Weight w, bool at_final) { 63 if ((w == Arc::Weight::One()) || (w == Arc::Weight::Zero())) 64 return; 65 66 if (at_final) { 67 // Remove 'w' from the final states 68 for (StateIterator< MutableFst<Arc> > sit(*fst); 69 !sit.Done(); 70 sit.Next()) 71 fst->SetFinal(sit.Value(), 72 Divide(fst->Final(sit.Value()), w, DIVIDE_RIGHT)); 73 } else { // at_final == false 74 // Remove 'w' from the initial state 75 typename Arc::StateId start = fst->Start(); 76 for (MutableArcIterator<MutableFst<Arc> > ait(fst, start); 77 !ait.Done(); 78 ait.Next()) { 79 Arc arc = ait.Value(); 80 arc.weight = Divide(arc.weight, w, DIVIDE_LEFT); 81 ait.SetValue(arc); 82 } 83 fst->SetFinal(start, Divide(fst->Final(start), w, DIVIDE_LEFT)); 84 } 85} 86} // namespace internal 87 88// Pushes the weights in FST in the direction defined by TYPE. If 89// pushing towards the initial state, the sum of the weight of the 90// outgoing transitions and final weight at a non-initial state is 91// equal to One() in the resulting machine. If pushing towards the 92// final state, the same property holds on the reverse machine. 93// 94// Weight needs to be left distributive when pushing towards the 95// initial state and right distributive when pushing towards the final 96// states. 97template <class Arc> 98void Push(MutableFst<Arc> *fst, 99 ReweightType type, 100 float delta = kDelta, 101 bool remove_total_weight = false) { 102 vector<typename Arc::Weight> distance; 103 ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL, delta); 104 typename Arc::Weight total_weight = Arc::Weight::One(); 105 if (remove_total_weight) 106 total_weight = internal::ComputeTotalWeight(*fst, distance, 107 type == REWEIGHT_TO_INITIAL); 108 Reweight(fst, distance, type); 109 if (remove_total_weight) 110 internal::RemoveWeight(fst, total_weight, type == REWEIGHT_TO_FINAL); 111} 112 113const uint32 kPushWeights = 0x0001; 114const uint32 kPushLabels = 0x0002; 115const uint32 kPushRemoveTotalWeight = 0x0004; 116const uint32 kPushRemoveCommonAffix = 0x0008; 117 118// OFST obtained from IFST by pushing weights and/or labels according 119// to PTYPE in the direction defined by RTYPE. Weight needs to be 120// left distributive when pushing weights towards the initial state 121// and right distributive when pushing weights towards the final 122// states. 123template <class Arc, ReweightType rtype> 124void Push(const Fst<Arc> &ifst, 125 MutableFst<Arc> *ofst, 126 uint32 ptype, 127 float delta = kDelta) { 128 129 if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) { 130 *ofst = ifst; 131 Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight); 132 } else if (ptype & kPushLabels) { 133 const StringType stype = rtype == REWEIGHT_TO_INITIAL 134 ? STRING_LEFT 135 : STRING_RIGHT; 136 vector<typename GallicArc<Arc, stype>::Weight> gdistance; 137 VectorFst<GallicArc<Arc, stype> > gfst; 138 ArcMap(ifst, &gfst, ToGallicMapper<Arc, stype>()); 139 if (ptype & kPushWeights ) { 140 ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); 141 } else { 142 ArcMapFst<Arc, Arc, RmWeightMapper<Arc> > 143 uwfst(ifst, RmWeightMapper<Arc>()); 144 ArcMapFst<Arc, GallicArc<Arc, stype>, ToGallicMapper<Arc, stype> > 145 guwfst(uwfst, ToGallicMapper<Arc, stype>()); 146 ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); 147 } 148 typename GallicArc<Arc, stype>::Weight total_weight = 149 GallicArc<Arc, stype>::Weight::One(); 150 if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) { 151 total_weight = internal::ComputeTotalWeight( 152 gfst, gdistance, rtype == REWEIGHT_TO_INITIAL); 153 total_weight = typename GallicArc<Arc, stype>::Weight( 154 ptype & kPushRemoveCommonAffix ? total_weight.Value1() 155 : StringWeight<typename Arc::Label, stype>::One(), 156 ptype & kPushRemoveTotalWeight ? total_weight.Value2() 157 : Arc::Weight::One()); 158 } 159 Reweight(&gfst, gdistance, rtype); 160 if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) 161 internal::RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL); 162 FactorWeightFst< GallicArc<Arc, stype>, GallicFactor<typename Arc::Label, 163 typename Arc::Weight, stype> > fwfst(gfst); 164 ArcMap(fwfst, ofst, FromGallicMapper<Arc, stype>()); 165 ofst->SetOutputSymbols(ifst.OutputSymbols()); 166 } else { 167 LOG(WARNING) << "Push: pushing type is set to 0: " 168 << "pushing neither labels nor weights."; 169 *ofst = ifst; 170 } 171} 172 173} // namespace fst 174 175#endif /* FST_LIB_PUSH_H_ */ 176