1// encode.h 2 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// Copyright 2005-2010 Google, Inc. 16// Author: johans@google.com (Johan Schalkwyk) 17// 18// \file 19// Class to encode and decoder an fst. 20 21#ifndef FST_LIB_ENCODE_H__ 22#define FST_LIB_ENCODE_H__ 23 24#include <climits> 25#include <tr1/unordered_map> 26using std::tr1::unordered_map; 27using std::tr1::unordered_multimap; 28#include <string> 29#include <vector> 30using std::vector; 31 32#include <fst/arc-map.h> 33#include <fst/rmfinalepsilon.h> 34 35 36namespace fst { 37 38static const uint32 kEncodeLabels = 0x0001; 39static const uint32 kEncodeWeights = 0x0002; 40static const uint32 kEncodeFlags = 0x0003; // All non-internal flags 41 42static const uint32 kEncodeHasISymbols = 0x0004; // For internal use 43static const uint32 kEncodeHasOSymbols = 0x0008; // For internal use 44 45enum EncodeType { ENCODE = 1, DECODE = 2 }; 46 47// Identifies stream data as an encode table (and its endianity) 48static const int32 kEncodeMagicNumber = 2129983209; 49 50 51// The following class encapsulates implementation details for the 52// encoding and decoding of label/weight tuples used for encoding 53// and decoding of Fsts. The EncodeTable is bidirectional. I.E it 54// stores both the Tuple of encode labels and weights to a unique 55// label, and the reverse. 56template <class A> class EncodeTable { 57 public: 58 typedef typename A::Label Label; 59 typedef typename A::Weight Weight; 60 61 // Encoded data consists of arc input/output labels and arc weight 62 struct Tuple { 63 Tuple() {} 64 Tuple(Label ilabel_, Label olabel_, Weight weight_) 65 : ilabel(ilabel_), olabel(olabel_), weight(weight_) {} 66 Tuple(const Tuple& tuple) 67 : ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {} 68 69 Label ilabel; 70 Label olabel; 71 Weight weight; 72 }; 73 74 // Comparison object for hashing EncodeTable Tuple(s). 75 class TupleEqual { 76 public: 77 bool operator()(const Tuple* x, const Tuple* y) const { 78 return (x->ilabel == y->ilabel && 79 x->olabel == y->olabel && 80 x->weight == y->weight); 81 } 82 }; 83 84 // Hash function for EncodeTabe Tuples. Based on the encode flags 85 // we either hash the labels, weights or combination of them. 86 class TupleKey { 87 public: 88 TupleKey() 89 : encode_flags_(kEncodeLabels | kEncodeWeights) {} 90 91 TupleKey(const TupleKey& key) 92 : encode_flags_(key.encode_flags_) {} 93 94 explicit TupleKey(uint32 encode_flags) 95 : encode_flags_(encode_flags) {} 96 97 size_t operator()(const Tuple* x) const { 98 size_t hash = x->ilabel; 99 const int lshift = 5; 100 const int rshift = CHAR_BIT * sizeof(size_t) - 5; 101 if (encode_flags_ & kEncodeLabels) 102 hash = hash << lshift ^ hash >> rshift ^ x->olabel; 103 if (encode_flags_ & kEncodeWeights) 104 hash = hash << lshift ^ hash >> rshift ^ x->weight.Hash(); 105 return hash; 106 } 107 108 private: 109 int32 encode_flags_; 110 }; 111 112 typedef unordered_map<const Tuple*, 113 Label, 114 TupleKey, 115 TupleEqual> EncodeHash; 116 117 explicit EncodeTable(uint32 encode_flags) 118 : flags_(encode_flags), 119 encode_hash_(1024, TupleKey(encode_flags)), 120 isymbols_(0), osymbols_(0) {} 121 122 ~EncodeTable() { 123 for (size_t i = 0; i < encode_tuples_.size(); ++i) { 124 delete encode_tuples_[i]; 125 } 126 delete isymbols_; 127 delete osymbols_; 128 } 129 130 // Given an arc encode either input/ouptut labels or input/costs or both 131 Label Encode(const A &arc) { 132 const Tuple tuple(arc.ilabel, 133 flags_ & kEncodeLabels ? arc.olabel : 0, 134 flags_ & kEncodeWeights ? arc.weight : Weight::One()); 135 typename EncodeHash::const_iterator it = encode_hash_.find(&tuple); 136 if (it == encode_hash_.end()) { 137 encode_tuples_.push_back(new Tuple(tuple)); 138 encode_hash_[encode_tuples_.back()] = encode_tuples_.size(); 139 return encode_tuples_.size(); 140 } else { 141 return it->second; 142 } 143 } 144 145 // Given an arc, look up its encoded label. Returns kNoLabel if not found. 146 Label GetLabel(const A &arc) const { 147 const Tuple tuple(arc.ilabel, 148 flags_ & kEncodeLabels ? arc.olabel : 0, 149 flags_ & kEncodeWeights ? arc.weight : Weight::One()); 150 typename EncodeHash::const_iterator it = encode_hash_.find(&tuple); 151 if (it == encode_hash_.end()) { 152 return kNoLabel; 153 } else { 154 return it->second; 155 } 156 } 157 158 // Given an encode arc Label decode back to input/output labels and costs 159 const Tuple* Decode(Label key) const { 160 if (key < 1 || key > encode_tuples_.size()) { 161 LOG(ERROR) << "EncodeTable::Decode: unknown decode key: " << key; 162 return 0; 163 } 164 return encode_tuples_[key - 1]; 165 } 166 167 size_t Size() const { return encode_tuples_.size(); } 168 169 bool Write(ostream &strm, const string &source) const; 170 171 static EncodeTable<A> *Read(istream &strm, const string &source); 172 173 const uint32 flags() const { return flags_ & kEncodeFlags; } 174 175 int RefCount() const { return ref_count_.count(); } 176 int IncrRefCount() { return ref_count_.Incr(); } 177 int DecrRefCount() { return ref_count_.Decr(); } 178 179 180 SymbolTable *InputSymbols() const { return isymbols_; } 181 182 SymbolTable *OutputSymbols() const { return osymbols_; } 183 184 void SetInputSymbols(const SymbolTable* syms) { 185 if (isymbols_) delete isymbols_; 186 if (syms) { 187 isymbols_ = syms->Copy(); 188 flags_ |= kEncodeHasISymbols; 189 } else { 190 isymbols_ = 0; 191 flags_ &= ~kEncodeHasISymbols; 192 } 193 } 194 195 void SetOutputSymbols(const SymbolTable* syms) { 196 if (osymbols_) delete osymbols_; 197 if (syms) { 198 osymbols_ = syms->Copy(); 199 flags_ |= kEncodeHasOSymbols; 200 } else { 201 osymbols_ = 0; 202 flags_ &= ~kEncodeHasOSymbols; 203 } 204 } 205 206 private: 207 uint32 flags_; 208 vector<Tuple*> encode_tuples_; 209 EncodeHash encode_hash_; 210 RefCounter ref_count_; 211 SymbolTable *isymbols_; // Pre-encoded ilabel symbol table 212 SymbolTable *osymbols_; // Pre-encoded olabel symbol table 213 214 DISALLOW_COPY_AND_ASSIGN(EncodeTable); 215}; 216 217template <class A> inline 218bool EncodeTable<A>::Write(ostream &strm, const string &source) const { 219 WriteType(strm, kEncodeMagicNumber); 220 WriteType(strm, flags_); 221 int64 size = encode_tuples_.size(); 222 WriteType(strm, size); 223 for (size_t i = 0; i < size; ++i) { 224 const Tuple* tuple = encode_tuples_[i]; 225 WriteType(strm, tuple->ilabel); 226 WriteType(strm, tuple->olabel); 227 tuple->weight.Write(strm); 228 } 229 230 if (flags_ & kEncodeHasISymbols) 231 isymbols_->Write(strm); 232 233 if (flags_ & kEncodeHasOSymbols) 234 osymbols_->Write(strm); 235 236 strm.flush(); 237 if (!strm) { 238 LOG(ERROR) << "EncodeTable::Write: write failed: " << source; 239 return false; 240 } 241 return true; 242} 243 244template <class A> inline 245EncodeTable<A> *EncodeTable<A>::Read(istream &strm, const string &source) { 246 int32 magic_number = 0; 247 ReadType(strm, &magic_number); 248 if (magic_number != kEncodeMagicNumber) { 249 LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source; 250 return 0; 251 } 252 uint32 flags; 253 ReadType(strm, &flags); 254 EncodeTable<A> *table = new EncodeTable<A>(flags); 255 256 int64 size; 257 ReadType(strm, &size); 258 if (!strm) { 259 LOG(ERROR) << "EncodeTable::Read: read failed: " << source; 260 return 0; 261 } 262 263 for (size_t i = 0; i < size; ++i) { 264 Tuple* tuple = new Tuple(); 265 ReadType(strm, &tuple->ilabel); 266 ReadType(strm, &tuple->olabel); 267 tuple->weight.Read(strm); 268 if (!strm) { 269 LOG(ERROR) << "EncodeTable::Read: read failed: " << source; 270 return 0; 271 } 272 table->encode_tuples_.push_back(tuple); 273 table->encode_hash_[table->encode_tuples_.back()] = 274 table->encode_tuples_.size(); 275 } 276 277 if (flags & kEncodeHasISymbols) 278 table->isymbols_ = SymbolTable::Read(strm, source); 279 280 if (flags & kEncodeHasOSymbols) 281 table->osymbols_ = SymbolTable::Read(strm, source); 282 283 return table; 284} 285 286 287// A mapper to encode/decode weighted transducers. Encoding of an 288// Fst is useful for performing classical determinization or minimization 289// on a weighted transducer by treating it as an unweighted acceptor over 290// encoded labels. 291// 292// The Encode mapper stores the encoding in a local hash table (EncodeTable) 293// This table is shared (and reference counted) between the encoder and 294// decoder. A decoder has read only access to the EncodeTable. 295// 296// The EncodeMapper allows on the fly encoding of the machine. As the 297// EncodeTable is generated the same table may by used to decode the machine 298// on the fly. For example in the following sequence of operations 299// 300// Encode -> Determinize -> Decode 301// 302// we will use the encoding table generated during the encode step in the 303// decode, even though the encoding is not complete. 304// 305template <class A> class EncodeMapper { 306 typedef typename A::Weight Weight; 307 typedef typename A::Label Label; 308 public: 309 EncodeMapper(uint32 flags, EncodeType type) 310 : flags_(flags), 311 type_(type), 312 table_(new EncodeTable<A>(flags)), 313 error_(false) {} 314 315 EncodeMapper(const EncodeMapper& mapper) 316 : flags_(mapper.flags_), 317 type_(mapper.type_), 318 table_(mapper.table_), 319 error_(false) { 320 table_->IncrRefCount(); 321 } 322 323 // Copy constructor but setting the type, typically to DECODE 324 EncodeMapper(const EncodeMapper& mapper, EncodeType type) 325 : flags_(mapper.flags_), 326 type_(type), 327 table_(mapper.table_), 328 error_(mapper.error_) { 329 table_->IncrRefCount(); 330 } 331 332 ~EncodeMapper() { 333 if (!table_->DecrRefCount()) delete table_; 334 } 335 336 A operator()(const A &arc); 337 338 MapFinalAction FinalAction() const { 339 return (type_ == ENCODE && (flags_ & kEncodeWeights)) ? 340 MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL; 341 } 342 343 MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; } 344 345 MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS;} 346 347 uint64 Properties(uint64 inprops) { 348 uint64 outprops = inprops; 349 if (error_) outprops |= kError; 350 351 uint64 mask = kFstProperties; 352 if (flags_ & kEncodeLabels) 353 mask &= kILabelInvariantProperties & kOLabelInvariantProperties; 354 if (flags_ & kEncodeWeights) 355 mask &= kILabelInvariantProperties & kWeightInvariantProperties & 356 (type_ == ENCODE ? kAddSuperFinalProperties : 357 kRmSuperFinalProperties); 358 359 return outprops & mask; 360 } 361 362 const uint32 flags() const { return flags_; } 363 const EncodeType type() const { return type_; } 364 const EncodeTable<A> &table() const { return *table_; } 365 366 bool Write(ostream &strm, const string& source) { 367 return table_->Write(strm, source); 368 } 369 370 bool Write(const string& filename) { 371 ofstream strm(filename.c_str(), ofstream::out | ofstream::binary); 372 if (!strm) { 373 LOG(ERROR) << "EncodeMap: Can't open file: " << filename; 374 return false; 375 } 376 return Write(strm, filename); 377 } 378 379 static EncodeMapper<A> *Read(istream &strm, 380 const string& source, 381 EncodeType type = ENCODE) { 382 EncodeTable<A> *table = EncodeTable<A>::Read(strm, source); 383 return table ? new EncodeMapper(table->flags(), type, table) : 0; 384 } 385 386 static EncodeMapper<A> *Read(const string& filename, 387 EncodeType type = ENCODE) { 388 ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); 389 if (!strm) { 390 LOG(ERROR) << "EncodeMap: Can't open file: " << filename; 391 return NULL; 392 } 393 return Read(strm, filename, type); 394 } 395 396 SymbolTable *InputSymbols() const { return table_->InputSymbols(); } 397 398 SymbolTable *OutputSymbols() const { return table_->OutputSymbols(); } 399 400 void SetInputSymbols(const SymbolTable* syms) { 401 table_->SetInputSymbols(syms); 402 } 403 404 void SetOutputSymbols(const SymbolTable* syms) { 405 table_->SetOutputSymbols(syms); 406 } 407 408 private: 409 uint32 flags_; 410 EncodeType type_; 411 EncodeTable<A>* table_; 412 bool error_; 413 414 explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table) 415 : flags_(flags), type_(type), table_(table) {} 416 void operator=(const EncodeMapper &); // Disallow. 417}; 418 419template <class A> inline 420A EncodeMapper<A>::operator()(const A &arc) { 421 if (type_ == ENCODE) { // labels and/or weights to single label 422 if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) || 423 (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) && 424 arc.weight == Weight::Zero())) { 425 return arc; 426 } else { 427 Label label = table_->Encode(arc); 428 return A(label, 429 flags_ & kEncodeLabels ? label : arc.olabel, 430 flags_ & kEncodeWeights ? Weight::One() : arc.weight, 431 arc.nextstate); 432 } 433 } else { // type_ == DECODE 434 if (arc.nextstate == kNoStateId) { 435 return arc; 436 } else { 437 if (arc.ilabel == 0) return arc; 438 if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel) { 439 FSTERROR() << "EncodeMapper: Label-encoded arc has different " 440 "input and output labels"; 441 error_ = true; 442 } 443 if (flags_ & kEncodeWeights && arc.weight != Weight::One()) { 444 FSTERROR() << 445 "EncodeMapper: Weight-encoded arc has non-trivial weight"; 446 error_ = true; 447 } 448 const typename EncodeTable<A>::Tuple* tuple = table_->Decode(arc.ilabel); 449 if (!tuple) { 450 FSTERROR() << "EncodeMapper: decode failed"; 451 error_ = true; 452 return A(kNoLabel, kNoLabel, Weight::NoWeight(), arc.nextstate); 453 } else { 454 return A(tuple->ilabel, 455 flags_ & kEncodeLabels ? tuple->olabel : arc.olabel, 456 flags_ & kEncodeWeights ? tuple->weight : arc.weight, 457 arc.nextstate); 458 } 459 } 460 } 461} 462 463 464// Complexity: O(nstates + narcs) 465template<class A> inline 466void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) { 467 mapper->SetInputSymbols(fst->InputSymbols()); 468 mapper->SetOutputSymbols(fst->OutputSymbols()); 469 ArcMap(fst, mapper); 470} 471 472template<class A> inline 473void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) { 474 ArcMap(fst, EncodeMapper<A>(mapper, DECODE)); 475 RmFinalEpsilon(fst); 476 fst->SetInputSymbols(mapper.InputSymbols()); 477 fst->SetOutputSymbols(mapper.OutputSymbols()); 478} 479 480 481// On the fly label and/or weight encoding of input Fst 482// 483// Complexity: 484// - Constructor: O(1) 485// - Traversal: O(nstates_visited + narcs_visited), assuming constant 486// time to visit an input state or arc. 487template <class A> 488class EncodeFst : public ArcMapFst<A, A, EncodeMapper<A> > { 489 public: 490 typedef A Arc; 491 typedef EncodeMapper<A> C; 492 typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl; 493 using ImplToFst<Impl>::GetImpl; 494 495 EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder) 496 : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) { 497 encoder->SetInputSymbols(fst.InputSymbols()); 498 encoder->SetOutputSymbols(fst.OutputSymbols()); 499 } 500 501 EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder) 502 : ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {} 503 504 // See Fst<>::Copy() for doc. 505 EncodeFst(const EncodeFst<A> &fst, bool copy = false) 506 : ArcMapFst<A, A, C>(fst, copy) {} 507 508 // Get a copy of this EncodeFst. See Fst<>::Copy() for further doc. 509 virtual EncodeFst<A> *Copy(bool safe = false) const { 510 if (safe) { 511 FSTERROR() << "EncodeFst::Copy(true): not allowed."; 512 GetImpl()->SetProperties(kError, kError); 513 } 514 return new EncodeFst(*this); 515 } 516}; 517 518 519// On the fly label and/or weight encoding of input Fst 520// 521// Complexity: 522// - Constructor: O(1) 523// - Traversal: O(nstates_visited + narcs_visited), assuming constant 524// time to visit an input state or arc. 525template <class A> 526class DecodeFst : public ArcMapFst<A, A, EncodeMapper<A> > { 527 public: 528 typedef A Arc; 529 typedef EncodeMapper<A> C; 530 typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl; 531 using ImplToFst<Impl>::GetImpl; 532 533 DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder) 534 : ArcMapFst<A, A, C>(fst, 535 EncodeMapper<A>(encoder, DECODE), 536 ArcMapFstOptions()) { 537 GetImpl()->SetInputSymbols(encoder.InputSymbols()); 538 GetImpl()->SetOutputSymbols(encoder.OutputSymbols()); 539 } 540 541 // See Fst<>::Copy() for doc. 542 DecodeFst(const DecodeFst<A> &fst, bool safe = false) 543 : ArcMapFst<A, A, C>(fst, safe) {} 544 545 // Get a copy of this DecodeFst. See Fst<>::Copy() for further doc. 546 virtual DecodeFst<A> *Copy(bool safe = false) const { 547 return new DecodeFst(*this, safe); 548 } 549}; 550 551 552// Specialization for EncodeFst. 553template <class A> 554class StateIterator< EncodeFst<A> > 555 : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > { 556 public: 557 explicit StateIterator(const EncodeFst<A> &fst) 558 : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {} 559}; 560 561 562// Specialization for EncodeFst. 563template <class A> 564class ArcIterator< EncodeFst<A> > 565 : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > { 566 public: 567 ArcIterator(const EncodeFst<A> &fst, typename A::StateId s) 568 : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {} 569}; 570 571 572// Specialization for DecodeFst. 573template <class A> 574class StateIterator< DecodeFst<A> > 575 : public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > { 576 public: 577 explicit StateIterator(const DecodeFst<A> &fst) 578 : StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {} 579}; 580 581 582// Specialization for DecodeFst. 583template <class A> 584class ArcIterator< DecodeFst<A> > 585 : public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > { 586 public: 587 ArcIterator(const DecodeFst<A> &fst, typename A::StateId s) 588 : ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {} 589}; 590 591 592// Useful aliases when using StdArc. 593typedef EncodeFst<StdArc> StdEncodeFst; 594 595typedef DecodeFst<StdArc> StdDecodeFst; 596 597} // namespace fst 598 599#endif // FST_LIB_ENCODE_H__ 600