fst-class.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
1
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13//
14// Copyright 2005-2010 Google, Inc.
15// Author: jpr@google.com (Jake Ratkiewicz)
16
17#ifndef FST_SCRIPT_FST_CLASS_H_
18#define FST_SCRIPT_FST_CLASS_H_
19
20#include <string>
21
22#include <fst/fst.h>
23#include <fst/mutable-fst.h>
24#include <fst/vector-fst.h>
25#include <iostream>
26#include <fstream>
27
28// Classes to support "boxing" all existing types of FST arcs in a single
29// FstClass which hides the arc types. This allows clients to load
30// and work with FSTs without knowing the arc type.
31
32// These classes are only recommended for use in high-level scripting
33// applications. Most users should use the lower-level templated versions
34// corresponding to these classes.
35
36namespace fst {
37namespace script {
38
39//
40// Abstract base class defining the set of functionalities implemented
41// in all impls, and passed through by all bases Below FstClassBase
42// the class hierarchy bifurcates; FstClassImplBase serves as the base
43// class for all implementations (of which FstClassImpl is currently
44// the only one) and FstClass serves as the base class for all
45// interfaces.
46//
47class FstClassBase {
48 public:
49  virtual const string &ArcType() const = 0;
50  virtual const string &FstType() const = 0;
51  virtual const string &WeightType() const = 0;
52  virtual const SymbolTable *InputSymbols() const = 0;
53  virtual const SymbolTable *OutputSymbols() const = 0;
54  virtual void Write(const string& fname) const = 0;
55  virtual uint64 Properties(uint64 mask, bool test) const = 0;
56  virtual ~FstClassBase() { }
57};
58
59class FstClassImplBase : public FstClassBase {
60 public:
61  virtual FstClassImplBase *Copy() = 0;
62  virtual void SetInputSymbols(SymbolTable *is) = 0;
63  virtual void SetOutputSymbols(SymbolTable *is) = 0;
64  virtual ~FstClassImplBase() { }
65};
66
67
68//
69// CONTAINER CLASS
70// Wraps an Fst<Arc>, hiding its arc type. Whether this Fst<Arc>
71// pointer refers to a special kind of FST (e.g. a MutableFst) is
72// known by the type of interface class that owns the pointer to this
73// container.
74//
75
76template<class Arc>
77class FstClassImpl : public FstClassImplBase {
78 public:
79  explicit FstClassImpl(Fst<Arc> *impl,
80                        bool should_own = false) :
81      impl_(should_own ? impl : impl->Copy()) { }
82
83  virtual const string &ArcType() const {
84    return Arc::Type();
85  }
86
87  virtual const string &FstType() const {
88    return impl_->Type();
89  }
90
91  virtual const string &WeightType() const {
92    return Arc::Weight::Type();
93  }
94
95  virtual const SymbolTable *InputSymbols() const {
96    return impl_->InputSymbols();
97  }
98
99  virtual const SymbolTable *OutputSymbols() const {
100    return impl_->OutputSymbols();
101  }
102
103  // Warning: calling this method casts the FST to a mutable FST.
104  virtual void SetInputSymbols(SymbolTable *is) {
105    static_cast<MutableFst<Arc> *>(impl_)->SetInputSymbols(is);
106  }
107
108  // Warning: calling this method casts the FST to a mutable FST.
109  virtual void SetOutputSymbols(SymbolTable *os) {
110    static_cast<MutableFst<Arc> *>(impl_)->SetOutputSymbols(os);
111  }
112
113  virtual void Write(const string &fname) const {
114    impl_->Write(fname);
115  }
116
117  virtual uint64 Properties(uint64 mask, bool test) const {
118    return impl_->Properties(mask, test);
119  }
120
121  virtual ~FstClassImpl() { delete impl_; }
122
123  Fst<Arc> *GetImpl() { return impl_; }
124
125  virtual FstClassImpl *Copy() {
126    return new FstClassImpl<Arc>(impl_);
127  }
128
129 private:
130  Fst<Arc> *impl_;
131};
132
133//
134// BASE CLASS DEFINITIONS
135//
136
137class MutableFstClass;
138
139class FstClass : public FstClassBase {
140 public:
141  template<class Arc>
142  static FstClass *Read(istream &stream,
143                        const FstReadOptions &opts) {
144    if (!opts.header) {
145      FSTERROR() << "FstClass::Read: options header not specified";
146      return 0;
147    }
148    const FstHeader &hdr = *opts.header;
149
150    if (hdr.Properties() & kMutable) {
151      return ReadTypedFst<MutableFstClass, MutableFst<Arc> >(stream, opts);
152    } else {
153      return ReadTypedFst<FstClass, Fst<Arc> >(stream, opts);
154    }
155  }
156
157  template<class Arc>
158  explicit FstClass(Fst<Arc> *fst) : impl_(new FstClassImpl<Arc>(fst)) { }
159
160  explicit FstClass(const FstClass &other) : impl_(other.impl_->Copy()) { }
161
162  static FstClass *Read(const string &fname);
163
164  virtual const string &ArcType() const {
165    return impl_->ArcType();
166  }
167
168  virtual const string& FstType() const {
169    return impl_->FstType();
170  }
171
172  virtual const SymbolTable *InputSymbols() const {
173    return impl_->InputSymbols();
174  }
175
176  virtual const SymbolTable *OutputSymbols() const {
177    return impl_->OutputSymbols();
178  }
179
180  virtual const string& WeightType() const {
181    return impl_->WeightType();
182  }
183
184  virtual void Write(const string &fname) const {
185    impl_->Write(fname);
186  }
187
188  virtual uint64 Properties(uint64 mask, bool test) const {
189    return impl_->Properties(mask, test);
190  }
191
192  template<class Arc>
193  const Fst<Arc> *GetFst() const {
194    if (Arc::Type() != ArcType()) {
195      return NULL;
196    } else {
197      FstClassImpl<Arc> *typed_impl = static_cast<FstClassImpl<Arc> *>(impl_);
198      return typed_impl->GetImpl();
199    }
200  }
201
202  virtual ~FstClass() { delete impl_; }
203
204  // These methods are required by IO registration
205  template<class Arc>
206  static FstClassImplBase *Convert(const FstClass &other) {
207    LOG(ERROR) << "Doesn't make sense to convert any class to type FstClass.";
208    return 0;
209  }
210
211  template<class Arc>
212  static FstClassImplBase *Create() {
213    LOG(ERROR) << "Doesn't make sense to create an FstClass with a "
214               << "particular arc type.";
215    return 0;
216  }
217 protected:
218  explicit FstClass(FstClassImplBase *impl) : impl_(impl) { }
219
220  // Generic template method for reading an arc-templated FST of type
221  // UnderlyingT, and returning it wrapped as FstClassT, with appropriate
222  // error checking. Called from arc-templated Read() static methods.
223  template<class FstClassT, class UnderlyingT>
224  static FstClassT* ReadTypedFst(istream &stream,
225                                     const FstReadOptions &opts) {
226    UnderlyingT *u = UnderlyingT::Read(stream, opts);
227    if (!u) {
228      return 0;
229    } else {
230      FstClassT *r = new FstClassT(u);
231      delete u;
232      return r;
233    }
234  }
235
236  FstClassImplBase *GetImpl() { return impl_; }
237 private:
238  FstClassImplBase *impl_;
239};
240
241//
242// Specific types of FstClass with special properties
243//
244
245class MutableFstClass : public FstClass {
246 public:
247  template<class Arc>
248  explicit MutableFstClass(MutableFst<Arc> *fst) :
249      FstClass(fst) { }
250
251  template<class Arc>
252  MutableFst<Arc> *GetMutableFst() {
253    Fst<Arc> *fst = const_cast<Fst<Arc> *>(this->GetFst<Arc>());
254    MutableFst<Arc> *mfst = static_cast<MutableFst<Arc> *>(fst);
255
256    return mfst;
257  }
258
259  template<class Arc>
260  static MutableFstClass *Read(istream &stream,
261                               const FstReadOptions &opts) {
262    MutableFst<Arc> *mfst = MutableFst<Arc>::Read(stream, opts);
263    if (!mfst) {
264      return 0;
265    } else {
266      MutableFstClass *retval = new MutableFstClass(mfst);
267      delete mfst;
268      return retval;
269    }
270  }
271
272  static MutableFstClass *Read(const string &fname, bool convert = false);
273
274  virtual void SetInputSymbols(SymbolTable *is) {
275    GetImpl()->SetInputSymbols(is);
276  }
277
278  virtual void SetOutputSymbols(SymbolTable *os) {
279    GetImpl()->SetOutputSymbols(os);
280  }
281
282  // These methods are required by IO registration
283  template<class Arc>
284  static FstClassImplBase *Convert(const FstClass &other) {
285    LOG(ERROR) << "Doesn't make sense to convert any class to type "
286               << "MutableFstClass.";
287    return 0;
288  }
289
290  template<class Arc>
291  static FstClassImplBase *Create() {
292    LOG(ERROR) << "Doesn't make sense to create a MutableFstClass with a "
293               << "particular arc type.";
294    return 0;
295  }
296
297 protected:
298  explicit MutableFstClass(FstClassImplBase *impl) : FstClass(impl) { }
299};
300
301
302class VectorFstClass : public MutableFstClass {
303 public:
304  explicit VectorFstClass(const FstClass &other);
305  explicit VectorFstClass(const string &arc_type);
306
307  template<class Arc>
308  explicit VectorFstClass(VectorFst<Arc> *fst) :
309      MutableFstClass(fst) { }
310
311  template<class Arc>
312  static VectorFstClass *Read(istream &stream,
313                              const FstReadOptions &opts) {
314    VectorFst<Arc> *vfst = VectorFst<Arc>::Read(stream, opts);
315    if (!vfst) {
316      return 0;
317    } else {
318      VectorFstClass *retval = new VectorFstClass(vfst);
319      delete vfst;
320      return retval;
321    }
322  }
323
324  static VectorFstClass *Read(const string &fname);
325
326  // Converter / creator for known arc types
327  template<class Arc>
328  static FstClassImplBase *Convert(const FstClass &other) {
329    return new FstClassImpl<Arc>(new VectorFst<Arc>(
330        *other.GetFst<Arc>()), true);
331  }
332
333  template<class Arc>
334  static FstClassImplBase *Create() {
335    return new FstClassImpl<Arc>(new VectorFst<Arc>(), true);
336  }
337};
338
339}  // namespace script
340}  // namespace fst
341
342
343#endif  // FST_SCRIPT_FST_CLASS_H_
344