compile-impl.h revision 3da1eb108d36da35333b2d655202791af854996b
1// compile.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Class to to compile a binary Fst from textual input.
20
21#ifndef FST_SCRIPT_COMPILE_IMPL_H_
22#define FST_SCRIPT_COMPILE_IMPL_H_
23
24#include <tr1/unordered_map>
25using std::tr1::unordered_map;
26using std::tr1::unordered_multimap;
27#include <sstream>
28#include <string>
29#include <vector>
30using std::vector;
31
32#include <iostream>
33#include <fstream>
34#include <sstream>
35#include <fst/fst.h>
36#include <fst/util.h>
37#include <fst/vector-fst.h>
38
39DECLARE_string(fst_field_separator);
40
41namespace fst {
42
43// Compile a binary Fst from textual input, helper class for fstcompile.cc
44// WARNING: Stand-alone use of this class not recommended, most code should
45// read/write using the binary format which is much more efficient.
46template <class A> class FstCompiler {
47 public:
48  typedef A Arc;
49  typedef typename A::StateId StateId;
50  typedef typename A::Label Label;
51  typedef typename A::Weight Weight;
52
53  // WARNING: use of 'allow_negative_labels = true' not recommended; may
54  // cause conflicts
55  FstCompiler(istream &istrm, const string &source,
56            const SymbolTable *isyms, const SymbolTable *osyms,
57            const SymbolTable *ssyms, bool accep, bool ikeep,
58              bool okeep, bool nkeep, bool allow_negative_labels = false)
59      : nline_(0), source_(source),
60        isyms_(isyms), osyms_(osyms), ssyms_(ssyms),
61        nstates_(0), keep_state_numbering_(nkeep),
62        allow_negative_labels_(allow_negative_labels) {
63    char line[kLineLen];
64    while (istrm.getline(line, kLineLen)) {
65      ++nline_;
66      vector<char *> col;
67      string separator = FLAGS_fst_field_separator + "\n";
68      SplitToVector(line, separator.c_str(), &col, true);
69      if (col.size() == 0 || col[0][0] == '\0')  // empty line
70        continue;
71      if (col.size() > 5 ||
72          (col.size() > 4 && accep) ||
73          (col.size() == 3 && !accep)) {
74        FSTERROR() << "FstCompiler: Bad number of columns, source = "
75                   << source_
76                   << ", line = " << nline_;
77        fst_.SetProperties(kError, kError);
78        return;
79      }
80      StateId s = StrToStateId(col[0]);
81      while (s >= fst_.NumStates())
82        fst_.AddState();
83      if (nline_ == 1)
84        fst_.SetStart(s);
85
86      Arc arc;
87      StateId d = s;
88      switch (col.size()) {
89      case 1:
90        fst_.SetFinal(s, Weight::One());
91        break;
92      case 2:
93        fst_.SetFinal(s, StrToWeight(col[1], true));
94        break;
95      case 3:
96        arc.nextstate = d = StrToStateId(col[1]);
97        arc.ilabel = StrToILabel(col[2]);
98        arc.olabel = arc.ilabel;
99        arc.weight = Weight::One();
100        fst_.AddArc(s, arc);
101        break;
102      case 4:
103        arc.nextstate = d = StrToStateId(col[1]);
104        arc.ilabel = StrToILabel(col[2]);
105        if (accep) {
106          arc.olabel = arc.ilabel;
107          arc.weight = StrToWeight(col[3], false);
108        } else {
109          arc.olabel = StrToOLabel(col[3]);
110          arc.weight = Weight::One();
111        }
112        fst_.AddArc(s, arc);
113        break;
114      case 5:
115        arc.nextstate = d = StrToStateId(col[1]);
116        arc.ilabel = StrToILabel(col[2]);
117        arc.olabel = StrToOLabel(col[3]);
118        arc.weight = StrToWeight(col[4], false);
119        fst_.AddArc(s, arc);
120      }
121      while (d >= fst_.NumStates())
122        fst_.AddState();
123    }
124    if (ikeep)
125      fst_.SetInputSymbols(isyms);
126    if (okeep)
127      fst_.SetOutputSymbols(osyms);
128  }
129
130  const VectorFst<A> &Fst() const {
131    return fst_;
132  }
133
134 private:
135  // Maximum line length in text file.
136  static const int kLineLen = 8096;
137
138  int64 StrToId(const char *s, const SymbolTable *syms,
139                const char *name, bool allow_negative = false) const {
140    int64 n = 0;
141
142    if (syms) {
143      n = syms->Find(s);
144      if (n == -1 || (!allow_negative && n < 0)) {
145        FSTERROR() << "FstCompiler: Symbol \"" << s
146                   << "\" is not mapped to any integer " << name
147                   << ", symbol table = " << syms->Name()
148                   << ", source = " << source_ << ", line = " << nline_;
149        fst_.SetProperties(kError, kError);
150      }
151    } else {
152      char *p;
153      n = strtoll(s, &p, 10);
154      if (p < s + strlen(s) || (!allow_negative && n < 0)) {
155        FSTERROR() << "FstCompiler: Bad " << name << " integer = \"" << s
156                   << "\", source = " << source_ << ", line = " << nline_;
157        fst_.SetProperties(kError, kError);
158      }
159    }
160    return n;
161  }
162
163  StateId StrToStateId(const char *s) {
164    StateId n = StrToId(s, ssyms_, "state ID");
165
166    if (keep_state_numbering_)
167      return n;
168
169    // remap state IDs to make dense set
170    typename unordered_map<StateId, StateId>::const_iterator it = states_.find(n);
171    if (it == states_.end()) {
172      states_[n] = nstates_;
173      return nstates_++;
174    } else {
175      return it->second;
176    }
177  }
178
179  StateId StrToILabel(const char *s) const {
180    return StrToId(s, isyms_, "arc ilabel", allow_negative_labels_);
181  }
182
183  StateId StrToOLabel(const char *s) const {
184    return StrToId(s, osyms_, "arc olabel", allow_negative_labels_);
185  }
186
187  Weight StrToWeight(const char *s, bool allow_zero) const {
188    Weight w;
189    istringstream strm(s);
190    strm >> w;
191    if (!strm || (!allow_zero && w == Weight::Zero())) {
192      FSTERROR() << "FstCompiler: Bad weight = \"" << s
193                 << "\", source = " << source_ << ", line = " << nline_;
194      fst_.SetProperties(kError, kError);
195      w = Weight::NoWeight();
196    }
197    return w;
198  }
199
200  mutable VectorFst<A> fst_;
201  size_t nline_;
202  string source_;                      // text FST source name
203  const SymbolTable *isyms_;           // ilabel symbol table
204  const SymbolTable *osyms_;           // olabel symbol table
205  const SymbolTable *ssyms_;           // slabel symbol table
206  unordered_map<StateId, StateId> states_;  // state ID map
207  StateId nstates_;                    // number of seen states
208  bool keep_state_numbering_;
209  bool allow_negative_labels_;         // not recommended; may cause conflicts
210
211  DISALLOW_COPY_AND_ASSIGN(FstCompiler);
212};
213
214}  // namespace fst
215
216#endif  // FST_SCRIPT_COMPILE_IMPL_H_
217