1
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13//
14// Copyright 2005-2010 Google, Inc.
15// All Rights Reserved.
16//
17// Author : Johan Schalkwyk
18//
19// \file
20// Classes to provide symbol-to-integer and integer-to-symbol mappings.
21
22#ifndef FST_LIB_SYMBOL_TABLE_H__
23#define FST_LIB_SYMBOL_TABLE_H__
24
25#include <cstring>
26#include <string>
27#include <utility>
28using std::pair; using std::make_pair;
29#include <vector>
30using std::vector;
31
32
33#include <fst/compat.h>
34#include <iostream>
35#include <fstream>
36
37
38#include <map>
39
40DECLARE_bool(fst_compat_symbols);
41
42namespace fst {
43
44// WARNING: Reading via symbol table read options should
45//          not be used. This is a temporary work around for
46//          reading symbol ranges of previously stored symbol sets.
47struct SymbolTableReadOptions {
48  SymbolTableReadOptions() { }
49
50  SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_,
51                         const string& source_)
52      : string_hash_ranges(string_hash_ranges_),
53        source(source_) { }
54
55  vector<pair<int64, int64> > string_hash_ranges;
56  string source;
57};
58
59class SymbolTableImpl {
60 public:
61  SymbolTableImpl(const string &name)
62      : name_(name),
63        available_key_(0),
64        dense_key_limit_(0),
65        check_sum_finalized_(false) {}
66
67  explicit SymbolTableImpl(const SymbolTableImpl& impl)
68      : name_(impl.name_),
69        available_key_(0),
70        dense_key_limit_(0),
71        check_sum_finalized_(false) {
72    for (size_t i = 0; i < impl.symbols_.size(); ++i) {
73      AddSymbol(impl.symbols_[i], impl.Find(impl.symbols_[i]));
74    }
75  }
76
77  ~SymbolTableImpl() {
78    for (size_t i = 0; i < symbols_.size(); ++i)
79      delete[] symbols_[i];
80  }
81
82  // TODO(johans): Add flag to specify whether the symbol
83  //               should be indexed as string or int or both.
84  int64 AddSymbol(const string& symbol, int64 key);
85
86  int64 AddSymbol(const string& symbol) {
87    int64 key = Find(symbol);
88    return (key == -1) ? AddSymbol(symbol, available_key_++) : key;
89  }
90
91  static SymbolTableImpl* ReadText(istream &strm,
92                                   const string &name,
93                                   bool allow_negative = false);
94
95  static SymbolTableImpl* Read(istream &strm,
96                               const SymbolTableReadOptions& opts);
97
98  bool Write(ostream &strm) const;
99
100  //
101  // Return the string associated with the key. If the key is out of
102  // range (<0, >max), return an empty string.
103  string Find(int64 key) const {
104    if (key >=0 && key < dense_key_limit_)
105      return string(symbols_[key]);
106
107    map<int64, const char*>::const_iterator it =
108        key_map_.find(key);
109    if (it == key_map_.end()) {
110      return "";
111    }
112    return string(it->second);
113  }
114
115  //
116  // Return the key associated with the symbol. If the symbol
117  // does not exists, return SymbolTable::kNoSymbol.
118  int64 Find(const string& symbol) const {
119    return Find(symbol.c_str());
120  }
121
122  //
123  // Return the key associated with the symbol. If the symbol
124  // does not exists, return SymbolTable::kNoSymbol.
125  int64 Find(const char* symbol) const {
126    map<const char *, int64, StrCmp>::const_iterator it =
127        symbol_map_.find(symbol);
128    if (it == symbol_map_.end()) {
129      return -1;
130    }
131    return it->second;
132  }
133
134  int64 GetNthKey(ssize_t pos) const {
135    if ((pos < 0) || (pos >= symbols_.size())) return -1;
136    else return Find(symbols_[pos]);
137  }
138
139  const string& Name() const { return name_; }
140
141  int IncrRefCount() const {
142    return ref_count_.Incr();
143  }
144  int DecrRefCount() const {
145    return ref_count_.Decr();
146  }
147  int RefCount() const {
148    return ref_count_.count();
149  }
150
151  string CheckSum() const {
152    MutexLock check_sum_lock(&check_sum_mutex_);
153    MaybeRecomputeCheckSum();
154    return check_sum_string_;
155  }
156
157  string LabeledCheckSum() const {
158    MutexLock check_sum_lock(&check_sum_mutex_);
159    MaybeRecomputeCheckSum();
160    return labeled_check_sum_string_;
161  }
162
163  int64 AvailableKey() const {
164    return available_key_;
165  }
166
167  size_t NumSymbols() const {
168    return symbols_.size();
169  }
170
171 private:
172  // Recomputes the checksums (both of them) if we've had changes since the last
173  // computation (i.e., if check_sum_finalized_ is false).
174  void MaybeRecomputeCheckSum() const;
175
176  struct StrCmp {
177    bool operator()(const char *s1, const char *s2) const {
178      return strcmp(s1, s2) < 0;
179    }
180  };
181
182  string name_;
183  int64 available_key_;
184  int64 dense_key_limit_;
185  vector<const char *> symbols_;
186  map<int64, const char*> key_map_;
187  map<const char *, int64, StrCmp> symbol_map_;
188
189  mutable RefCounter ref_count_;
190  mutable bool check_sum_finalized_;
191  mutable CheckSummer check_sum_;
192  mutable CheckSummer labeled_check_sum_;
193  mutable string check_sum_string_;
194  mutable string labeled_check_sum_string_;
195  mutable Mutex check_sum_mutex_;
196};
197
198//
199// \class SymbolTable
200// \brief Symbol (string) to int and reverse mapping
201//
202// The SymbolTable implements the mappings of labels to strings and reverse.
203// SymbolTables are used to describe the alphabet of the input and output
204// labels for arcs in a Finite State Transducer.
205//
206// SymbolTables are reference counted and can therefore be shared across
207// multiple machines. For example a language model grammar G, with a
208// SymbolTable for the words in the language model can share this symbol
209// table with the lexical representation L o G.
210//
211class SymbolTable {
212 public:
213  static const int64 kNoSymbol = -1;
214
215  // Construct symbol table with a unique name.
216  SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {}
217
218  // Create a reference counted copy.
219  SymbolTable(const SymbolTable& table) : impl_(table.impl_) {
220    impl_->IncrRefCount();
221  }
222
223  // Derefence implentation object. When reference count hits 0, delete
224  // implementation.
225  virtual ~SymbolTable() {
226    if (!impl_->DecrRefCount()) delete impl_;
227  }
228
229  // Read an ascii representation of the symbol table from an istream. Pass a
230  // name to give the resulting SymbolTable.
231  static SymbolTable* ReadText(istream &strm,
232                               const string& name,
233                               bool allow_negative = false) {
234    SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm,
235                                                      name,
236                                                      allow_negative);
237    if (!impl)
238      return 0;
239    else
240      return new SymbolTable(impl);
241  }
242
243  // read an ascii representation of the symbol table
244  static SymbolTable* ReadText(const string& filename,
245                               bool allow_negative = false) {
246    ifstream strm(filename.c_str(), ifstream::in);
247    if (!strm) {
248      LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename;
249      return 0;
250    }
251    return ReadText(strm, filename, allow_negative);
252  }
253
254
255  // WARNING: Reading via symbol table read options should
256  //          not be used. This is a temporary work around.
257  static SymbolTable* Read(istream &strm,
258                           const SymbolTableReadOptions& opts) {
259    SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts);
260    if (!impl)
261      return 0;
262    else
263      return new SymbolTable(impl);
264  }
265
266  // read a binary dump of the symbol table from a stream
267  static SymbolTable* Read(istream &strm, const string& source) {
268    SymbolTableReadOptions opts;
269    opts.source = source;
270    return Read(strm, opts);
271  }
272
273  // read a binary dump of the symbol table
274  static SymbolTable* Read(const string& filename) {
275    ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
276    if (!strm) {
277      LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
278      return 0;
279    }
280    return Read(strm, filename);
281  }
282
283  //--------------------------------------------------------
284  // Derivable Interface (final)
285  //--------------------------------------------------------
286  // create a reference counted copy
287  virtual SymbolTable* Copy() const {
288    return new SymbolTable(*this);
289  }
290
291  // Add a symbol with given key to table. A symbol table also
292  // keeps track of the last available key (highest key value in
293  // the symbol table).
294  virtual int64 AddSymbol(const string& symbol, int64 key) {
295    MutateCheck();
296    return impl_->AddSymbol(symbol, key);
297  }
298
299  // Add a symbol to the table. The associated value key is automatically
300  // assigned by the symbol table.
301  virtual int64 AddSymbol(const string& symbol) {
302    MutateCheck();
303    return impl_->AddSymbol(symbol);
304  }
305
306  // Add another symbol table to this table. All key values will be offset
307  // by the current available key (highest key value in the symbol table).
308  // Note string symbols with the same key value with still have the same
309  // key value after the symbol table has been merged, but a different
310  // value. Adding symbol tables do not result in changes in the base table.
311  virtual void AddTable(const SymbolTable& table);
312
313  // return the name of the symbol table
314  virtual const string& Name() const {
315    return impl_->Name();
316  }
317
318  // Return the label-agnostic MD5 check-sum for this table.  All new symbols
319  // added to the table will result in an updated checksum.
320  // DEPRECATED.
321  virtual string CheckSum() const {
322    return impl_->CheckSum();
323  }
324
325  // Same as CheckSum(), but this returns an label-dependent version.
326  virtual string LabeledCheckSum() const {
327    return impl_->LabeledCheckSum();
328  }
329
330  virtual bool Write(ostream &strm) const {
331    return impl_->Write(strm);
332  }
333
334  bool Write(const string& filename) const {
335    ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
336    if (!strm) {
337      LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
338      return false;
339    }
340    return Write(strm);
341  }
342
343  // Dump an ascii text representation of the symbol table via a stream
344  virtual bool WriteText(ostream &strm) const;
345
346  // Dump an ascii text representation of the symbol table
347  bool WriteText(const string& filename) const {
348    ofstream strm(filename.c_str());
349    if (!strm) {
350      LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
351      return false;
352    }
353    return WriteText(strm);
354  }
355
356  // Return the string associated with the key. If the key is out of
357  // range (<0, >max), log error and return an empty string.
358  virtual string Find(int64 key) const {
359    return impl_->Find(key);
360  }
361
362  // Return the key associated with the symbol. If the symbol
363  // does not exists, log error and  return SymbolTable::kNoSymbol
364  virtual int64 Find(const string& symbol) const {
365    return impl_->Find(symbol);
366  }
367
368  // Return the key associated with the symbol. If the symbol
369  // does not exists, log error and  return SymbolTable::kNoSymbol
370  virtual int64 Find(const char* symbol) const {
371    return impl_->Find(symbol);
372  }
373
374  // Return the current available key (i.e highest key number+1) in
375  // the symbol table
376  virtual int64 AvailableKey(void) const {
377    return impl_->AvailableKey();
378  }
379
380  // Return the current number of symbols in table (not necessarily
381  // equal to AvailableKey())
382  virtual size_t NumSymbols(void) const {
383    return impl_->NumSymbols();
384  }
385
386  virtual int64 GetNthKey(ssize_t pos) const {
387    return impl_->GetNthKey(pos);
388  }
389
390 private:
391  explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {}
392
393  void MutateCheck() {
394    // Copy on write
395    if (impl_->RefCount() > 1) {
396      impl_->DecrRefCount();
397      impl_ = new SymbolTableImpl(*impl_);
398    }
399  }
400
401  const SymbolTableImpl* Impl() const {
402    return impl_;
403  }
404
405 private:
406  SymbolTableImpl* impl_;
407
408  void operator=(const SymbolTable &table);  // disallow
409};
410
411
412//
413// \class SymbolTableIterator
414// \brief Iterator class for symbols in a symbol table
415class SymbolTableIterator {
416 public:
417  SymbolTableIterator(const SymbolTable& table)
418      : table_(table),
419        pos_(0),
420        nsymbols_(table.NumSymbols()),
421        key_(table.GetNthKey(0)) { }
422
423  ~SymbolTableIterator() { }
424
425  // is iterator done
426  bool Done(void) {
427    return (pos_ == nsymbols_);
428  }
429
430  // return the Value() of the current symbol (int64 key)
431  int64 Value(void) {
432    return key_;
433  }
434
435  // return the string of the current symbol
436  string Symbol(void) {
437    return table_.Find(key_);
438  }
439
440  // advance iterator forward
441  void Next(void) {
442    ++pos_;
443    if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_);
444  }
445
446  // reset iterator
447  void Reset(void) {
448    pos_ = 0;
449    key_ = table_.GetNthKey(0);
450  }
451
452 private:
453  const SymbolTable& table_;
454  ssize_t pos_;
455  size_t nsymbols_;
456  int64 key_;
457};
458
459
460// Tests compatibilty between two sets of symbol tables
461inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
462                          bool warning = true) {
463  if (!FLAGS_fst_compat_symbols) {
464    return true;
465  } else if (!syms1 && !syms2) {
466    return true;
467  } else if (syms1 && !syms2) {
468    if (warning)
469      LOG(WARNING) <<
470          "CompatSymbols: first symbol table present but second missing";
471    return false;
472  } else if (!syms1 && syms2) {
473    if (warning)
474      LOG(WARNING) <<
475          "CompatSymbols: second symbol table present but first missing";
476    return false;
477  } else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) {
478    if (warning)
479      LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match";
480    return false;
481  } else {
482    return true;
483  }
484}
485
486
487// Relabels a symbol table as specified by the input vector of pairs
488// (old label, new label). The new symbol table only retains symbols
489// for which a relabeling is *explicitely* specified.
490// TODO(allauzen): consider adding options to allow for some form
491// of implicit identity relabeling.
492template <class Label>
493SymbolTable *RelabelSymbolTable(const SymbolTable *table,
494                                const vector<pair<Label, Label> > &pairs) {
495  SymbolTable *new_table = new SymbolTable(
496      table->Name().empty() ? string() :
497      (string("relabeled_") + table->Name()));
498
499  for (size_t i = 0; i < pairs.size(); ++i)
500    new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second);
501
502  return new_table;
503}
504
505}  // namespace fst
506
507#endif  // FST_LIB_SYMBOL_TABLE_H__
508