1// encode.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: johans@google.com (Johan Schalkwyk)
17//
18// \file
19// Class to encode and decoder an fst.
20
21#ifndef FST_LIB_ENCODE_H__
22#define FST_LIB_ENCODE_H__
23
24#include <climits>
25#include <tr1/unordered_map>
26using std::tr1::unordered_map;
27using std::tr1::unordered_multimap;
28#include <string>
29#include <vector>
30using std::vector;
31
32#include <fst/arc-map.h>
33#include <fst/rmfinalepsilon.h>
34
35
36namespace fst {
37
38static const uint32 kEncodeLabels      = 0x0001;
39static const uint32 kEncodeWeights     = 0x0002;
40static const uint32 kEncodeFlags       = 0x0003;  // All non-internal flags
41
42static const uint32 kEncodeHasISymbols = 0x0004;  // For internal use
43static const uint32 kEncodeHasOSymbols = 0x0008;  // For internal use
44
45enum EncodeType { ENCODE = 1, DECODE = 2 };
46
47// Identifies stream data as an encode table (and its endianity)
48static const int32 kEncodeMagicNumber = 2129983209;
49
50
51// The following class encapsulates implementation details for the
52// encoding and decoding of label/weight tuples used for encoding
53// and decoding of Fsts. The EncodeTable is bidirectional. I.E it
54// stores both the Tuple of encode labels and weights to a unique
55// label, and the reverse.
56template <class A>  class EncodeTable {
57 public:
58  typedef typename A::Label Label;
59  typedef typename A::Weight Weight;
60
61  // Encoded data consists of arc input/output labels and arc weight
62  struct Tuple {
63    Tuple() {}
64    Tuple(Label ilabel_, Label olabel_, Weight weight_)
65        : ilabel(ilabel_), olabel(olabel_), weight(weight_) {}
66    Tuple(const Tuple& tuple)
67        : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {}
68
69    Label ilabel;
70    Label olabel;
71    Weight weight;
72  };
73
74  // Comparison object for hashing EncodeTable Tuple(s).
75  class TupleEqual {
76   public:
77    bool operator()(const Tuple* x, const Tuple* y) const {
78      return (x->ilabel == y->ilabel &&
79              x->olabel == y->olabel &&
80              x->weight == y->weight);
81    }
82  };
83
84  // Hash function for EncodeTabe Tuples. Based on the encode flags
85  // we either hash the labels, weights or combination of them.
86  class TupleKey {
87   public:
88    TupleKey()
89        : encode_flags_(kEncodeLabels | kEncodeWeights) {}
90
91    TupleKey(const TupleKey& key)
92        : encode_flags_(key.encode_flags_) {}
93
94    explicit TupleKey(uint32 encode_flags)
95        : encode_flags_(encode_flags) {}
96
97    size_t operator()(const Tuple* x) const {
98      size_t hash = x->ilabel;
99      const int lshift = 5;
100      const int rshift = CHAR_BIT * sizeof(size_t) - 5;
101      if (encode_flags_ & kEncodeLabels)
102        hash = hash << lshift ^ hash >> rshift ^ x->olabel;
103      if (encode_flags_ & kEncodeWeights)
104        hash = hash << lshift ^ hash >> rshift ^ x->weight.Hash();
105      return hash;
106    }
107
108   private:
109    int32 encode_flags_;
110  };
111
112  typedef unordered_map<const Tuple*,
113                   Label,
114                   TupleKey,
115                   TupleEqual> EncodeHash;
116
117  explicit EncodeTable(uint32 encode_flags)
118      : flags_(encode_flags),
119        encode_hash_(1024, TupleKey(encode_flags)),
120        isymbols_(0), osymbols_(0) {}
121
122  ~EncodeTable() {
123    for (size_t i = 0; i < encode_tuples_.size(); ++i) {
124      delete encode_tuples_[i];
125    }
126    delete isymbols_;
127    delete osymbols_;
128  }
129
130  // Given an arc encode either input/ouptut labels or input/costs or both
131  Label Encode(const A &arc) {
132    const Tuple tuple(arc.ilabel,
133                      flags_ & kEncodeLabels ? arc.olabel : 0,
134                      flags_ & kEncodeWeights ? arc.weight : Weight::One());
135    typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
136    if (it == encode_hash_.end()) {
137      encode_tuples_.push_back(new Tuple(tuple));
138      encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
139      return encode_tuples_.size();
140    } else {
141      return it->second;
142    }
143  }
144
145  // Given an arc, look up its encoded label. Returns kNoLabel if not found.
146  Label GetLabel(const A &arc) const {
147    const Tuple tuple(arc.ilabel,
148                      flags_ & kEncodeLabels ? arc.olabel : 0,
149                      flags_ & kEncodeWeights ? arc.weight : Weight::One());
150    typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
151    if (it == encode_hash_.end()) {
152      return kNoLabel;
153    } else {
154      return it->second;
155    }
156  }
157
158  // Given an encode arc Label decode back to input/output labels and costs
159  const Tuple* Decode(Label key) const {
160    if (key < 1 || key > encode_tuples_.size()) {
161      LOG(ERROR) << "EncodeTable::Decode: unknown decode key: " << key;
162      return 0;
163    }
164    return encode_tuples_[key - 1];
165  }
166
167  size_t Size() const { return encode_tuples_.size(); }
168
169  bool Write(ostream &strm, const string &source) const;
170
171  static EncodeTable<A> *Read(istream &strm, const string &source);
172
173  const uint32 flags() const { return flags_ & kEncodeFlags; }
174
175  int RefCount() const { return ref_count_.count(); }
176  int IncrRefCount() { return ref_count_.Incr(); }
177  int DecrRefCount() { return ref_count_.Decr(); }
178
179
180  SymbolTable *InputSymbols() const { return isymbols_; }
181
182  SymbolTable *OutputSymbols() const { return osymbols_; }
183
184  void SetInputSymbols(const SymbolTable* syms) {
185    if (isymbols_) delete isymbols_;
186    if (syms) {
187      isymbols_ = syms->Copy();
188      flags_ |= kEncodeHasISymbols;
189    } else {
190      isymbols_ = 0;
191      flags_ &= ~kEncodeHasISymbols;
192    }
193  }
194
195  void SetOutputSymbols(const SymbolTable* syms) {
196    if (osymbols_) delete osymbols_;
197    if (syms) {
198      osymbols_ = syms->Copy();
199      flags_ |= kEncodeHasOSymbols;
200    } else {
201      osymbols_ = 0;
202      flags_ &= ~kEncodeHasOSymbols;
203    }
204  }
205
206 private:
207  uint32 flags_;
208  vector<Tuple*> encode_tuples_;
209  EncodeHash encode_hash_;
210  RefCounter ref_count_;
211  SymbolTable *isymbols_;       // Pre-encoded ilabel symbol table
212  SymbolTable *osymbols_;       // Pre-encoded olabel symbol table
213
214  DISALLOW_COPY_AND_ASSIGN(EncodeTable);
215};
216
217template <class A> inline
218bool EncodeTable<A>::Write(ostream &strm, const string &source) const {
219  WriteType(strm, kEncodeMagicNumber);
220  WriteType(strm, flags_);
221  int64 size = encode_tuples_.size();
222  WriteType(strm, size);
223  for (size_t i = 0;  i < size; ++i) {
224    const Tuple* tuple = encode_tuples_[i];
225    WriteType(strm, tuple->ilabel);
226    WriteType(strm, tuple->olabel);
227    tuple->weight.Write(strm);
228  }
229
230  if (flags_ & kEncodeHasISymbols)
231    isymbols_->Write(strm);
232
233  if (flags_ & kEncodeHasOSymbols)
234    osymbols_->Write(strm);
235
236  strm.flush();
237  if (!strm) {
238    LOG(ERROR) << "EncodeTable::Write: write failed: " << source;
239    return false;
240  }
241  return true;
242}
243
244template <class A> inline
245EncodeTable<A> *EncodeTable<A>::Read(istream &strm, const string &source) {
246  int32 magic_number = 0;
247  ReadType(strm, &magic_number);
248  if (magic_number != kEncodeMagicNumber) {
249    LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source;
250    return 0;
251  }
252  uint32 flags;
253  ReadType(strm, &flags);
254  EncodeTable<A> *table = new EncodeTable<A>(flags);
255
256  int64 size;
257  ReadType(strm, &size);
258  if (!strm) {
259    LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
260    return 0;
261  }
262
263  for (size_t i = 0; i < size; ++i) {
264    Tuple* tuple = new Tuple();
265    ReadType(strm, &tuple->ilabel);
266    ReadType(strm, &tuple->olabel);
267    tuple->weight.Read(strm);
268    if (!strm) {
269      LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
270      return 0;
271    }
272    table->encode_tuples_.push_back(tuple);
273    table->encode_hash_[table->encode_tuples_.back()] =
274        table->encode_tuples_.size();
275  }
276
277  if (flags & kEncodeHasISymbols)
278    table->isymbols_ = SymbolTable::Read(strm, source);
279
280  if (flags & kEncodeHasOSymbols)
281    table->osymbols_ = SymbolTable::Read(strm, source);
282
283  return table;
284}
285
286
287// A mapper to encode/decode weighted transducers. Encoding of an
288// Fst is useful for performing classical determinization or minimization
289// on a weighted transducer by treating it as an unweighted acceptor over
290// encoded labels.
291//
292// The Encode mapper stores the encoding in a local hash table (EncodeTable)
293// This table is shared (and reference counted) between the encoder and
294// decoder. A decoder has read only access to the EncodeTable.
295//
296// The EncodeMapper allows on the fly encoding of the machine. As the
297// EncodeTable is generated the same table may by used to decode the machine
298// on the fly. For example in the following sequence of operations
299//
300//  Encode -> Determinize -> Decode
301//
302// we will use the encoding table generated during the encode step in the
303// decode, even though the encoding is not complete.
304//
305template <class A> class EncodeMapper {
306  typedef typename A::Weight Weight;
307  typedef typename A::Label  Label;
308 public:
309  EncodeMapper(uint32 flags, EncodeType type)
310    : flags_(flags),
311      type_(type),
312      table_(new EncodeTable<A>(flags)),
313      error_(false) {}
314
315  EncodeMapper(const EncodeMapper& mapper)
316      : flags_(mapper.flags_),
317        type_(mapper.type_),
318        table_(mapper.table_),
319        error_(false) {
320    table_->IncrRefCount();
321  }
322
323  // Copy constructor but setting the type, typically to DECODE
324  EncodeMapper(const EncodeMapper& mapper, EncodeType type)
325      : flags_(mapper.flags_),
326        type_(type),
327        table_(mapper.table_),
328        error_(mapper.error_) {
329    table_->IncrRefCount();
330  }
331
332  ~EncodeMapper() {
333    if (!table_->DecrRefCount()) delete table_;
334  }
335
336  A operator()(const A &arc);
337
338  MapFinalAction FinalAction() const {
339    return (type_ == ENCODE && (flags_ & kEncodeWeights)) ?
340                   MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL;
341  }
342
343  MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
344
345  MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS;}
346
347  uint64 Properties(uint64 inprops) {
348    uint64 outprops = inprops;
349    if (error_) outprops |= kError;
350
351    uint64 mask = kFstProperties;
352    if (flags_ & kEncodeLabels)
353      mask &= kILabelInvariantProperties & kOLabelInvariantProperties;
354    if (flags_ & kEncodeWeights)
355      mask &= kILabelInvariantProperties & kWeightInvariantProperties &
356          (type_ == ENCODE ? kAddSuperFinalProperties :
357           kRmSuperFinalProperties);
358
359    return outprops & mask;
360  }
361
362  const uint32 flags() const { return flags_; }
363  const EncodeType type() const { return type_; }
364  const EncodeTable<A> &table() const { return *table_; }
365
366  bool Write(ostream &strm, const string& source) {
367    return table_->Write(strm, source);
368  }
369
370  bool Write(const string& filename) {
371    ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
372    if (!strm) {
373      LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
374      return false;
375    }
376    return Write(strm, filename);
377  }
378
379  static EncodeMapper<A> *Read(istream &strm,
380                               const string& source,
381                               EncodeType type = ENCODE) {
382    EncodeTable<A> *table = EncodeTable<A>::Read(strm, source);
383    return table ? new EncodeMapper(table->flags(), type, table) : 0;
384  }
385
386  static EncodeMapper<A> *Read(const string& filename,
387                               EncodeType type = ENCODE) {
388    ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
389    if (!strm) {
390      LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
391      return NULL;
392    }
393    return Read(strm, filename, type);
394  }
395
396  SymbolTable *InputSymbols() const { return table_->InputSymbols(); }
397
398  SymbolTable *OutputSymbols() const { return table_->OutputSymbols(); }
399
400  void SetInputSymbols(const SymbolTable* syms) {
401    table_->SetInputSymbols(syms);
402  }
403
404  void SetOutputSymbols(const SymbolTable* syms) {
405    table_->SetOutputSymbols(syms);
406  }
407
408 private:
409  uint32 flags_;
410  EncodeType type_;
411  EncodeTable<A>* table_;
412  bool error_;
413
414  explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table)
415      : flags_(flags), type_(type), table_(table) {}
416  void operator=(const EncodeMapper &);  // Disallow.
417};
418
419template <class A> inline
420A EncodeMapper<A>::operator()(const A &arc) {
421  if (type_ == ENCODE) {  // labels and/or weights to single label
422    if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) ||
423        (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) &&
424         arc.weight == Weight::Zero())) {
425      return arc;
426    } else {
427      Label label = table_->Encode(arc);
428      return A(label,
429               flags_ & kEncodeLabels ? label : arc.olabel,
430               flags_ & kEncodeWeights ? Weight::One() : arc.weight,
431               arc.nextstate);
432    }
433  } else {  // type_ == DECODE
434    if (arc.nextstate == kNoStateId) {
435      return arc;
436    } else {
437      if (arc.ilabel == 0) return arc;
438      if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel) {
439        FSTERROR() << "EncodeMapper: Label-encoded arc has different "
440            "input and output labels";
441        error_ = true;
442      }
443      if (flags_ & kEncodeWeights && arc.weight != Weight::One()) {
444        FSTERROR() <<
445            "EncodeMapper: Weight-encoded arc has non-trivial weight";
446        error_ = true;
447      }
448      const typename EncodeTable<A>::Tuple* tuple = table_->Decode(arc.ilabel);
449      if (!tuple) {
450        FSTERROR() << "EncodeMapper: decode failed";
451        error_ = true;
452        return A(kNoLabel, kNoLabel, Weight::NoWeight(), arc.nextstate);
453      } else {
454        return A(tuple->ilabel,
455                 flags_ & kEncodeLabels ? tuple->olabel : arc.olabel,
456                 flags_ & kEncodeWeights ? tuple->weight : arc.weight,
457                 arc.nextstate);
458      }
459    }
460  }
461}
462
463
464// Complexity: O(nstates + narcs)
465template<class A> inline
466void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) {
467  mapper->SetInputSymbols(fst->InputSymbols());
468  mapper->SetOutputSymbols(fst->OutputSymbols());
469  ArcMap(fst, mapper);
470}
471
472template<class A> inline
473void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) {
474  ArcMap(fst, EncodeMapper<A>(mapper, DECODE));
475  RmFinalEpsilon(fst);
476  fst->SetInputSymbols(mapper.InputSymbols());
477  fst->SetOutputSymbols(mapper.OutputSymbols());
478}
479
480
481// On the fly label and/or weight encoding of input Fst
482//
483// Complexity:
484// - Constructor: O(1)
485// - Traversal: O(nstates_visited + narcs_visited), assuming constant
486//   time to visit an input state or arc.
487template <class A>
488class EncodeFst : public ArcMapFst<A, A, EncodeMapper<A> > {
489 public:
490  typedef A Arc;
491  typedef EncodeMapper<A> C;
492  typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl;
493  using ImplToFst<Impl>::GetImpl;
494
495  EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder)
496      : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {
497    encoder->SetInputSymbols(fst.InputSymbols());
498    encoder->SetOutputSymbols(fst.OutputSymbols());
499  }
500
501  EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
502      : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {}
503
504  // See Fst<>::Copy() for doc.
505  EncodeFst(const EncodeFst<A> &fst, bool copy = false)
506      : ArcMapFst<A, A, C>(fst, copy) {}
507
508  // Get a copy of this EncodeFst. See Fst<>::Copy() for further doc.
509  virtual EncodeFst<A> *Copy(bool safe = false) const {
510    if (safe) {
511      FSTERROR() << "EncodeFst::Copy(true): not allowed.";
512      GetImpl()->SetProperties(kError, kError);
513    }
514    return new EncodeFst(*this);
515  }
516};
517
518
519// On the fly label and/or weight encoding of input Fst
520//
521// Complexity:
522// - Constructor: O(1)
523// - Traversal: O(nstates_visited + narcs_visited), assuming constant
524//   time to visit an input state or arc.
525template <class A>
526class DecodeFst : public ArcMapFst<A, A, EncodeMapper<A> > {
527 public:
528  typedef A Arc;
529  typedef EncodeMapper<A> C;
530  typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl;
531  using ImplToFst<Impl>::GetImpl;
532
533  DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
534      : ArcMapFst<A, A, C>(fst,
535                            EncodeMapper<A>(encoder, DECODE),
536                            ArcMapFstOptions()) {
537    GetImpl()->SetInputSymbols(encoder.InputSymbols());
538    GetImpl()->SetOutputSymbols(encoder.OutputSymbols());
539  }
540
541  // See Fst<>::Copy() for doc.
542  DecodeFst(const DecodeFst<A> &fst, bool safe = false)
543      : ArcMapFst<A, A, C>(fst, safe) {}
544
545  // Get a copy of this DecodeFst. See Fst<>::Copy() for further doc.
546  virtual DecodeFst<A> *Copy(bool safe = false) const {
547    return new DecodeFst(*this, safe);
548  }
549};
550
551
552// Specialization for EncodeFst.
553template <class A>
554class StateIterator< EncodeFst<A> >
555    : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
556 public:
557  explicit StateIterator(const EncodeFst<A> &fst)
558      : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {}
559};
560
561
562// Specialization for EncodeFst.
563template <class A>
564class ArcIterator< EncodeFst<A> >
565    : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
566 public:
567  ArcIterator(const EncodeFst<A> &fst, typename A::StateId s)
568      : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {}
569};
570
571
572// Specialization for DecodeFst.
573template <class A>
574class StateIterator< DecodeFst<A> >
575    : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
576 public:
577  explicit StateIterator(const DecodeFst<A> &fst)
578      : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {}
579};
580
581
582// Specialization for DecodeFst.
583template <class A>
584class ArcIterator< DecodeFst<A> >
585    : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
586 public:
587  ArcIterator(const DecodeFst<A> &fst, typename A::StateId s)
588      : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {}
589};
590
591
592// Useful aliases when using StdArc.
593typedef EncodeFst<StdArc> StdEncodeFst;
594
595typedef DecodeFst<StdArc> StdDecodeFst;
596
597}  // namespace fst
598
599#endif  // FST_LIB_ENCODE_H__
600