1
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13//
14// Copyright 2005-2010 Google, Inc.
15// All Rights Reserved.
16//
17// Author : Johan Schalkwyk
18//
19// \file
20// Classes to provide symbol-to-integer and integer-to-symbol mappings.
21
22#ifndef FST_LIB_SYMBOL_TABLE_H__
23#define FST_LIB_SYMBOL_TABLE_H__
24
25#include <cstring>
26#include <string>
27#include <utility>
28using std::pair; using std::make_pair;
29#include <vector>
30using std::vector;
31
32
33#include <fst/compat.h>
34#include <iostream>
35#include <fstream>
36#include <sstream>
37
38
39#include <map>
40
41DECLARE_bool(fst_compat_symbols);
42
43namespace fst {
44
45// WARNING: Reading via symbol table read options should
46//          not be used. This is a temporary work around for
47//          reading symbol ranges of previously stored symbol sets.
48struct SymbolTableReadOptions {
49  SymbolTableReadOptions() { }
50
51  SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_,
52                         const string& source_)
53      : string_hash_ranges(string_hash_ranges_),
54        source(source_) { }
55
56  vector<pair<int64, int64> > string_hash_ranges;
57  string source;
58};
59
60struct SymbolTableTextOptions {
61  SymbolTableTextOptions();
62
63  bool allow_negative;
64  string fst_field_separator;
65};
66
67class SymbolTableImpl {
68 public:
69  SymbolTableImpl(const string &name)
70      : name_(name),
71        available_key_(0),
72        dense_key_limit_(0),
73        check_sum_finalized_(false) {}
74
75  explicit SymbolTableImpl(const SymbolTableImpl& impl)
76      : name_(impl.name_),
77        available_key_(0),
78        dense_key_limit_(0),
79        check_sum_finalized_(false) {
80    for (size_t i = 0; i < impl.symbols_.size(); ++i) {
81      AddSymbol(impl.symbols_[i], impl.Find(impl.symbols_[i]));
82    }
83  }
84
85  ~SymbolTableImpl() {
86    for (size_t i = 0; i < symbols_.size(); ++i)
87      delete[] symbols_[i];
88  }
89
90  // TODO(johans): Add flag to specify whether the symbol
91  //               should be indexed as string or int or both.
92  int64 AddSymbol(const string& symbol, int64 key);
93
94  int64 AddSymbol(const string& symbol) {
95    int64 key = Find(symbol);
96    return (key == -1) ? AddSymbol(symbol, available_key_++) : key;
97  }
98
99  static SymbolTableImpl* ReadText(
100      istream &strm, const string &name,
101      const SymbolTableTextOptions &opts = SymbolTableTextOptions());
102
103  static SymbolTableImpl* Read(istream &strm,
104                               const SymbolTableReadOptions& opts);
105
106  bool Write(ostream &strm) const;
107
108  //
109  // Return the string associated with the key. If the key is out of
110  // range (<0, >max), return an empty string.
111  string Find(int64 key) const {
112    if (key >=0 && key < dense_key_limit_)
113      return string(symbols_[key]);
114
115    map<int64, const char*>::const_iterator it =
116        key_map_.find(key);
117    if (it == key_map_.end()) {
118      return "";
119    }
120    return string(it->second);
121  }
122
123  //
124  // Return the key associated with the symbol. If the symbol
125  // does not exists, return SymbolTable::kNoSymbol.
126  int64 Find(const string& symbol) const {
127    return Find(symbol.c_str());
128  }
129
130  //
131  // Return the key associated with the symbol. If the symbol
132  // does not exists, return SymbolTable::kNoSymbol.
133  int64 Find(const char* symbol) const {
134    map<const char *, int64, StrCmp>::const_iterator it =
135        symbol_map_.find(symbol);
136    if (it == symbol_map_.end()) {
137      return -1;
138    }
139    return it->second;
140  }
141
142  int64 GetNthKey(ssize_t pos) const {
143    if ((pos < 0) || (pos >= symbols_.size())) return -1;
144    else return Find(symbols_[pos]);
145  }
146
147  const string& Name() const { return name_; }
148
149  int IncrRefCount() const {
150    return ref_count_.Incr();
151  }
152  int DecrRefCount() const {
153    return ref_count_.Decr();
154  }
155  int RefCount() const {
156    return ref_count_.count();
157  }
158
159  string CheckSum() const {
160    MaybeRecomputeCheckSum();
161    return check_sum_string_;
162  }
163
164  string LabeledCheckSum() const {
165    MaybeRecomputeCheckSum();
166    return labeled_check_sum_string_;
167  }
168
169  int64 AvailableKey() const {
170    return available_key_;
171  }
172
173  size_t NumSymbols() const {
174    return symbols_.size();
175  }
176
177 private:
178  // Recomputes the checksums (both of them) if we've had changes since the last
179  // computation (i.e., if check_sum_finalized_ is false).
180  // Takes ~2.5 microseconds (dbg) or ~230 nanoseconds (opt) on a 2.67GHz Xeon
181  // if the checksum is up-to-date (requiring no recomputation).
182  void MaybeRecomputeCheckSum() const;
183
184  struct StrCmp {
185    bool operator()(const char *s1, const char *s2) const {
186      return strcmp(s1, s2) < 0;
187    }
188  };
189
190  string name_;
191  int64 available_key_;
192  int64 dense_key_limit_;
193  vector<const char *> symbols_;
194  map<int64, const char*> key_map_;
195  map<const char *, int64, StrCmp> symbol_map_;
196
197  mutable RefCounter ref_count_;
198  mutable bool check_sum_finalized_;
199  mutable string check_sum_string_;
200  mutable string labeled_check_sum_string_;
201  mutable Mutex check_sum_mutex_;
202};
203
204//
205// \class SymbolTable
206// \brief Symbol (string) to int and reverse mapping
207//
208// The SymbolTable implements the mappings of labels to strings and reverse.
209// SymbolTables are used to describe the alphabet of the input and output
210// labels for arcs in a Finite State Transducer.
211//
212// SymbolTables are reference counted and can therefore be shared across
213// multiple machines. For example a language model grammar G, with a
214// SymbolTable for the words in the language model can share this symbol
215// table with the lexical representation L o G.
216//
217class SymbolTable {
218 public:
219  static const int64 kNoSymbol = -1;
220
221  // Construct symbol table with an unspecified name.
222  SymbolTable() : impl_(new SymbolTableImpl("<unspecified>")) {}
223
224  // Construct symbol table with a unique name.
225  SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {}
226
227  // Create a reference counted copy.
228  SymbolTable(const SymbolTable& table) : impl_(table.impl_) {
229    impl_->IncrRefCount();
230  }
231
232  // Derefence implentation object. When reference count hits 0, delete
233  // implementation.
234  virtual ~SymbolTable() {
235    if (!impl_->DecrRefCount()) delete impl_;
236  }
237
238  // Copys the implemenation from one symbol table to another.
239  void operator=(const SymbolTable &st) {
240    if (impl_ != st.impl_) {
241      st.impl_->IncrRefCount();
242      if (!impl_->DecrRefCount()) delete impl_;
243      impl_ = st.impl_;
244    }
245  }
246
247  // Read an ascii representation of the symbol table from an istream. Pass a
248  // name to give the resulting SymbolTable.
249  static SymbolTable* ReadText(
250      istream &strm, const string& name,
251      const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
252    SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm, name, opts);
253    if (!impl)
254      return 0;
255    else
256      return new SymbolTable(impl);
257  }
258
259  // read an ascii representation of the symbol table
260  static SymbolTable* ReadText(const string& filename,
261      const SymbolTableTextOptions &opts = SymbolTableTextOptions()) {
262    ifstream strm(filename.c_str(), ifstream::in);
263    if (!strm) {
264      LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename;
265      return 0;
266    }
267    return ReadText(strm, filename, opts);
268  }
269
270
271  // WARNING: Reading via symbol table read options should
272  //          not be used. This is a temporary work around.
273  static SymbolTable* Read(istream &strm,
274                           const SymbolTableReadOptions& opts) {
275    SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts);
276    if (!impl)
277      return 0;
278    else
279      return new SymbolTable(impl);
280  }
281
282  // read a binary dump of the symbol table from a stream
283  static SymbolTable* Read(istream &strm, const string& source) {
284    SymbolTableReadOptions opts;
285    opts.source = source;
286    return Read(strm, opts);
287  }
288
289  // read a binary dump of the symbol table
290  static SymbolTable* Read(const string& filename) {
291    ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
292    if (!strm) {
293      LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
294      return 0;
295    }
296    return Read(strm, filename);
297  }
298
299  //--------------------------------------------------------
300  // Derivable Interface (final)
301  //--------------------------------------------------------
302  // create a reference counted copy
303  virtual SymbolTable* Copy() const {
304    return new SymbolTable(*this);
305  }
306
307  // Add a symbol with given key to table. A symbol table also
308  // keeps track of the last available key (highest key value in
309  // the symbol table).
310  virtual int64 AddSymbol(const string& symbol, int64 key) {
311    MutateCheck();
312    return impl_->AddSymbol(symbol, key);
313  }
314
315  // Add a symbol to the table. The associated value key is automatically
316  // assigned by the symbol table.
317  virtual int64 AddSymbol(const string& symbol) {
318    MutateCheck();
319    return impl_->AddSymbol(symbol);
320  }
321
322  // Add another symbol table to this table. All key values will be offset
323  // by the current available key (highest key value in the symbol table).
324  // Note string symbols with the same key value with still have the same
325  // key value after the symbol table has been merged, but a different
326  // value. Adding symbol tables do not result in changes in the base table.
327  virtual void AddTable(const SymbolTable& table);
328
329  // return the name of the symbol table
330  virtual const string& Name() const {
331    return impl_->Name();
332  }
333
334  // Return the label-agnostic MD5 check-sum for this table.  All new symbols
335  // added to the table will result in an updated checksum.
336  // DEPRECATED.
337  virtual string CheckSum() const {
338    return impl_->CheckSum();
339  }
340
341  // Same as CheckSum(), but this returns an label-dependent version.
342  virtual string LabeledCheckSum() const {
343    return impl_->LabeledCheckSum();
344  }
345
346  virtual bool Write(ostream &strm) const {
347    return impl_->Write(strm);
348  }
349
350  bool Write(const string& filename) const {
351    ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
352    if (!strm) {
353      LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
354      return false;
355    }
356    return Write(strm);
357  }
358
359  // Dump an ascii text representation of the symbol table via a stream
360  virtual bool WriteText(
361      ostream &strm,
362      const SymbolTableTextOptions &opts = SymbolTableTextOptions()) const;
363
364  // Dump an ascii text representation of the symbol table
365  bool WriteText(const string& filename) const {
366    ofstream strm(filename.c_str());
367    if (!strm) {
368      LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
369      return false;
370    }
371    return WriteText(strm);
372  }
373
374  // Return the string associated with the key. If the key is out of
375  // range (<0, >max), log error and return an empty string.
376  virtual string Find(int64 key) const {
377    return impl_->Find(key);
378  }
379
380  // Return the key associated with the symbol. If the symbol
381  // does not exists, log error and  return SymbolTable::kNoSymbol
382  virtual int64 Find(const string& symbol) const {
383    return impl_->Find(symbol);
384  }
385
386  // Return the key associated with the symbol. If the symbol
387  // does not exists, log error and  return SymbolTable::kNoSymbol
388  virtual int64 Find(const char* symbol) const {
389    return impl_->Find(symbol);
390  }
391
392  // Return the current available key (i.e highest key number+1) in
393  // the symbol table
394  virtual int64 AvailableKey(void) const {
395    return impl_->AvailableKey();
396  }
397
398  // Return the current number of symbols in table (not necessarily
399  // equal to AvailableKey())
400  virtual size_t NumSymbols(void) const {
401    return impl_->NumSymbols();
402  }
403
404  virtual int64 GetNthKey(ssize_t pos) const {
405    return impl_->GetNthKey(pos);
406  }
407
408 private:
409  explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {}
410
411  void MutateCheck() {
412    // Copy on write
413    if (impl_->RefCount() > 1) {
414      impl_->DecrRefCount();
415      impl_ = new SymbolTableImpl(*impl_);
416    }
417  }
418
419  const SymbolTableImpl* Impl() const {
420    return impl_;
421  }
422
423 private:
424  SymbolTableImpl* impl_;
425};
426
427
428//
429// \class SymbolTableIterator
430// \brief Iterator class for symbols in a symbol table
431class SymbolTableIterator {
432 public:
433  SymbolTableIterator(const SymbolTable& table)
434      : table_(table),
435        pos_(0),
436        nsymbols_(table.NumSymbols()),
437        key_(table.GetNthKey(0)) { }
438
439  ~SymbolTableIterator() { }
440
441  // is iterator done
442  bool Done(void) {
443    return (pos_ == nsymbols_);
444  }
445
446  // return the Value() of the current symbol (int64 key)
447  int64 Value(void) {
448    return key_;
449  }
450
451  // return the string of the current symbol
452  string Symbol(void) {
453    return table_.Find(key_);
454  }
455
456  // advance iterator forward
457  void Next(void) {
458    ++pos_;
459    if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_);
460  }
461
462  // reset iterator
463  void Reset(void) {
464    pos_ = 0;
465    key_ = table_.GetNthKey(0);
466  }
467
468 private:
469  const SymbolTable& table_;
470  ssize_t pos_;
471  size_t nsymbols_;
472  int64 key_;
473};
474
475
476// Tests compatibilty between two sets of symbol tables
477inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
478                          bool warning = true) {
479  if (!FLAGS_fst_compat_symbols) {
480    return true;
481  } else if (!syms1 && !syms2) {
482    return true;
483  } else if (syms1 && !syms2) {
484    if (warning)
485      LOG(WARNING) <<
486          "CompatSymbols: first symbol table present but second missing";
487    return false;
488  } else if (!syms1 && syms2) {
489    if (warning)
490      LOG(WARNING) <<
491          "CompatSymbols: second symbol table present but first missing";
492    return false;
493  } else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) {
494    if (warning)
495      LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match";
496    return false;
497  } else {
498    return true;
499  }
500}
501
502
503// Relabels a symbol table as specified by the input vector of pairs
504// (old label, new label). The new symbol table only retains symbols
505// for which a relabeling is *explicitely* specified.
506// TODO(allauzen): consider adding options to allow for some form
507// of implicit identity relabeling.
508template <class Label>
509SymbolTable *RelabelSymbolTable(const SymbolTable *table,
510                                const vector<pair<Label, Label> > &pairs) {
511  SymbolTable *new_table = new SymbolTable(
512      table->Name().empty() ? string() :
513      (string("relabeled_") + table->Name()));
514
515  for (size_t i = 0; i < pairs.size(); ++i)
516    new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second);
517
518  return new_table;
519}
520
521// Symbol Table Serialization
522inline void SymbolTableToString(const SymbolTable *table, string *result) {
523  ostringstream ostrm;
524  table->Write(ostrm);
525  *result = ostrm.str();
526}
527
528inline SymbolTable *StringToSymbolTable(const string &s) {
529  istringstream istrm(s);
530  return SymbolTable::Read(istrm, SymbolTableReadOptions());
531}
532
533
534
535}  // namespace fst
536
537#endif  // FST_LIB_SYMBOL_TABLE_H__
538