1// prune.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16//
17// \file
18// Functions implementing pruning.
19
20#ifndef FST_LIB_PRUNE_H__
21#define FST_LIB_PRUNE_H__
22
23#include "fst/lib/arcfilter.h"
24#include "fst/lib/shortest-distance.h"
25
26namespace fst {
27
28template <class A, class ArcFilter>
29class PruneOptions {
30 public:
31  typedef typename A::Weight Weight;
32
33  // Pruning threshold.
34  Weight threshold;
35  // Arc filter.
36  ArcFilter filter;
37  // If non-zero, passes in pre-computed shortest distance from initial state
38  // (possibly resized).
39  vector<Weight> *idistance;
40  // If non-zero, passes in pre-computed shortest distance to final states
41  // (possibly resized).
42  vector<Weight> *fdistance;
43
44  PruneOptions(const Weight& t, ArcFilter f, vector<Weight> *id = 0,
45               vector<Weight> *fd = 0)
46      : threshold(t), filter(f), idistance(id), fdistance(fd) {}
47};
48
49
50// Pruning algorithm: this version modifies its input and it takes an
51// options class as an argment. Delete states and arcs in 'fst' that
52// do not belong to a successful path whose weight is no more than
53// 'opts.threshold' Times() the weight of the shortest path. Weights
54// need to be commutative and have the path property.
55template <class Arc, class ArcFilter>
56void Prune(MutableFst<Arc> *fst,
57           const PruneOptions<Arc, ArcFilter> &opts) {
58  typedef typename Arc::Weight Weight;
59  typedef typename Arc::StateId StateId;
60
61  if ((Weight::Properties() & (kPath | kCommutative))
62      != (kPath | kCommutative))
63    LOG(FATAL) << "Prune: Weight needs to have the path property and"
64               << " be commutative: "
65               << Weight::Type();
66
67  StateId ns = fst->NumStates();
68  if (ns == 0) return;
69
70  vector<Weight> *idistance = opts.idistance;
71  vector<Weight> *fdistance = opts.fdistance;
72
73  if (!idistance) {
74    idistance = new vector<Weight>(ns, Weight::Zero());
75    ShortestDistance(*fst, idistance, false);
76  } else {
77    idistance->resize(ns, Weight::Zero());
78  }
79
80  if (!fdistance) {
81    fdistance = new vector<Weight>(ns, Weight::Zero());
82    ShortestDistance(*fst, fdistance, true);
83  } else {
84    fdistance->resize(ns, Weight::Zero());
85  }
86
87  vector<StateId> dead;
88  dead.push_back(fst->AddState());
89  NaturalLess<Weight> less;
90  Weight ceiling = Times((*fdistance)[fst->Start()], opts.threshold);
91
92  for (StateId state = 0; state < ns; ++state) {
93    if (less(ceiling, Times((*idistance)[state], (*fdistance)[state]))) {
94      dead.push_back(state);
95      continue;
96    }
97    for (MutableArcIterator< MutableFst<Arc> > it(fst, state);
98         !it.Done();
99         it.Next()) {
100      Arc arc = it.Value();
101      if (!opts.filter(arc)) continue;
102      Weight weight = Times(Times((*idistance)[state], arc.weight),
103                           (*fdistance)[arc.nextstate]);
104      if(less(ceiling, weight)) {
105        arc.nextstate = dead[0];
106        it.SetValue(arc);
107      }
108    }
109    if (less(ceiling, Times((*idistance)[state], fst->Final(state))))
110      fst->SetFinal(state, Weight::Zero());
111  }
112
113  fst->DeleteStates(dead);
114
115  if (!opts.idistance)
116    delete idistance;
117  if (!opts.fdistance)
118    delete fdistance;
119}
120
121
122// Pruning algorithm: this version modifies its input and simply takes
123// the pruning threshold as an argument. Delete states and arcs in
124// 'fst' that do not belong to a successful path whose weight is no
125// more than 'opts.threshold' Times() the weight of the shortest
126// path. Weights need to be commutative and have the path property.
127template <class Arc>
128void Prune(MutableFst<Arc> *fst, typename Arc::Weight threshold) {
129  PruneOptions<Arc, AnyArcFilter<Arc> > opts(threshold, AnyArcFilter<Arc>());
130  Prune(fst, opts);
131}
132
133
134// Pruning algorithm: this version writes the pruned input Fst to an
135// output MutableFst and it takes an options class as an argument.
136// 'ofst' contains states and arcs that belong to a successful path in
137// 'ifst' whose weight is no more than 'opts.threshold' Times() the
138// weight of the shortest path. Weights need to be commutative and
139// have the path property.
140template <class Arc, class ArcFilter>
141void Prune(const Fst<Arc> &ifst,
142           MutableFst<Arc> *ofst,
143           const PruneOptions<Arc, ArcFilter> &opts) {
144  typedef typename Arc::Weight Weight;
145  typedef typename Arc::StateId StateId;
146
147  if ((Weight::Properties() & (kPath | kCommutative))
148      != (kPath | kCommutative))
149    LOG(FATAL) << "Prune: Weight needs to have the path property and"
150               << " be commutative: "
151               << Weight::Type();
152
153  ofst->DeleteStates();
154
155  if (ifst.Start() == kNoStateId)
156    return;
157
158  vector<Weight> *idistance = opts.idistance;
159  vector<Weight> *fdistance = opts.fdistance;
160
161  if (!idistance) {
162    idistance = new vector<Weight>;
163    ShortestDistance(ifst, idistance, false);
164  }
165
166  if (!fdistance) {
167    fdistance = new vector<Weight>;
168    ShortestDistance(ifst, fdistance, true);
169  }
170
171  vector<StateId> copy;
172  NaturalLess<Weight> less;
173  while (fdistance->size() <= ifst.Start())
174    fdistance->push_back(Weight::Zero());
175  Weight ceiling = Times((*fdistance)[ifst.Start()], opts.threshold);
176
177  for (StateIterator< Fst<Arc> > sit(ifst);
178       !sit.Done();
179       sit.Next()) {
180    StateId state = sit.Value();
181    while (idistance->size() <= state)
182      idistance->push_back(Weight::Zero());
183    while (fdistance->size() <= state)
184      fdistance->push_back(Weight::Zero());
185    while (copy.size() <= state)
186      copy.push_back(kNoStateId);
187
188    if (less(ceiling, Times((*idistance)[state], (*fdistance)[state])))
189      continue;
190
191    if (copy[state] == kNoStateId)
192      copy[state] = ofst->AddState();
193    if (!less(ceiling, Times((*idistance)[state], ifst.Final(state))))
194      ofst->SetFinal(copy[state], ifst.Final(state));
195
196    for (ArcIterator< Fst<Arc> > ait(ifst, state);
197         !ait.Done();
198         ait.Next()) {
199      Arc arc = ait.Value();
200
201      if (!opts.filter(arc)) continue;
202
203      while (idistance->size() <= arc.nextstate)
204        idistance->push_back(Weight::Zero());
205      while (fdistance->size() <= arc.nextstate)
206        fdistance->push_back(Weight::Zero());
207      while (copy.size() <= arc.nextstate)
208        copy.push_back(kNoStateId);
209
210      Weight weight = Times(Times((*idistance)[state], arc.weight),
211                           (*fdistance)[arc.nextstate]);
212
213      if (!less(ceiling, weight)) {
214        if (copy[arc.nextstate] == kNoStateId)
215          copy[arc.nextstate] = ofst->AddState();
216        arc.nextstate = copy[arc.nextstate];
217        ofst->AddArc(copy[state], arc);
218      }
219    }
220  }
221
222  ofst->SetStart(copy[ifst.Start()]);
223
224  if (!opts.idistance)
225    delete idistance;
226  if (!opts.fdistance)
227    delete fdistance;
228}
229
230
231// Pruning algorithm: this version writes the pruned input Fst to an
232// output MutableFst and simply takes the pruning threshold as an
233// argument.  'ofst' contains states and arcs that belong to a
234// successful path in 'ifst' whose weight is no more than
235// 'opts.threshold' Times() the weight of the shortest path. Weights
236// need to be commutative and have the path property.
237template <class Arc>
238void Prune(const Fst<Arc> &ifst,
239           MutableFst<Arc> *ofst,
240           typename Arc::Weight threshold) {
241  PruneOptions<Arc, AnyArcFilter<Arc> > opts(threshold, AnyArcFilter<Arc>());
242  Prune(ifst, ofst, opts);
243}
244
245} // namespace fst
246
247#endif // FST_LIB_PRUNE_H_
248