1// push.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: allauzen@google.com (Cyril Allauzen)
17//
18// \file
19// Class to reweight/push an FST.
20
21#ifndef FST_LIB_PUSH_H__
22#define FST_LIB_PUSH_H__
23
24#include <vector>
25using std::vector;
26
27#include <fst/factor-weight.h>
28#include <fst/fst.h>
29#include <fst/arc-map.h>
30#include <fst/reweight.h>
31#include <fst/shortest-distance.h>
32
33
34namespace fst {
35
36// Private helper functions for Push
37namespace internal {
38
39// Compute the total weight (sum of the weights of all accepting paths) from
40// the output of ShortestDistance. 'distance' is the shortest distance from the
41// initial state when 'reverse == false' and to the final states when
42// 'reverse == true'.
43template <class Arc>
44typename Arc::Weight ComputeTotalWeight(
45    const Fst<Arc> &fst,
46    const vector<typename Arc::Weight> &distance,
47    bool reverse) {
48  if (reverse)
49    return fst.Start() < distance.size() ?
50        distance[fst.Start()] : Arc::Weight::Zero();
51
52  typename Arc::Weight sum = Arc::Weight::Zero();
53  for (typename Arc::StateId s = 0; s < distance.size(); ++s)
54    sum = Plus(sum, Times(distance[s], fst.Final(s)));
55  return sum;
56}
57
58// Divide the weight of every accepting path by 'w'. The weight 'w' is
59// divided at the final states if 'at_final == true' and at the
60// initial state otherwise.
61template <class Arc>
62void RemoveWeight(MutableFst<Arc> *fst, typename Arc::Weight w, bool at_final) {
63  if ((w == Arc::Weight::One()) || (w == Arc::Weight::Zero()))
64      return;
65
66  if (at_final) {
67    // Remove 'w' from the final states
68    for (StateIterator< MutableFst<Arc> > sit(*fst);
69         !sit.Done();
70         sit.Next())
71      fst->SetFinal(sit.Value(),
72                    Divide(fst->Final(sit.Value()), w,  DIVIDE_RIGHT));
73  } else {  // at_final == false
74    // Remove 'w' from the initial state
75    typename Arc::StateId start = fst->Start();
76    for (MutableArcIterator<MutableFst<Arc> > ait(fst, start);
77         !ait.Done();
78         ait.Next()) {
79      Arc arc = ait.Value();
80      arc.weight = Divide(arc.weight, w, DIVIDE_LEFT);
81      ait.SetValue(arc);
82    }
83    fst->SetFinal(start, Divide(fst->Final(start), w, DIVIDE_LEFT));
84  }
85}
86}  // namespace internal
87
88// Pushes the weights in FST in the direction defined by TYPE.  If
89// pushing towards the initial state, the sum of the weight of the
90// outgoing transitions and final weight at a non-initial state is
91// equal to One() in the resulting machine.  If pushing towards the
92// final state, the same property holds on the reverse machine.
93//
94// Weight needs to be left distributive when pushing towards the
95// initial state and right distributive when pushing towards the final
96// states.
97template <class Arc>
98void Push(MutableFst<Arc> *fst,
99          ReweightType type,
100          float delta = kDelta,
101          bool remove_total_weight = false) {
102  vector<typename Arc::Weight> distance;
103  ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL, delta);
104  typename Arc::Weight total_weight = Arc::Weight::One();
105  if (remove_total_weight)
106    total_weight = internal::ComputeTotalWeight(*fst, distance,
107                                                type == REWEIGHT_TO_INITIAL);
108  Reweight(fst, distance, type);
109  if (remove_total_weight)
110    internal::RemoveWeight(fst, total_weight, type == REWEIGHT_TO_FINAL);
111}
112
113const uint32 kPushWeights = 0x0001;
114const uint32 kPushLabels =  0x0002;
115const uint32 kPushRemoveTotalWeight = 0x0004;
116const uint32 kPushRemoveCommonAffix = 0x0008;
117
118// OFST obtained from IFST by pushing weights and/or labels according
119// to PTYPE in the direction defined by RTYPE.  Weight needs to be
120// left distributive when pushing weights towards the initial state
121// and right distributive when pushing weights towards the final
122// states.
123template <class Arc, ReweightType rtype>
124void Push(const Fst<Arc> &ifst,
125          MutableFst<Arc> *ofst,
126          uint32 ptype,
127          float delta = kDelta) {
128
129  if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) {
130    *ofst = ifst;
131    Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight);
132  } else if (ptype & kPushLabels) {
133    const StringType stype = rtype == REWEIGHT_TO_INITIAL
134                             ? STRING_LEFT
135                             : STRING_RIGHT;
136    vector<typename GallicArc<Arc, stype>::Weight> gdistance;
137    VectorFst<GallicArc<Arc, stype> > gfst;
138    ArcMap(ifst, &gfst, ToGallicMapper<Arc, stype>());
139    if (ptype & kPushWeights ) {
140      ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
141    } else {
142      ArcMapFst<Arc, Arc, RmWeightMapper<Arc> >
143        uwfst(ifst, RmWeightMapper<Arc>());
144      ArcMapFst<Arc, GallicArc<Arc, stype>, ToGallicMapper<Arc, stype> >
145        guwfst(uwfst, ToGallicMapper<Arc, stype>());
146      ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
147    }
148    typename GallicArc<Arc, stype>::Weight total_weight =
149        GallicArc<Arc, stype>::Weight::One();
150    if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) {
151      total_weight = internal::ComputeTotalWeight(
152          gfst, gdistance, rtype == REWEIGHT_TO_INITIAL);
153      total_weight = typename GallicArc<Arc, stype>::Weight(
154          ptype & kPushRemoveCommonAffix ? total_weight.Value1()
155          : StringWeight<typename Arc::Label, stype>::One(),
156          ptype & kPushRemoveTotalWeight ? total_weight.Value2()
157          : Arc::Weight::One());
158    }
159    Reweight(&gfst, gdistance, rtype);
160    if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix))
161      internal::RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL);
162    FactorWeightFst< GallicArc<Arc, stype>, GallicFactor<typename Arc::Label,
163      typename Arc::Weight, stype> > fwfst(gfst);
164    ArcMap(fwfst, ofst, FromGallicMapper<Arc, stype>());
165    ofst->SetOutputSymbols(ifst.OutputSymbols());
166  } else {
167    LOG(WARNING) << "Push: pushing type is set to 0: "
168                 << "pushing neither labels nor weights.";
169    *ofst = ifst;
170  }
171}
172
173}  // namespace fst
174
175#endif /* FST_LIB_PUSH_H_ */
176