string.h revision dfd8b8327b93660601d016cdc6f29f433b45a8d8
1
2// string.h
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//     http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16// Copyright 2005-2010 Google, Inc.
17// Author: allauzen@google.com (Cyril Allauzen)
18//
19// \file
20// Utilities to convert strings into FSTs.
21//
22
23#ifndef FST_LIB_STRING_H_
24#define FST_LIB_STRING_H_
25
26#include <fst/compact-fst.h>
27#include <fst/icu.h>
28#include <fst/mutable-fst.h>
29
30DECLARE_string(fst_field_separator);
31
32namespace fst {
33
34// Functor compiling a string in an FST
35template <class A>
36class StringCompiler {
37 public:
38  typedef A Arc;
39  typedef typename A::Label Label;
40  typedef typename A::Weight Weight;
41
42  enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
43
44  StringCompiler(TokenType type, const SymbolTable *syms = 0,
45                 Label unknown_label = kNoLabel,
46                 bool allow_negative = false)
47      : token_type_(type), syms_(syms), unknown_label_(unknown_label),
48        allow_negative_(allow_negative) {}
49
50  // Compile string 's' into FST 'fst'.
51  template <class F>
52  bool operator()(const string &s, F *fst) const {
53    vector<Label> labels;
54    if (!ConvertStringToLabels(s, &labels))
55      return false;
56    Compile(labels, fst);
57    return true;
58  }
59
60 private:
61  bool ConvertStringToLabels(const string &str, vector<Label> *labels) const {
62    labels->clear();
63    if (token_type_ == BYTE) {
64      for (size_t i = 0; i < str.size(); ++i)
65        labels->push_back(static_cast<unsigned char>(str[i]));
66    } else if (token_type_ == UTF8) {
67      return UTF8StringToLabels(str, labels);
68    } else {
69      char *c_str = new char[str.size() + 1];
70      str.copy(c_str, str.size());
71      c_str[str.size()] = 0;
72      vector<char *> vec;
73      string separator = "\n" + FLAGS_fst_field_separator;
74      SplitToVector(c_str, separator.c_str(), &vec, true);
75      for (size_t i = 0; i < vec.size(); ++i) {
76        Label label;
77        if (!ConvertSymbolToLabel(vec[i], &label))
78          return false;
79        labels->push_back(label);
80      }
81      delete[] c_str;
82    }
83    return true;
84  }
85
86  void Compile(const vector<Label> &labels, MutableFst<A> *fst) const {
87    fst->DeleteStates();
88    while (fst->NumStates() <= labels.size())
89      fst->AddState();
90    for (size_t i = 0; i < labels.size(); ++i)
91      fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
92    fst->SetStart(0);
93    fst->SetFinal(labels.size(), Weight::One());
94  }
95
96  template <class Unsigned>
97  void Compile(const vector<Label> &labels, CompactFst<A, StringCompactor<A>,
98               Unsigned> *fst) const {
99    fst->SetCompactElements(labels.begin(), labels.end());
100  }
101
102  bool ConvertSymbolToLabel(const char *s, Label* output) const {
103    int64 n;
104    if (syms_) {
105      n = syms_->Find(s);
106      if ((n == -1) && (unknown_label_ != kNoLabel))
107        n = unknown_label_;
108      if (n == -1 || (!allow_negative_ && n < 0)) {
109        VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Symbol \"" << s
110                << "\" is not mapped to any integer label, symbol table = "
111                 << syms_->Name();
112        return false;
113      }
114    } else {
115      char *p;
116      n = strtoll(s, &p, 10);
117      if (p < s + strlen(s) || (!allow_negative_ && n < 0)) {
118        VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Bad label integer "
119                << "= \"" << s << "\"";
120        return false;
121      }
122    }
123    *output = n;
124    return true;
125  }
126
127  TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
128  const SymbolTable *syms_;  // Symbol table used when token type is symbol
129  Label unknown_label_;      // Label for token missing from symbol table
130  bool allow_negative_;      // Negative labels allowed?
131
132  DISALLOW_COPY_AND_ASSIGN(StringCompiler);
133};
134
135// Functor to print a string FST as a string.
136template <class A>
137class StringPrinter {
138 public:
139  typedef A Arc;
140  typedef typename A::Label Label;
141  typedef typename A::StateId StateId;
142  typedef typename A::Weight Weight;
143
144  enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
145
146  StringPrinter(TokenType token_type,
147                const SymbolTable *syms = 0)
148      : token_type_(token_type), syms_(syms) {}
149
150  // Convert the FST 'fst' into the string 'output'
151  bool operator()(const Fst<A> &fst, string *output) {
152    bool is_a_string = FstToLabels(fst);
153    if (!is_a_string) {
154      VLOG(1) << "StringPrinter::operator(): Fst is not a string.";
155      return false;
156    }
157
158    output->clear();
159
160    if (token_type_ == SYMBOL) {
161      stringstream sstrm;
162      for (size_t i = 0; i < labels_.size(); ++i) {
163        if (i)
164          sstrm << *(FLAGS_fst_field_separator.rbegin());
165        if (!PrintLabel(labels_[i], sstrm))
166          return false;
167      }
168      *output = sstrm.str();
169    } else if (token_type_ == BYTE) {
170      for (size_t i = 0; i < labels_.size(); ++i) {
171        output->push_back(labels_[i]);
172      }
173    } else if (token_type_ == UTF8) {
174      return LabelsToUTF8String(labels_, output);
175    } else {
176      VLOG(1) << "StringPrinter::operator(): Unknown token type: "
177              << token_type_;
178      return false;
179    }
180    return true;
181  }
182
183 private:
184  bool FstToLabels(const Fst<A> &fst) {
185    labels_.clear();
186
187    StateId s = fst.Start();
188    if (s == kNoStateId) {
189      VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for "
190              << "string fst.";
191      return false;
192    }
193
194    while (fst.Final(s) == Weight::Zero()) {
195      ArcIterator<Fst<A> > aiter(fst, s);
196      if (aiter.Done()) {
197        VLOG(2) << "StringPrinter::FstToLabels: String fst traversal does "
198                << "not reach final state.";
199        return false;
200      }
201
202      const A& arc = aiter.Value();
203      labels_.push_back(arc.olabel);
204
205      s = arc.nextstate;
206      if (s == kNoStateId) {
207        VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid "
208                << "state.";
209        return false;
210      }
211
212      aiter.Next();
213      if (!aiter.Done()) {
214        VLOG(2) << "StringPrinter::FstToLabels: State with multiple "
215                << "outgoing arcs found.";
216        return false;
217      }
218    }
219
220    return true;
221  }
222
223  bool PrintLabel(Label lab, ostream& ostrm) {
224    if (syms_) {
225      string symbol = syms_->Find(lab);
226      if (symbol == "") {
227        VLOG(2) << "StringPrinter::PrintLabel: Integer " << lab << " is not "
228                << "mapped to any textual symbol, symbol table = "
229                 << syms_->Name();
230        return false;
231      }
232      ostrm << symbol;
233    } else {
234      ostrm << lab;
235    }
236    return true;
237  }
238
239  TokenType token_type_;     // Token type: symbol, byte or utf8 encoded
240  const SymbolTable *syms_;  // Symbol table used when token type is symbol
241  vector<Label> labels_;     // Input FST labels.
242
243  DISALLOW_COPY_AND_ASSIGN(StringPrinter);
244};
245
246}  // namespace fst
247
248#endif // FST_LIB_STRING_H_
249