string.h revision 5b6dc79427b8f7eeb6a7ff68034ab8548ce670ea
1
2// string.h
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16// Copyright 2005-2010 Google, Inc.
17// Author: allauzen@google.com (Cyril Allauzen)
18//
19// \file
20// Utilities to convert strings into FSTs.
21//
22
23#ifndef FST_LIB_STRING_H_
24#define FST_LIB_STRING_H_
25
26#include <fst/compact-fst.h>
27#include <fst/icu.h>
28#include <fst/mutable-fst.h>
29
30DECLARE_string(fst_field_separator);
31
32namespace fst {
33
34// Functor compiling a string in an FST
35template <class A>
36class StringCompiler {
37 public:
38  typedef A Arc;
39  typedef typename A::Label Label;
40  typedef typename A::Weight Weight;
41
42  enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
43
44  StringCompiler(TokenType type, const SymbolTable *syms = 0,
45                 Label unknown_label = kNoLabel,
46                 bool allow_negative = false)
47      : token_type_(type), syms_(syms), unknown_label_(unknown_label),
48        allow_negative_(allow_negative) {}
49
50  // Compile string 's' into FST 'fst'.
51  template <class F>
52  bool operator()(const string &s, F *fst) const {
53    vector<Label> labels;
54    if (!ConvertStringToLabels(s, &labels))
55      return false;
56    Compile(labels, fst);
57    return true;
58  }
59
60  template <class F>
61  bool operator()(const string &s, F *fst, Weight w) const {
62    vector<Label> labels;
63    if (!ConvertStringToLabels(s, &labels))
64      return false;
65    Compile(labels, fst, w);
66    return true;
67  }
68
69 private:
70  bool ConvertStringToLabels(const string &str, vector<Label> *labels) const {
71    labels->clear();
72    if (token_type_ == BYTE) {
73      for (size_t i = 0; i < str.size(); ++i)
74        labels->push_back(static_cast<unsigned char>(str[i]));
75    } else if (token_type_ == UTF8) {
76      return UTF8StringToLabels(str, labels);
77    } else {
78      char *c_str = new char[str.size() + 1];
79      str.copy(c_str, str.size());
80      c_str[str.size()] = 0;
81      vector<char *> vec;
82      string separator = "\n" + FLAGS_fst_field_separator;
83      SplitToVector(c_str, separator.c_str(), &vec, true);
84      for (size_t i = 0; i < vec.size(); ++i) {
85        Label label;
86        if (!ConvertSymbolToLabel(vec[i], &label))
87          return false;
88        labels->push_back(label);
89      }
90      delete[] c_str;
91    }
92    return true;
93  }
94
95  void Compile(const vector<Label> &labels, MutableFst<A> *fst,
96               const Weight &weight = Weight::One()) const {
97    fst->DeleteStates();
98    while (fst->NumStates() <= labels.size())
99      fst->AddState();
100    for (size_t i = 0; i < labels.size(); ++i)
101      fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
102    fst->SetStart(0);
103    fst->SetFinal(labels.size(), weight);
104  }
105
106  template <class Unsigned>
107  void Compile(const vector<Label> &labels,
108               CompactFst<A, StringCompactor<A>, Unsigned> *fst) const {
109    fst->SetCompactElements(labels.begin(), labels.end());
110  }
111
112  template <class Unsigned>
113  void Compile(const vector<Label> &labels,
114               CompactFst<A, WeightedStringCompactor<A>, Unsigned> *fst,
115               const Weight &weight = Weight::One()) const {
116    vector<pair<Label, Weight> > compacts;
117    compacts.reserve(labels.size());
118    for (size_t i = 0; i < labels.size(); ++i)
119      compacts.push_back(make_pair(labels[i], Weight::One()));
120    compacts.back().second = weight;
121    fst->SetCompactElements(compacts.begin(), compacts.end());
122  }
123
124  bool ConvertSymbolToLabel(const char *s, Label* output) const {
125    int64 n;
126    if (syms_) {
127      n = syms_->Find(s);
128      if ((n == -1) && (unknown_label_ != kNoLabel))
129        n = unknown_label_;
130      if (n == -1 || (!allow_negative_ && n < 0)) {
131        VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Symbol \"" << s
132                << "\" is not mapped to any integer label, symbol table = "
133                 << syms_->Name();
134        return false;
135      }
136    } else {
137      char *p;
138      n = strtoll(s, &p, 10);
139      if (p < s + strlen(s) || (!allow_negative_ && n < 0)) {
140        VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Bad label integer "
141                << "= \"" << s << "\"";
142        return false;
143      }
144    }
145    *output = n;
146    return true;
147  }
148
149  TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
150  const SymbolTable *syms_;  // Symbol table used when token type is symbol
151  Label unknown_label_;      // Label for token missing from symbol table
152  bool allow_negative_;      // Negative labels allowed?
153
154  DISALLOW_COPY_AND_ASSIGN(StringCompiler);
155};
156
157// Functor to print a string FST as a string.
158template <class A>
159class StringPrinter {
160 public:
161  typedef A Arc;
162  typedef typename A::Label Label;
163  typedef typename A::StateId StateId;
164  typedef typename A::Weight Weight;
165
166  enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
167
168  StringPrinter(TokenType token_type,
169                const SymbolTable *syms = 0)
170      : token_type_(token_type), syms_(syms) {}
171
172  // Convert the FST 'fst' into the string 'output'
173  bool operator()(const Fst<A> &fst, string *output) {
174    bool is_a_string = FstToLabels(fst);
175    if (!is_a_string) {
176      VLOG(1) << "StringPrinter::operator(): Fst is not a string.";
177      return false;
178    }
179
180    output->clear();
181
182    if (token_type_ == SYMBOL) {
183      stringstream sstrm;
184      for (size_t i = 0; i < labels_.size(); ++i) {
185        if (i)
186          sstrm << *(FLAGS_fst_field_separator.rbegin());
187        if (!PrintLabel(labels_[i], sstrm))
188          return false;
189      }
190      *output = sstrm.str();
191    } else if (token_type_ == BYTE) {
192      output->reserve(labels_.size());
193      for (size_t i = 0; i < labels_.size(); ++i) {
194        output->push_back(labels_[i]);
195      }
196    } else if (token_type_ == UTF8) {
197      return LabelsToUTF8String(labels_, output);
198    } else {
199      VLOG(1) << "StringPrinter::operator(): Unknown token type: "
200              << token_type_;
201      return false;
202    }
203    return true;
204  }
205
206 private:
207  bool FstToLabels(const Fst<A> &fst) {
208    labels_.clear();
209
210    StateId s = fst.Start();
211    if (s == kNoStateId) {
212      VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for "
213              << "string fst.";
214      return false;
215    }
216
217    while (fst.Final(s) == Weight::Zero()) {
218      ArcIterator<Fst<A> > aiter(fst, s);
219      if (aiter.Done()) {
220        VLOG(2) << "StringPrinter::FstToLabels: String fst traversal does "
221                << "not reach final state.";
222        return false;
223      }
224
225      const A& arc = aiter.Value();
226      labels_.push_back(arc.olabel);
227
228      s = arc.nextstate;
229      if (s == kNoStateId) {
230        VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid "
231                << "state.";
232        return false;
233      }
234
235      aiter.Next();
236      if (!aiter.Done()) {
237        VLOG(2) << "StringPrinter::FstToLabels: State with multiple "
238                << "outgoing arcs found.";
239        return false;
240      }
241    }
242
243    return true;
244  }
245
246  bool PrintLabel(Label lab, ostream& ostrm) {
247    if (syms_) {
248      string symbol = syms_->Find(lab);
249      if (symbol == "") {
250        VLOG(2) << "StringPrinter::PrintLabel: Integer " << lab << " is not "
251                << "mapped to any textual symbol, symbol table = "
252                 << syms_->Name();
253        return false;
254      }
255      ostrm << symbol;
256    } else {
257      ostrm << lab;
258    }
259    return true;
260  }
261
262  TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
263  const SymbolTable *syms_;  // Symbol table used when token type is symbol
264  vector<Label> labels_;     // Input FST labels.
265
266  DISALLOW_COPY_AND_ASSIGN(StringPrinter);
267};
268
269}  // namespace fst
270
271#endif // FST_LIB_STRING_H_
272