reweight.h revision 4a68b3365c8c50aa93505e99ead2565ab73dcdb0
1// reweight.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16//
17// \file
18// Function to reweight an FST.
19
20#ifndef FST_LIB_REWEIGHT_H__
21#define FST_LIB_REWEIGHT_H__
22
23#include "fst/lib/mutable-fst.h"
24
25namespace fst {
26
27enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL };
28
29// Reweight FST according to the potentials defined by the POTENTIAL
30// vector in the direction defined by TYPE. Weight needs to be left
31// distributive when reweighting towards the initial state and right
32// distributive when reweighting towards the final states.
33//
34// An arc of weight w, with an origin state of potential p and
35// destination state of potential q, is reweighted by p\wq when
36// reweighting towards the initial state and by pw/q when reweighting
37// towards the final states.
38template <class Arc>
39void Reweight(MutableFst<Arc> *fst, vector<typename Arc::Weight> potential,
40              ReweightType type) {
41  typedef typename Arc::Weight Weight;
42
43  if (!fst->NumStates())
44    return;
45  while ( (int64)potential.size() < (int64)fst->NumStates())
46    potential.push_back(Weight::Zero());
47
48  if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring))
49    LOG(FATAL) << "Reweight: Reweighting to the final states requires "
50               << "Weight to be right distributive: "
51               << Weight::Type();
52
53  if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring))
54    LOG(FATAL) << "Reweight: Reweighting to the initial state requires "
55               << "Weight to be left distributive: "
56               << Weight::Type();
57
58  for (StateIterator< MutableFst<Arc> > sit(*fst);
59       !sit.Done();
60       sit.Next()) {
61    typename Arc::StateId state = sit.Value();
62    for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
63         !ait.Done();
64         ait.Next()) {
65      Arc arc = ait.Value();
66      if ((potential[state] == Weight::Zero()) ||
67	  (potential[arc.nextstate] == Weight::Zero()))
68	continue; //temp fix: needs to find best solution for zeros
69      if ((type == REWEIGHT_TO_INITIAL)
70	  && (potential[state] != Weight::Zero()))
71        arc.weight = Divide(Times(arc.weight, potential[arc.nextstate]),
72			    potential[state], DIVIDE_LEFT);
73      else if ((type == REWEIGHT_TO_FINAL)
74	       && (potential[arc.nextstate] != Weight::Zero()))
75        arc.weight = Divide(Times(potential[state], arc.weight),
76                            potential[arc.nextstate], DIVIDE_RIGHT);
77      ait.SetValue(arc);
78    }
79    if ((type == REWEIGHT_TO_INITIAL)
80	&& (potential[state] != Weight::Zero()))
81      fst->SetFinal(state,
82                    Divide(fst->Final(state), potential[state], DIVIDE_LEFT));
83    else if (type == REWEIGHT_TO_FINAL)
84      fst->SetFinal(state, Times(potential[state], fst->Final(state)));
85  }
86
87  if ((potential[fst->Start()] != Weight::One()) &&
88      (potential[fst->Start()] != Weight::Zero())) {
89    if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
90      typename Arc::StateId state = fst->Start();
91      for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
92           !ait.Done();
93           ait.Next()) {
94        Arc arc = ait.Value();
95        if (type == REWEIGHT_TO_INITIAL)
96          arc.weight = Times(potential[state], arc.weight);
97        else
98          arc.weight = Times(
99              Divide(Weight::One(), potential[state], DIVIDE_RIGHT),
100              arc.weight);
101        ait.SetValue(arc);
102      }
103      if (type == REWEIGHT_TO_INITIAL)
104        fst->SetFinal(state, Times(potential[state], fst->Final(state)));
105      else
106        fst->SetFinal(state, Times(Divide(Weight::One(), potential[state],
107                                          DIVIDE_RIGHT),
108                                   fst->Final(state)));
109    }
110    else {
111      typename Arc::StateId state = fst->AddState();
112      Weight w = type == REWEIGHT_TO_INITIAL ?
113                 potential[fst->Start()] :
114                 Divide(Weight::One(), potential[fst->Start()], DIVIDE_RIGHT);
115      Arc arc (0, 0, w, fst->Start());
116      fst->AddArc(state, arc);
117      fst->SetStart(state);
118    }
119  }
120
121  fst->SetProperties(ReweightProperties(
122                         fst->Properties(kFstProperties, false)),
123                     kFstProperties);
124}
125
126}  // namespace fst
127
128#endif /* FST_LIB_REWEIGHT_H_ */
129