compile-impl.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
1// compile.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Class to to compile a binary Fst from textual input.
20
21#ifndef FST_SCRIPT_COMPILE_IMPL_H_
22#define FST_SCRIPT_COMPILE_IMPL_H_
23
24#include <unordered_map>
25using std::tr1::unordered_map;
26using std::tr1::unordered_multimap;
27#include <sstream>
28#include <string>
29#include <vector>
30using std::vector;
31
32#include <iostream>
33#include <fstream>
34#include <fst/fst.h>
35#include <fst/util.h>
36#include <fst/vector-fst.h>
37
38DECLARE_string(fst_field_separator);
39
40namespace fst {
41
42// Compile a binary Fst from textual input, helper class for fstcompile.cc
43// WARNING: Stand-alone use of this class not recommended, most code should
44// read/write using the binary format which is much more efficient.
45template <class A> class FstCompiler {
46 public:
47  typedef A Arc;
48  typedef typename A::StateId StateId;
49  typedef typename A::Label Label;
50  typedef typename A::Weight Weight;
51
52  // WARNING: use of 'allow_negative_labels = true' not recommended; may
53  // cause conflicts
54  FstCompiler(istream &istrm, const string &source,
55            const SymbolTable *isyms, const SymbolTable *osyms,
56            const SymbolTable *ssyms, bool accep, bool ikeep,
57              bool okeep, bool nkeep, bool allow_negative_labels = false)
58      : nline_(0), source_(source),
59        isyms_(isyms), osyms_(osyms), ssyms_(ssyms),
60        nstates_(0), keep_state_numbering_(nkeep),
61        allow_negative_labels_(allow_negative_labels) {
62    char line[kLineLen];
63    while (istrm.getline(line, kLineLen)) {
64      ++nline_;
65      vector<char *> col;
66      string separator = FLAGS_fst_field_separator + "\n";
67      SplitToVector(line, separator.c_str(), &col, true);
68      if (col.size() == 0 || col[0][0] == '\0')  // empty line
69        continue;
70      if (col.size() > 5 ||
71          (col.size() > 4 && accep) ||
72          (col.size() == 3 && !accep)) {
73        FSTERROR() << "FstCompiler: Bad number of columns, source = "
74                   << source_
75                   << ", line = " << nline_;
76        fst_.SetProperties(kError, kError);
77        return;
78      }
79      StateId s = StrToStateId(col[0]);
80      while (s >= fst_.NumStates())
81        fst_.AddState();
82      if (nline_ == 1)
83        fst_.SetStart(s);
84
85      Arc arc;
86      StateId d = s;
87      switch (col.size()) {
88      case 1:
89        fst_.SetFinal(s, Weight::One());
90        break;
91      case 2:
92        fst_.SetFinal(s, StrToWeight(col[1], true));
93        break;
94      case 3:
95        arc.nextstate = d = StrToStateId(col[1]);
96        arc.ilabel = StrToILabel(col[2]);
97        arc.olabel = arc.ilabel;
98        arc.weight = Weight::One();
99        fst_.AddArc(s, arc);
100        break;
101      case 4:
102        arc.nextstate = d = StrToStateId(col[1]);
103        arc.ilabel = StrToILabel(col[2]);
104        if (accep) {
105          arc.olabel = arc.ilabel;
106          arc.weight = StrToWeight(col[3], false);
107        } else {
108          arc.olabel = StrToOLabel(col[3]);
109          arc.weight = Weight::One();
110        }
111        fst_.AddArc(s, arc);
112        break;
113      case 5:
114        arc.nextstate = d = StrToStateId(col[1]);
115        arc.ilabel = StrToILabel(col[2]);
116        arc.olabel = StrToOLabel(col[3]);
117        arc.weight = StrToWeight(col[4], false);
118        fst_.AddArc(s, arc);
119      }
120      while (d >= fst_.NumStates())
121        fst_.AddState();
122    }
123    if (ikeep)
124      fst_.SetInputSymbols(isyms);
125    if (okeep)
126      fst_.SetOutputSymbols(osyms);
127  }
128
129  const VectorFst<A> &Fst() const {
130    return fst_;
131  }
132
133 private:
134  // Maximum line length in text file.
135  static const int kLineLen = 8096;
136
137  int64 StrToId(const char *s, const SymbolTable *syms,
138                const char *name, bool allow_negative = false) const {
139    int64 n = 0;
140
141    if (syms) {
142      n = syms->Find(s);
143      if (n == -1 || (!allow_negative && n < 0)) {
144        FSTERROR() << "FstCompiler: Symbol \"" << s
145                   << "\" is not mapped to any integer " << name
146                   << ", symbol table = " << syms->Name()
147                   << ", source = " << source_ << ", line = " << nline_;
148        fst_.SetProperties(kError, kError);
149      }
150    } else {
151      char *p;
152      n = strtoll(s, &p, 10);
153      if (p < s + strlen(s) || (!allow_negative && n < 0)) {
154        FSTERROR() << "FstCompiler: Bad " << name << " integer = \"" << s
155                   << "\", source = " << source_ << ", line = " << nline_;
156        fst_.SetProperties(kError, kError);
157      }
158    }
159    return n;
160  }
161
162  StateId StrToStateId(const char *s) {
163    StateId n = StrToId(s, ssyms_, "state ID");
164
165    if (keep_state_numbering_)
166      return n;
167
168    // remap state IDs to make dense set
169    typename unordered_map<StateId, StateId>::const_iterator it = states_.find(n);
170    if (it == states_.end()) {
171      states_[n] = nstates_;
172      return nstates_++;
173    } else {
174      return it->second;
175    }
176  }
177
178  StateId StrToILabel(const char *s) const {
179    return StrToId(s, isyms_, "arc ilabel", allow_negative_labels_);
180  }
181
182  StateId StrToOLabel(const char *s) const {
183    return StrToId(s, osyms_, "arc olabel", allow_negative_labels_);
184  }
185
186  Weight StrToWeight(const char *s, bool allow_zero) const {
187    Weight w;
188    istringstream strm(s);
189    strm >> w;
190    if (!strm || (!allow_zero && w == Weight::Zero())) {
191      FSTERROR() << "FstCompiler: Bad weight = \"" << s
192                 << "\", source = " << source_ << ", line = " << nline_;
193      fst_.SetProperties(kError, kError);
194      w = Weight::NoWeight();
195    }
196    return w;
197  }
198
199  mutable VectorFst<A> fst_;
200  size_t nline_;
201  string source_;                      // text FST source name
202  const SymbolTable *isyms_;           // ilabel symbol table
203  const SymbolTable *osyms_;           // olabel symbol table
204  const SymbolTable *ssyms_;           // slabel symbol table
205  unordered_map<StateId, StateId> states_;  // state ID map
206  StateId nstates_;                    // number of seen states
207  bool keep_state_numbering_;
208  bool allow_negative_labels_;         // not recommended; may cause conflicts
209
210  DISALLOW_COPY_AND_ASSIGN(FstCompiler);
211};
212
213}  // namespace fst
214
215#endif  // FST_SCRIPT_COMPILE_IMPL_H_
216