arcsort.h revision 73018b4a1d088cdda0e7bd059fddf1f308a8195a
1// arcsort.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// Functions and classes to sort arcs in an FST.
18
19#ifndef FST_LIB_ARCSORT_H__
20#define FST_LIB_ARCSORT_H__
21
22#include <algorithm>
23
24#include "fst/lib/cache.h"
25#include "fst/lib/test-properties.h"
26
27namespace fst {
28
29// Sorts the arcs in an FST according to function object 'comp' of
30// type Compare. This version modifies its input.  Comparison function
31// objects IlabelCompare and OlabelCompare are provived by the
32// library. In general, Compare must meet the requirements for an STL
33// sort comparision function object. It must also have a member
34// Properties(uint64) that specifies the known properties of the
35// sorted FST; it takes as argument the input FST's known properties
36// before the sort.
37//
38// Complexity:
39// - Time: O(V + D log D)
40// - Space: O(D)
41// where V = # of states and D = maximum out-degree.
42template<class Arc, class Compare>
43void ArcSort(MutableFst<Arc> *fst, Compare comp) {
44  typedef typename Arc::StateId StateId;
45
46  uint64 props = fst->Properties(kFstProperties, false);
47
48  vector<Arc> arcs;
49  for (StateIterator< MutableFst<Arc> > siter(*fst);
50       !siter.Done();
51       siter.Next()) {
52    StateId s = siter.Value();
53    arcs.clear();
54    for (ArcIterator< MutableFst<Arc> > aiter(*fst, s);
55         !aiter.Done();
56         aiter.Next())
57      arcs.push_back(aiter.Value());
58    sort(arcs.begin(), arcs.end(), comp);
59    fst->DeleteArcs(s);
60    for (size_t a = 0; a < arcs.size(); ++a)
61      fst->AddArc(s, arcs[a]);
62  }
63
64  fst->SetProperties(comp.Properties(props), kFstProperties);
65}
66
67typedef CacheOptions ArcSortFstOptions;
68
69// Implementation of delayed ArcSortFst.
70template<class A, class C>
71class ArcSortFstImpl : public CacheImpl<A> {
72 public:
73  using FstImpl<A>::SetType;
74  using FstImpl<A>::SetProperties;
75  using FstImpl<A>::Properties;
76  using FstImpl<A>::SetInputSymbols;
77  using FstImpl<A>::SetOutputSymbols;
78  using FstImpl<A>::InputSymbols;
79  using FstImpl<A>::OutputSymbols;
80
81  using VectorFstBaseImpl<typename CacheImpl<A>::State>::NumStates;
82
83  using CacheImpl<A>::HasArcs;
84  using CacheImpl<A>::HasFinal;
85  using CacheImpl<A>::HasStart;
86
87  typedef typename A::Weight Weight;
88  typedef typename A::StateId StateId;
89
90  ArcSortFstImpl(const Fst<A> &fst, const C &comp,
91                 const ArcSortFstOptions &opts)
92      : CacheImpl<A>(opts), fst_(fst.Copy()), comp_(comp) {
93    SetType("arcsort");
94    uint64 props = fst_->Properties(kCopyProperties, false);
95    SetProperties(comp_.Properties(props));
96    SetInputSymbols(fst.InputSymbols());
97    SetOutputSymbols(fst.OutputSymbols());
98  }
99
100  ArcSortFstImpl(const ArcSortFstImpl& impl)
101      : fst_(impl.fst_->Copy()), comp_(impl.comp_) {
102    SetType("arcsort");
103    SetProperties(impl.Properties(), kCopyProperties);
104    SetInputSymbols(impl.InputSymbols());
105    SetOutputSymbols(impl.OutputSymbols());
106  }
107
108  ~ArcSortFstImpl() { delete fst_; }
109
110  StateId Start() {
111    if (!HasStart())
112      SetStart(fst_->Start());
113    return CacheImpl<A>::Start();
114  }
115
116  Weight Final(StateId s) {
117    if (!HasFinal(s))
118      SetFinal(s, fst_->Final(s));
119    return CacheImpl<A>::Final(s);
120  }
121
122  size_t NumArcs(StateId s) {
123    if (!HasArcs(s))
124      Expand(s);
125    return CacheImpl<A>::NumArcs(s);
126  }
127
128  size_t NumInputEpsilons(StateId s) {
129    if (!HasArcs(s))
130      Expand(s);
131    return CacheImpl<A>::NumInputEpsilons(s);
132  }
133
134  size_t NumOutputEpsilons(StateId s) {
135    if (!HasArcs(s))
136      Expand(s);
137    return CacheImpl<A>::NumOutputEpsilons(s);
138  }
139
140  void InitStateIterator(StateIteratorData<A> *data) const {
141    fst_->InitStateIterator(data);
142  }
143
144  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
145    if (!HasArcs(s))
146      Expand(s);
147    CacheImpl<A>::InitArcIterator(s, data);
148  }
149
150  void Expand(StateId s) {
151    for (ArcIterator< Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next())
152      AddArc(s, aiter.Value());
153    SetArcs(s);
154
155    if (s < NumStates()) {  // ensure state exists
156      vector<A> &carcs = GetState(s)->arcs;
157      sort(carcs.begin(), carcs.end(), comp_);
158    }
159  }
160
161 private:
162  const Fst<A> *fst_;
163  C comp_;
164
165  void operator=(const ArcSortFstImpl<A, C> &impl);  // Disallow
166};
167
168
169// Sorts the arcs in an FST according to function object 'comp' of
170// type Compare. This version is a delayed Fst.  Comparsion function
171// objects IlabelCompare and OlabelCompare are provided by the
172// library. In general, Compare must meet the requirements for an STL
173// comparision function object (e.g. as used for STL sort). It must
174// also have a member Properties(uint64) that specifies the known
175// properties of the sorted FST; it takes as argument the input FST's
176// known properties.
177//
178// Complexity:
179// - Time: O(v + d log d)
180// - Space: O(v + d)
181// where v = # of states visited, d = maximum out-degree of states
182// visited. Constant time and space to visit an input state is assumed
183// and exclusive of caching.
184template <class A, class C>
185class ArcSortFst : public Fst<A> {
186 public:
187  friend class CacheArcIterator< ArcSortFst<A, C> >;
188  friend class ArcIterator< ArcSortFst<A, C> >;
189
190  typedef A Arc;
191  typedef C Compare;
192  typedef typename A::Weight Weight;
193  typedef typename A::StateId StateId;
194  typedef CacheState<A> State;
195
196  ArcSortFst(const Fst<A> &fst, const C &comp)
197      : impl_(new ArcSortFstImpl<A, C>(fst, comp, ArcSortFstOptions())) {}
198
199  ArcSortFst(const Fst<A> &fst, const C &comp, const ArcSortFstOptions &opts)
200      : impl_(new ArcSortFstImpl<A, C>(fst, comp, opts)) {}
201
202  ArcSortFst(const ArcSortFst<A, C> &fst) :
203      impl_(new ArcSortFstImpl<A, C>(*(fst.impl_))) {}
204
205  virtual ~ArcSortFst() { if (!impl_->DecrRefCount()) delete impl_; }
206
207  virtual StateId Start() const { return impl_->Start(); }
208
209  virtual Weight Final(StateId s) const { return impl_->Final(s); }
210
211  virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
212
213  virtual size_t NumInputEpsilons(StateId s) const {
214    return impl_->NumInputEpsilons(s);
215  }
216
217  virtual size_t NumOutputEpsilons(StateId s) const {
218    return impl_->NumOutputEpsilons(s);
219  }
220
221  virtual uint64 Properties(uint64 mask, bool test) const {
222    if (test) {
223      uint64 known, test = TestProperties(*this, mask, &known);
224      impl_->SetProperties(test, known);
225      return test & mask;
226    } else {
227      return impl_->Properties(mask);
228    }
229  }
230
231  virtual const string& Type() const { return impl_->Type(); }
232
233  virtual ArcSortFst<A, C> *Copy() const {
234    return new ArcSortFst<A, C>(*this);
235  }
236
237  virtual const SymbolTable* InputSymbols() const {
238    return impl_->InputSymbols();
239  }
240
241  virtual const SymbolTable* OutputSymbols() const {
242    return impl_->OutputSymbols();
243  }
244
245  virtual void InitStateIterator(StateIteratorData<A> *data) const {
246    impl_->InitStateIterator(data);
247  }
248
249  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
250    impl_->InitArcIterator(s, data);
251  }
252
253 private:
254  ArcSortFstImpl<A, C> *impl_;
255
256  void operator=(const ArcSortFst<A, C> &fst);  // Disallow
257};
258
259
260// Specialization for ArcSortFst.
261template <class A, class C>
262class ArcIterator< ArcSortFst<A, C> >
263    : public CacheArcIterator< ArcSortFst<A, C> > {
264 public:
265  typedef typename A::StateId StateId;
266
267  ArcIterator(const ArcSortFst<A, C> &fst, StateId s)
268      : CacheArcIterator< ArcSortFst<A, C> >(fst, s) {
269    if (!fst.impl_->HasArcs(s))
270      fst.impl_->Expand(s);
271  }
272
273 private:
274  DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
275};
276
277
278// Compare class for comparing input labels of arcs.
279template<class A> class ILabelCompare {
280 public:
281  bool operator() (A arc1, A arc2) const {
282    return arc1.ilabel < arc2.ilabel;
283  }
284
285  uint64 Properties(uint64 props) const {
286    return (props & kArcSortProperties) | kILabelSorted;
287  }
288};
289
290
291// Compare class for comparing output labels of arcs.
292template<class A> class OLabelCompare {
293 public:
294  bool operator() (const A &arc1, const A &arc2) const {
295    return arc1.olabel < arc2.olabel;
296  }
297
298  uint64 Properties(uint64 props) const {
299    return (props & kArcSortProperties) | kOLabelSorted;
300  }
301};
302
303
304// Useful aliases when using StdArc.
305template<class C> class StdArcSortFst : public ArcSortFst<StdArc, C> {
306 public:
307  typedef StdArc Arc;
308  typedef C Compare;
309};
310
311typedef ILabelCompare<StdArc> StdILabelCompare;
312
313typedef OLabelCompare<StdArc> StdOLabelCompare;
314
315}  // namespace fst
316
317#endif  // FST_LIB_ARCSORT_H__
318