1// reweight.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: allauzen@google.com (Cyril Allauzen)
17//
18// \file
19// Function to reweight an FST.
20
21#ifndef FST_LIB_REWEIGHT_H__
22#define FST_LIB_REWEIGHT_H__
23
24#include <vector>
25using std::vector;
26
27#include <fst/mutable-fst.h>
28
29
30namespace fst {
31
32enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL };
33
34// Reweight FST according to the potentials defined by the POTENTIAL
35// vector in the direction defined by TYPE. Weight needs to be left
36// distributive when reweighting towards the initial state and right
37// distributive when reweighting towards the final states.
38//
39// An arc of weight w, with an origin state of potential p and
40// destination state of potential q, is reweighted by p\wq when
41// reweighting towards the initial state and by pw/q when reweighting
42// towards the final states.
43template <class Arc>
44void Reweight(MutableFst<Arc> *fst,
45              const vector<typename Arc::Weight> &potential,
46              ReweightType type) {
47  typedef typename Arc::Weight Weight;
48
49  if (fst->NumStates() == 0)
50    return;
51
52  if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) {
53    FSTERROR() << "Reweight: Reweighting to the final states requires "
54               << "Weight to be right distributive: "
55               << Weight::Type();
56    fst->SetProperties(kError, kError);
57    return;
58  }
59
60  if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) {
61    FSTERROR() << "Reweight: Reweighting to the initial state requires "
62               << "Weight to be left distributive: "
63               << Weight::Type();
64    fst->SetProperties(kError, kError);
65    return;
66  }
67
68  StateIterator< MutableFst<Arc> > sit(*fst);
69  for (; !sit.Done(); sit.Next()) {
70    typename Arc::StateId state = sit.Value();
71    if (state == potential.size())
72      break;
73    typename Arc::Weight weight = potential[state];
74    if (weight != Weight::Zero()) {
75      for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
76           !ait.Done();
77           ait.Next()) {
78        Arc arc = ait.Value();
79        if (arc.nextstate >= potential.size())
80          continue;
81        typename Arc::Weight nextweight = potential[arc.nextstate];
82        if (nextweight == Weight::Zero())
83          continue;
84        if (type == REWEIGHT_TO_INITIAL)
85          arc.weight = Divide(Times(arc.weight, nextweight), weight,
86                              DIVIDE_LEFT);
87        if (type == REWEIGHT_TO_FINAL)
88          arc.weight = Divide(Times(weight, arc.weight), nextweight,
89                              DIVIDE_RIGHT);
90        ait.SetValue(arc);
91      }
92      if (type == REWEIGHT_TO_INITIAL)
93        fst->SetFinal(state, Divide(fst->Final(state), weight, DIVIDE_LEFT));
94    }
95    if (type == REWEIGHT_TO_FINAL)
96      fst->SetFinal(state, Times(weight, fst->Final(state)));
97  }
98
99  // This handles elements past the end of the potentials array.
100  for (; !sit.Done(); sit.Next()) {
101    typename Arc::StateId state = sit.Value();
102    if (type == REWEIGHT_TO_FINAL)
103      fst->SetFinal(state, Times(Weight::Zero(), fst->Final(state)));
104  }
105
106  typename Arc::Weight startweight = fst->Start() < potential.size() ?
107      potential[fst->Start()] : Weight::Zero();
108  if ((startweight != Weight::One()) && (startweight != Weight::Zero())) {
109    if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
110      typename Arc::StateId state = fst->Start();
111      for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
112           !ait.Done();
113           ait.Next()) {
114        Arc arc = ait.Value();
115        if (type == REWEIGHT_TO_INITIAL)
116          arc.weight = Times(startweight, arc.weight);
117        else
118          arc.weight = Times(
119              Divide(Weight::One(), startweight, DIVIDE_RIGHT),
120              arc.weight);
121        ait.SetValue(arc);
122      }
123      if (type == REWEIGHT_TO_INITIAL)
124        fst->SetFinal(state, Times(startweight, fst->Final(state)));
125      else
126        fst->SetFinal(state, Times(Divide(Weight::One(), startweight,
127                                          DIVIDE_RIGHT),
128                                   fst->Final(state)));
129    } else {
130      typename Arc::StateId state = fst->AddState();
131      Weight w = type == REWEIGHT_TO_INITIAL ?  startweight :
132                 Divide(Weight::One(), startweight, DIVIDE_RIGHT);
133      Arc arc(0, 0, w, fst->Start());
134      fst->AddArc(state, arc);
135      fst->SetStart(state);
136    }
137  }
138
139  fst->SetProperties(ReweightProperties(
140                         fst->Properties(kFstProperties, false)),
141                     kFstProperties);
142}
143
144}  // namespace fst
145
146#endif  // FST_LIB_REWEIGHT_H_
147