encode.h revision 560eaab489316778f491132c7b05a647b098d2a0
1// encode.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// Class to encode and decoder an fst.
18
19#ifndef FST_LIB_ENCODE_H__
20#define FST_LIB_ENCODE_H__
21
22#include "fst/lib/map.h"
23#include "fst/lib/rmfinalepsilon.h"
24
25namespace fst {
26
27static const uint32 kEncodeLabels = 0x00001;
28static const uint32 kEncodeWeights  = 0x00002;
29
30enum EncodeType { ENCODE = 1, DECODE = 2 };
31
32// Identifies stream data as an encode table (and its endianity)
33static const int32 kEncodeMagicNumber = 2129983209;
34
35
36// The following class encapsulates implementation details for the
37// encoding and decoding of label/weight tuples used for encoding
38// and decoding of Fsts. The EncodeTable is bidirectional. I.E it
39// stores both the Tuple of encode labels and weights to a unique
40// label, and the reverse.
41template <class A>  class EncodeTable {
42 public:
43  typedef typename A::Label Label;
44  typedef typename A::Weight Weight;
45
46  // Encoded data consists of arc input/output labels and arc weight
47  struct Tuple {
48    Tuple() {}
49    Tuple(Label ilabel_, Label olabel_, Weight weight_)
50        : ilabel(ilabel_), olabel(olabel_), weight(weight_) {}
51    Tuple(const Tuple& tuple)
52        : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {}
53
54    Label ilabel;
55    Label olabel;
56    Weight weight;
57  };
58
59  // Comparison object for hashing EncodeTable Tuple(s).
60  class TupleEqual {
61   public:
62    bool operator()(const Tuple* x, const Tuple* y) const {
63      return (x->ilabel == y->ilabel &&
64              x->olabel == y->olabel &&
65              x->weight == y->weight);
66    }
67  };
68
69  // Hash function for EncodeTabe Tuples. Based on the encode flags
70  // we either hash the labels, weights or compbination of them.
71  class TupleKey {
72    static const int kPrime = 7853;
73   public:
74    TupleKey()
75        : encode_flags_(kEncodeLabels | kEncodeWeights) {}
76
77    TupleKey(const TupleKey& key)
78        : encode_flags_(key.encode_flags_) {}
79
80    explicit TupleKey(uint32 encode_flags)
81        : encode_flags_(encode_flags) {}
82
83    size_t operator()(const Tuple* x) const {
84      int lshift = x->ilabel % kPrime;
85      int rshift = sizeof(size_t) - lshift;
86      size_t hash = x->ilabel << lshift;
87      if (encode_flags_ & kEncodeLabels) hash ^= x->olabel >> rshift;
88      if (encode_flags_ & kEncodeWeights)  hash ^= x->weight.Hash();
89      return hash;
90    }
91
92   private:
93    int32 encode_flags_;
94  };
95
96  typedef std::unordered_map<const Tuple*, Label, TupleKey, TupleEqual> EncodeHash;
97
98  explicit EncodeTable(uint32 encode_flags)
99      : flags_(encode_flags),
100        encode_hash_(1024, TupleKey(encode_flags)) {}
101
102  ~EncodeTable() {
103    for (size_t i = 0; i < encode_tuples_.size(); ++i) {
104      delete encode_tuples_[i];
105    }
106  }
107
108  // Given an arc encode either input/ouptut labels or input/costs or both
109  Label Encode(const A &arc) {
110    const Tuple tuple(arc.ilabel,
111                      flags_ & kEncodeLabels ? arc.olabel : 0,
112                      flags_ & kEncodeWeights ? arc.weight : Weight::One());
113    typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
114    if (it == encode_hash_.end()) {
115      encode_tuples_.push_back(new Tuple(tuple));
116      encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
117      return encode_tuples_.size();
118    } else {
119      return it->second;
120    }
121  }
122
123  // Given an encode arc Label decode back to input/output labels and costs
124  const Tuple* Decode(Label key) {
125    return key <= (Label)encode_tuples_.size() ? encode_tuples_[key - 1] : 0;
126  }
127
128  bool Write(ostream &strm, const string &source) const {
129    WriteType(strm, kEncodeMagicNumber);
130    WriteType(strm, flags_);
131    int64 size = encode_tuples_.size();
132    WriteType(strm, size);
133    for (size_t i = 0;  i < size; ++i) {
134      const Tuple* tuple = encode_tuples_[i];
135      WriteType(strm, tuple->ilabel);
136      WriteType(strm, tuple->olabel);
137      tuple->weight.Write(strm);
138    }
139    strm.flush();
140    if (!strm) {
141      LOG(ERROR) << "EncodeTable::Write: write failed: " << source;
142      return false;
143    }
144    return true;
145  }
146
147  bool Read(istream &strm, const string &source) {
148    encode_tuples_.clear();
149    encode_hash_.clear();
150    int32 magic_number = 0;
151    ReadType(strm, &magic_number);
152    if (magic_number != kEncodeMagicNumber) {
153      LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source;
154      return false;
155    }
156    ReadType(strm, &flags_);
157    int64 size;
158    ReadType(strm, &size);
159    if (!strm) {
160      LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
161      return false;
162    }
163    for (size_t i = 0; i < size; ++i) {
164      Tuple* tuple = new Tuple();
165      ReadType(strm, &tuple->ilabel);
166      ReadType(strm, &tuple->olabel);
167      tuple->weight.Read(strm);
168      encode_tuples_.push_back(tuple);
169      encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
170    }
171    if (!strm) {
172      LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
173      return false;
174    }
175    return true;
176  }
177
178  uint32 flags() const { return flags_; }
179 private:
180  uint32 flags_;
181  vector<Tuple*> encode_tuples_;
182  EncodeHash encode_hash_;
183
184  DISALLOW_EVIL_CONSTRUCTORS(EncodeTable);
185};
186
187
188// A mapper to encode/decode weighted transducers. Encoding of an
189// Fst is useful for performing classical determinization or minimization
190// on a weighted transducer by treating it as an unweighted acceptor over
191// encoded labels.
192//
193// The Encode mapper stores the encoding in a local hash table (EncodeTable)
194// This table is shared (and reference counted) between the encoder and
195// decoder. A decoder has read only access to the EncodeTable.
196//
197// The EncodeMapper allows on the fly encoding of the machine. As the
198// EncodeTable is generated the same table may by used to decode the machine
199// on the fly. For example in the following sequence of operations
200//
201//  Encode -> Determinize -> Decode
202//
203// we will use the encoding table generated during the encode step in the
204// decode, even though the encoding is not complete.
205//
206template <class A> class EncodeMapper {
207  typedef typename A::Weight Weight;
208  typedef typename A::Label  Label;
209 public:
210  EncodeMapper(uint32 flags, EncodeType type)
211    : ref_count_(1), flags_(flags), type_(type),
212      table_(new EncodeTable<A>(flags)) {}
213
214  EncodeMapper(const EncodeMapper& mapper)
215      : ref_count_(mapper.ref_count_ + 1),
216        flags_(mapper.flags_),
217        type_(mapper.type_),
218        table_(mapper.table_) { }
219
220  // Copy constructor but setting the type, typically to DECODE
221  EncodeMapper(const EncodeMapper& mapper, EncodeType type)
222      : ref_count_(mapper.ref_count_ + 1),
223        flags_(mapper.flags_),
224        type_(type),
225        table_(mapper.table_) { }
226
227  ~EncodeMapper() {
228    if (--ref_count_ == 0) delete table_;
229  }
230
231  A operator()(const A &arc) {
232    if (type_ == ENCODE) {  // labels and/or weights to single label
233      if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) ||
234          (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) &&
235           arc.weight == Weight::Zero())) {
236        return arc;
237      } else {
238        Label label = table_->Encode(arc);
239        return A(label,
240                 flags_ & kEncodeLabels ? label : arc.olabel,
241                 flags_ & kEncodeWeights ? Weight::One() : arc.weight,
242                 arc.nextstate);
243      }
244    } else {
245      if (arc.nextstate == kNoStateId) {
246        return arc;
247      } else {
248        const typename EncodeTable<A>::Tuple* tuple =
249          table_->Decode(arc.ilabel);
250        return A(tuple->ilabel,
251                 flags_ & kEncodeLabels ? tuple->olabel : arc.olabel,
252                 flags_ & kEncodeWeights ? tuple->weight : arc.weight,
253                 arc.nextstate);;
254      }
255    }
256  }
257
258  uint64 Properties(uint64 props) {
259    uint64 mask = kFstProperties;
260    if (flags_ & kEncodeLabels)
261      mask &= kILabelInvariantProperties & kOLabelInvariantProperties;
262    if (flags_ & kEncodeWeights)
263      mask &= kILabelInvariantProperties & kWeightInvariantProperties &
264          (type_ == ENCODE ? kAddSuperFinalProperties :
265           kRmSuperFinalProperties);
266    return props & mask;
267  }
268
269
270  MapFinalAction FinalAction() const {
271    return (type_ == ENCODE && (flags_ & kEncodeWeights)) ?
272                   MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL;
273  }
274
275  uint32 flags() const { return flags_; }
276  EncodeType type() const { return type_; }
277
278  bool Write(ostream &strm, const string& source) {
279    return table_->Write(strm, source);
280  }
281
282  bool Write(const string& filename) {
283    ofstream strm(filename.c_str());
284    if (!strm) {
285      LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
286      return false;
287    }
288    return Write(strm, filename);
289  }
290
291  static EncodeMapper<A> *Read(istream &strm,
292                               const string& source, EncodeType type) {
293    EncodeTable<A> *table = new EncodeTable<A>(0);
294    bool r = table->Read(strm, source);
295    return r ? new EncodeMapper(table->flags(), type, table) : 0;
296  }
297
298  static EncodeMapper<A> *Read(const string& filename, EncodeType type) {
299    ifstream strm(filename.c_str());
300    if (!strm) {
301      LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
302      return false;
303    }
304    return Read(strm, filename, type);
305  }
306
307 private:
308  uint32  ref_count_;
309  uint32  flags_;
310  EncodeType type_;
311  EncodeTable<A>* table_;
312
313  explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table)
314      : ref_count_(1), flags_(flags), type_(type), table_(table) {}
315  void operator=(const EncodeMapper &);  // Disallow.
316};
317
318
319// Complexity: O(nstates + narcs)
320template<class A> inline
321void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) {
322  Map(fst, mapper);
323}
324
325
326template<class A> inline
327void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) {
328  Map(fst, EncodeMapper<A>(mapper, DECODE));
329  RmFinalEpsilon(fst);
330}
331
332
333// On the fly label and/or weight encoding of input Fst
334//
335// Complexity:
336// - Constructor: O(1)
337// - Traversal: O(nstates_visited + narcs_visited), assuming constant
338//   time to visit an input state or arc.
339template <class A>
340class EncodeFst : public MapFst<A, A, EncodeMapper<A> > {
341 public:
342  typedef A Arc;
343  typedef EncodeMapper<A> C;
344
345  EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder)
346      : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {}
347
348  EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
349      : MapFst<A, A, C>(fst, encoder, MapFstOptions()) {}
350
351  EncodeFst(const EncodeFst<A> &fst)
352      : MapFst<A, A, C>(fst) {}
353
354  virtual EncodeFst<A> *Copy() const { return new EncodeFst(*this); }
355};
356
357
358// On the fly label and/or weight encoding of input Fst
359//
360// Complexity:
361// - Constructor: O(1)
362// - Traversal: O(nstates_visited + narcs_visited), assuming constant
363//   time to visit an input state or arc.
364template <class A>
365class DecodeFst : public MapFst<A, A, EncodeMapper<A> > {
366 public:
367  typedef A Arc;
368  typedef EncodeMapper<A> C;
369
370  DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
371      : MapFst<A, A, C>(fst,
372                            EncodeMapper<A>(encoder, DECODE),
373                            MapFstOptions()) {}
374
375  DecodeFst(const EncodeFst<A> &fst)
376      : MapFst<A, A, C>(fst) {}
377
378  virtual DecodeFst<A> *Copy() const { return new DecodeFst(*this); }
379};
380
381
382// Specialization for EncodeFst.
383template <class A>
384class StateIterator< EncodeFst<A> >
385    : public StateIterator< MapFst<A, A, EncodeMapper<A> > > {
386 public:
387  explicit StateIterator(const EncodeFst<A> &fst)
388      : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {}
389};
390
391
392// Specialization for EncodeFst.
393template <class A>
394class ArcIterator< EncodeFst<A> >
395    : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > {
396 public:
397  ArcIterator(const EncodeFst<A> &fst, typename A::StateId s)
398      : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {}
399};
400
401
402// Specialization for DecodeFst.
403template <class A>
404class StateIterator< DecodeFst<A> >
405    : public StateIterator< MapFst<A, A, EncodeMapper<A> > > {
406 public:
407  explicit StateIterator(const DecodeFst<A> &fst)
408      : StateIterator< MapFst<A, A, EncodeMapper<A> > >(fst) {}
409};
410
411
412// Specialization for DecodeFst.
413template <class A>
414class ArcIterator< DecodeFst<A> >
415    : public ArcIterator< MapFst<A, A, EncodeMapper<A> > > {
416 public:
417  ArcIterator(const DecodeFst<A> &fst, typename A::StateId s)
418      : ArcIterator< MapFst<A, A, EncodeMapper<A> > >(fst, s) {}
419};
420
421
422// Useful aliases when using StdArc.
423typedef EncodeFst<StdArc> StdEncodeFst;
424
425typedef DecodeFst<StdArc> StdDecodeFst;
426
427}
428
429#endif  // FST_LIB_ENCODE_H__
430