expanded-fst.h revision 4a68b3365c8c50aa93505e99ead2565ab73dcdb0
1// expanded-fst.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// Generic FST augmented with state count - interface class definition.
18
19#ifndef FST_LIB_EXPANDED_FST_H__
20#define FST_LIB_EXPANDED_FST_H__
21
22#include "fst/lib/fst.h"
23
24namespace fst {
25
26// A generic FST plus state count.
27template <class A>
28class ExpandedFst : public Fst<A> {
29 public:
30  typedef A Arc;
31  typedef typename A::StateId StateId;
32
33  virtual StateId NumStates() const = 0;  // State count
34
35  // Get a copy of this ExpandedFst.
36  virtual ExpandedFst<A> *Copy() const = 0;
37  // Read an ExpandedFst from an input stream; return NULL on error.
38  static ExpandedFst<A> *Read(istream &strm, const FstReadOptions &opts) {
39    FstReadOptions ropts(opts);
40    FstHeader hdr;
41    if (ropts.header)
42      hdr = *opts.header;
43    else {
44      if (!hdr.Read(strm, opts.source))
45        return 0;
46      ropts.header = &hdr;
47    }
48    if (!(hdr.Properties() & kExpanded)) {
49      LOG(ERROR) << "ExpandedFst::Read: Not an ExpandedFst: " << ropts.source;
50      return 0;
51    }
52    FstRegister<A> *registr = FstRegister<A>::GetRegister();
53    const typename FstRegister<A>::Reader reader =
54      registr->GetReader(hdr.FstType());
55    if (!reader) {
56      LOG(ERROR) << "ExpandedFst::Read: Unknown FST type \"" << hdr.FstType()
57                 << "\" (arc type = \"" << A::Type()
58                 << "\"): " << ropts.source;
59      return 0;
60    }
61    Fst<A> *fst = reader(strm, ropts);
62    if (!fst) return 0;
63    return down_cast<ExpandedFst<A> *>(fst);
64  }
65  // Read an ExpandedFst from a file; return NULL on error.
66  static ExpandedFst<A> *Read(const string &filename) {
67    ifstream strm(filename.c_str());
68    if (!strm) {
69      LOG(ERROR) << "ExpandedFst::Read: Can't open file: " << filename;
70      return 0;
71    }
72    return Read(strm, FstReadOptions(filename));
73  }
74};
75
76// A useful alias when using StdArc.
77typedef ExpandedFst<StdArc> StdExpandedFst;
78
79// Function to return the number of states in an FST, counting them
80// if necessary.
81template <class Arc>
82typename Arc::StateId CountStates(const Fst<Arc> &fst) {
83  if (fst.Properties(kExpanded, false)) {
84    const ExpandedFst<Arc> *efst = down_cast<const ExpandedFst<Arc> *>(&fst);
85    return efst->NumStates();
86  } else {
87    typename Arc::StateId nstates = 0;
88    for (StateIterator< Fst<Arc> > siter(fst); !siter.Done(); siter.Next())
89      ++nstates;
90    return nstates;
91  }
92}
93
94}  // FST_LIB_FST_H__
95
96#endif  // FST_LIB_EXPANDED_FST_H__
97