fst.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
1// fst.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Finite-State Transducer (FST) - abstract base class definition,
20// state and arc iterator interface, and suggested base implementation.
21//
22
23#ifndef FST_LIB_FST_H__
24#define FST_LIB_FST_H__
25
26#include <stddef.h>
27#include <sys/types.h>
28#include <cmath>
29#include <string>
30
31#include <fst/compat.h>
32#include <fst/types.h>
33
34#include <fst/arc.h>
35#include <fst/properties.h>
36#include <fst/register.h>
37#include <iostream>
38#include <fstream>
39#include <fst/symbol-table.h>
40#include <fst/util.h>
41
42
43DECLARE_bool(fst_align);
44
45namespace fst {
46
47bool IsFstHeader(istream &, const string &);
48
49class FstHeader;
50template <class A> class StateIteratorData;
51template <class A> class ArcIteratorData;
52template <class A> class MatcherBase;
53
54struct FstReadOptions {
55  string source;                // Where you're reading from
56  const FstHeader *header;      // Pointer to Fst header. If non-zero, use
57                                // this info (don't read a stream header)
58  const SymbolTable* isymbols;  // Pointer to input symbols. If non-zero, use
59                                // this info (read and skip stream isymbols)
60  const SymbolTable* osymbols;  // Pointer to output symbols. If non-zero, use
61                                // this info (read and skip stream osymbols)
62
63  explicit FstReadOptions(const string& src = "<unspecfied>",
64                          const FstHeader *hdr = 0,
65                          const SymbolTable* isym = 0,
66                          const SymbolTable* osym = 0)
67      : source(src), header(hdr), isymbols(isym), osymbols(osym) {}
68
69  explicit FstReadOptions(const string& src,
70                          const SymbolTable* isym,
71                          const SymbolTable* osym = 0)
72      : source(src), header(0), isymbols(isym), osymbols(osym) {}
73};
74
75
76struct FstWriteOptions {
77  string source;                 // Where you're writing to
78  bool write_header;             // Write the header?
79  bool write_isymbols;           // Write input symbols?
80  bool write_osymbols;           // Write output symbols?
81  bool align;                    // Write data aligned where appropriate;
82                                 // this may fail on pipes
83
84  explicit FstWriteOptions(const string& src = "<unspecifed>",
85                           bool hdr = true, bool isym = true,
86                           bool osym = true, bool alig = FLAGS_fst_align)
87      : source(src), write_header(hdr),
88        write_isymbols(isym), write_osymbols(osym), align(alig) {}
89};
90
91//
92// Fst HEADER CLASS
93//
94// This is the recommended Fst file header representation.
95//
96class FstHeader {
97 public:
98  enum {
99    HAS_ISYMBOLS = 0x1,          // Has input symbol table
100    HAS_OSYMBOLS = 0x2,          // Has output symbol table
101    IS_ALIGNED   = 0x4,          // Memory-aligned (where appropriate)
102  } Flags;
103
104  FstHeader() : version_(0), flags_(0), properties_(0), start_(-1),
105                numstates_(0), numarcs_(0) {}
106  const string &FstType() const { return fsttype_; }
107  const string &ArcType() const { return arctype_; }
108  int32 Version() const { return version_; }
109  int32 GetFlags() const { return flags_; }
110  uint64 Properties() const { return properties_; }
111  int64 Start() const { return start_; }
112  int64 NumStates() const { return numstates_; }
113  int64 NumArcs() const { return numarcs_; }
114
115  void SetFstType(const string& type) { fsttype_ = type; }
116  void SetArcType(const string& type) { arctype_ = type; }
117  void SetVersion(int32 version) { version_ = version; }
118  void SetFlags(int32 flags) { flags_ = flags; }
119  void SetProperties(uint64 properties) { properties_ = properties; }
120  void SetStart(int64 start) { start_ = start; }
121  void SetNumStates(int64 numstates) { numstates_ = numstates; }
122  void SetNumArcs(int64 numarcs) { numarcs_ = numarcs; }
123
124  bool Read(istream &strm, const string &source, bool rewind = false);
125  bool Write(ostream &strm, const string &source) const;
126
127 private:
128
129  string fsttype_;                   // E.g. "vector"
130  string arctype_;                   // E.g. "standard"
131  int32 version_;                    // Type version #
132  int32 flags_;                      // File format bits
133  uint64 properties_;                // FST property bits
134  int64 start_;                      // Start state
135  int64 numstates_;                  // # of states
136  int64 numarcs_;                    // # of arcs
137};
138
139
140// Specifies matcher action.
141enum MatchType { MATCH_INPUT,      // Match input label.
142                 MATCH_OUTPUT,     // Match output label.
143                 MATCH_BOTH,       // Match input or output label.
144                 MATCH_NONE,       // Match nothing.
145                 MATCH_UNKNOWN };  // Match type unknown.
146
147//
148// Fst INTERFACE CLASS DEFINITION
149//
150
151// A generic FST, templated on the arc definition, with
152// common-demoninator methods (use StateIterator and ArcIterator to
153// iterate over its states and arcs).
154template <class A>
155class Fst {
156 public:
157  typedef A Arc;
158  typedef typename A::Weight Weight;
159  typedef typename A::StateId StateId;
160
161  virtual ~Fst() {}
162
163  virtual StateId Start() const = 0;          // Initial state
164
165  virtual Weight Final(StateId) const = 0;    // State's final weight
166
167  virtual size_t NumArcs(StateId) const = 0;  // State's arc count
168
169  virtual size_t NumInputEpsilons(StateId)
170      const = 0;                              // State's input epsilon count
171
172  virtual size_t NumOutputEpsilons(StateId)
173      const = 0;                              // State's output epsilon count
174
175  // If test=false, return stored properties bits for mask (some poss. unknown)
176  // If test=true, return property bits for mask (computing o.w. unknown)
177  virtual uint64 Properties(uint64 mask, bool test)
178      const = 0;  // Property bits
179
180  virtual const string& Type() const = 0;    // Fst type name
181
182  // Get a copy of this Fst. The copying behaves as follows:
183  //
184  // (1) The copying is constant time if safe = false or if safe = true
185  // and is on an otherwise unaccessed Fst.
186  //
187  // (2) If safe = true, the copy is thread-safe in that the original
188  // and copy can be safely accessed (but not necessarily mutated) by
189  // separate threads. For some Fst types, 'Copy(true)' should only be
190  // called on an Fst that has not otherwise been accessed. Its behavior
191  // is undefined otherwise.
192  //
193  // (3) If a MutableFst is copied and then mutated, then the original is
194  // unmodified and vice versa (often by a copy-on-write on the initial
195  // mutation, which may not be constant time).
196  virtual Fst<A> *Copy(bool safe = false) const = 0;
197
198  // Read an Fst from an input stream; returns NULL on error
199  static Fst<A> *Read(istream &strm, const FstReadOptions &opts) {
200    FstReadOptions ropts(opts);
201    FstHeader hdr;
202    if (ropts.header)
203      hdr = *opts.header;
204    else {
205      if (!hdr.Read(strm, opts.source))
206        return 0;
207      ropts.header = &hdr;
208    }
209    FstRegister<A> *registr = FstRegister<A>::GetRegister();
210    const typename FstRegister<A>::Reader reader =
211      registr->GetReader(hdr.FstType());
212    if (!reader) {
213      LOG(ERROR) << "Fst::Read: Unknown FST type \"" << hdr.FstType()
214                 << "\" (arc type = \"" << A::Type()
215                 << "\"): " << ropts.source;
216      return 0;
217    }
218    return reader(strm, ropts);
219  };
220
221  // Read an Fst from a file; return NULL on error
222  // Empty filename reads from standard input
223  static Fst<A> *Read(const string &filename) {
224    if (!filename.empty()) {
225      ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
226      if (!strm) {
227        LOG(ERROR) << "Fst::Read: Can't open file: " << filename;
228        return 0;
229      }
230      return Read(strm, FstReadOptions(filename));
231    } else {
232      return Read(std::cin, FstReadOptions("standard input"));
233    }
234  }
235
236  // Write an Fst to an output stream; return false on error
237  virtual bool Write(ostream &strm, const FstWriteOptions &opts) const {
238    LOG(ERROR) << "Fst::Write: No write stream method for " << Type()
239               << " Fst type";
240    return false;
241  }
242
243  // Write an Fst to a file; return false on error
244  // Empty filename writes to standard output
245  virtual bool Write(const string &filename) const {
246    LOG(ERROR) << "Fst::Write: No write filename method for " << Type()
247               << " Fst type";
248    return false;
249  }
250
251  // Return input label symbol table; return NULL if not specified
252  virtual const SymbolTable* InputSymbols() const = 0;
253
254  // Return output label symbol table; return NULL if not specified
255  virtual const SymbolTable* OutputSymbols() const = 0;
256
257  // For generic state iterator construction; not normally called
258  // directly by users.
259  virtual void InitStateIterator(StateIteratorData<A> *) const = 0;
260
261  // For generic arc iterator construction; not normally called
262  // directly by users.
263  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *) const = 0;
264
265  // For generic matcher construction; not normally called
266  // directly by users.
267  virtual MatcherBase<A> *InitMatcher(MatchType match_type) const;
268
269 protected:
270
271  bool WriteFile(const string &filename) const {
272    if (!filename.empty()) {
273      ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
274      if (!strm) {
275        LOG(ERROR) << "Fst::Write: Can't open file: " << filename;
276        return false;
277      }
278      return Write(strm, FstWriteOptions(filename));
279    } else {
280      return Write(std::cout, FstWriteOptions("standard output"));
281    }
282  }
283};
284
285
286//
287// STATE and ARC ITERATOR DEFINITIONS
288//
289
290// State iterator interface templated on the Arc definition; used
291// for StateIterator specializations returned by the InitStateIterator
292// Fst method.
293template <class A>
294class StateIteratorBase {
295 public:
296  typedef A Arc;
297  typedef typename A::StateId StateId;
298
299  virtual ~StateIteratorBase() {}
300
301  bool Done() const { return Done_(); }       // End of iterator?
302  StateId Value() const { return Value_(); }  // Current state (when !Done)
303  void Next() { Next_(); }      // Advance to next state (when !Done)
304  void Reset() { Reset_(); }    // Return to initial condition
305
306 private:
307  // This allows base class virtual access to non-virtual derived-
308  // class members of the same name. It makes the derived class more
309  // efficient to use but unsafe to further derive.
310  virtual bool Done_() const = 0;
311  virtual StateId Value_() const = 0;
312  virtual void Next_() = 0;
313  virtual void Reset_() = 0;
314};
315
316
317// StateIterator initialization data
318
319template <class A> struct StateIteratorData {
320  StateIteratorBase<A> *base;   // Specialized iterator if non-zero
321  typename A::StateId nstates;  // O.w. total # of states
322};
323
324
325// Generic state iterator, templated on the FST definition
326// - a wrapper around pointer to specific one.
327// Here is a typical use: \code
328//   for (StateIterator<StdFst> siter(fst);
329//        !siter.Done();
330//        siter.Next()) {
331//     StateId s = siter.Value();
332//     ...
333//   } \endcode
334template <class F>
335class StateIterator {
336 public:
337  typedef F FST;
338  typedef typename F::Arc Arc;
339  typedef typename Arc::StateId StateId;
340
341  explicit StateIterator(const F &fst) : s_(0) {
342    fst.InitStateIterator(&data_);
343  }
344
345  ~StateIterator() { if (data_.base) delete data_.base; }
346
347  bool Done() const {
348    return data_.base ? data_.base->Done() : s_ >= data_.nstates;
349  }
350
351  StateId Value() const { return data_.base ? data_.base->Value() : s_; }
352
353  void Next() {
354    if (data_.base)
355      data_.base->Next();
356    else
357      ++s_;
358  }
359
360  void Reset() {
361    if (data_.base)
362      data_.base->Reset();
363    else
364      s_ = 0;
365  }
366
367 private:
368  StateIteratorData<Arc> data_;
369  StateId s_;
370
371  DISALLOW_COPY_AND_ASSIGN(StateIterator);
372};
373
374
375// Flags to control the behavior on an arc iterator:
376static const uint32 kArcILabelValue    = 0x0001;  // Value() gives valid ilabel
377static const uint32 kArcOLabelValue    = 0x0002;  //  "       "     "    olabel
378static const uint32 kArcWeightValue    = 0x0004;  //  "       "     "    weight
379static const uint32 kArcNextStateValue = 0x0008;  //  "       "     " nextstate
380static const uint32 kArcNoCache   = 0x0010;       // No need to cache arcs
381
382static const uint32 kArcValueFlags =
383                  kArcILabelValue | kArcOLabelValue |
384                  kArcWeightValue | kArcNextStateValue;
385
386static const uint32 kArcFlags = kArcValueFlags | kArcNoCache;
387
388
389// Arc iterator interface, templated on the Arc definition; used
390// for Arc iterator specializations that are returned by the InitArcIterator
391// Fst method.
392template <class A>
393class ArcIteratorBase {
394 public:
395  typedef A Arc;
396  typedef typename A::StateId StateId;
397
398  virtual ~ArcIteratorBase() {}
399
400  bool Done() const { return Done_(); }            // End of iterator?
401  const A& Value() const { return Value_(); }      // Current arc (when !Done)
402  void Next() { Next_(); }           // Advance to next arc (when !Done)
403  size_t Position() const { return Position_(); }  // Return current position
404  void Reset() { Reset_(); }         // Return to initial condition
405  void Seek(size_t a) { Seek_(a); }  // Random arc access by position
406  uint32 Flags() const { return Flags_(); }  // Return current behavorial flags
407  void SetFlags(uint32 flags, uint32 mask) {  // Set behavorial flags
408    SetFlags_(flags, mask);
409  }
410
411 private:
412  // This allows base class virtual access to non-virtual derived-
413  // class members of the same name. It makes the derived class more
414  // efficient to use but unsafe to further derive.
415  virtual bool Done_() const = 0;
416  virtual const A& Value_() const = 0;
417  virtual void Next_() = 0;
418  virtual size_t Position_() const = 0;
419  virtual void Reset_() = 0;
420  virtual void Seek_(size_t a) = 0;
421  virtual uint32 Flags_() const = 0;
422  virtual void SetFlags_(uint32 flags, uint32 mask) = 0;
423};
424
425
426// ArcIterator initialization data
427template <class A> struct ArcIteratorData {
428  ArcIteratorBase<A> *base;  // Specialized iterator if non-zero
429  const A *arcs;             // O.w. arcs pointer
430  size_t narcs;              // ... and arc count
431  int *ref_count;            // ... and reference count if non-zero
432};
433
434
435// Generic arc iterator, templated on the FST definition
436// - a wrapper around pointer to specific one.
437// Here is a typical use: \code
438//   for (ArcIterator<StdFst> aiter(fst, s));
439//        !aiter.Done();
440//         aiter.Next()) {
441//     StdArc &arc = aiter.Value();
442//     ...
443//   } \endcode
444template <class F>
445class ArcIterator {
446   public:
447  typedef F FST;
448  typedef typename F::Arc Arc;
449  typedef typename Arc::StateId StateId;
450
451  ArcIterator(const F &fst, StateId s) : i_(0) {
452    fst.InitArcIterator(s, &data_);
453  }
454
455  explicit ArcIterator(const ArcIteratorData<Arc> &data) : data_(data), i_(0) {
456    if (data_.ref_count)
457      ++(*data_.ref_count);
458  }
459
460  ~ArcIterator() {
461    if (data_.base)
462      delete data_.base;
463    else if (data_.ref_count)
464      --(*data_.ref_count);
465  }
466
467  bool Done() const {
468    return data_.base ?  data_.base->Done() : i_ >= data_.narcs;
469  }
470
471  const Arc& Value() const {
472    return data_.base ? data_.base->Value() : data_.arcs[i_];
473  }
474
475  void Next() {
476    if (data_.base)
477      data_.base->Next();
478    else
479      ++i_;
480  }
481
482  void Reset() {
483    if (data_.base)
484      data_.base->Reset();
485    else
486      i_ = 0;
487  }
488
489  void Seek(size_t a) {
490    if (data_.base)
491      data_.base->Seek(a);
492    else
493      i_ = a;
494  }
495
496  size_t Position() const {
497    return data_.base ? data_.base->Position() : i_;
498  }
499
500  uint32 Flags() const {
501    if (data_.base)
502      return data_.base->Flags();
503    else
504      return kArcValueFlags;
505  }
506
507  void SetFlags(uint32 flags, uint32 mask) {
508    if (data_.base)
509      data_.base->SetFlags(flags, mask);
510  }
511
512 private:
513  ArcIteratorData<Arc> data_;
514  size_t i_;
515  DISALLOW_COPY_AND_ASSIGN(ArcIterator);
516};
517
518//
519// MATCHER DEFINITIONS
520//
521
522template <class A>
523MatcherBase<A> *Fst<A>::InitMatcher(MatchType match_type) const {
524  return 0;  // Use the default matcher
525}
526
527
528//
529// FST ACCESSORS - Useful functions in high-performance cases.
530//
531
532namespace internal {
533
534// General case - requires non-abstract, 'final' methods. Use for inlining.
535template <class F> inline
536typename F::Arc::Weight Final(const F &fst, typename F::Arc::StateId s) {
537  return fst.F::Final(s);
538}
539
540template <class F> inline
541ssize_t NumArcs(const F &fst, typename F::Arc::StateId s) {
542  return fst.F::NumArcs(s);
543}
544
545template <class F> inline
546ssize_t NumInputEpsilons(const F &fst, typename F::Arc::StateId s) {
547  return fst.F::NumInputEpsilons(s);
548}
549
550template <class F> inline
551ssize_t NumOutputEpsilons(const F &fst, typename F::Arc::StateId s) {
552  return fst.F::NumOutputEpsilons(s);
553}
554
555
556//  Fst<A> case - abstract methods.
557template <class A> inline
558typename A::Weight Final(const Fst<A> &fst, typename A::StateId s) {
559  return fst.Final(s);
560}
561
562template <class A> inline
563ssize_t NumArcs(const Fst<A> &fst, typename A::StateId s) {
564  return fst.NumArcs(s);
565}
566
567template <class A> inline
568ssize_t NumInputEpsilons(const Fst<A> &fst, typename A::StateId s) {
569  return fst.NumInputEpsilons(s);
570}
571
572template <class A> inline
573ssize_t NumOutputEpsilons(const Fst<A> &fst, typename A::StateId s) {
574  return fst.NumOutputEpsilons(s);
575}
576
577}  // namespace internal
578
579// A useful alias when using StdArc.
580typedef Fst<StdArc> StdFst;
581
582
583//
584//  CONSTANT DEFINITIONS
585//
586
587const int kNoStateId   =  -1;  // Not a valid state ID
588const int kNoLabel     =  -1;  // Not a valid label
589
590//
591// Fst IMPLEMENTATION BASE
592//
593// This is the recommended Fst implementation base class. It will
594// handle reference counts, property bits, type information and symbols.
595//
596
597template <class A> class FstImpl {
598 public:
599  typedef typename A::Weight Weight;
600  typedef typename A::StateId StateId;
601
602  FstImpl()
603      : properties_(0), type_("null"), isymbols_(0), osymbols_(0) {}
604
605  FstImpl(const FstImpl<A> &impl)
606      : properties_(impl.properties_), type_(impl.type_),
607        isymbols_(impl.isymbols_ ? impl.isymbols_->Copy() : 0),
608        osymbols_(impl.osymbols_ ? impl.osymbols_->Copy() : 0) {}
609
610  virtual ~FstImpl() {
611    delete isymbols_;
612    delete osymbols_;
613  }
614
615  const string& Type() const { return type_; }
616
617  void SetType(const string &type) { type_ = type; }
618
619  virtual uint64 Properties() const { return properties_; }
620
621  virtual uint64 Properties(uint64 mask) const { return properties_ & mask; }
622
623  void SetProperties(uint64 props) {
624    properties_ &= kError;          // kError can't be cleared
625    properties_ |= props;
626  }
627
628  void SetProperties(uint64 props, uint64 mask) {
629    properties_ &= ~mask | kError;  // kError can't be cleared
630    properties_ |= props & mask;
631  }
632
633  // Allows (only) setting error bit on const FST impls
634  void SetProperties(uint64 props, uint64 mask) const {
635    if (mask != kError)
636      FSTERROR() << "FstImpl::SetProperties() const: can only set kError";
637    properties_ |= kError;
638  }
639
640  const SymbolTable* InputSymbols() const { return isymbols_; }
641
642  const SymbolTable* OutputSymbols() const { return osymbols_; }
643
644  SymbolTable* InputSymbols() { return isymbols_; }
645
646  SymbolTable* OutputSymbols() { return osymbols_; }
647
648  void SetInputSymbols(const SymbolTable* isyms) {
649    if (isymbols_) delete isymbols_;
650    isymbols_ = isyms ? isyms->Copy() : 0;
651  }
652
653  void SetOutputSymbols(const SymbolTable* osyms) {
654    if (osymbols_) delete osymbols_;
655    osymbols_ = osyms ? osyms->Copy() : 0;
656  }
657
658  int RefCount() const {
659    return ref_count_.count();
660  }
661
662  int IncrRefCount() {
663    return ref_count_.Incr();
664  }
665
666  int DecrRefCount() {
667    return ref_count_.Decr();
668  }
669
670  // Read-in header and symbols from input stream, initialize Fst, and
671  // return the header.  If opts.header is non-null, skip read-in and
672  // use the option value.  If opts.[io]symbols is non-null, read-in
673  // (if present), but use the option value.
674  bool ReadHeader(istream &strm, const FstReadOptions& opts,
675                  int min_version, FstHeader *hdr);
676
677  // Write-out header and symbols from output stream.
678  // If a opts.header is false, skip writing header.
679  // If opts.[io]symbols is false, skip writing those symbols.
680  // This method is needed for Impl's that implement Write methods.
681  void WriteHeader(ostream &strm, const FstWriteOptions& opts,
682                   int version, FstHeader *hdr) const {
683    if (opts.write_header) {
684      hdr->SetFstType(type_);
685      hdr->SetArcType(A::Type());
686      hdr->SetVersion(version);
687      hdr->SetProperties(properties_);
688      int32 file_flags = 0;
689      if (isymbols_ && opts.write_isymbols)
690        file_flags |= FstHeader::HAS_ISYMBOLS;
691      if (osymbols_ && opts.write_osymbols)
692        file_flags |= FstHeader::HAS_OSYMBOLS;
693      if (opts.align)
694        file_flags |= FstHeader::IS_ALIGNED;
695      hdr->SetFlags(file_flags);
696      hdr->Write(strm, opts.source);
697    }
698    if (isymbols_ && opts.write_isymbols) isymbols_->Write(strm);
699    if (osymbols_ && opts.write_osymbols) osymbols_->Write(strm);
700  }
701
702  // Write-out header and symbols to output stream.
703  // If a opts.header is false, skip writing header.
704  // If opts.[io]symbols is false, skip writing those symbols.
705  // type is the Fst type being written.
706  // This method is used in the cross-type serialization methods Fst::WriteFst.
707  static void WriteFstHeader(const Fst<A> &fst, ostream &strm,
708                             const FstWriteOptions& opts, int version,
709                             const string &type, FstHeader *hdr) {
710    if (opts.write_header) {
711      hdr->SetFstType(type);
712      hdr->SetArcType(A::Type());
713      hdr->SetVersion(version);
714      hdr->SetProperties(fst.Properties(kFstProperties, false));
715      int32 file_flags = 0;
716      if (fst.InputSymbols() && opts.write_isymbols)
717        file_flags |= FstHeader::HAS_ISYMBOLS;
718      if (fst.OutputSymbols() && opts.write_osymbols)
719        file_flags |= FstHeader::HAS_OSYMBOLS;
720      if (opts.align)
721        file_flags |= FstHeader::IS_ALIGNED;
722      hdr->SetFlags(file_flags);
723      hdr->Write(strm, opts.source);
724    }
725    if (fst.InputSymbols() && opts.write_isymbols) {
726      fst.InputSymbols()->Write(strm);
727    }
728    if (fst.OutputSymbols() && opts.write_osymbols) {
729      fst.OutputSymbols()->Write(strm);
730    }
731  }
732
733  // In serialization routines where the header cannot be written until after
734  // the machine has been serialized, this routine can be called to seek to
735  // the beginning of the file an rewrite the header with updated fields.
736  // It repositions the file pointer back at the end of the file.
737  // returns true on success, false on failure.
738  static bool UpdateFstHeader(const Fst<A> &fst, ostream &strm,
739                              const FstWriteOptions& opts, int version,
740                              const string &type, FstHeader *hdr,
741                              size_t header_offset) {
742    strm.seekp(header_offset);
743    if (!strm) {
744      LOG(ERROR) << "Fst::UpdateFstHeader: write failed: " << opts.source;
745      return false;
746    }
747    WriteFstHeader(fst, strm, opts, version, type, hdr);
748    if (!strm) {
749      LOG(ERROR) << "Fst::UpdateFstHeader: write failed: " << opts.source;
750      return false;
751    }
752    strm.seekp(0, ios_base::end);
753    if (!strm) {
754      LOG(ERROR) << "Fst::UpdateFstHeader: write failed: " << opts.source;
755      return false;
756    }
757    return true;
758  }
759
760 protected:
761  mutable uint64 properties_;           // Property bits
762
763 private:
764  string type_;                 // Unique name of Fst class
765  SymbolTable *isymbols_;       // Ilabel symbol table
766  SymbolTable *osymbols_;       // Olabel symbol table
767  RefCounter ref_count_;        // Reference count
768
769  void operator=(const FstImpl<A> &impl);  // disallow
770};
771
772template <class A> inline
773bool FstImpl<A>::ReadHeader(istream &strm, const FstReadOptions& opts,
774                            int min_version, FstHeader *hdr) {
775  if (opts.header)
776    *hdr = *opts.header;
777  else if (!hdr->Read(strm, opts.source))
778    return false;
779
780  if (FLAGS_v >= 2) {
781    LOG(INFO) << "FstImpl::ReadHeader: source: " << opts.source
782              << ", fst_type: " << hdr->FstType()
783              << ", arc_type: " << A::Type()
784              << ", version: " << hdr->Version()
785              << ", flags: " << hdr->GetFlags();
786  }
787
788  if (hdr->FstType() != type_) {
789    LOG(ERROR) << "FstImpl::ReadHeader: Fst not of type \"" << type_
790               << "\": " << opts.source;
791    return false;
792  }
793  if (hdr->ArcType() != A::Type()) {
794    LOG(ERROR) << "FstImpl::ReadHeader: Arc not of type \"" << A::Type()
795               << "\": " << opts.source;
796    return false;
797  }
798  if (hdr->Version() < min_version) {
799    LOG(ERROR) << "FstImpl::ReadHeader: Obsolete " << type_
800               << " Fst version: " << opts.source;
801    return false;
802  }
803  properties_ = hdr->Properties();
804  if (hdr->GetFlags() & FstHeader::HAS_ISYMBOLS)
805    isymbols_ = SymbolTable::Read(strm, opts.source);
806  if (hdr->GetFlags() & FstHeader::HAS_OSYMBOLS)
807    osymbols_ =SymbolTable::Read(strm, opts.source);
808
809  if (opts.isymbols) {
810    delete isymbols_;
811    isymbols_ = opts.isymbols->Copy();
812  }
813  if (opts.osymbols) {
814    delete osymbols_;
815    osymbols_ = opts.osymbols->Copy();
816  }
817  return true;
818}
819
820
821template<class Arc>
822uint64 TestProperties(const Fst<Arc> &fst, uint64 mask, uint64 *known);
823
824
825// This is a helper class template useful for attaching an Fst interface to
826// its implementation, handling reference counting.
827template < class I, class F = Fst<typename I::Arc> >
828class ImplToFst : public F {
829 public:
830  typedef typename I::Arc Arc;
831  typedef typename Arc::Weight Weight;
832  typedef typename Arc::StateId StateId;
833
834  virtual ~ImplToFst() { if (!impl_->DecrRefCount()) delete impl_;  }
835
836  virtual StateId Start() const { return impl_->Start(); }
837
838  virtual Weight Final(StateId s) const { return impl_->Final(s); }
839
840  virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
841
842  virtual size_t NumInputEpsilons(StateId s) const {
843    return impl_->NumInputEpsilons(s);
844  }
845
846  virtual size_t NumOutputEpsilons(StateId s) const {
847    return impl_->NumOutputEpsilons(s);
848  }
849
850  virtual uint64 Properties(uint64 mask, bool test) const {
851    if (test) {
852      uint64 knownprops, testprops = TestProperties(*this, mask, &knownprops);
853      impl_->SetProperties(testprops, knownprops);
854      return testprops & mask;
855    } else {
856      return impl_->Properties(mask);
857    }
858  }
859
860  virtual const string& Type() const { return impl_->Type(); }
861
862  virtual const SymbolTable* InputSymbols() const {
863    return impl_->InputSymbols();
864  }
865
866  virtual const SymbolTable* OutputSymbols() const {
867    return impl_->OutputSymbols();
868  }
869
870 protected:
871  ImplToFst() : impl_(0) {}
872
873  ImplToFst(I *impl) : impl_(impl) {}
874
875  ImplToFst(const ImplToFst<I, F> &fst) {
876    impl_ = fst.impl_;
877    impl_->IncrRefCount();
878  }
879
880  // This constructor presumes there is a copy constructor for the
881  // implementation.
882  ImplToFst(const ImplToFst<I, F> &fst, bool safe) {
883    if (safe) {
884      impl_ = new I(*(fst.impl_));
885    } else {
886      impl_ = fst.impl_;
887      impl_->IncrRefCount();
888    }
889  }
890
891  I *GetImpl() const { return impl_; }
892
893  // Change Fst implementation pointer. If 'own_impl' is true,
894  // ownership of the input implementation is given to this
895  // object; otherwise, the input implementation's reference count
896  // should be incremented.
897  void SetImpl(I *impl, bool own_impl = true) {
898    if (!own_impl)
899      impl->IncrRefCount();
900    if (impl_ && !impl_->DecrRefCount()) delete impl_;
901    impl_ = impl;
902  }
903
904 private:
905  // Disallow
906  ImplToFst<I, F> &operator=(const ImplToFst<I, F> &fst);
907
908  ImplToFst<I, F> &operator=(const Fst<Arc> &fst) {
909    FSTERROR() << "ImplToFst: Assignment operator disallowed";
910    GetImpl()->SetProperties(kError, kError);
911    return *this;
912  }
913
914  I *impl_;
915};
916
917
918// Converts FSTs by casting their implementations, where this makes
919// sense (which excludes implementations with weight-dependent virtual
920// methods). Must be a friend of the Fst classes involved (currently
921// the concrete Fsts: VectorFst, ConstFst, CompactFst).
922template<class F, class G> void Cast(const F &ifst, G *ofst) {
923  ofst->SetImpl(reinterpret_cast<typename G::Impl *>(ifst.GetImpl()), false);
924}
925
926// Fst Serialization
927template <class A>
928void FstToString(const Fst<A> &fst, string *result) {
929  ostringstream ostrm;
930  fst.Write(ostrm, FstWriteOptions("FstToString"));
931  *result = ostrm.str();
932}
933
934template <class A>
935Fst<A> *StringToFst(const string &s) {
936  istringstream istrm(s);
937  return Fst<A>::Read(istrm, FstReadOptions("StringToFst"));
938}
939
940}  // namespace fst
941
942#endif  // FST_LIB_FST_H__
943