14a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// reweight.h
24a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project//
34a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// Licensed under the Apache License, Version 2.0 (the "License");
44a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// you may not use this file except in compliance with the License.
54a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// You may obtain a copy of the License at
64a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project//
74a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project//      http://www.apache.org/licenses/LICENSE-2.0
84a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project//
94a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// Unless required by applicable law or agreed to in writing, software
104a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// distributed under the License is distributed on an "AS IS" BASIS,
114a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
124a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// See the License for the specific language governing permissions and
134a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// limitations under the License.
144a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project//
154a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// Author: allauzen@cs.nyu.edu (Cyril Allauzen)
164a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project//
174a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// \file
184a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// Function to reweight an FST.
194a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project
204a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project#ifndef FST_LIB_REWEIGHT_H__
214a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project#define FST_LIB_REWEIGHT_H__
224a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project
234a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project#include "fst/lib/mutable-fst.h"
244a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project
254a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Projectnamespace fst {
264a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project
274a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Projectenum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL };
284a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project
294a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// Reweight FST according to the potentials defined by the POTENTIAL
304a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// vector in the direction defined by TYPE. Weight needs to be left
314a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// distributive when reweighting towards the initial state and right
324a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// distributive when reweighting towards the final states.
334a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project//
344a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// An arc of weight w, with an origin state of potential p and
354a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// destination state of potential q, is reweighted by p\wq when
364a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// reweighting towards the initial state and by pw/q when reweighting
374a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project// towards the final states.
384a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Projecttemplate <class Arc>
394a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Projectvoid Reweight(MutableFst<Arc> *fst, vector<typename Arc::Weight> potential,
404a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project              ReweightType type) {
414a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project  typedef typename Arc::Weight Weight;
424a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project
434a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project  if (!fst->NumStates())
444a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project    return;
454a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project  while ( (int64)potential.size() < (int64)fst->NumStates())
464a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project    potential.push_back(Weight::Zero());
474a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project
484a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project  if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring))
494a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project    LOG(FATAL) << "Reweight: Reweighting to the final states requires "
504a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project               << "Weight to be right distributive: "
514a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project               << Weight::Type();
524a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project
534a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project  if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring))
544a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project    LOG(FATAL) << "Reweight: Reweighting to the initial state requires "
554a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project               << "Weight to be left distributive: "
564a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project               << Weight::Type();
574a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project
584a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project  for (StateIterator< MutableFst<Arc> > sit(*fst);
594a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project       !sit.Done();
604a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project       sit.Next()) {
614a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project    typename Arc::StateId state = sit.Value();
624a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project    for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
634a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project         !ait.Done();
644a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project         ait.Next()) {
654a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      Arc arc = ait.Value();
664a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      if ((potential[state] == Weight::Zero()) ||
674a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project	  (potential[arc.nextstate] == Weight::Zero()))
684a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project	continue; //temp fix: needs to find best solution for zeros
694a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      if ((type == REWEIGHT_TO_INITIAL)
704a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project	  && (potential[state] != Weight::Zero()))
714a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project        arc.weight = Divide(Times(arc.weight, potential[arc.nextstate]),
724a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project			    potential[state], DIVIDE_LEFT);
734a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      else if ((type == REWEIGHT_TO_FINAL)
744a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project	       && (potential[arc.nextstate] != Weight::Zero()))
754a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project        arc.weight = Divide(Times(potential[state], arc.weight),
764a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project                            potential[arc.nextstate], DIVIDE_RIGHT);
774a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      ait.SetValue(arc);
784a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project    }
794a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project    if ((type == REWEIGHT_TO_INITIAL)
804a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project	&& (potential[state] != Weight::Zero()))
814a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      fst->SetFinal(state,
824a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project                    Divide(fst->Final(state), potential[state], DIVIDE_LEFT));
834a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project    else if (type == REWEIGHT_TO_FINAL)
844a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      fst->SetFinal(state, Times(potential[state], fst->Final(state)));
854a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project  }
864a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project
874a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project  if ((potential[fst->Start()] != Weight::One()) &&
884a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      (potential[fst->Start()] != Weight::Zero())) {
894a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project    if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
904a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      typename Arc::StateId state = fst->Start();
914a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
924a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project           !ait.Done();
934a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project           ait.Next()) {
944a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project        Arc arc = ait.Value();
954a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project        if (type == REWEIGHT_TO_INITIAL)
964a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project          arc.weight = Times(potential[state], arc.weight);
974a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project        else
984a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project          arc.weight = Times(
994a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project              Divide(Weight::One(), potential[state], DIVIDE_RIGHT),
1004a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project              arc.weight);
1014a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project        ait.SetValue(arc);
1024a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      }
1034a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      if (type == REWEIGHT_TO_INITIAL)
1044a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project        fst->SetFinal(state, Times(potential[state], fst->Final(state)));
1054a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      else
1064a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project        fst->SetFinal(state, Times(Divide(Weight::One(), potential[state],
1074a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project                                          DIVIDE_RIGHT),
1084a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project                                   fst->Final(state)));
1094a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project    }
1104a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project    else {
1114a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      typename Arc::StateId state = fst->AddState();
1124a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      Weight w = type == REWEIGHT_TO_INITIAL ?
1134a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project                 potential[fst->Start()] :
1144a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project                 Divide(Weight::One(), potential[fst->Start()], DIVIDE_RIGHT);
1154a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      Arc arc (0, 0, w, fst->Start());
1164a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      fst->AddArc(state, arc);
1174a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project      fst->SetStart(state);
1184a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project    }
1194a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project  }
1204a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project
1214a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project  fst->SetProperties(ReweightProperties(
1224a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project                         fst->Properties(kFstProperties, false)),
1234a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project                     kFstProperties);
1244a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project}
1254a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project
1264a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project}  // namespace fst
1274a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project
1284a68b3365c8c50aa93505e99ead2565ab73dcdb0The Android Open Source Project#endif /* FST_LIB_REWEIGHT_H_ */
129