push.h revision 4a68b3365c8c50aa93505e99ead2565ab73dcdb0
1// push.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16//
17// \file
18// Class to reweight/push an FST.
19
20#ifndef FST_LIB_PUSH_H__
21#define FST_LIB_PUSH_H__
22
23#include "fst/lib/factor-weight.h"
24#include "fst/lib/fst.h"
25#include "fst/lib/map.h"
26#include "fst/lib/reweight.h"
27#include "fst/lib/shortest-distance.h"
28
29namespace fst {
30
31// Pushes the weights in FST in the direction defined by TYPE.  If
32// pushing towards the initial state, the sum of the weight of the
33// outgoing transitions and final weight at a non-initial state is
34// equal to One() in the resulting machine.  If pushing towards the
35// final state, the same property holds on the reverse machine.
36//
37// Weight needs to be left distributive when pushing towards the
38// initial state and right distributive when pushing towards the final
39// states.
40template <class Arc>
41void Push(MutableFst<Arc> *fst, ReweightType type) {
42  vector<typename Arc::Weight> distance;
43  ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL);
44  Reweight(fst, distance, type);
45}
46
47
48const uint32 kPushWeights = 0x0001;
49const uint32 kPushLabels =  0x0002;
50
51// OFST obtained from IFST by pushing weights and/or labels according
52// to PTYPE in the direction defined by RTYPE.  Weight needs to be
53// left distributive when pushing weights towards the initial state
54// and right distributive when pushing weights towards the final
55// states.
56template <class Arc, ReweightType rtype>
57void Push(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, uint32 ptype) {
58
59  if (ptype == kPushWeights) {
60    *ofst = ifst;
61    Push(ofst, rtype);
62  } else if (ptype & kPushLabels) {
63    const StringType stype = rtype == REWEIGHT_TO_INITIAL
64                             ? STRING_LEFT
65                             : STRING_RIGHT;
66    vector<typename GallicArc<Arc, stype>::Weight> gdistance;
67    VectorFst< GallicArc<Arc, stype> > gfst;
68    Map(ifst, &gfst, ToGallicMapper<Arc, stype>());
69    if (ptype == (kPushWeights | kPushLabels)) {
70      ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL);
71    } else {
72      MapFst<Arc, Arc, RmWeightMapper<Arc> >
73        uwfst(ifst, RmWeightMapper<Arc>());
74      MapFst<Arc, GallicArc<Arc, stype>, ToGallicMapper<Arc, stype> >
75        guwfst(uwfst, ToGallicMapper<Arc, stype>());
76      ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL);
77    }
78    Reweight(&gfst, gdistance, rtype);
79    FactorWeightFst< GallicArc<Arc, stype>, GallicFactor<typename Arc::Label,
80      typename Arc::Weight, stype> > fwfst(gfst);
81    Map(fwfst, ofst, FromGallicMapper<Arc, stype>());
82  } else {
83    *ofst = ifst;
84  }
85}
86
87}  // namespace fst
88
89#endif /* FST_LIB_PUSH_H_ */
90