compile-strings.h revision dfd8b8327b93660601d016cdc6f29f433b45a8d8
1
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13//
14// Copyright 2005-2010 Google, Inc.
15// Authors: allauzen@google.com (Cyril Allauzen)
16//          ttai@google.com (Terry Tai)
17//          jpr@google.com (Jake Ratkiewicz)
18
19
20#ifndef FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_
21#define FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_
22
23#include <libgen.h>
24#include <string>
25#include <vector>
26using std::vector;
27
28#include <fst/extensions/far/far.h>
29#include <fst/string.h>
30
31namespace fst {
32
33// Construct a reader that provides FSTs from a file (stream) either on a
34// line-by-line basis or on a per-stream basis.  Note that the freshly
35// constructed reader is already set to the first input.
36//
37// Sample Usage:
38//   for (StringReader<Arc> reader(...); !reader.Done(); reader.Next()) {
39//     Fst *fst = reader.GetVectorFst();
40//   }
41template <class A>
42class StringReader {
43 public:
44  typedef A Arc;
45  typedef typename A::Label Label;
46  typedef typename A::Weight Weight;
47  typedef typename StringCompiler<A>::TokenType TokenType;
48
49  enum EntryType { LINE = 1, FILE = 2 };
50
51  StringReader(istream &istrm,
52               const string &source,
53               EntryType entry_type,
54               TokenType token_type,
55               bool allow_negative_labels,
56               const SymbolTable *syms = 0,
57               Label unknown_label = kNoStateId)
58      : nline_(0), strm_(istrm), source_(source), entry_type_(entry_type),
59        token_type_(token_type), symbols_(syms), done_(false),
60        compiler_(token_type, syms, unknown_label, allow_negative_labels) {
61    Next();  // Initialize the reader to the first input.
62  }
63
64  bool Done() {
65    return done_;
66  }
67
68  void Next() {
69    VLOG(1) << "Processing source " << source_ << " at line " << nline_;
70    if (!strm_) {                    // We're done if we have no more input.
71      done_ = true;
72      return;
73    }
74    if (entry_type_ == LINE) {
75      getline(strm_, content_);
76      ++nline_;
77    } else {
78      content_.clear();
79      string line;
80      while (getline(strm_, line)) {
81        ++nline_;
82        content_.append(line);
83        content_.append("\n");
84      }
85    }
86    if (!strm_ && content_.empty())  // We're also done if we read off all the
87      done_ = true;                  // whitespace at the end of a file.
88  }
89
90  VectorFst<A> *GetVectorFst(bool keep_symbols = false) {
91    VectorFst<A> *fst = new VectorFst<A>;
92    if (keep_symbols) {
93      fst->SetInputSymbols(symbols_);
94      fst->SetOutputSymbols(symbols_);
95    }
96    if (compiler_(content_, fst)) {
97      return fst;
98    } else {
99      delete fst;
100      return NULL;
101    }
102  }
103
104  CompactFst<A, StringCompactor<A> > *GetCompactFst(bool keep_symbols = false) {
105    CompactFst<A, StringCompactor<A> > *fst;
106    if (keep_symbols) {
107      VectorFst<A> tmp;
108      tmp.SetInputSymbols(symbols_);
109      tmp.SetOutputSymbols(symbols_);
110      fst = new CompactFst<A, StringCompactor<A> >(tmp);
111    } else {
112      fst = new CompactFst<A, StringCompactor<A> >;
113    }
114    if (compiler_(content_, fst)) {
115      return fst;
116    } else {
117      delete fst;
118      return NULL;
119    }
120  }
121
122 private:
123  size_t nline_;
124  istream &strm_;
125  string source_;
126  EntryType entry_type_;
127  TokenType token_type_;
128  const SymbolTable *symbols_;
129  bool done_;
130  StringCompiler<A> compiler_;
131  string content_;  // The actual content of the input stream's next FST.
132
133  DISALLOW_COPY_AND_ASSIGN(StringReader);
134};
135
136// Compute the minimal length required to encode each line number as a decimal
137// number.
138int KeySize(const char *filename);
139
140template <class Arc>
141void FarCompileStrings(const vector<string> &in_fnames,
142                       const string &out_fname,
143                       const string &fst_type,
144                       const FarType &far_type,
145                       int32 generate_keys,
146                       FarEntryType fet,
147                       FarTokenType tt,
148                       const string &symbols_fname,
149                       const string &unknown_symbol,
150                       bool keep_symbols,
151                       bool initial_symbols,
152                       bool allow_negative_labels,
153                       bool file_list_input,
154                       const string &key_prefix,
155                       const string &key_suffix) {
156  typename StringReader<Arc>::EntryType entry_type;
157  if (fet == FET_LINE) {
158    entry_type = StringReader<Arc>::LINE;
159  } else if (fet == FET_FILE) {
160    entry_type = StringReader<Arc>::FILE;
161  } else {
162    FSTERROR() << "FarCompileStrings: unknown entry type";
163    return;
164  }
165
166  typename StringCompiler<Arc>::TokenType token_type;
167  if (tt == FTT_SYMBOL) {
168    token_type = StringCompiler<Arc>::SYMBOL;
169  } else if (tt == FTT_BYTE) {
170    token_type = StringCompiler<Arc>::BYTE;
171  } else if (tt == FTT_UTF8) {
172    token_type = StringCompiler<Arc>::UTF8;
173  } else {
174    FSTERROR() << "FarCompileStrings: unknown token type";
175    return;
176  }
177
178  bool compact;
179  if (fst_type.empty() || (fst_type == "vector")) {
180    compact = false;
181  } else if (fst_type == "compact") {
182    compact = true;
183  } else {
184    FSTERROR() << "FarCompileStrings: unknown fst type: "
185               << fst_type;
186    return;
187  }
188
189  const SymbolTable *syms = 0;
190  typename Arc::Label unknown_label = kNoLabel;
191  if (!symbols_fname.empty()) {
192    SymbolTableTextOptions opts;
193    opts.allow_negative = allow_negative_labels;
194    syms = SymbolTable::ReadText(symbols_fname, opts);
195    if (!syms) {
196      FSTERROR() << "FarCompileStrings: error reading symbol table: "
197                 << symbols_fname;
198      return;
199    }
200    if (!unknown_symbol.empty()) {
201      unknown_label = syms->Find(unknown_symbol);
202      if (unknown_label == kNoLabel) {
203        FSTERROR() << "FarCompileStrings: unknown label \"" << unknown_label
204                   << "\" missing from symbol table: " << symbols_fname;
205        return;
206      }
207    }
208  }
209
210  FarWriter<Arc> *far_writer =
211      FarWriter<Arc>::Create(out_fname, far_type);
212  if (!far_writer) return;
213
214  vector<string> inputs;
215  if (file_list_input) {
216    for (int i = 1; i < in_fnames.size(); ++i) {
217      istream *istrm = in_fnames.empty() ? &cin :
218          new ifstream(in_fnames[i].c_str());
219      string str;
220      while (getline(*istrm, str))
221        inputs.push_back(str);
222      if (!in_fnames.empty())
223        delete istrm;
224    }
225  } else {
226    inputs = in_fnames;
227  }
228
229  for (int i = 0, n = 0; i < inputs.size(); ++i) {
230    if (generate_keys == 0 && inputs[i].empty()) {
231      FSTERROR() << "FarCompileStrings: read from a file instead of stdin or"
232                 << " set the --generate_keys flags.";
233      delete far_writer;
234      delete syms;
235      return;
236    }
237    int key_size = generate_keys ? generate_keys :
238        (entry_type == StringReader<Arc>::FILE ? 1 :
239         KeySize(inputs[i].c_str()));
240    istream *istrm = inputs[i].empty() ? &cin :
241        new ifstream(inputs[i].c_str());
242
243    bool keep_syms = keep_symbols;
244    for (StringReader<Arc> reader(
245             *istrm, inputs[i].empty() ? "stdin" : inputs[i],
246             entry_type, token_type, allow_negative_labels,
247             syms, unknown_label);
248         !reader.Done();
249         reader.Next()) {
250      ++n;
251      const Fst<Arc> *fst;
252      if (compact)
253        fst = reader.GetCompactFst(keep_syms);
254      else
255        fst = reader.GetVectorFst(keep_syms);
256      if (initial_symbols)
257        keep_syms = false;
258      if (!fst) {
259        FSTERROR() << "FarCompileStrings: compiling string number " << n
260                   << " in file " << inputs[i] << " failed with token_type = "
261                   << (tt == FTT_BYTE ? "byte" :
262                       (tt == FTT_UTF8 ? "utf8" :
263                        (tt == FTT_SYMBOL ? "symbol" : "unknown")))
264                   << " and entry_type = "
265                   << (fet == FET_LINE ? "line" :
266                       (fet == FET_FILE ? "file" : "unknown"));
267        delete far_writer;
268        delete syms;
269        if (!inputs[i].empty()) delete istrm;
270        return;
271      }
272      ostringstream keybuf;
273      keybuf.width(key_size);
274      keybuf.fill('0');
275      keybuf << n;
276      string key;
277      if (generate_keys > 0) {
278        key = keybuf.str();
279      } else {
280        char* filename = new char[inputs[i].size() + 1];
281        strcpy(filename, inputs[i].c_str());
282        key = basename(filename);
283        if (entry_type != StringReader<Arc>::FILE) {
284          key += "-";
285          key += keybuf.str();
286        }
287        delete[] filename;
288      }
289      far_writer->Add(key_prefix + key + key_suffix, *fst);
290      delete fst;
291    }
292    if (generate_keys == 0)
293      n = 0;
294    if (!inputs[i].empty())
295      delete istrm;
296  }
297
298  delete far_writer;
299}
300
301}  // namespace fst
302
303
304#endif  // FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_
305