1#ifndef __FST_IO_H__
2#define __FST_IO_H__
3
4// fst-io.h
5// This is a copy of the OPENFST SDK application sample files ...
6// except for the main functions ifdef'ed out
7// 2007, 2008 Nuance Communications
8//
9// print-main.h compile-main.h
10//
11// Licensed under the Apache License, Version 2.0 (the "License");
12// you may not use this file except in compliance with the License.
13// You may obtain a copy of the License at
14//
15//      http://www.apache.org/licenses/LICENSE-2.0
16//
17// Unless required by applicable law or agreed to in writing, software
18// distributed under the License is distributed on an "AS IS" BASIS,
19// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20// See the License for the specific language governing permissions and
21// limitations under the License.
22//
23//
24// \file
25// Classes and functions to compile a binary Fst from textual input.
26// Includes helper function for fstcompile.cc that templates the main
27// on the arc type to support multiple and extensible arc types.
28
29#include <fstream>
30#include <sstream>
31
32#include "fst/lib/fst.h"
33#include "fst/lib/fstlib.h"
34#include "fst/lib/fst-decl.h"
35#include "fst/lib/vector-fst.h"
36#include "fst/lib/arcsort.h"
37#include "fst/lib/invert.h"
38
39namespace fst {
40
41  template <class A> class FstPrinter {
42  public:
43    typedef A Arc;
44    typedef typename A::StateId StateId;
45    typedef typename A::Label Label;
46    typedef typename A::Weight Weight;
47
48    FstPrinter(const Fst<A> &fst,
49	       const SymbolTable *isyms,
50	       const SymbolTable *osyms,
51	       const SymbolTable *ssyms,
52	       bool accep)
53      : fst_(fst), isyms_(isyms), osyms_(osyms), ssyms_(ssyms),
54      accep_(accep && fst.Properties(kAcceptor, true)), ostrm_(0) {}
55
56    // Print Fst to an output strm
57    void Print(ostream *ostrm, const string &dest) {
58      ostrm_ = ostrm;
59      dest_ = dest;
60      StateId start = fst_.Start();
61      if (start == kNoStateId)
62	return;
63      // initial state first
64      PrintState(start);
65      for (StateIterator< Fst<A> > siter(fst_);
66	   !siter.Done();
67	   siter.Next()) {
68	StateId s = siter.Value();
69	if (s != start)
70	  PrintState(s);
71      }
72    }
73
74  private:
75    // Maximum line length in text file.
76    static const int kLineLen = 8096;
77
78    void PrintId(int64 id, const SymbolTable *syms,
79		 const char *name) const {
80      if (syms) {
81	string symbol = syms->Find(id);
82	if (symbol == "") {
83	  LOG(ERROR) << "FstPrinter: Integer " << id
84		     << " is not mapped to any textual symbol"
85		     << ", symbol table = " << syms->Name()
86		     << ", destination = " << dest_;
87	  exit(1);
88	}
89	*ostrm_ << symbol;
90      } else {
91	*ostrm_ << id;
92      }
93    }
94
95    void PrintStateId(StateId s) const {
96      PrintId(s, ssyms_, "state ID");
97    }
98
99    void PrintILabel(Label l) const {
100      PrintId(l, isyms_, "arc input label");
101    }
102
103    void PrintOLabel(Label l) const {
104      PrintId(l, osyms_, "arc output label");
105    }
106
107    void PrintState(StateId s) const {
108      bool output = false;
109      for (ArcIterator< Fst<A> > aiter(fst_, s);
110	   !aiter.Done();
111	   aiter.Next()) {
112	Arc arc = aiter.Value();
113	PrintStateId(s);
114	*ostrm_ << "\t";
115	PrintStateId(arc.nextstate);
116	*ostrm_ << "\t";
117	PrintILabel(arc.ilabel);
118	if (!accep_) {
119	  *ostrm_ << "\t";
120	  PrintOLabel(arc.olabel);
121	}
122	if (arc.weight != Weight::One())
123	  *ostrm_ << "\t" << arc.weight;
124	*ostrm_ << "\n";
125	output = true;
126      }
127      Weight final = fst_.Final(s);
128      if (final != Weight::Zero() || !output) {
129	PrintStateId(s);
130	if (final != Weight::One()) {
131	  *ostrm_ << "\t" << final;
132	}
133	*ostrm_ << "\n";
134      }
135    }
136
137    const Fst<A> &fst_;
138    const SymbolTable *isyms_;     // ilabel symbol table
139    const SymbolTable *osyms_;     // olabel symbol table
140    const SymbolTable *ssyms_;     // slabel symbol table
141    bool accep_;                   // print as acceptor when possible
142    ostream *ostrm_;                // binary FST destination
143    string dest_;                  // binary FST destination name
144    DISALLOW_EVIL_CONSTRUCTORS(FstPrinter);
145  };
146
147#if 0
148  // Main function for fstprint templated on the arc type.
149  template <class Arc>
150    int PrintMain(int argc, char **argv, istream &istrm,
151		  const FstReadOptions &opts) {
152    Fst<Arc> *fst = Fst<Arc>::Read(istrm, opts);
153    if (!fst) return 1;
154
155    string dest = "standard output";
156    ostream *ostrm = &std::cout;
157    if (argc == 3) {
158      dest = argv[2];
159      ostrm = new ofstream(argv[2]);
160      if (!*ostrm) {
161	LOG(ERROR) << argv[0] << ": Open failed, file = " << argv[2];
162	return 0;
163      }
164    }
165    ostrm->precision(9);
166
167    const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0;
168
169    if (!FLAGS_isymbols.empty() && !FLAGS_numeric) {
170      isyms = SymbolTable::ReadText(FLAGS_isymbols);
171      if (!isyms) exit(1);
172    }
173
174    if (!FLAGS_osymbols.empty() && !FLAGS_numeric) {
175      osyms = SymbolTable::ReadText(FLAGS_osymbols);
176      if (!osyms) exit(1);
177    }
178
179    if (!FLAGS_ssymbols.empty() && !FLAGS_numeric) {
180      ssyms = SymbolTable::ReadText(FLAGS_ssymbols);
181      if (!ssyms) exit(1);
182    }
183
184    if (!isyms && !FLAGS_numeric)
185      isyms = fst->InputSymbols();
186    if (!osyms && !FLAGS_numeric)
187      osyms = fst->OutputSymbols();
188
189    FstPrinter<Arc> fstprinter(*fst, isyms, osyms, ssyms, FLAGS_acceptor);
190    fstprinter.Print(ostrm, dest);
191
192    if (isyms && !FLAGS_save_isymbols.empty())
193      isyms->WriteText(FLAGS_save_isymbols);
194
195    if (osyms && !FLAGS_save_osymbols.empty())
196      osyms->WriteText(FLAGS_save_osymbols);
197
198    if (ostrm != &std::cout)
199      delete ostrm;
200    return 0;
201  }
202#endif
203
204
205  template <class A> class FstReader {
206  public:
207    typedef A Arc;
208    typedef typename A::StateId StateId;
209    typedef typename A::Label Label;
210    typedef typename A::Weight Weight;
211
212    FstReader(istream &istrm, const string &source,
213	      const SymbolTable *isyms, const SymbolTable *osyms,
214	      const SymbolTable *ssyms, bool accep, bool ikeep,
215	      bool okeep, bool nkeep)
216      : nline_(0), source_(source),
217      isyms_(isyms), osyms_(osyms), ssyms_(ssyms),
218      nstates_(0), keep_state_numbering_(nkeep) {
219      char line[kLineLen];
220      while (istrm.getline(line, kLineLen)) {
221	++nline_;
222	vector<char *> col;
223	SplitToVector(line, "\n\t ", &col, true);
224	if (col.size() == 0 || col[0][0] == '\0')  // empty line
225	  continue;
226	if (col.size() > 5 ||
227	    col.size() > 4 && accep ||
228	    col.size() == 3 && !accep) {
229	  LOG(ERROR) << "FstReader: Bad number of columns, source = " << source_
230		     << ", line = " << nline_;
231	  exit(1);
232	}
233	StateId s = StrToStateId(col[0]);
234	while (s >= fst_.NumStates())
235	  fst_.AddState();
236	if (nline_ == 1)
237	  fst_.SetStart(s);
238
239	Arc arc;
240	StateId d = s;
241	switch (col.size()) {
242	case 1:
243	  fst_.SetFinal(s, Weight::One());
244	  break;
245	case 2:
246	  fst_.SetFinal(s, StrToWeight(col[1], true));
247	  break;
248	case 3:
249	  arc.nextstate = d = StrToStateId(col[1]);
250	  arc.ilabel = StrToILabel(col[2]);
251	  arc.olabel = arc.ilabel;
252	  arc.weight = Weight::One();
253	  fst_.AddArc(s, arc);
254	  break;
255	case 4:
256	  arc.nextstate = d = StrToStateId(col[1]);
257	  arc.ilabel = StrToILabel(col[2]);
258	  if (accep) {
259	    arc.olabel = arc.ilabel;
260	    arc.weight = StrToWeight(col[3], false);
261	  } else {
262	    arc.olabel = StrToOLabel(col[3]);
263	    arc.weight = Weight::One();
264	  }
265	  fst_.AddArc(s, arc);
266	  break;
267	case 5:
268	  arc.nextstate = d = StrToStateId(col[1]);
269	  arc.ilabel = StrToILabel(col[2]);
270	  arc.olabel = StrToOLabel(col[3]);
271	  arc.weight = StrToWeight(col[4], false);
272	  fst_.AddArc(s, arc);
273	}
274	while (d >= fst_.NumStates())
275	  fst_.AddState();
276      }
277      if (ikeep)
278	fst_.SetInputSymbols(isyms);
279      if (okeep)
280	fst_.SetOutputSymbols(osyms);
281    }
282
283    const VectorFst<A> &Fst() const { return fst_; }
284
285  private:
286    // Maximum line length in text file.
287    static const int kLineLen = 8096;
288
289    int64 StrToId(const char *s, const SymbolTable *syms,
290		  const char *name) const {
291      int64 n;
292
293      if (syms) {
294	n = syms->Find(s);
295	if (n < 0) {
296	  LOG(ERROR) << "FstReader: Symbol \"" << s
297		     << "\" is not mapped to any integer " << name
298		     << ", symbol table = " << syms->Name()
299		     << ", source = " << source_ << ", line = " << nline_;
300	  exit(1);
301	}
302      } else {
303	char *p;
304	n = strtoll(s, &p, 10);
305	if (p < s + strlen(s) || n < 0) {
306	  LOG(ERROR) << "FstReader: Bad " << name << " integer = \"" << s
307		     << "\", source = " << source_ << ", line = " << nline_;
308	  exit(1);
309	}
310      }
311      return n;
312    }
313
314    StateId StrToStateId(const char *s) {
315      StateId n = StrToId(s, ssyms_, "state ID");
316
317      if (keep_state_numbering_)
318	return n;
319
320      // remap state IDs to make dense set
321      typename hash_map<StateId, StateId>::const_iterator it = states_.find(n);
322      if (it == states_.end()) {
323	states_[n] = nstates_;
324	return nstates_++;
325      } else {
326	return it->second;
327      }
328    }
329
330    StateId StrToILabel(const char *s) const {
331      return StrToId(s, isyms_, "arc ilabel");
332    }
333
334    StateId StrToOLabel(const char *s) const {
335      return StrToId(s, osyms_, "arc olabel");
336    }
337
338    Weight StrToWeight(const char *s, bool allow_zero) const {
339      Weight w;
340      istringstream strm(s);
341      strm >> w;
342      if (strm.fail() || !allow_zero && w == Weight::Zero()) {
343	LOG(ERROR) << "FstReader: Bad weight = \"" << s
344		   << "\", source = " << source_ << ", line = " << nline_;
345	exit(1);
346      }
347      return w;
348    }
349
350    VectorFst<A> fst_;
351    size_t nline_;
352    string source_;                      // text FST source name
353    const SymbolTable *isyms_;           // ilabel symbol table
354    const SymbolTable *osyms_;           // olabel symbol table
355    const SymbolTable *ssyms_;           // slabel symbol table
356    hash_map<StateId, StateId> states_;  // state ID map
357    StateId nstates_;                    // number of seen states
358    bool keep_state_numbering_;
359    DISALLOW_EVIL_CONSTRUCTORS(FstReader);
360  };
361
362#if 0
363  // Main function for fstcompile templated on the arc type.  Last two
364  // arguments unneeded since fstcompile passes the arc type as a flag
365  // unlike the other mains, which infer the arc type from an input Fst.
366  template <class Arc>
367    int CompileMain(int argc, char **argv, istream& /* strm */,
368		    const FstReadOptions & /* opts */) {
369    char *ifilename = "standard input";
370    istream *istrm = &std::cin;
371    if (argc > 1 && strcmp(argv[1], "-") != 0) {
372      ifilename = argv[1];
373      istrm = new ifstream(ifilename);
374      if (!*istrm) {
375	LOG(ERROR) << argv[0] << ": Open failed, file = " << ifilename;
376	return 1;
377      }
378    }
379    const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0;
380
381    if (!FLAGS_isymbols.empty()) {
382      isyms = SymbolTable::ReadText(FLAGS_isymbols);
383      if (!isyms) exit(1);
384    }
385
386    if (!FLAGS_osymbols.empty()) {
387      osyms = SymbolTable::ReadText(FLAGS_osymbols);
388      if (!osyms) exit(1);
389    }
390
391    if (!FLAGS_ssymbols.empty()) {
392      ssyms = SymbolTable::ReadText(FLAGS_ssymbols);
393      if (!ssyms) exit(1);
394    }
395
396    FstReader<Arc> fstreader(*istrm, ifilename, isyms, osyms, ssyms,
397			     FLAGS_acceptor, FLAGS_keep_isymbols,
398			     FLAGS_keep_osymbols, FLAGS_keep_state_numbering);
399
400    const Fst<Arc> *fst = &fstreader.Fst();
401    if (FLAGS_fst_type != "vector") {
402      fst = Convert<Arc>(*fst, FLAGS_fst_type);
403      if (!fst) return 1;
404    }
405    fst->Write(argc > 2 ? argv[2] : "");
406    if (istrm != &std::cin)
407      delete istrm;
408    return 0;
409  }
410#endif
411
412}  // namespace fst
413
414#endif /* __FST_IO_H__ */
415
416