1// symbol-table.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// Classes to provide symbol-to-integer and integer-to-symbol mappings.
18
19#ifndef FST_LIB_SYMBOL_TABLE_H__
20#define FST_LIB_SYMBOL_TABLE_H__
21
22#include <fstream>
23#include <iostream>
24#include <string>
25#include <unordered_map>
26#include <vector>
27
28#include "fst/lib/compat.h"
29
30
31
32DECLARE_bool(fst_compat_symbols);
33
34namespace fst {
35
36class SymbolTableImpl {
37  friend class SymbolTableIterator;
38 public:
39  SymbolTableImpl(const string &name)
40      : name_(name), available_key_(0), ref_count_(1),
41        check_sum_finalized_(false) {}
42  ~SymbolTableImpl() {
43    for (size_t i = 0; i < symbols_.size(); ++i)
44      delete[] symbols_[i];
45  }
46
47  int64 AddSymbol(const string& symbol, int64 key);
48
49  int64 AddSymbol(const string& symbol) {
50    int64 key = Find(symbol);
51    return (key == -1) ? AddSymbol(symbol, available_key_++) : key;
52  }
53
54  void AddTable(SymbolTableImpl* table) {
55    for (size_t i = 0; i < table->symbols_.size(); ++i) {
56      AddSymbol(table->symbols_[i]);
57    }
58  }
59
60  static SymbolTableImpl* ReadText(const string& filename);
61
62  static SymbolTableImpl* Read(istream &strm, const string& source);
63
64  bool Write(ostream &strm) const;
65
66  bool WriteText(ostream &strm) const;
67
68  //
69  // Return the string associated with the key. If the key is out of
70  // range (<0, >max), return an empty string.
71  string Find(int64 key) const {
72    std::unordered_map<int64, string>::const_iterator it =
73      key_map_.find(key);
74    if (it == key_map_.end()) {
75      return "";
76    }
77    return it->second;
78  }
79
80  //
81  // Return the key associated with the symbol. If the symbol
82  // does not exists, return -1.
83  int64 Find(const string& symbol) const {
84    return Find(symbol.c_str());
85  }
86
87  //
88  // Return the key associated with the symbol. If the symbol
89  // does not exists, return -1.
90  int64 Find(const char* symbol) const {
91    unordered_map<string, int64>::const_iterator it =
92      symbol_map_.find(symbol);
93    if (it == symbol_map_.end()) {
94      return -1;
95    }
96    return it->second;
97  }
98
99  const string& Name() const { return name_; }
100
101  int IncrRefCount() const {
102    return ++ref_count_;
103  }
104  int DecrRefCount() const {
105    return --ref_count_;
106  }
107
108  string CheckSum() const {
109    if (!check_sum_finalized_) {
110      RecomputeCheckSum();
111      check_sum_string_ = check_sum_.Digest();
112    }
113    return check_sum_string_;
114  }
115
116  int64 AvailableKey() const {
117    return available_key_;
118  }
119
120  // private support methods
121 private:
122  void RecomputeCheckSum() const;
123  static SymbolTableImpl* Read1(istream &, const string &);
124
125  string name_;
126  int64 available_key_;
127  vector<const char *> symbols_;
128  std::unordered_map<int64, string> key_map_;
129  std::unordered_map<string, int64> symbol_map_;
130
131  mutable int ref_count_;
132  mutable bool check_sum_finalized_;
133  mutable MD5 check_sum_;
134  mutable string check_sum_string_;
135
136  DISALLOW_EVIL_CONSTRUCTORS(SymbolTableImpl);
137};
138
139
140class SymbolTableIterator;
141
142//
143// \class SymbolTable
144// \brief Symbol (string) to int and reverse mapping
145//
146// The SymbolTable implements the mappings of labels to strings and reverse.
147// SymbolTables are used to describe the alphabet of the input and output
148// labels for arcs in a Finite State Transducer.
149//
150// SymbolTables are reference counted and can therefore be shared across
151// multiple machines. For example a language model grammar G, with a
152// SymbolTable for the words in the language model can share this symbol
153// table with the lexical representation L o G.
154//
155class SymbolTable {
156  friend class SymbolTableIterator;
157 public:
158  static const int64 kNoSymbol = -1;
159
160  // Construct symbol table with a unique name.
161  SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {}
162
163  // Create a reference counted copy.
164  SymbolTable(const SymbolTable& table) : impl_(table.impl_) {
165    impl_->IncrRefCount();
166  }
167
168  // Derefence implentation object. When reference count hits 0, delete
169  // implementation.
170  ~SymbolTable() {
171    if (!impl_->DecrRefCount()) delete impl_;
172  }
173
174  // create a reference counted copy
175  SymbolTable* Copy() const {
176    return new SymbolTable(*this);
177  }
178
179  // Add a symbol with given key to table. A symbol table also
180  // keeps track of the last available key (highest key value in
181  // the symbol table).
182  //
183  // \param symbol string symbol to add
184  // \param key associated key for string symbol
185  // \return the key created by the symbol table. Symbols allready added to
186  //         the symbol table will not get a different key.
187  int64 AddSymbol(const string& symbol, int64 key) {
188    return impl_->AddSymbol(symbol, key);
189  }
190
191  // Add a symbol to the table. The associated value key is automatically
192  // assigned by the symbol table.
193  //
194  // \param symbol string to add to the table
195  // \return the value key assigned to the associated string symbol
196  int64 AddSymbol(const string& symbol) {
197    return impl_->AddSymbol(symbol);
198  }
199
200  // Add another symbol table to this table. All key values will be offset
201  // by the current available key (highest key value in the symbol table).
202  // Note string symbols with the same key value with still have the same
203  // key value after the symbol table has been merged, but a different
204  // value. Adding symbol tables do not result in changes in the base table.
205  //
206  // Merging N symbol tables is often useful when combining the various
207  // name spaces of transducers to a unified representation.
208  //
209  // \param table the symbol table to add to this table
210  void AddTable(const SymbolTable& table) {
211    return impl_->AddTable(table.impl_);
212  }
213
214  // return the name of the symbol table
215  const string& Name() const {
216    return impl_->Name();
217  }
218
219  // return the MD5 check-sum for this table. All new symbols added to
220  // the table will result in an updated checksum.
221  string CheckSum() const {
222    return impl_->CheckSum();
223  }
224
225  // read an ascii representation of the symbol table
226  static SymbolTable* ReadText(const string& filename) {
227    SymbolTableImpl* impl = SymbolTableImpl::ReadText(filename);
228    if (!impl)
229      return 0;
230    else
231      return new SymbolTable(impl);
232  }
233
234  // read a binary dump of the symbol table
235  static SymbolTable* Read(istream &strm, const string& source) {
236    SymbolTableImpl* impl = SymbolTableImpl::Read(strm, source);
237    if (!impl)
238      return 0;
239    else
240      return new SymbolTable(impl);
241  }
242
243  // read a binary dump of the symbol table
244  static SymbolTable* Read(const string& filename) {
245    ifstream strm(filename.c_str());
246    if (!strm) {
247      LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
248      return 0;
249    }
250    return Read(strm, filename);
251  }
252
253  bool Write(ostream  &strm) const {
254    return impl_->Write(strm);
255  }
256
257  bool Write(const string& filename) const {
258    ofstream strm(filename.c_str());
259    if (!strm) {
260      LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
261      return false;
262    }
263    return Write(strm);
264  }
265
266  // Dump an ascii text representation of the symbol table
267  bool WriteText(ostream &strm) const {
268    return impl_->WriteText(strm);
269  }
270
271  // Dump an ascii text representation of the symbol table
272  bool WriteText(const string& filename) const {
273    ofstream strm(filename.c_str());
274    if (!strm) {
275      LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
276      return false;
277    }
278    return WriteText(strm);
279  }
280
281  // Return the string associated with the key. If the key is out of
282  // range (<0, >max), log error and return an empty string.
283  string Find(int64 key) const {
284    return impl_->Find(key);
285  }
286
287  // Return the key associated with the symbol. If the symbol
288  // does not exists, log error and  return -1
289  int64 Find(const string& symbol) const {
290    return impl_->Find(symbol);
291  }
292
293  // Return the key associated with the symbol. If the symbol
294  // does not exists, log error and  return -1
295  int64 Find(const char* symbol) const {
296    return impl_->Find(symbol);
297  }
298
299  // return the current available key (i.e highest key number) in
300  // the symbol table
301  int64 AvailableKey(void) const {
302    return impl_->AvailableKey();
303  }
304
305 protected:
306  explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {}
307
308  const SymbolTableImpl* Impl() const {
309    return impl_;
310  }
311
312 private:
313  SymbolTableImpl* impl_;
314
315
316  void operator=(const SymbolTable &table);  // disallow
317};
318
319
320//
321// \class SymbolTableIterator
322// \brief Iterator class for symbols in a symbol table
323class SymbolTableIterator {
324 public:
325  // Constructor creates a refcounted copy of underlying implementation
326  SymbolTableIterator(const SymbolTable& symbol_table) {
327    impl_ = symbol_table.Impl();
328    impl_->IncrRefCount();
329    pos_ = 0;
330    size_ = impl_->symbols_.size();
331  }
332
333  // decrement implementation refcount, and delete if 0
334  ~SymbolTableIterator() {
335    if (!impl_->DecrRefCount()) delete impl_;
336  }
337
338  // is iterator done
339  bool Done(void) {
340    return (pos_ == size_);
341  }
342
343  // return the Value() of the current symbol (in64 key)
344  int64 Value(void) {
345    return impl_->Find(impl_->symbols_[pos_]);
346  }
347
348  // return the string of the current symbol
349  const char* Symbol(void) {
350    return impl_->symbols_[pos_];
351  }
352
353  // advance iterator forward
354  void Next(void) {
355    if (Done()) return;
356    ++pos_;
357  }
358
359  // reset iterator
360  void Reset(void) {
361    pos_ = 0;
362  }
363
364 private:
365  const SymbolTableImpl* impl_;
366  size_t pos_;
367  size_t size_;
368};
369
370
371// Tests compatibilty between two sets of symbol tables
372inline bool CompatSymbols(const SymbolTable *syms1,
373                          const SymbolTable *syms2) {
374  if (!FLAGS_fst_compat_symbols)
375    return true;
376  else if (!syms1 && !syms2)
377    return true;
378  else if ((syms1 && !syms2) || (!syms1 && syms2))
379    return false;
380  else
381    return syms1->CheckSum() == syms2->CheckSum();
382}
383
384}  // namespace fst
385
386#endif  // FST_LIB_SYMBOL_TABLE_H__
387