1f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// push.h
2f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
3f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Licensed under the Apache License, Version 2.0 (the "License");
4f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// you may not use this file except in compliance with the License.
5f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// You may obtain a copy of the License at
6f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//
7f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//     http://www.apache.org/licenses/LICENSE-2.0
8f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//
9f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Unless required by applicable law or agreed to in writing, software
10f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// distributed under the License is distributed on an "AS IS" BASIS,
11f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// See the License for the specific language governing permissions and
13f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// limitations under the License.
14f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//
15f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Copyright 2005-2010 Google, Inc.
16f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Author: allauzen@google.com (Cyril Allauzen)
17f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//
18f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// \file
19f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Class to reweight/push an FST.
20f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
21f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#ifndef FST_LIB_PUSH_H__
22f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#define FST_LIB_PUSH_H__
23f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
24f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <vector>
25f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonusing std::vector;
26f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
27f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/factor-weight.h>
28f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/fst.h>
29f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/arc-map.h>
30f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/reweight.h>
31f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#include <fst/shortest-distance.h>
32f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
33f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
34f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonnamespace fst {
35f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
36f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Private helper functions for Push
37f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonnamespace internal {
38f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
39f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Compute the total weight (sum of the weights of all accepting paths) from
40f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// the output of ShortestDistance. 'distance' is the shortest distance from the
41f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// initial state when 'reverse == false' and to the final states when
42f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// 'reverse == true'.
43f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class Arc>
44f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontypename Arc::Weight ComputeTotalWeight(
45f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    const Fst<Arc> &fst,
46f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    const vector<typename Arc::Weight> &distance,
47f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    bool reverse) {
48f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  if (reverse)
49f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    return fst.Start() < distance.size() ?
50f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        distance[fst.Start()] : Arc::Weight::Zero();
51f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
52f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typename Arc::Weight sum = Arc::Weight::Zero();
53f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  for (typename Arc::StateId s = 0; s < distance.size(); ++s)
54f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    sum = Plus(sum, Times(distance[s], fst.Final(s)));
55f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  return sum;
56f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson}
57f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
58f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Divide the weight of every accepting path by 'w'. The weight 'w' is
59f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// divided at the final states if 'at_final == true' and at the
60f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// initial state otherwise.
61f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class Arc>
62f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonvoid RemoveWeight(MutableFst<Arc> *fst, typename Arc::Weight w, bool at_final) {
63f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  if ((w == Arc::Weight::One()) || (w == Arc::Weight::Zero()))
64f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      return;
65f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
66f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  if (at_final) {
67f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    // Remove 'w' from the final states
68f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    for (StateIterator< MutableFst<Arc> > sit(*fst);
69f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson         !sit.Done();
70f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson         sit.Next())
71f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      fst->SetFinal(sit.Value(),
72f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson                    Divide(fst->Final(sit.Value()), w,  DIVIDE_RIGHT));
73f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  } else {  // at_final == false
74f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    // Remove 'w' from the initial state
75f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    typename Arc::StateId start = fst->Start();
76f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    for (MutableArcIterator<MutableFst<Arc> > ait(fst, start);
77f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson         !ait.Done();
78f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson         ait.Next()) {
79f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      Arc arc = ait.Value();
80f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      arc.weight = Divide(arc.weight, w, DIVIDE_LEFT);
81f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      ait.SetValue(arc);
82f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
83f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    fst->SetFinal(start, Divide(fst->Final(start), w, DIVIDE_LEFT));
84f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
85f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson}
86f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson}  // namespace internal
87f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
88f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Pushes the weights in FST in the direction defined by TYPE.  If
89f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// pushing towards the initial state, the sum of the weight of the
90f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// outgoing transitions and final weight at a non-initial state is
91f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// equal to One() in the resulting machine.  If pushing towards the
92f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// final state, the same property holds on the reverse machine.
93f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson//
94f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// Weight needs to be left distributive when pushing towards the
95f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// initial state and right distributive when pushing towards the final
96f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// states.
97f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class Arc>
98f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonvoid Push(MutableFst<Arc> *fst,
99f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          ReweightType type,
100f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          float delta = kDelta,
101f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          bool remove_total_weight = false) {
102f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  vector<typename Arc::Weight> distance;
103f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL, delta);
104f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  typename Arc::Weight total_weight = Arc::Weight::One();
105f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  if (remove_total_weight)
106f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    total_weight = internal::ComputeTotalWeight(*fst, distance,
107f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson                                                type == REWEIGHT_TO_INITIAL);
108f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  Reweight(fst, distance, type);
109f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  if (remove_total_weight)
110f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    internal::RemoveWeight(fst, total_weight, type == REWEIGHT_TO_FINAL);
111f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson}
112f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
113f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonconst uint32 kPushWeights = 0x0001;
114f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonconst uint32 kPushLabels =  0x0002;
115f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonconst uint32 kPushRemoveTotalWeight = 0x0004;
116f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonconst uint32 kPushRemoveCommonAffix = 0x0008;
117f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
118f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// OFST obtained from IFST by pushing weights and/or labels according
119f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// to PTYPE in the direction defined by RTYPE.  Weight needs to be
120f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// left distributive when pushing weights towards the initial state
121f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// and right distributive when pushing weights towards the final
122f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson// states.
123f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsontemplate <class Arc, ReweightType rtype>
124f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodsonvoid Push(const Fst<Arc> &ifst,
125f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          MutableFst<Arc> *ofst,
126f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          uint32 ptype,
127f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          float delta = kDelta) {
128f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
129f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) {
130f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    *ofst = ifst;
131f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight);
132f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  } else if (ptype & kPushLabels) {
133f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    const StringType stype = rtype == REWEIGHT_TO_INITIAL
134f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson                             ? STRING_LEFT
135f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson                             : STRING_RIGHT;
136f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    vector<typename GallicArc<Arc, stype>::Weight> gdistance;
137f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    VectorFst<GallicArc<Arc, stype> > gfst;
138f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    ArcMap(ifst, &gfst, ToGallicMapper<Arc, stype>());
139f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (ptype & kPushWeights ) {
140f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
141f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    } else {
142f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      ArcMapFst<Arc, Arc, RmWeightMapper<Arc> >
143f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        uwfst(ifst, RmWeightMapper<Arc>());
144f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      ArcMapFst<Arc, GallicArc<Arc, stype>, ToGallicMapper<Arc, stype> >
145f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        guwfst(uwfst, ToGallicMapper<Arc, stype>());
146f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
147f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
148f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    typename GallicArc<Arc, stype>::Weight total_weight =
149f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson        GallicArc<Arc, stype>::Weight::One();
150f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) {
151f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      total_weight = internal::ComputeTotalWeight(
152f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          gfst, gdistance, rtype == REWEIGHT_TO_INITIAL);
153f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      total_weight = typename GallicArc<Arc, stype>::Weight(
154f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          ptype & kPushRemoveCommonAffix ? total_weight.Value1()
155f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          : StringWeight<typename Arc::Label, stype>::One(),
156f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          ptype & kPushRemoveTotalWeight ? total_weight.Value2()
157f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson          : Arc::Weight::One());
158f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    }
159f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    Reweight(&gfst, gdistance, rtype);
160f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix))
161f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      internal::RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL);
162f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    FactorWeightFst< GallicArc<Arc, stype>, GallicFactor<typename Arc::Label,
163f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson      typename Arc::Weight, stype> > fwfst(gfst);
164f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    ArcMap(fwfst, ofst, FromGallicMapper<Arc, stype>());
165f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    ofst->SetOutputSymbols(ifst.OutputSymbols());
166f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  } else {
167f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    LOG(WARNING) << "Push: pushing type is set to 0: "
168f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson                 << "pushing neither labels nor weights.";
169f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson    *ofst = ifst;
170f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson  }
171f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson}
172f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
173f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson}  // namespace fst
174f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson
175f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2Ian Hodson#endif /* FST_LIB_PUSH_H_ */
176