state-map.h revision 3da1eb108d36da35333b2d655202791af854996b
1// map.h 2 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// Copyright 2005-2010 Google, Inc. 16// Author: riley@google.com (Michael Riley) 17// 18// \file 19// Class to map over/transform states e.g., sort transitions 20// Consider using when operation does not change the number of states. 21 22#ifndef FST_LIB_STATE_MAP_H__ 23#define FST_LIB_STATE_MAP_H__ 24 25#include <algorithm> 26#include <tr1/unordered_map> 27using std::tr1::unordered_map; 28using std::tr1::unordered_multimap; 29#include <string> 30#include <utility> 31using std::pair; using std::make_pair; 32 33#include <fst/cache.h> 34#include <fst/arc-map.h> 35#include <fst/mutable-fst.h> 36 37 38namespace fst { 39 40// StateMapper Interface - class determinies how states are mapped. 41// Useful for implementing operations that do not change the number of states. 42// 43// class StateMapper { 44// public: 45// typedef A FromArc; 46// typedef B ToArc; 47// 48// // Typical constructor 49// StateMapper(const Fst<A> &fst); 50// // Required copy constructor that allows updating Fst argument; 51// // pass only if relevant and changed. 52// StateMapper(const StateMapper &mapper, const Fst<A> *fst = 0); 53// 54// // Specifies initial state of result 55// B::StateId Start() const; 56// // Specifies state's final weight in result 57// B::Weight Final(B::StateId s) const; 58// 59// // These methods iterate through a state's arcs in result 60// // Specifies state to iterate over 61// void SetState(B::StateId s); 62// // End of arcs? 63// bool Done() const; 64// // Current arc 65 66// const B &Value() const; 67// // Advance to next arc (when !Done) 68// void Next(); 69// 70// // Specifies input symbol table action the mapper requires (see above). 71// MapSymbolsAction InputSymbolsAction() const; 72// // Specifies output symbol table action the mapper requires (see above). 73// MapSymbolsAction OutputSymbolsAction() const; 74// // This specifies the known properties of an Fst mapped by this 75// // mapper. It takes as argument the input Fst's known properties. 76// uint64 Properties(uint64 props) const; 77// }; 78// 79// We include a various state map versions below. One dimension of 80// variation is whether the mapping mutates its input, writes to a 81// new result Fst, or is an on-the-fly Fst. Another dimension is how 82// we pass the mapper. We allow passing the mapper by pointer 83// for cases that we need to change the state of the user's mapper. 84// We also include map versions that pass the mapper 85// by value or const reference when this suffices. 86 87// Maps an arc type A using a mapper function object C, passed 88// by pointer. This version modifies its Fst input. 89template<class A, class C> 90void StateMap(MutableFst<A> *fst, C* mapper) { 91 typedef typename A::StateId StateId; 92 typedef typename A::Weight Weight; 93 94 if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) 95 fst->SetInputSymbols(0); 96 97 if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) 98 fst->SetOutputSymbols(0); 99 100 if (fst->Start() == kNoStateId) 101 return; 102 103 uint64 props = fst->Properties(kFstProperties, false); 104 105 fst->SetStart(mapper->Start()); 106 107 for (StateId s = 0; s < fst->NumStates(); ++s) { 108 mapper->SetState(s); 109 fst->DeleteArcs(s); 110 for (; !mapper->Done(); mapper->Next()) 111 fst->AddArc(s, mapper->Value()); 112 fst->SetFinal(s, mapper->Final(s)); 113 } 114 115 fst->SetProperties(mapper->Properties(props), kFstProperties); 116} 117 118// Maps an arc type A using a mapper function object C, passed 119// by value. This version modifies its Fst input. 120template<class A, class C> 121void StateMap(MutableFst<A> *fst, C mapper) { 122 StateMap(fst, &mapper); 123} 124 125 126// Maps an arc type A to an arc type B using mapper function 127// object C, passed by pointer. This version writes the mapped 128// input Fst to an output MutableFst. 129template<class A, class B, class C> 130void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C* mapper) { 131 typedef typename A::StateId StateId; 132 typedef typename A::Weight Weight; 133 134 ofst->DeleteStates(); 135 136 if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS) 137 ofst->SetInputSymbols(ifst.InputSymbols()); 138 else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) 139 ofst->SetInputSymbols(0); 140 141 if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS) 142 ofst->SetOutputSymbols(ifst.OutputSymbols()); 143 else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) 144 ofst->SetOutputSymbols(0); 145 146 uint64 iprops = ifst.Properties(kCopyProperties, false); 147 148 if (ifst.Start() == kNoStateId) { 149 if (iprops & kError) ofst->SetProperties(kError, kError); 150 return; 151 } 152 153 // Add all states. 154 if (ifst.Properties(kExpanded, false)) 155 ofst->ReserveStates(CountStates(ifst)); 156 for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next()) 157 ofst->AddState(); 158 159 ofst->SetStart(mapper->Start()); 160 161 for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next()) { 162 StateId s = siter.Value(); 163 mapper->SetState(s); 164 for (; !mapper->Done(); mapper->Next()) 165 ofst->AddArc(s, mapper->Value()); 166 ofst->SetFinal(s, mapper->Final(s)); 167 } 168 169 uint64 oprops = ofst->Properties(kFstProperties, false); 170 ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties); 171} 172 173// Maps an arc type A to an arc type B using mapper function 174// object C, passed by value. This version writes the mapped input 175// Fst to an output MutableFst. 176template<class A, class B, class C> 177void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) { 178 StateMap(ifst, ofst, &mapper); 179} 180 181typedef CacheOptions StateMapFstOptions; 182 183template <class A, class B, class C> class StateMapFst; 184 185// Implementation of delayed StateMapFst. 186template <class A, class B, class C> 187class StateMapFstImpl : public CacheImpl<B> { 188 public: 189 using FstImpl<B>::SetType; 190 using FstImpl<B>::SetProperties; 191 using FstImpl<B>::SetInputSymbols; 192 using FstImpl<B>::SetOutputSymbols; 193 194 using VectorFstBaseImpl<typename CacheImpl<B>::State>::NumStates; 195 196 using CacheImpl<B>::PushArc; 197 using CacheImpl<B>::HasArcs; 198 using CacheImpl<B>::HasFinal; 199 using CacheImpl<B>::HasStart; 200 using CacheImpl<B>::SetArcs; 201 using CacheImpl<B>::SetFinal; 202 using CacheImpl<B>::SetStart; 203 204 friend class StateIterator< StateMapFst<A, B, C> >; 205 206 typedef B Arc; 207 typedef typename B::Weight Weight; 208 typedef typename B::StateId StateId; 209 210 StateMapFstImpl(const Fst<A> &fst, const C &mapper, 211 const StateMapFstOptions& opts) 212 : CacheImpl<B>(opts), 213 fst_(fst.Copy()), 214 mapper_(new C(mapper, fst_)), 215 own_mapper_(true) { 216 Init(); 217 } 218 219 StateMapFstImpl(const Fst<A> &fst, C *mapper, 220 const StateMapFstOptions& opts) 221 : CacheImpl<B>(opts), 222 fst_(fst.Copy()), 223 mapper_(mapper), 224 own_mapper_(false) { 225 Init(); 226 } 227 228 StateMapFstImpl(const StateMapFstImpl<A, B, C> &impl) 229 : CacheImpl<B>(impl), 230 fst_(impl.fst_->Copy(true)), 231 mapper_(new C(*impl.mapper_, fst_)), 232 own_mapper_(true) { 233 Init(); 234 } 235 236 ~StateMapFstImpl() { 237 delete fst_; 238 if (own_mapper_) delete mapper_; 239 } 240 241 StateId Start() { 242 if (!HasStart()) 243 SetStart(mapper_->Start()); 244 return CacheImpl<B>::Start(); 245 } 246 247 Weight Final(StateId s) { 248 if (!HasFinal(s)) 249 SetFinal(s, mapper_->Final(s)); 250 return CacheImpl<B>::Final(s); 251 } 252 253 size_t NumArcs(StateId s) { 254 if (!HasArcs(s)) 255 Expand(s); 256 return CacheImpl<B>::NumArcs(s); 257 } 258 259 size_t NumInputEpsilons(StateId s) { 260 if (!HasArcs(s)) 261 Expand(s); 262 return CacheImpl<B>::NumInputEpsilons(s); 263 } 264 265 size_t NumOutputEpsilons(StateId s) { 266 if (!HasArcs(s)) 267 Expand(s); 268 return CacheImpl<B>::NumOutputEpsilons(s); 269 } 270 271 void InitStateIterator(StateIteratorData<A> *data) const { 272 fst_->InitStateIterator(data); 273 } 274 275 void InitArcIterator(StateId s, ArcIteratorData<B> *data) { 276 if (!HasArcs(s)) 277 Expand(s); 278 CacheImpl<B>::InitArcIterator(s, data); 279 } 280 281 uint64 Properties() const { return Properties(kFstProperties); } 282 283 // Set error if found; return FST impl properties. 284 uint64 Properties(uint64 mask) const { 285 if ((mask & kError) && (fst_->Properties(kError, false) || 286 (mapper_->Properties(0) & kError))) 287 SetProperties(kError, kError); 288 return FstImpl<Arc>::Properties(mask); 289 } 290 291 void Expand(StateId s) { 292 // Add exiting arcs. 293 for (mapper_->SetState(s); !mapper_->Done(); mapper_->Next()) 294 PushArc(s, mapper_->Value()); 295 SetArcs(s); 296 } 297 298 const Fst<A> &GetFst() const { 299 return *fst_; 300 } 301 302 private: 303 void Init() { 304 SetType("statemap"); 305 306 if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS) 307 SetInputSymbols(fst_->InputSymbols()); 308 else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) 309 SetInputSymbols(0); 310 311 if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS) 312 SetOutputSymbols(fst_->OutputSymbols()); 313 else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) 314 SetOutputSymbols(0); 315 316 uint64 props = fst_->Properties(kCopyProperties, false); 317 SetProperties(mapper_->Properties(props)); 318 } 319 320 const Fst<A> *fst_; 321 C* mapper_; 322 bool own_mapper_; 323 324 void operator=(const StateMapFstImpl<A, B, C> &); // disallow 325}; 326 327 328// Maps an arc type A to an arc type B using Mapper function object 329// C. This version is a delayed Fst. 330template <class A, class B, class C> 331class StateMapFst : public ImplToFst< StateMapFstImpl<A, B, C> > { 332 public: 333 friend class ArcIterator< StateMapFst<A, B, C> >; 334 335 typedef B Arc; 336 typedef typename B::Weight Weight; 337 typedef typename B::StateId StateId; 338 typedef CacheState<B> State; 339 typedef StateMapFstImpl<A, B, C> Impl; 340 341 StateMapFst(const Fst<A> &fst, const C &mapper, 342 const StateMapFstOptions& opts) 343 : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {} 344 345 StateMapFst(const Fst<A> &fst, C* mapper, const StateMapFstOptions& opts) 346 : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {} 347 348 StateMapFst(const Fst<A> &fst, const C &mapper) 349 : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {} 350 351 StateMapFst(const Fst<A> &fst, C* mapper) 352 : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {} 353 354 // See Fst<>::Copy() for doc. 355 StateMapFst(const StateMapFst<A, B, C> &fst, bool safe = false) 356 : ImplToFst<Impl>(fst, safe) {} 357 358 // Get a copy of this StateMapFst. See Fst<>::Copy() for further doc. 359 virtual StateMapFst<A, B, C> *Copy(bool safe = false) const { 360 return new StateMapFst<A, B, C>(*this, safe); 361 } 362 363 virtual void InitStateIterator(StateIteratorData<A> *data) const { 364 GetImpl()->InitStateIterator(data); 365 } 366 367 virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const { 368 GetImpl()->InitArcIterator(s, data); 369 } 370 371 protected: 372 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } 373 374 private: 375 void operator=(const StateMapFst<A, B, C> &fst); // disallow 376}; 377 378 379// Specialization for StateMapFst. 380template <class A, class B, class C> 381class ArcIterator< StateMapFst<A, B, C> > 382 : public CacheArcIterator< StateMapFst<A, B, C> > { 383 public: 384 typedef typename A::StateId StateId; 385 386 ArcIterator(const StateMapFst<A, B, C> &fst, StateId s) 387 : CacheArcIterator< StateMapFst<A, B, C> >(fst.GetImpl(), s) { 388 if (!fst.GetImpl()->HasArcs(s)) 389 fst.GetImpl()->Expand(s); 390 } 391 392 private: 393 DISALLOW_COPY_AND_ASSIGN(ArcIterator); 394}; 395 396// 397// Utility Mappers 398// 399 400// Mapper that returns its input. 401template <class A> 402class IdentityStateMapper { 403 public: 404 typedef A FromArc; 405 typedef A ToArc; 406 407 typedef typename A::StateId StateId; 408 typedef typename A::Weight Weight; 409 410 explicit IdentityStateMapper(const Fst<A> &fst) : fst_(fst), aiter_(0) {} 411 412 // Allows updating Fst argument; pass only if changed. 413 IdentityStateMapper(const IdentityStateMapper<A> &mapper, 414 const Fst<A> *fst = 0) 415 : fst_(fst ? *fst : mapper.fst_), aiter_(0) {} 416 417 ~IdentityStateMapper() { delete aiter_; } 418 419 StateId Start() const { return fst_.Start(); } 420 421 Weight Final(StateId s) const { return fst_.Final(s); } 422 423 void SetState(StateId s) { 424 if (aiter_) delete aiter_; 425 aiter_ = new ArcIterator< Fst<A> >(fst_, s); 426 } 427 428 bool Done() const { return aiter_->Done(); } 429 const A &Value() const { return aiter_->Value(); } 430 void Next() { aiter_->Next(); } 431 432 MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } 433 MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;} 434 435 uint64 Properties(uint64 props) const { return props; } 436 437 private: 438 const Fst<A> &fst_; 439 ArcIterator< Fst<A> > *aiter_; 440}; 441 442template <class A> 443class ArcSumMapper { 444 public: 445 typedef A FromArc; 446 typedef A ToArc; 447 448 typedef typename A::StateId StateId; 449 typedef typename A::Weight Weight; 450 451 explicit ArcSumMapper(const Fst<A> &fst) : fst_(fst), i_(0) {} 452 453 // Allows updating Fst argument; pass only if changed. 454 ArcSumMapper(const ArcSumMapper<A> &mapper, 455 const Fst<A> *fst = 0) 456 : fst_(fst ? *fst : mapper.fst_), i_(0) {} 457 458 StateId Start() const { return fst_.Start(); } 459 Weight Final(StateId s) const { return fst_.Final(s); } 460 461 void SetState(StateId s) { 462 i_ = 0; 463 arcs_.clear(); 464 arcs_.reserve(fst_.NumArcs(s)); 465 for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next()) 466 arcs_.push_back(aiter.Value()); 467 468 // First sorts the exiting arcs by input label, output label 469 // and destination state and then sums weights of arcs with 470 // the same input label, output label, and destination state. 471 sort(arcs_.begin(), arcs_.end(), comp_); 472 size_t narcs = 0; 473 for (size_t i = 0; i < arcs_.size(); ++i) { 474 if (narcs > 0 && equal_(arcs_[i], arcs_[narcs - 1])) { 475 arcs_[narcs - 1].weight = Plus(arcs_[narcs - 1].weight, 476 arcs_[i].weight); 477 } else { 478 arcs_[narcs++] = arcs_[i]; 479 } 480 } 481 arcs_.resize(narcs); 482 } 483 484 bool Done() const { return i_ >= arcs_.size(); } 485 const A &Value() const { return arcs_[i_]; } 486 void Next() { ++i_; } 487 488 MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } 489 MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; } 490 491 uint64 Properties(uint64 props) const { 492 return props & kArcSortProperties & 493 kDeleteArcsProperties & kWeightInvariantProperties; 494 } 495 496 private: 497 struct Compare { 498 bool operator()(const A& x, const A& y) { 499 if (x.ilabel < y.ilabel) return true; 500 if (x.ilabel > y.ilabel) return false; 501 if (x.olabel < y.olabel) return true; 502 if (x.olabel > y.olabel) return false; 503 if (x.nextstate < y.nextstate) return true; 504 if (x.nextstate > y.nextstate) return false; 505 return false; 506 } 507 }; 508 509 struct Equal { 510 bool operator()(const A& x, const A& y) { 511 return (x.ilabel == y.ilabel && 512 x.olabel == y.olabel && 513 x.nextstate == y.nextstate); 514 } 515 }; 516 517 const Fst<A> &fst_; 518 Compare comp_; 519 Equal equal_; 520 vector<A> arcs_; 521 ssize_t i_; // current arc position 522 523 void operator=(const ArcSumMapper<A> &); // disallow 524}; 525 526template <class A> 527class ArcUniqueMapper { 528 public: 529 typedef A FromArc; 530 typedef A ToArc; 531 532 typedef typename A::StateId StateId; 533 typedef typename A::Weight Weight; 534 535 explicit ArcUniqueMapper(const Fst<A> &fst) : fst_(fst), i_(0) {} 536 537 // Allows updating Fst argument; pass only if changed. 538 ArcUniqueMapper(const ArcSumMapper<A> &mapper, 539 const Fst<A> *fst = 0) 540 : fst_(fst ? *fst : mapper.fst_), i_(0) {} 541 542 StateId Start() const { return fst_.Start(); } 543 Weight Final(StateId s) const { return fst_.Final(s); } 544 545 void SetState(StateId s) { 546 i_ = 0; 547 arcs_.clear(); 548 arcs_.reserve(fst_.NumArcs(s)); 549 for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next()) 550 arcs_.push_back(aiter.Value()); 551 552 // First sorts the exiting arcs by input label, output label 553 // and destination state and then uniques identical arcs 554 sort(arcs_.begin(), arcs_.end(), comp_); 555 typename vector<A>::iterator unique_end = 556 unique(arcs_.begin(), arcs_.end(), equal_); 557 arcs_.resize(unique_end - arcs_.begin()); 558 } 559 560 bool Done() const { return i_ >= arcs_.size(); } 561 const A &Value() const { return arcs_[i_]; } 562 void Next() { ++i_; } 563 564 MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } 565 MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; } 566 567 uint64 Properties(uint64 props) const { 568 return props & kArcSortProperties & kDeleteArcsProperties; 569 } 570 571 private: 572 struct Compare { 573 bool operator()(const A& x, const A& y) { 574 if (x.ilabel < y.ilabel) return true; 575 if (x.ilabel > y.ilabel) return false; 576 if (x.olabel < y.olabel) return true; 577 if (x.olabel > y.olabel) return false; 578 if (x.nextstate < y.nextstate) return true; 579 if (x.nextstate > y.nextstate) return false; 580 return false; 581 } 582 }; 583 584 struct Equal { 585 bool operator()(const A& x, const A& y) { 586 return (x.ilabel == y.ilabel && 587 x.olabel == y.olabel && 588 x.nextstate == y.nextstate && 589 x.weight == y.weight); 590 } 591 }; 592 593 const Fst<A> &fst_; 594 Compare comp_; 595 Equal equal_; 596 vector<A> arcs_; 597 ssize_t i_; // current arc position 598 599 void operator=(const ArcUniqueMapper<A> &); // disallow 600}; 601 602 603} // namespace fst 604 605#endif // FST_LIB_STATE_MAP_H__ 606