1
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13//
14// Copyright 2005-2010 Google, Inc.
15// Author: jpr@google.com (Jake Ratkiewicz)
16
17#ifndef FST_SCRIPT_FST_CLASS_H_
18#define FST_SCRIPT_FST_CLASS_H_
19
20#include <string>
21
22#include <fst/fst.h>
23#include <fst/mutable-fst.h>
24#include <fst/vector-fst.h>
25#include <iostream>
26#include <fstream>
27#include <sstream>
28
29// Classes to support "boxing" all existing types of FST arcs in a single
30// FstClass which hides the arc types. This allows clients to load
31// and work with FSTs without knowing the arc type.
32
33// These classes are only recommended for use in high-level scripting
34// applications. Most users should use the lower-level templated versions
35// corresponding to these classes.
36
37namespace fst {
38namespace script {
39
40//
41// Abstract base class defining the set of functionalities implemented
42// in all impls, and passed through by all bases Below FstClassBase
43// the class hierarchy bifurcates; FstClassImplBase serves as the base
44// class for all implementations (of which FstClassImpl is currently
45// the only one) and FstClass serves as the base class for all
46// interfaces.
47//
48class FstClassBase {
49 public:
50  virtual const string &ArcType() const = 0;
51  virtual const string &FstType() const = 0;
52  virtual const string &WeightType() const = 0;
53  virtual const SymbolTable *InputSymbols() const = 0;
54  virtual const SymbolTable *OutputSymbols() const = 0;
55  virtual bool Write(const string& fname) const = 0;
56  virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const = 0;
57  virtual uint64 Properties(uint64 mask, bool test) const = 0;
58  virtual ~FstClassBase() { }
59};
60
61class FstClassImplBase : public FstClassBase {
62 public:
63  virtual FstClassImplBase *Copy() = 0;
64  virtual void SetInputSymbols(SymbolTable *is) = 0;
65  virtual void SetOutputSymbols(SymbolTable *is) = 0;
66  virtual ~FstClassImplBase() { }
67};
68
69
70//
71// CONTAINER CLASS
72// Wraps an Fst<Arc>, hiding its arc type. Whether this Fst<Arc>
73// pointer refers to a special kind of FST (e.g. a MutableFst) is
74// known by the type of interface class that owns the pointer to this
75// container.
76//
77
78template<class Arc>
79class FstClassImpl : public FstClassImplBase {
80 public:
81  explicit FstClassImpl(Fst<Arc> *impl,
82                        bool should_own = false) :
83      impl_(should_own ? impl : impl->Copy()) { }
84
85  explicit FstClassImpl(const Fst<Arc> &impl) : impl_(impl.Copy()) {  }
86
87  virtual const string &ArcType() const {
88    return Arc::Type();
89  }
90
91  virtual const string &FstType() const {
92    return impl_->Type();
93  }
94
95  virtual const string &WeightType() const {
96    return Arc::Weight::Type();
97  }
98
99  virtual const SymbolTable *InputSymbols() const {
100    return impl_->InputSymbols();
101  }
102
103  virtual const SymbolTable *OutputSymbols() const {
104    return impl_->OutputSymbols();
105  }
106
107  // Warning: calling this method casts the FST to a mutable FST.
108  virtual void SetInputSymbols(SymbolTable *is) {
109    static_cast<MutableFst<Arc> *>(impl_)->SetInputSymbols(is);
110  }
111
112  // Warning: calling this method casts the FST to a mutable FST.
113  virtual void SetOutputSymbols(SymbolTable *os) {
114    static_cast<MutableFst<Arc> *>(impl_)->SetOutputSymbols(os);
115  }
116
117  virtual bool Write(const string &fname) const {
118    return impl_->Write(fname);
119  }
120
121  virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const {
122    return impl_->Write(ostr, opts);
123  }
124
125  virtual uint64 Properties(uint64 mask, bool test) const {
126    return impl_->Properties(mask, test);
127  }
128
129  virtual ~FstClassImpl() { delete impl_; }
130
131  Fst<Arc> *GetImpl() const { return impl_; }
132
133  Fst<Arc> *GetImpl() { return impl_; }
134
135  virtual FstClassImpl *Copy() {
136    return new FstClassImpl<Arc>(impl_);
137  }
138
139 private:
140  Fst<Arc> *impl_;
141};
142
143//
144// BASE CLASS DEFINITIONS
145//
146
147class MutableFstClass;
148
149class FstClass : public FstClassBase {
150 public:
151  template<class Arc>
152  static FstClass *Read(istream &stream,
153                        const FstReadOptions &opts) {
154    if (!opts.header) {
155      FSTERROR() << "FstClass::Read: options header not specified";
156      return 0;
157    }
158    const FstHeader &hdr = *opts.header;
159
160    if (hdr.Properties() & kMutable) {
161      return ReadTypedFst<MutableFstClass, MutableFst<Arc> >(stream, opts);
162    } else {
163      return ReadTypedFst<FstClass, Fst<Arc> >(stream, opts);
164    }
165  }
166
167  FstClass() : impl_(NULL) {
168  }
169
170  template<class Arc>
171  explicit FstClass(const Fst<Arc> &fst) : impl_(new FstClassImpl<Arc>(fst)) {
172  }
173
174  FstClass(const FstClass &other) : impl_(other.impl_->Copy()) { }
175
176  FstClass &operator=(const FstClass &other) {
177    delete impl_;
178    impl_ = other.impl_->Copy();
179    return *this;
180  }
181
182  static FstClass *Read(const string &fname);
183
184  static FstClass *Read(istream &istr, const string &source);
185
186  virtual const string &ArcType() const {
187    return impl_->ArcType();
188  }
189
190  virtual const string& FstType() const {
191    return impl_->FstType();
192  }
193
194  virtual const SymbolTable *InputSymbols() const {
195    return impl_->InputSymbols();
196  }
197
198  virtual const SymbolTable *OutputSymbols() const {
199    return impl_->OutputSymbols();
200  }
201
202  virtual const string& WeightType() const {
203    return impl_->WeightType();
204  }
205
206  virtual bool Write(const string &fname) const {
207    return impl_->Write(fname);
208  }
209
210  virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const {
211    return impl_->Write(ostr, opts);
212  }
213
214  virtual uint64 Properties(uint64 mask, bool test) const {
215    return impl_->Properties(mask, test);
216  }
217
218  template<class Arc>
219  const Fst<Arc> *GetFst() const {
220    if (Arc::Type() != ArcType()) {
221      return NULL;
222    } else {
223      FstClassImpl<Arc> *typed_impl = static_cast<FstClassImpl<Arc> *>(impl_);
224      return typed_impl->GetImpl();
225    }
226  }
227
228  virtual ~FstClass() { delete impl_; }
229
230  // These methods are required by IO registration
231  template<class Arc>
232  static FstClassImplBase *Convert(const FstClass &other) {
233    LOG(ERROR) << "Doesn't make sense to convert any class to type FstClass.";
234    return 0;
235  }
236
237  template<class Arc>
238  static FstClassImplBase *Create() {
239    LOG(ERROR) << "Doesn't make sense to create an FstClass with a "
240               << "particular arc type.";
241    return 0;
242  }
243
244
245 protected:
246  explicit FstClass(FstClassImplBase *impl) : impl_(impl) { }
247
248  // Generic template method for reading an arc-templated FST of type
249  // UnderlyingT, and returning it wrapped as FstClassT, with appropriate
250  // error checking. Called from arc-templated Read() static methods.
251  template<class FstClassT, class UnderlyingT>
252  static FstClassT* ReadTypedFst(istream &stream,
253                                     const FstReadOptions &opts) {
254    UnderlyingT *u = UnderlyingT::Read(stream, opts);
255    if (!u) {
256      return 0;
257    } else {
258      FstClassT *r = new FstClassT(*u);
259      delete u;
260      return r;
261    }
262  }
263
264  FstClassImplBase *GetImpl() const { return impl_; }
265
266  FstClassImplBase *GetImpl() { return impl_; }
267
268//  friend ostream &operator<<(ostream&, const FstClass&);
269
270 private:
271  FstClassImplBase *impl_;
272};
273
274//
275// Specific types of FstClass with special properties
276//
277
278class MutableFstClass : public FstClass {
279 public:
280  template<class Arc>
281  explicit MutableFstClass(const MutableFst<Arc> &fst) :
282      FstClass(fst) { }
283
284  template<class Arc>
285  MutableFst<Arc> *GetMutableFst() {
286    Fst<Arc> *fst = const_cast<Fst<Arc> *>(this->GetFst<Arc>());
287    MutableFst<Arc> *mfst = static_cast<MutableFst<Arc> *>(fst);
288
289    return mfst;
290  }
291
292  template<class Arc>
293  static MutableFstClass *Read(istream &stream,
294                               const FstReadOptions &opts) {
295    MutableFst<Arc> *mfst = MutableFst<Arc>::Read(stream, opts);
296    if (!mfst) {
297      return 0;
298    } else {
299      MutableFstClass *retval = new MutableFstClass(*mfst);
300      delete mfst;
301      return retval;
302    }
303  }
304
305  virtual bool Write(const string &fname) const {
306    return GetImpl()->Write(fname);
307  }
308
309  virtual bool Write(ostream &ostr, const FstWriteOptions &opts) const {
310    return GetImpl()->Write(ostr, opts);
311  }
312
313  static MutableFstClass *Read(const string &fname, bool convert = false);
314
315  virtual void SetInputSymbols(SymbolTable *is) {
316    GetImpl()->SetInputSymbols(is);
317  }
318
319  virtual void SetOutputSymbols(SymbolTable *os) {
320    GetImpl()->SetOutputSymbols(os);
321  }
322
323  // These methods are required by IO registration
324  template<class Arc>
325  static FstClassImplBase *Convert(const FstClass &other) {
326    LOG(ERROR) << "Doesn't make sense to convert any class to type "
327               << "MutableFstClass.";
328    return 0;
329  }
330
331  template<class Arc>
332  static FstClassImplBase *Create() {
333    LOG(ERROR) << "Doesn't make sense to create a MutableFstClass with a "
334               << "particular arc type.";
335    return 0;
336  }
337
338 protected:
339  explicit MutableFstClass(FstClassImplBase *impl) : FstClass(impl) { }
340};
341
342
343class VectorFstClass : public MutableFstClass {
344 public:
345  explicit VectorFstClass(const FstClass &other);
346  explicit VectorFstClass(const string &arc_type);
347
348  template<class Arc>
349  explicit VectorFstClass(const VectorFst<Arc> &fst) :
350      MutableFstClass(fst) { }
351
352  template<class Arc>
353  static VectorFstClass *Read(istream &stream,
354                              const FstReadOptions &opts) {
355    VectorFst<Arc> *vfst = VectorFst<Arc>::Read(stream, opts);
356    if (!vfst) {
357      return 0;
358    } else {
359      VectorFstClass *retval = new VectorFstClass(*vfst);
360      delete vfst;
361      return retval;
362    }
363  }
364
365  static VectorFstClass *Read(const string &fname);
366
367  // Converter / creator for known arc types
368  template<class Arc>
369  static FstClassImplBase *Convert(const FstClass &other) {
370    return new FstClassImpl<Arc>(new VectorFst<Arc>(
371        *other.GetFst<Arc>()), true);
372  }
373
374  template<class Arc>
375  static FstClassImplBase *Create() {
376    return new FstClassImpl<Arc>(new VectorFst<Arc>(), true);
377  }
378};
379
380}  // namespace script
381}  // namespace fst
382#endif  // FST_SCRIPT_FST_CLASS_H_
383