1
2// string.h
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16// Copyright 2005-2010 Google, Inc.
17// Author: allauzen@google.com (Cyril Allauzen)
18//
19// \file
20// Utilities to convert strings into FSTs.
21//
22
23#ifndef FST_LIB_STRING_H_
24#define FST_LIB_STRING_H_
25
26#include <fst/compact-fst.h>
27#include <fst/mutable-fst.h>
28
29DECLARE_string(fst_field_separator);
30
31namespace fst {
32
33// Functor compiling a string in an FST
34template <class A>
35class StringCompiler {
36 public:
37  typedef A Arc;
38  typedef typename A::Label Label;
39  typedef typename A::Weight Weight;
40
41  enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
42
43  StringCompiler(TokenType type, const SymbolTable *syms = 0,
44                 Label unknown_label = kNoLabel,
45                 bool allow_negative = false)
46      : token_type_(type), syms_(syms), unknown_label_(unknown_label),
47        allow_negative_(allow_negative) {}
48
49  // Compile string 's' into FST 'fst'.
50  template <class F>
51  bool operator()(const string &s, F *fst) {
52    vector<Label> labels;
53    if (!ConvertStringToLabels(s, &labels))
54      return false;
55    Compile(labels, fst);
56    return true;
57  }
58
59 private:
60  bool ConvertStringToLabels(const string &str, vector<Label> *labels) const {
61    labels->clear();
62    if (token_type_ == BYTE) {
63      for (size_t i = 0; i < str.size(); ++i)
64        labels->push_back(static_cast<unsigned char>(str[i]));
65    } else if (token_type_ == UTF8) {
66      return UTF8StringToLabels(str, labels);
67    } else {
68      char *c_str = new char[str.size() + 1];
69      str.copy(c_str, str.size());
70      c_str[str.size()] = 0;
71      vector<char *> vec;
72      string separator = "\n" + FLAGS_fst_field_separator;
73      SplitToVector(c_str, separator.c_str(), &vec, true);
74      for (size_t i = 0; i < vec.size(); ++i) {
75        Label label;
76        if (!ConvertSymbolToLabel(vec[i], &label))
77          return false;
78        labels->push_back(label);
79      }
80      delete[] c_str;
81    }
82    return true;
83  }
84
85  void Compile(const vector<Label> &labels, MutableFst<A> *fst) const {
86    fst->DeleteStates();
87    while (fst->NumStates() <= labels.size())
88      fst->AddState();
89    for (size_t i = 0; i < labels.size(); ++i)
90      fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
91    fst->SetStart(0);
92    fst->SetFinal(labels.size(), Weight::One());
93  }
94
95  template <class Unsigned>
96  void Compile(const vector<Label> &labels, CompactFst<A, StringCompactor<A>,
97               Unsigned> *fst) const {
98    fst->SetCompactElements(labels.begin(), labels.end());
99  }
100
101  bool ConvertSymbolToLabel(const char *s, Label* output) const {
102    int64 n;
103    if (syms_) {
104      n = syms_->Find(s);
105      if ((n == -1) && (unknown_label_ != kNoLabel))
106        n = unknown_label_;
107      if (n == -1 || (!allow_negative_ && n < 0)) {
108        VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Symbol \"" << s
109                << "\" is not mapped to any integer label, symbol table = "
110                 << syms_->Name();
111        return false;
112      }
113    } else {
114      char *p;
115      n = strtoll(s, &p, 10);
116      if (p < s + strlen(s) || (!allow_negative_ && n < 0)) {
117        VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Bad label integer "
118                << "= \"" << s << "\"";
119        return false;
120      }
121    }
122    *output = n;
123    return true;
124  }
125
126  TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
127  const SymbolTable *syms_;  // Symbol table used when token type is symbol
128  Label unknown_label_;      // Label for token missing from symbol table
129  bool allow_negative_;      // Negative labels allowed?
130
131  DISALLOW_COPY_AND_ASSIGN(StringCompiler);
132};
133
134// Functor to print a string FST as a string.
135template <class A>
136class StringPrinter {
137 public:
138  typedef A Arc;
139  typedef typename A::Label Label;
140  typedef typename A::StateId StateId;
141  typedef typename A::Weight Weight;
142
143  enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
144
145  StringPrinter(TokenType token_type,
146                const SymbolTable *syms = 0)
147      : token_type_(token_type), syms_(syms) {}
148
149  // Convert the FST 'fst' into the string 'output'
150  bool operator()(const Fst<A> &fst, string *output) {
151    bool is_a_string = FstToLabels(fst);
152    if (!is_a_string) {
153      VLOG(1) << "StringPrinter::operator(): Fst is not a string.";
154      return false;
155    }
156
157    output->clear();
158
159    if (token_type_ == SYMBOL) {
160      stringstream sstrm;
161      for (size_t i = 0; i < labels_.size(); ++i) {
162        if (i)
163          sstrm << *(FLAGS_fst_field_separator.rbegin());
164        if (!PrintLabel(labels_[i], sstrm))
165          return false;
166      }
167      *output = sstrm.str();
168    } else if (token_type_ == BYTE) {
169      for (size_t i = 0; i < labels_.size(); ++i) {
170        output->push_back(labels_[i]);
171      }
172    } else if (token_type_ == UTF8) {
173      return LabelsToUTF8String(labels_, output);
174    } else {
175      VLOG(1) << "StringPrinter::operator(): Unknown token type: "
176              << token_type_;
177      return false;
178    }
179    return true;
180  }
181
182 private:
183  bool FstToLabels(const Fst<A> &fst) {
184    labels_.clear();
185
186    StateId s = fst.Start();
187    if (s == kNoStateId) {
188      VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for "
189              << "string fst.";
190      return false;
191    }
192
193    while (fst.Final(s) == Weight::Zero()) {
194      ArcIterator<Fst<A> > aiter(fst, s);
195      if (aiter.Done()) {
196        VLOG(2) << "StringPrinter::FstToLabels: String fst traversal does "
197                << "not reach final state.";
198        return false;
199      }
200
201      const A& arc = aiter.Value();
202      labels_.push_back(arc.olabel);
203
204      s = arc.nextstate;
205      if (s == kNoStateId) {
206        VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid "
207                << "state.";
208        return false;
209      }
210
211      aiter.Next();
212      if (!aiter.Done()) {
213        VLOG(2) << "StringPrinter::FstToLabels: State with multiple "
214                << "outgoing arcs found.";
215        return false;
216      }
217    }
218
219    return true;
220  }
221
222  bool PrintLabel(Label lab, ostream& ostrm) {
223    if (syms_) {
224      string symbol = syms_->Find(lab);
225      if (symbol == "") {
226        VLOG(2) << "StringPrinter::PrintLabel: Integer " << lab << " is not "
227                << "mapped to any textual symbol, symbol table = "
228                 << syms_->Name();
229        return false;
230      }
231      ostrm << symbol;
232    } else {
233      ostrm << lab;
234    }
235    return true;
236  }
237
238  TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
239  const SymbolTable *syms_;  // Symbol table used when token type is symbol
240  vector<Label> labels_;     // Input FST labels.
241
242  DISALLOW_COPY_AND_ASSIGN(StringPrinter);
243};
244
245}  // namespace fst
246
247#endif // FST_LIB_STRING_H_
248