prune.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
1// prune.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: allauzen@google.com (Cyril Allauzen)
17//
18// \file
19// Functions implementing pruning.
20
21#ifndef FST_LIB_PRUNE_H__
22#define FST_LIB_PRUNE_H__
23
24#include <vector>
25using std::vector;
26
27#include <fst/arcfilter.h>
28#include <fst/heap.h>
29#include <fst/shortest-distance.h>
30
31
32namespace fst {
33
34template <class A, class ArcFilter>
35class PruneOptions {
36 public:
37  typedef typename A::Weight Weight;
38  typedef typename A::StateId StateId;
39
40  // Pruning weight threshold.
41  Weight weight_threshold;
42  // Pruning state threshold.
43  StateId state_threshold;
44  // Arc filter.
45  ArcFilter filter;
46  // If non-zero, passes in pre-computed shortest distance to final states.
47  const vector<Weight> *distance;
48  // Determines the degree of convergence required when computing shortest
49  // distances.
50  float delta;
51
52  explicit PruneOptions(const Weight& w, StateId s, ArcFilter f,
53                        vector<Weight> *d = 0, float e = kDelta)
54      : weight_threshold(w),
55        state_threshold(s),
56        filter(f),
57        distance(d),
58        delta(e) {}
59 private:
60  PruneOptions();  // disallow
61};
62
63
64template <class S, class W>
65class PruneCompare {
66 public:
67  typedef S StateId;
68  typedef W Weight;
69
70  PruneCompare(const vector<Weight> &idistance,
71               const vector<Weight> &fdistance)
72      : idistance_(idistance), fdistance_(fdistance) {}
73
74  bool operator()(const StateId x, const StateId y) const {
75    Weight wx = Times(x < idistance_.size() ? idistance_[x] : Weight::Zero(),
76                      x < fdistance_.size() ? fdistance_[x] : Weight::Zero());
77    Weight wy = Times(y < idistance_.size() ? idistance_[y] : Weight::Zero(),
78                      y < fdistance_.size() ? fdistance_[y] : Weight::Zero());
79    return less_(wx, wy);
80  }
81
82 private:
83  const vector<Weight> &idistance_;
84  const vector<Weight> &fdistance_;
85  NaturalLess<Weight> less_;
86};
87
88
89
90// Pruning algorithm: this version modifies its input and it takes an
91// options class as an argment. Delete states and arcs in 'fst' that
92// do not belong to a successful path whose weight is no more than
93// the weight of the shortest path Times() 'opts.weight_threshold'.
94// When 'opts.state_threshold != kNoStateId', the resulting transducer
95// will restricted further to have at most 'opts.state_threshold'
96// states. Weights need to be commutative and have the path
97// property. The weight 'w' of any cycle needs to be bounded, i.e.,
98// 'Plus(w, W::One()) = One()'.
99template <class Arc, class ArcFilter>
100void Prune(MutableFst<Arc> *fst,
101           const PruneOptions<Arc, ArcFilter> &opts) {
102  typedef typename Arc::Weight Weight;
103  typedef typename Arc::StateId StateId;
104
105  if ((Weight::Properties() & (kPath | kCommutative))
106      != (kPath | kCommutative)) {
107    FSTERROR() << "Prune: Weight needs to have the path property and"
108               << " be commutative: "
109               << Weight::Type();
110    fst->SetProperties(kError, kError);
111    return;
112  }
113  StateId ns = fst->NumStates();
114  if (ns == 0) return;
115  vector<Weight> idistance(ns, Weight::Zero());
116  vector<Weight> tmp;
117  if (!opts.distance) {
118    tmp.reserve(ns);
119    ShortestDistance(*fst, &tmp, true, opts.delta);
120  }
121  const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp;
122
123  if ((opts.state_threshold == 0) ||
124      (fdistance->size() <= fst->Start()) ||
125      ((*fdistance)[fst->Start()] == Weight::Zero())) {
126    fst->DeleteStates();
127    return;
128  }
129  PruneCompare<StateId, Weight> compare(idistance, *fdistance);
130  Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare);
131  vector<bool> visited(ns, false);
132  vector<size_t> enqueued(ns, kNoKey);
133  vector<StateId> dead;
134  dead.push_back(fst->AddState());
135  NaturalLess<Weight> less;
136  Weight limit = Times((*fdistance)[fst->Start()], opts.weight_threshold);
137
138  StateId num_visited = 0;
139  StateId s = fst->Start();
140  if (!less(limit, (*fdistance)[s])) {
141    idistance[s] = Weight::One();
142    enqueued[s] = heap.Insert(s);
143    ++num_visited;
144  }
145
146  while (!heap.Empty()) {
147    s = heap.Top();
148    heap.Pop();
149    enqueued[s] = kNoKey;
150    visited[s] = true;
151    if (less(limit, Times(idistance[s], fst->Final(s))))
152      fst->SetFinal(s, Weight::Zero());
153    for (MutableArcIterator< MutableFst<Arc> > ait(fst, s);
154         !ait.Done();
155         ait.Next()) {
156      Arc arc = ait.Value();
157      if (!opts.filter(arc)) continue;
158      Weight weight = Times(Times(idistance[s], arc.weight),
159                            arc.nextstate < fdistance->size()
160                            ? (*fdistance)[arc.nextstate]
161                            : Weight::Zero());
162      if (less(limit, weight)) {
163        arc.nextstate = dead[0];
164        ait.SetValue(arc);
165        continue;
166      }
167      if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate]))
168        idistance[arc.nextstate] = Times(idistance[s], arc.weight);
169      if (visited[arc.nextstate]) continue;
170      if ((opts.state_threshold != kNoStateId) &&
171          (num_visited >= opts.state_threshold))
172        continue;
173      if (enqueued[arc.nextstate] == kNoKey) {
174        enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
175        ++num_visited;
176      } else {
177        heap.Update(enqueued[arc.nextstate], arc.nextstate);
178      }
179    }
180  }
181  for (size_t i = 0; i < visited.size(); ++i)
182    if (!visited[i]) dead.push_back(i);
183  fst->DeleteStates(dead);
184}
185
186
187// Pruning algorithm: this version modifies its input and simply takes
188// the pruning threshold as an argument. Delete states and arcs in
189// 'fst' that do not belong to a successful path whose weight is no
190// more than the weight of the shortest path Times()
191// 'weight_threshold'.  When 'state_threshold != kNoStateId', the
192// resulting transducer will be restricted further to have at most
193// 'opts.state_threshold' states. Weights need to be commutative and
194// have the path property. The weight 'w' of any cycle needs to be
195// bounded, i.e., 'Plus(w, W::One()) = One()'.
196template <class Arc>
197void Prune(MutableFst<Arc> *fst,
198           typename Arc::Weight weight_threshold,
199           typename Arc::StateId state_threshold = kNoStateId,
200           double delta = kDelta) {
201  PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold,
202                                             AnyArcFilter<Arc>(), 0, delta);
203  Prune(fst, opts);
204}
205
206
207// Pruning algorithm: this version writes the pruned input Fst to an
208// output MutableFst and it takes an options class as an argument.
209// 'ofst' contains states and arcs that belong to a successful path in
210// 'ifst' whose weight is no more than the weight of the shortest path
211// Times() 'opts.weight_threshold'. When 'opts.state_threshold !=
212// kNoStateId', 'ofst' will be restricted further to have at most
213// 'opts.state_threshold' states. Weights need to be commutative and
214// have the path property. The weight 'w' of any cycle needs to be
215// bounded, i.e., 'Plus(w, W::One()) = One()'.
216template <class Arc, class ArcFilter>
217void Prune(const Fst<Arc> &ifst,
218           MutableFst<Arc> *ofst,
219           const PruneOptions<Arc, ArcFilter> &opts) {
220  typedef typename Arc::Weight Weight;
221  typedef typename Arc::StateId StateId;
222
223  if ((Weight::Properties() & (kPath | kCommutative))
224      != (kPath | kCommutative)) {
225    FSTERROR() << "Prune: Weight needs to have the path property and"
226               << " be commutative: "
227               << Weight::Type();
228    ofst->SetProperties(kError, kError);
229    return;
230  }
231  ofst->DeleteStates();
232  ofst->SetInputSymbols(ifst.InputSymbols());
233  ofst->SetOutputSymbols(ifst.OutputSymbols());
234  if (ifst.Start() == kNoStateId)
235    return;
236  NaturalLess<Weight> less;
237  if (less(opts.weight_threshold, Weight::One()) ||
238      (opts.state_threshold == 0))
239    return;
240  vector<Weight> idistance;
241  vector<Weight> tmp;
242  if (!opts.distance)
243    ShortestDistance(ifst, &tmp, true, opts.delta);
244  const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp;
245
246  if ((fdistance->size() <= ifst.Start()) ||
247      ((*fdistance)[ifst.Start()] == Weight::Zero())) {
248    return;
249  }
250  PruneCompare<StateId, Weight> compare(idistance, *fdistance);
251  Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare);
252  vector<StateId> copy;
253  vector<size_t> enqueued;
254  vector<bool> visited;
255
256  StateId s = ifst.Start();
257  Weight limit = Times(s < fdistance->size() ? (*fdistance)[s] : Weight::Zero(),
258                         opts.weight_threshold);
259  while (copy.size() <= s)
260    copy.push_back(kNoStateId);
261  copy[s] = ofst->AddState();
262  ofst->SetStart(copy[s]);
263  while (idistance.size() <= s)
264    idistance.push_back(Weight::Zero());
265  idistance[s] = Weight::One();
266  while (enqueued.size() <= s) {
267    enqueued.push_back(kNoKey);
268    visited.push_back(false);
269  }
270  enqueued[s] = heap.Insert(s);
271
272  while (!heap.Empty()) {
273    s = heap.Top();
274    heap.Pop();
275    enqueued[s] = kNoKey;
276    visited[s] = true;
277    if (!less(limit, Times(idistance[s], ifst.Final(s))))
278      ofst->SetFinal(copy[s], ifst.Final(s));
279    for (ArcIterator< Fst<Arc> > ait(ifst, s);
280         !ait.Done();
281         ait.Next()) {
282      const Arc &arc = ait.Value();
283      if (!opts.filter(arc)) continue;
284      Weight weight = Times(Times(idistance[s], arc.weight),
285                            arc.nextstate < fdistance->size()
286                            ? (*fdistance)[arc.nextstate]
287                            : Weight::Zero());
288      if (less(limit, weight)) continue;
289      if ((opts.state_threshold != kNoStateId) &&
290          (ofst->NumStates() >= opts.state_threshold))
291        continue;
292      while (idistance.size() <= arc.nextstate)
293        idistance.push_back(Weight::Zero());
294      if (less(Times(idistance[s], arc.weight),
295               idistance[arc.nextstate]))
296        idistance[arc.nextstate] = Times(idistance[s], arc.weight);
297      while (copy.size() <= arc.nextstate)
298        copy.push_back(kNoStateId);
299      if (copy[arc.nextstate] == kNoStateId)
300        copy[arc.nextstate] = ofst->AddState();
301      ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight,
302                                copy[arc.nextstate]));
303      while (enqueued.size() <= arc.nextstate) {
304        enqueued.push_back(kNoKey);
305        visited.push_back(false);
306      }
307      if (visited[arc.nextstate]) continue;
308      if (enqueued[arc.nextstate] == kNoKey)
309        enqueued[arc.nextstate] = heap.Insert(arc.nextstate);
310      else
311        heap.Update(enqueued[arc.nextstate], arc.nextstate);
312    }
313  }
314}
315
316
317// Pruning algorithm: this version writes the pruned input Fst to an
318// output MutableFst and simply takes the pruning threshold as an
319// argument.  'ofst' contains states and arcs that belong to a
320// successful path in 'ifst' whose weight is no more than
321// the weight of the shortest path Times() 'weight_threshold'. When
322// 'state_threshold != kNoStateId', 'ofst' will be restricted further
323// to have at most 'opts.state_threshold' states. Weights need to be
324// commutative and have the path property. The weight 'w' of any cycle
325// needs to be bounded, i.e., 'Plus(w, W::One()) = W::One()'.
326template <class Arc>
327void Prune(const Fst<Arc> &ifst,
328           MutableFst<Arc> *ofst,
329           typename Arc::Weight weight_threshold,
330           typename Arc::StateId state_threshold = kNoStateId,
331           float delta = kDelta) {
332  PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold,
333                                             AnyArcFilter<Arc>(), 0, delta);
334  Prune(ifst, ofst, opts);
335}
336
337}  // namespace fst
338
339#endif // FST_LIB_PRUNE_H_
340