fst.h revision 4a68b3365c8c50aa93505e99ead2565ab73dcdb0
1// fst.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// Finite-State Transducer (FST) - abstract base class definition,
18// state and arc iterator interface, and suggested base implementation.
19
20#ifndef FST_LIB_FST_H__
21#define FST_LIB_FST_H__
22
23#include "fst/lib/arc.h"
24#include "fst/lib/compat.h"
25#include "fst/lib/properties.h"
26#include "fst/lib/register.h"
27#include "fst/lib/symbol-table.h"
28#include "fst/lib/util.h"
29
30namespace fst {
31
32class FstHeader;
33template <class A> class StateIteratorData;
34template <class A> class ArcIteratorData;
35
36struct FstReadOptions  {
37  string source;                // Where you're reading from
38  const FstHeader *header;      // Pointer to Fst header (if non-zero)
39  const SymbolTable* isymbols;  // Pointer to input symbols (if non-zero)
40  const SymbolTable* osymbols;  // Pointer to output symbols (if non-zero)
41
42  explicit FstReadOptions(const string& src = "<unspecified>",
43                          const FstHeader *hdr = 0,
44                          const SymbolTable* isym = 0,
45                          const SymbolTable* osym = 0)
46      : source(src), header(hdr), isymbols(isym), osymbols(osym) {}
47};
48
49
50struct FstWriteOptions {
51  string source;                    // Where you're writing to
52  bool write_header;                // Write the header?
53  bool write_isymbols;              // Write input symbols?
54  bool write_osymbols;              // Write output symbols?
55
56  explicit FstWriteOptions(const string& src = "<unspecifed>",
57                           bool hdr = true, bool isym = true,
58                           bool osym = true)
59      : source(src), write_header(hdr),
60        write_isymbols(isym),  write_osymbols(osym) {}
61};
62
63//
64// Fst HEADER CLASS
65//
66// This is the recommended Fst file header representation.
67//
68
69class FstHeader {
70 public:
71  enum {
72    HAS_ISYMBOLS = 1,                           // Has input symbol table
73    HAS_OSYMBOLS = 2                            // Has output symbol table
74  } Flags;
75
76  FstHeader() : version_(0), flags_(0), properties_(0), start_(-1),
77                numstates_(0), numarcs_(0) {}
78  const string &FstType() const { return fsttype_; }
79  const string &ArcType() const { return arctype_; }
80  int32 Version() const { return version_; }
81  int32 GetFlags() const { return flags_; }
82  uint64 Properties() const { return properties_; }
83  int64 Start() const { return start_; }
84  int64 NumStates() const { return numstates_; }
85  int64 NumArcs() const { return numarcs_; }
86
87  void SetFstType(const string& type) { fsttype_ = type; }
88  void SetArcType(const string& type) { arctype_ = type; }
89  void SetVersion(int32 version) { version_ = version; }
90  void SetFlags(int32 flags) { flags_ = flags; }
91  void SetProperties(uint64 properties) { properties_ = properties; }
92  void SetStart(int64 start) { start_ = start; }
93  void SetNumStates(int64 numstates) { numstates_ = numstates; }
94  void SetNumArcs(int64 numarcs) { numarcs_ = numarcs; }
95
96  bool Read(istream &strm, const string &source);
97  bool Write(ostream &strm, const string &source) const;
98
99 private:
100  string fsttype_;                   // E.g. "vector"
101  string arctype_;                   // E.g. "standard"
102  int32 version_;                    // Type version #
103  int32 flags_;                      // File format bits
104  uint64 properties_;                // FST property bits
105  int64 start_;                      // Start state
106  int64 numstates_;                  // # of states
107  int64 numarcs_;                    // # of arcs
108};
109
110//
111// Fst INTERFACE CLASS DEFINITION
112//
113
114// A generic FST, templated on the arc definition, with
115// common-demoninator methods (use StateIterator and ArcIterator to
116// iterate over its states and arcs).
117template <class A>
118class Fst {
119 public:
120  typedef A Arc;
121  typedef typename A::Weight Weight;
122  typedef typename A::StateId StateId;
123
124  virtual ~Fst() {}
125
126  virtual StateId Start() const = 0;          // Initial state
127
128  virtual Weight Final(StateId) const = 0;    // State's final weight
129
130  virtual size_t NumArcs(StateId) const = 0;  // State's arc count
131
132  virtual size_t NumInputEpsilons(StateId)
133      const = 0;                              // State's input epsilon count
134
135  virtual size_t NumOutputEpsilons(StateId)
136      const = 0;                              // State's output epsilon count
137
138  // If test=false, return stored properties bits for mask (some poss. unknown)
139  // If test=true, return property bits for mask (computing o.w. unknown)
140  virtual uint64 Properties(uint64 mask, bool test)
141      const = 0;  // Property bits
142
143  virtual const string& Type() const = 0;    // Fst type name
144
145  // Get a copy of this Fst.
146  virtual Fst<A> *Copy() const = 0;
147  // Read an Fst from an input stream; returns NULL on error
148
149  static Fst<A> *Read(istream &strm, const FstReadOptions &opts) {
150    FstReadOptions ropts(opts);
151    FstHeader hdr;
152    if (ropts.header)
153      hdr = *opts.header;
154    else {
155      if (!hdr.Read(strm, opts.source))
156        return 0;
157      ropts.header = &hdr;
158    }
159    FstRegister<A> *registr = FstRegister<A>::GetRegister();
160    const typename FstRegister<A>::Reader reader =
161        registr->GetReader(hdr.FstType());
162    if (!reader) {
163      LOG(ERROR) << "Fst::Read: Unknown FST type \"" << hdr.FstType()
164                 << "\" (arc type = \"" << A::Type()
165                 << "\"): " << ropts.source;
166      return 0;
167    }
168    return reader(strm, ropts);
169  };
170
171  // Read an Fst from a file; return NULL on error
172  static Fst<A> *Read(const string &filename) {
173    ifstream strm(filename.c_str());
174    if (!strm) {
175      LOG(ERROR) << "Fst::Read: Can't open file: " << filename;
176      return 0;
177    }
178    return Read(strm, FstReadOptions(filename));
179  }
180
181  // Write an Fst to an output stream; return false on error
182  virtual bool Write(ostream &strm, const FstWriteOptions &opts) const {
183    LOG(ERROR) << "Fst::Write: No write method for " << Type() << " Fst type";
184    return false;
185  }
186
187  // Write an Fst to a file; return false on error
188  virtual bool Write(const string &filename) const {
189    LOG(ERROR) << "Fst::Write: No write method for "
190               << Type() << " Fst type: "
191               << (filename.empty() ? "standard output" : filename);
192    return false;
193  }
194
195  // Return input label symbol table; return NULL if not specified
196  virtual const SymbolTable* InputSymbols() const = 0;
197
198  // Return output label symbol table; return NULL if not specified
199  virtual const SymbolTable* OutputSymbols() const = 0;
200
201  // For generic state iterator construction; not normally called
202  // directly by users.
203  virtual void InitStateIterator(StateIteratorData<A> *) const = 0;
204
205  // For generic arc iterator construction; not normally called
206  // directly by users.
207  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *) const = 0;
208};
209
210
211//
212// STATE and ARC ITERATOR DEFINITIONS
213//
214
215// State iterator interface templated on the Arc definition; used
216// for StateIterator specializations returned by InitStateIterator.
217template <class A>
218class StateIteratorBase {
219 public:
220  typedef A Arc;
221  typedef typename A::StateId StateId;
222
223  virtual ~StateIteratorBase() {}
224  virtual bool Done() const = 0;      // End of iterator?
225  virtual StateId Value() const = 0;  // Current state (when !Done)
226  virtual void Next() = 0;            // Advance to next state (when !Done)
227  virtual void Reset() = 0;           // Return to initial condition
228};
229
230
231// StateIterator initialization data
232template <class A> struct StateIteratorData {
233  StateIteratorBase<A> *base;   // Specialized iterator if non-zero
234  typename A::StateId nstates;  // O.w. total # of states
235};
236
237
238// Generic state iterator, templated on the FST definition
239// - a wrapper around pointer to specific one.
240// Here is a typical use: \code
241//   for (StateIterator<StdFst> siter(fst);
242//        !siter.Done();
243//        siter.Next()) {
244//     StateId s = siter.Value();
245//     ...
246//   } \endcode
247template <class F>
248class StateIterator {
249 public:
250  typedef typename F::Arc Arc;
251  typedef typename Arc::StateId StateId;
252
253  explicit StateIterator(const F &fst) : s_(0) {
254    fst.InitStateIterator(&data_);
255  }
256
257  ~StateIterator() { if (data_.base) delete data_.base; }
258
259  bool Done() const {
260    return data_.base ? data_.base->Done() : s_ >= data_.nstates;
261  }
262
263  StateId Value() const { return data_.base ? data_.base->Value() : s_; }
264
265  void Next() {
266    if (data_.base)
267      data_.base->Next();
268    else
269      ++s_;
270  }
271
272  void Reset() {
273    if (data_.base)
274      data_.base->Reset();
275    else
276      s_ = 0;
277  }
278
279 private:
280  StateIteratorData<Arc> data_;
281  StateId s_;
282  DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
283};
284
285
286// Arc iterator interface, templated on the Arc definition; used
287// for Arc iterator specializations that are returned by InitArcIterator.
288template <class A>
289class ArcIteratorBase {
290 public:
291  typedef A Arc;
292  typedef typename A::StateId StateId;
293
294  virtual ~ArcIteratorBase() {}
295  virtual bool Done() const = 0;       // End of iterator?
296  virtual const A& Value() const = 0;  // Current state (when !Done)
297  virtual void Next() = 0;             // Advance to next arc (when !Done)
298  virtual void Reset() = 0;            // Return to initial condition
299  virtual void Seek(size_t a) = 0;     // Random arc access by position
300};
301
302
303// ArcIterator initialization data
304template <class A> struct ArcIteratorData {
305  ArcIteratorBase<A> *base;  // Specialized iterator if non-zero
306  const A *arcs;             // O.w. arcs pointer
307  size_t narcs;              // ... and arc count
308  int *ref_count;            // ... and reference count if non-zero
309};
310
311
312// Generic arc iterator, templated on the FST definition
313// - a wrapper around pointer to specific one.
314// Here is a typical use: \code
315//   for (ArcIterator<StdFst> aiter(fst, s));
316//        !aiter.Done();
317//         aiter.Next()) {
318//     StdArc &arc = aiter.Value();
319//     ...
320//   } \endcode
321template <class F>
322class ArcIterator {
323   public:
324  typedef typename F::Arc Arc;
325  typedef typename Arc::StateId StateId;
326
327  ArcIterator(const F &fst, StateId s) : i_(0) {
328    fst.InitArcIterator(s, &data_);
329  }
330
331  ~ArcIterator() {
332    if (data_.base)
333      delete data_.base;
334    else if (data_.ref_count)
335    --(*data_.ref_count);
336  }
337
338  bool Done() const {
339    return data_.base ?  data_.base->Done() : i_ >= data_.narcs;
340  }
341
342  const Arc& Value() const {
343    return data_.base ? data_.base->Value() : data_.arcs[i_];
344  }
345
346  void Next() {
347    if (data_.base)
348      data_.base->Next();
349    else
350      ++i_;
351  }
352
353  void Reset() {
354    if (data_.base)
355      data_.base->Reset();
356    else
357      i_ = 0;
358  }
359
360  void Seek(size_t a) {
361    if (data_.base)
362      data_.base->Seek(a);
363    else
364      i_ = a;
365  }
366
367 private:
368  ArcIteratorData<Arc> data_;
369  size_t i_;
370  DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
371};
372
373
374// A useful alias when using StdArc.
375typedef Fst<StdArc> StdFst;
376
377
378//
379//  CONSTANT DEFINITIONS
380//
381
382const int kNoStateId   =  -1;  // Not a valid state ID
383const int kNoLabel     =  -1;  // Not a valid label
384const int kPhiLabel    =  -2;  // Failure transition label
385const int kRhoLabel    =  -3;  // Matches o.w. unmatched labels (lib. internal)
386const int kSigmaLabel  =  -4;  // Matches all labels in alphabet.
387
388
389//
390// Fst IMPLEMENTATION BASE
391//
392// This is the recommended Fst implementation base class. It will
393// handle reference counts, property bits, type information and symbols.
394//
395
396template <class A> class FstImpl {
397 public:
398  typedef typename A::Weight Weight;
399  typedef typename A::StateId StateId;
400
401  FstImpl()
402      : properties_(0), type_("null"), isymbols_(0), osymbols_(0),
403        ref_count_(1) {}
404
405  FstImpl(const FstImpl<A> &impl)
406      : properties_(impl.properties_), type_(impl.type_),
407        isymbols_(impl.isymbols_ ? new SymbolTable(impl.isymbols_) : 0),
408        osymbols_(impl.osymbols_ ? new SymbolTable(impl.osymbols_) : 0),
409        ref_count_(1) {}
410
411  ~FstImpl() {
412    delete isymbols_;
413    delete osymbols_;
414  }
415
416  const string& Type() const { return type_; }
417
418  void SetType(const string &type) { type_ = type; }
419
420  uint64 Properties() const { return properties_; }
421
422  uint64 Properties(uint64 mask) const { return properties_ & mask; }
423
424  void SetProperties(uint64 props) { properties_ = props; }
425
426  void SetProperties(uint64 props, uint64 mask) {
427    properties_ &= ~mask;
428    properties_ |= props & mask;
429  }
430
431  const SymbolTable* InputSymbols() const { return isymbols_; }
432
433  const SymbolTable* OutputSymbols() const { return osymbols_; }
434
435  SymbolTable* InputSymbols() { return isymbols_; }
436
437  SymbolTable* OutputSymbols() { return osymbols_; }
438
439  void SetInputSymbols(const SymbolTable* isyms) {
440    if (isymbols_) delete isymbols_;
441    isymbols_ = isyms ? isyms->Copy() : 0;
442  }
443
444  void SetOutputSymbols(const SymbolTable* osyms) {
445    if (osymbols_) delete osymbols_;
446    osymbols_ = osyms ? osyms->Copy() : 0;
447  }
448
449  int RefCount() const { return ref_count_; }
450
451  int IncrRefCount() { return ++ref_count_; }
452
453  int DecrRefCount() { return --ref_count_; }
454
455  // Read-in header and symbols, initialize Fst, and return the header.
456  // If opts.header is non-null, skip read-in and use the option value.
457  // If opts.[io]symbols is non-null, read-in but use the option value.
458  bool ReadHeaderAndSymbols(istream &strm, const FstReadOptions& opts,
459                  int min_version, FstHeader *hdr) {
460    if (opts.header)
461      *hdr = *opts.header;
462    else if (!hdr->Read(strm, opts.source))
463      return false;
464    if (hdr->FstType() != type_) {
465      LOG(ERROR) << "FstImpl::ReadHeaderAndSymbols: Fst not of type \""
466                 << type_ << "\": " << opts.source;
467      return false;
468    }
469    if (hdr->ArcType() != A::Type()) {
470      LOG(ERROR) << "FstImpl::ReadHeaderAndSymbols: Arc not of type \""
471                 << A::Type()
472                 << "\": " << opts.source;
473      return false;
474    }
475    if (hdr->Version() < min_version) {
476      LOG(ERROR) << "FstImpl::ReadHeaderAndSymbols: Obsolete "
477                 << type_ << " Fst version: " << opts.source;
478      return false;
479    }
480    properties_ = hdr->Properties();
481    if (hdr->GetFlags() & FstHeader::HAS_ISYMBOLS)
482      isymbols_ = SymbolTable::Read(strm, opts.source);
483    if (hdr->GetFlags() & FstHeader::HAS_OSYMBOLS)
484      osymbols_ =SymbolTable::Read(strm, opts.source);
485
486    if (opts.isymbols) {
487      delete isymbols_;
488      isymbols_ = opts.isymbols->Copy();
489    }
490    if (opts.osymbols) {
491      delete osymbols_;
492      osymbols_ = opts.osymbols->Copy();
493    }
494    return true;
495  }
496
497  // Write-out header and symbols.
498  // If a opts.header is false, skip writing header.
499  // If opts.[io]symbols is false, skip writing those symbols.
500  void WriteHeaderAndSymbols(ostream &strm, const FstWriteOptions& opts,
501                             int version, FstHeader *hdr) const {
502    if (opts.write_header) {
503      hdr->SetFstType(type_);
504      hdr->SetArcType(A::Type());
505      hdr->SetVersion(version);
506      hdr->SetProperties(properties_);
507      int32 file_flags = 0;
508      if (isymbols_ && opts.write_isymbols)
509        file_flags |= FstHeader::HAS_ISYMBOLS;
510      if (osymbols_ && opts.write_osymbols)
511        file_flags |= FstHeader::HAS_OSYMBOLS;
512      hdr->SetFlags(file_flags);
513      hdr->Write(strm, opts.source);
514    }
515    if (isymbols_ && opts.write_isymbols) isymbols_->Write(strm);
516    if (osymbols_ && opts.write_osymbols) osymbols_->Write(strm);
517  }
518
519 protected:
520  uint64 properties_;           // Property bits
521
522 private:
523  string type_;                 // Unique name of Fst class
524  SymbolTable *isymbols_;       // Ilabel symbol table
525  SymbolTable *osymbols_;       // Olabel symbol table
526  int ref_count_;               // Reference count
527
528  void operator=(const FstImpl<A> &impl);  // disallow
529};
530
531}  // namespace fst;
532
533#endif  // FST_LIB_FST_H__
534