1#ifndef __FST_IO_H__ 2#define __FST_IO_H__ 3 4// fst-io.h 5// This is a copy of the OPENFST SDK application sample files ... 6// except for the main functions ifdef'ed out 7// 2007, 2008 Nuance Communications 8// 9// print-main.h compile-main.h 10// 11// Licensed under the Apache License, Version 2.0 (the "License"); 12// you may not use this file except in compliance with the License. 13// You may obtain a copy of the License at 14// 15// http://www.apache.org/licenses/LICENSE-2.0 16// 17// Unless required by applicable law or agreed to in writing, software 18// distributed under the License is distributed on an "AS IS" BASIS, 19// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20// See the License for the specific language governing permissions and 21// limitations under the License. 22// 23// 24// \file 25// Classes and functions to compile a binary Fst from textual input. 26// Includes helper function for fstcompile.cc that templates the main 27// on the arc type to support multiple and extensible arc types. 28 29#include <fstream> 30#include <sstream> 31 32#include "fst/lib/fst.h" 33#include "fst/lib/fstlib.h" 34#include "fst/lib/fst-decl.h" 35#include "fst/lib/vector-fst.h" 36#include "fst/lib/arcsort.h" 37#include "fst/lib/invert.h" 38 39namespace fst { 40 41 template <class A> class FstPrinter { 42 public: 43 typedef A Arc; 44 typedef typename A::StateId StateId; 45 typedef typename A::Label Label; 46 typedef typename A::Weight Weight; 47 48 FstPrinter(const Fst<A> &fst, 49 const SymbolTable *isyms, 50 const SymbolTable *osyms, 51 const SymbolTable *ssyms, 52 bool accep) 53 : fst_(fst), isyms_(isyms), osyms_(osyms), ssyms_(ssyms), 54 accep_(accep && fst.Properties(kAcceptor, true)), ostrm_(0) {} 55 56 // Print Fst to an output strm 57 void Print(ostream *ostrm, const string &dest) { 58 ostrm_ = ostrm; 59 dest_ = dest; 60 StateId start = fst_.Start(); 61 if (start == kNoStateId) 62 return; 63 // initial state first 64 PrintState(start); 65 for (StateIterator< Fst<A> > siter(fst_); 66 !siter.Done(); 67 siter.Next()) { 68 StateId s = siter.Value(); 69 if (s != start) 70 PrintState(s); 71 } 72 } 73 74 private: 75 // Maximum line length in text file. 76 static const int kLineLen = 8096; 77 78 void PrintId(int64 id, const SymbolTable *syms, 79 const char *name) const { 80 if (syms) { 81 string symbol = syms->Find(id); 82 if (symbol == "") { 83 LOG(ERROR) << "FstPrinter: Integer " << id 84 << " is not mapped to any textual symbol" 85 << ", symbol table = " << syms->Name() 86 << ", destination = " << dest_; 87 exit(1); 88 } 89 *ostrm_ << symbol; 90 } else { 91 *ostrm_ << id; 92 } 93 } 94 95 void PrintStateId(StateId s) const { 96 PrintId(s, ssyms_, "state ID"); 97 } 98 99 void PrintILabel(Label l) const { 100 PrintId(l, isyms_, "arc input label"); 101 } 102 103 void PrintOLabel(Label l) const { 104 PrintId(l, osyms_, "arc output label"); 105 } 106 107 void PrintState(StateId s) const { 108 bool output = false; 109 for (ArcIterator< Fst<A> > aiter(fst_, s); 110 !aiter.Done(); 111 aiter.Next()) { 112 Arc arc = aiter.Value(); 113 PrintStateId(s); 114 *ostrm_ << "\t"; 115 PrintStateId(arc.nextstate); 116 *ostrm_ << "\t"; 117 PrintILabel(arc.ilabel); 118 if (!accep_) { 119 *ostrm_ << "\t"; 120 PrintOLabel(arc.olabel); 121 } 122 if (arc.weight != Weight::One()) 123 *ostrm_ << "\t" << arc.weight; 124 *ostrm_ << "\n"; 125 output = true; 126 } 127 Weight final = fst_.Final(s); 128 if (final != Weight::Zero() || !output) { 129 PrintStateId(s); 130 if (final != Weight::One()) { 131 *ostrm_ << "\t" << final; 132 } 133 *ostrm_ << "\n"; 134 } 135 } 136 137 const Fst<A> &fst_; 138 const SymbolTable *isyms_; // ilabel symbol table 139 const SymbolTable *osyms_; // olabel symbol table 140 const SymbolTable *ssyms_; // slabel symbol table 141 bool accep_; // print as acceptor when possible 142 ostream *ostrm_; // binary FST destination 143 string dest_; // binary FST destination name 144 DISALLOW_EVIL_CONSTRUCTORS(FstPrinter); 145 }; 146 147#if 0 148 // Main function for fstprint templated on the arc type. 149 template <class Arc> 150 int PrintMain(int argc, char **argv, istream &istrm, 151 const FstReadOptions &opts) { 152 Fst<Arc> *fst = Fst<Arc>::Read(istrm, opts); 153 if (!fst) return 1; 154 155 string dest = "standard output"; 156 ostream *ostrm = &std::cout; 157 if (argc == 3) { 158 dest = argv[2]; 159 ostrm = new ofstream(argv[2]); 160 if (!*ostrm) { 161 LOG(ERROR) << argv[0] << ": Open failed, file = " << argv[2]; 162 return 0; 163 } 164 } 165 ostrm->precision(9); 166 167 const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0; 168 169 if (!FLAGS_isymbols.empty() && !FLAGS_numeric) { 170 isyms = SymbolTable::ReadText(FLAGS_isymbols); 171 if (!isyms) exit(1); 172 } 173 174 if (!FLAGS_osymbols.empty() && !FLAGS_numeric) { 175 osyms = SymbolTable::ReadText(FLAGS_osymbols); 176 if (!osyms) exit(1); 177 } 178 179 if (!FLAGS_ssymbols.empty() && !FLAGS_numeric) { 180 ssyms = SymbolTable::ReadText(FLAGS_ssymbols); 181 if (!ssyms) exit(1); 182 } 183 184 if (!isyms && !FLAGS_numeric) 185 isyms = fst->InputSymbols(); 186 if (!osyms && !FLAGS_numeric) 187 osyms = fst->OutputSymbols(); 188 189 FstPrinter<Arc> fstprinter(*fst, isyms, osyms, ssyms, FLAGS_acceptor); 190 fstprinter.Print(ostrm, dest); 191 192 if (isyms && !FLAGS_save_isymbols.empty()) 193 isyms->WriteText(FLAGS_save_isymbols); 194 195 if (osyms && !FLAGS_save_osymbols.empty()) 196 osyms->WriteText(FLAGS_save_osymbols); 197 198 if (ostrm != &std::cout) 199 delete ostrm; 200 return 0; 201 } 202#endif 203 204 205 template <class A> class FstReader { 206 public: 207 typedef A Arc; 208 typedef typename A::StateId StateId; 209 typedef typename A::Label Label; 210 typedef typename A::Weight Weight; 211 212 FstReader(istream &istrm, const string &source, 213 const SymbolTable *isyms, const SymbolTable *osyms, 214 const SymbolTable *ssyms, bool accep, bool ikeep, 215 bool okeep, bool nkeep) 216 : nline_(0), source_(source), 217 isyms_(isyms), osyms_(osyms), ssyms_(ssyms), 218 nstates_(0), keep_state_numbering_(nkeep) { 219 char line[kLineLen]; 220 while (istrm.getline(line, kLineLen)) { 221 ++nline_; 222 vector<char *> col; 223 SplitToVector(line, "\n\t ", &col, true); 224 if (col.size() == 0 || col[0][0] == '\0') // empty line 225 continue; 226 if (col.size() > 5 || 227 col.size() > 4 && accep || 228 col.size() == 3 && !accep) { 229 LOG(ERROR) << "FstReader: Bad number of columns, source = " << source_ 230 << ", line = " << nline_; 231 exit(1); 232 } 233 StateId s = StrToStateId(col[0]); 234 while (s >= fst_.NumStates()) 235 fst_.AddState(); 236 if (nline_ == 1) 237 fst_.SetStart(s); 238 239 Arc arc; 240 StateId d = s; 241 switch (col.size()) { 242 case 1: 243 fst_.SetFinal(s, Weight::One()); 244 break; 245 case 2: 246 fst_.SetFinal(s, StrToWeight(col[1], true)); 247 break; 248 case 3: 249 arc.nextstate = d = StrToStateId(col[1]); 250 arc.ilabel = StrToILabel(col[2]); 251 arc.olabel = arc.ilabel; 252 arc.weight = Weight::One(); 253 fst_.AddArc(s, arc); 254 break; 255 case 4: 256 arc.nextstate = d = StrToStateId(col[1]); 257 arc.ilabel = StrToILabel(col[2]); 258 if (accep) { 259 arc.olabel = arc.ilabel; 260 arc.weight = StrToWeight(col[3], false); 261 } else { 262 arc.olabel = StrToOLabel(col[3]); 263 arc.weight = Weight::One(); 264 } 265 fst_.AddArc(s, arc); 266 break; 267 case 5: 268 arc.nextstate = d = StrToStateId(col[1]); 269 arc.ilabel = StrToILabel(col[2]); 270 arc.olabel = StrToOLabel(col[3]); 271 arc.weight = StrToWeight(col[4], false); 272 fst_.AddArc(s, arc); 273 } 274 while (d >= fst_.NumStates()) 275 fst_.AddState(); 276 } 277 if (ikeep) 278 fst_.SetInputSymbols(isyms); 279 if (okeep) 280 fst_.SetOutputSymbols(osyms); 281 } 282 283 const VectorFst<A> &Fst() const { return fst_; } 284 285 private: 286 // Maximum line length in text file. 287 static const int kLineLen = 8096; 288 289 int64 StrToId(const char *s, const SymbolTable *syms, 290 const char *name) const { 291 int64 n; 292 293 if (syms) { 294 n = syms->Find(s); 295 if (n < 0) { 296 LOG(ERROR) << "FstReader: Symbol \"" << s 297 << "\" is not mapped to any integer " << name 298 << ", symbol table = " << syms->Name() 299 << ", source = " << source_ << ", line = " << nline_; 300 exit(1); 301 } 302 } else { 303 char *p; 304 n = strtoll(s, &p, 10); 305 if (p < s + strlen(s) || n < 0) { 306 LOG(ERROR) << "FstReader: Bad " << name << " integer = \"" << s 307 << "\", source = " << source_ << ", line = " << nline_; 308 exit(1); 309 } 310 } 311 return n; 312 } 313 314 StateId StrToStateId(const char *s) { 315 StateId n = StrToId(s, ssyms_, "state ID"); 316 317 if (keep_state_numbering_) 318 return n; 319 320 // remap state IDs to make dense set 321 typename hash_map<StateId, StateId>::const_iterator it = states_.find(n); 322 if (it == states_.end()) { 323 states_[n] = nstates_; 324 return nstates_++; 325 } else { 326 return it->second; 327 } 328 } 329 330 StateId StrToILabel(const char *s) const { 331 return StrToId(s, isyms_, "arc ilabel"); 332 } 333 334 StateId StrToOLabel(const char *s) const { 335 return StrToId(s, osyms_, "arc olabel"); 336 } 337 338 Weight StrToWeight(const char *s, bool allow_zero) const { 339 Weight w; 340 istringstream strm(s); 341 strm >> w; 342 if (strm.fail() || !allow_zero && w == Weight::Zero()) { 343 LOG(ERROR) << "FstReader: Bad weight = \"" << s 344 << "\", source = " << source_ << ", line = " << nline_; 345 exit(1); 346 } 347 return w; 348 } 349 350 VectorFst<A> fst_; 351 size_t nline_; 352 string source_; // text FST source name 353 const SymbolTable *isyms_; // ilabel symbol table 354 const SymbolTable *osyms_; // olabel symbol table 355 const SymbolTable *ssyms_; // slabel symbol table 356 hash_map<StateId, StateId> states_; // state ID map 357 StateId nstates_; // number of seen states 358 bool keep_state_numbering_; 359 DISALLOW_EVIL_CONSTRUCTORS(FstReader); 360 }; 361 362#if 0 363 // Main function for fstcompile templated on the arc type. Last two 364 // arguments unneeded since fstcompile passes the arc type as a flag 365 // unlike the other mains, which infer the arc type from an input Fst. 366 template <class Arc> 367 int CompileMain(int argc, char **argv, istream& /* strm */, 368 const FstReadOptions & /* opts */) { 369 char *ifilename = "standard input"; 370 istream *istrm = &std::cin; 371 if (argc > 1 && strcmp(argv[1], "-") != 0) { 372 ifilename = argv[1]; 373 istrm = new ifstream(ifilename); 374 if (!*istrm) { 375 LOG(ERROR) << argv[0] << ": Open failed, file = " << ifilename; 376 return 1; 377 } 378 } 379 const SymbolTable *isyms = 0, *osyms = 0, *ssyms = 0; 380 381 if (!FLAGS_isymbols.empty()) { 382 isyms = SymbolTable::ReadText(FLAGS_isymbols); 383 if (!isyms) exit(1); 384 } 385 386 if (!FLAGS_osymbols.empty()) { 387 osyms = SymbolTable::ReadText(FLAGS_osymbols); 388 if (!osyms) exit(1); 389 } 390 391 if (!FLAGS_ssymbols.empty()) { 392 ssyms = SymbolTable::ReadText(FLAGS_ssymbols); 393 if (!ssyms) exit(1); 394 } 395 396 FstReader<Arc> fstreader(*istrm, ifilename, isyms, osyms, ssyms, 397 FLAGS_acceptor, FLAGS_keep_isymbols, 398 FLAGS_keep_osymbols, FLAGS_keep_state_numbering); 399 400 const Fst<Arc> *fst = &fstreader.Fst(); 401 if (FLAGS_fst_type != "vector") { 402 fst = Convert<Arc>(*fst, FLAGS_fst_type); 403 if (!fst) return 1; 404 } 405 fst->Write(argc > 2 ? argv[2] : ""); 406 if (istrm != &std::cin) 407 delete istrm; 408 return 0; 409 } 410#endif 411 412} // namespace fst 413 414#endif /* __FST_IO_H__ */ 415 416