1// queue.h 2 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// Copyright 2005-2010 Google, Inc. 16// Author: allauzen@google.com (Cyril Allauzen) 17// 18// \file 19// Functions and classes for various Fst state queues with 20// a unified interface. 21 22#ifndef FST_LIB_QUEUE_H__ 23#define FST_LIB_QUEUE_H__ 24 25#include <deque> 26#include <vector> 27using std::vector; 28 29#include <fst/arcfilter.h> 30#include <fst/connect.h> 31#include <fst/heap.h> 32#include <fst/topsort.h> 33 34 35namespace fst { 36 37// template <class S> 38// class Queue { 39// public: 40// typedef typename S StateId; 41// 42// // Ctr: may need args (e.g., Fst, comparator) for some queues 43// Queue(...); 44// // Returns the head of the queue 45// StateId Head() const; 46// // Inserts a state 47// void Enqueue(StateId s); 48// // Removes the head of the queue 49// void Dequeue(); 50// // Updates ordering of state s when weight changes, if necessary 51// void Update(StateId s); 52// // Does the queue contain no elements? 53// bool Empty() const; 54// // Remove all states from queue 55// void Clear(); 56// }; 57 58// State queue types. 59enum QueueType { 60 TRIVIAL_QUEUE = 0, // Single state queue 61 FIFO_QUEUE = 1, // First-in, first-out queue 62 LIFO_QUEUE = 2, // Last-in, first-out queue 63 SHORTEST_FIRST_QUEUE = 3, // Shortest-first queue 64 TOP_ORDER_QUEUE = 4, // Topologically-ordered queue 65 STATE_ORDER_QUEUE = 5, // State-ID ordered queue 66 SCC_QUEUE = 6, // Component graph top-ordered meta-queue 67 AUTO_QUEUE = 7, // Auto-selected queue 68 OTHER_QUEUE = 8 69 }; 70 71 72// QueueBase, templated on the StateId, is the base class shared by the 73// queues considered by AutoQueue. 74template <class S> 75class QueueBase { 76 public: 77 typedef S StateId; 78 79 QueueBase(QueueType type) : queue_type_(type), error_(false) {} 80 virtual ~QueueBase() {} 81 StateId Head() const { return Head_(); } 82 void Enqueue(StateId s) { Enqueue_(s); } 83 void Dequeue() { Dequeue_(); } 84 void Update(StateId s) { Update_(s); } 85 bool Empty() const { return Empty_(); } 86 void Clear() { Clear_(); } 87 QueueType Type() { return queue_type_; } 88 bool Error() const { return error_; } 89 void SetError(bool error) { error_ = error; } 90 91 private: 92 // This allows base-class virtual access to non-virtual derived- 93 // class members of the same name. It makes the derived class more 94 // efficient to use but unsafe to further derive. 95 virtual StateId Head_() const = 0; 96 virtual void Enqueue_(StateId s) = 0; 97 virtual void Dequeue_() = 0; 98 virtual void Update_(StateId s) = 0; 99 virtual bool Empty_() const = 0; 100 virtual void Clear_() = 0; 101 102 QueueType queue_type_; 103 bool error_; 104}; 105 106 107// Trivial queue discipline, templated on the StateId. You may enqueue 108// at most one state at a time. It is used for strongly connected components 109// with only one state and no self loops. 110template <class S> 111class TrivialQueue : public QueueBase<S> { 112public: 113 typedef S StateId; 114 115 TrivialQueue() : QueueBase<S>(TRIVIAL_QUEUE), front_(kNoStateId) {} 116 StateId Head() const { return front_; } 117 void Enqueue(StateId s) { front_ = s; } 118 void Dequeue() { front_ = kNoStateId; } 119 void Update(StateId s) {} 120 bool Empty() const { return front_ == kNoStateId; } 121 void Clear() { front_ = kNoStateId; } 122 123 124private: 125 // This allows base-class virtual access to non-virtual derived- 126 // class members of the same name. It makes the derived class more 127 // efficient to use but unsafe to further derive. 128 virtual StateId Head_() const { return Head(); } 129 virtual void Enqueue_(StateId s) { Enqueue(s); } 130 virtual void Dequeue_() { Dequeue(); } 131 virtual void Update_(StateId s) { Update(s); } 132 virtual bool Empty_() const { return Empty(); } 133 virtual void Clear_() { return Clear(); } 134 135 StateId front_; 136}; 137 138 139// First-in, first-out queue discipline, templated on the StateId. 140template <class S> 141class FifoQueue : public QueueBase<S>, public deque<S> { 142 public: 143 using deque<S>::back; 144 using deque<S>::push_front; 145 using deque<S>::pop_back; 146 using deque<S>::empty; 147 using deque<S>::clear; 148 149 typedef S StateId; 150 151 FifoQueue() : QueueBase<S>(FIFO_QUEUE) {} 152 StateId Head() const { return back(); } 153 void Enqueue(StateId s) { push_front(s); } 154 void Dequeue() { pop_back(); } 155 void Update(StateId s) {} 156 bool Empty() const { return empty(); } 157 void Clear() { clear(); } 158 159 private: 160 // This allows base-class virtual access to non-virtual derived- 161 // class members of the same name. It makes the derived class more 162 // efficient to use but unsafe to further derive. 163 virtual StateId Head_() const { return Head(); } 164 virtual void Enqueue_(StateId s) { Enqueue(s); } 165 virtual void Dequeue_() { Dequeue(); } 166 virtual void Update_(StateId s) { Update(s); } 167 virtual bool Empty_() const { return Empty(); } 168 virtual void Clear_() { return Clear(); } 169}; 170 171 172// Last-in, first-out queue discipline, templated on the StateId. 173template <class S> 174class LifoQueue : public QueueBase<S>, public deque<S> { 175 public: 176 using deque<S>::front; 177 using deque<S>::push_front; 178 using deque<S>::pop_front; 179 using deque<S>::empty; 180 using deque<S>::clear; 181 182 typedef S StateId; 183 184 LifoQueue() : QueueBase<S>(LIFO_QUEUE) {} 185 StateId Head() const { return front(); } 186 void Enqueue(StateId s) { push_front(s); } 187 void Dequeue() { pop_front(); } 188 void Update(StateId s) {} 189 bool Empty() const { return empty(); } 190 void Clear() { clear(); } 191 192 private: 193 // This allows base-class virtual access to non-virtual derived- 194 // class members of the same name. It makes the derived class more 195 // efficient to use but unsafe to further derive. 196 virtual StateId Head_() const { return Head(); } 197 virtual void Enqueue_(StateId s) { Enqueue(s); } 198 virtual void Dequeue_() { Dequeue(); } 199 virtual void Update_(StateId s) { Update(s); } 200 virtual bool Empty_() const { return Empty(); } 201 virtual void Clear_() { return Clear(); } 202}; 203 204 205// Shortest-first queue discipline, templated on the StateId and 206// comparison function object. Comparison function object COMP is 207// used to compare two StateIds. If a (single) state's order changes, 208// it can be reordered in the queue with a call to Update(). 209// If 'update == false', call to Update() does not reorder the queue. 210template <typename S, typename C, bool update = true> 211class ShortestFirstQueue : public QueueBase<S> { 212 public: 213 typedef S StateId; 214 typedef C Compare; 215 216 ShortestFirstQueue(C comp) 217 : QueueBase<S>(SHORTEST_FIRST_QUEUE), heap_(comp) {} 218 219 StateId Head() const { return heap_.Top(); } 220 221 void Enqueue(StateId s) { 222 if (update) { 223 for (StateId i = key_.size(); i <= s; ++i) 224 key_.push_back(kNoKey); 225 key_[s] = heap_.Insert(s); 226 } else { 227 heap_.Insert(s); 228 } 229 } 230 231 void Dequeue() { 232 if (update) 233 key_[heap_.Pop()] = kNoKey; 234 else 235 heap_.Pop(); 236 } 237 238 void Update(StateId s) { 239 if (!update) 240 return; 241 if (s >= key_.size() || key_[s] == kNoKey) { 242 Enqueue(s); 243 } else { 244 heap_.Update(key_[s], s); 245 } 246 } 247 248 bool Empty() const { return heap_.Empty(); } 249 250 void Clear() { 251 heap_.Clear(); 252 if (update) key_.clear(); 253 } 254 255 private: 256 Heap<S, C, false> heap_; 257 vector<ssize_t> key_; 258 259 // This allows base-class virtual access to non-virtual derived- 260 // class members of the same name. It makes the derived class more 261 // efficient to use but unsafe to further derive. 262 virtual StateId Head_() const { return Head(); } 263 virtual void Enqueue_(StateId s) { Enqueue(s); } 264 virtual void Dequeue_() { Dequeue(); } 265 virtual void Update_(StateId s) { Update(s); } 266 virtual bool Empty_() const { return Empty(); } 267 virtual void Clear_() { return Clear(); } 268}; 269 270 271// Given a vector that maps from states to weights and a Less 272// comparison function object between weights, this class defines a 273// comparison function object between states. 274template <typename S, typename L> 275class StateWeightCompare { 276 public: 277 typedef L Less; 278 typedef typename L::Weight Weight; 279 typedef S StateId; 280 281 StateWeightCompare(const vector<Weight>& weights, const L &less) 282 : weights_(weights), less_(less) {} 283 284 bool operator()(const S x, const S y) const { 285 return less_(weights_[x], weights_[y]); 286 } 287 288 private: 289 const vector<Weight>& weights_; 290 L less_; 291}; 292 293 294// Shortest-first queue discipline, templated on the StateId and Weight, is 295// specialized to use the weight's natural order for the comparison function. 296template <typename S, typename W> 297class NaturalShortestFirstQueue : 298 public ShortestFirstQueue<S, StateWeightCompare<S, NaturalLess<W> > > { 299 public: 300 typedef StateWeightCompare<S, NaturalLess<W> > C; 301 302 NaturalShortestFirstQueue(const vector<W> &distance) : 303 ShortestFirstQueue<S, C>(C(distance, less_)) {} 304 305 private: 306 NaturalLess<W> less_; 307}; 308 309// Topological-order queue discipline, templated on the StateId. 310// States are ordered in the queue topologically. The FST must be acyclic. 311template <class S> 312class TopOrderQueue : public QueueBase<S> { 313 public: 314 typedef S StateId; 315 316 // This constructor computes the top. order. It accepts an arc filter 317 // to limit the transitions considered in that computation (e.g., only 318 // the epsilon graph). 319 template <class Arc, class ArcFilter> 320 TopOrderQueue(const Fst<Arc> &fst, ArcFilter filter) 321 : QueueBase<S>(TOP_ORDER_QUEUE), front_(0), back_(kNoStateId), 322 order_(0), state_(0) { 323 bool acyclic; 324 TopOrderVisitor<Arc> top_order_visitor(&order_, &acyclic); 325 DfsVisit(fst, &top_order_visitor, filter); 326 if (!acyclic) { 327 FSTERROR() << "TopOrderQueue: fst is not acyclic."; 328 QueueBase<S>::SetError(true); 329 } 330 state_.resize(order_.size(), kNoStateId); 331 } 332 333 // This constructor is passed the top. order, useful when we know it 334 // beforehand. 335 TopOrderQueue(const vector<StateId> &order) 336 : QueueBase<S>(TOP_ORDER_QUEUE), front_(0), back_(kNoStateId), 337 order_(order), state_(order.size(), kNoStateId) {} 338 339 StateId Head() const { return state_[front_]; } 340 341 void Enqueue(StateId s) { 342 if (front_ > back_) front_ = back_ = order_[s]; 343 else if (order_[s] > back_) back_ = order_[s]; 344 else if (order_[s] < front_) front_ = order_[s]; 345 state_[order_[s]] = s; 346 } 347 348 void Dequeue() { 349 state_[front_] = kNoStateId; 350 while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_; 351 } 352 353 void Update(StateId s) {} 354 355 bool Empty() const { return front_ > back_; } 356 357 void Clear() { 358 for (StateId i = front_; i <= back_; ++i) state_[i] = kNoStateId; 359 back_ = kNoStateId; 360 front_ = 0; 361 } 362 363 private: 364 StateId front_; 365 StateId back_; 366 vector<StateId> order_; 367 vector<StateId> state_; 368 369 // This allows base-class virtual access to non-virtual derived- 370 // class members of the same name. It makes the derived class more 371 // efficient to use but unsafe to further derive. 372 virtual StateId Head_() const { return Head(); } 373 virtual void Enqueue_(StateId s) { Enqueue(s); } 374 virtual void Dequeue_() { Dequeue(); } 375 virtual void Update_(StateId s) { Update(s); } 376 virtual bool Empty_() const { return Empty(); } 377 virtual void Clear_() { return Clear(); } 378}; 379 380 381// State order queue discipline, templated on the StateId. 382// States are ordered in the queue by state Id. 383template <class S> 384class StateOrderQueue : public QueueBase<S> { 385public: 386 typedef S StateId; 387 388 StateOrderQueue() 389 : QueueBase<S>(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {} 390 391 StateId Head() const { return front_; } 392 393 void Enqueue(StateId s) { 394 if (front_ > back_) front_ = back_ = s; 395 else if (s > back_) back_ = s; 396 else if (s < front_) front_ = s; 397 while (enqueued_.size() <= s) enqueued_.push_back(false); 398 enqueued_[s] = true; 399 } 400 401 void Dequeue() { 402 enqueued_[front_] = false; 403 while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_; 404 } 405 406 void Update(StateId s) {} 407 408 bool Empty() const { return front_ > back_; } 409 410 void Clear() { 411 for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false; 412 front_ = 0; 413 back_ = kNoStateId; 414 } 415 416private: 417 StateId front_; 418 StateId back_; 419 vector<bool> enqueued_; 420 421 // This allows base-class virtual access to non-virtual derived- 422 // class members of the same name. It makes the derived class more 423 // efficient to use but unsafe to further derive. 424 virtual StateId Head_() const { return Head(); } 425 virtual void Enqueue_(StateId s) { Enqueue(s); } 426 virtual void Dequeue_() { Dequeue(); } 427 virtual void Update_(StateId s) { Update(s); } 428 virtual bool Empty_() const { return Empty(); } 429 virtual void Clear_() { return Clear(); } 430 431}; 432 433 434// SCC topological-order meta-queue discipline, templated on the StateId S 435// and a queue Q, which is used inside each SCC. It visits the SCC's 436// of an FST in topological order. Its constructor is passed the queues to 437// to use within an SCC. 438template <class S, class Q> 439class SccQueue : public QueueBase<S> { 440 public: 441 typedef S StateId; 442 typedef Q Queue; 443 444 // Constructor takes a vector specifying the SCC number per state 445 // and a vector giving the queue to use per SCC number. 446 SccQueue(const vector<StateId> &scc, vector<Queue*> *queue) 447 : QueueBase<S>(SCC_QUEUE), queue_(queue), scc_(scc), front_(0), 448 back_(kNoStateId) {} 449 450 StateId Head() const { 451 while ((front_ <= back_) && 452 (((*queue_)[front_] && (*queue_)[front_]->Empty()) 453 || (((*queue_)[front_] == 0) && 454 ((front_ > trivial_queue_.size()) 455 || (trivial_queue_[front_] == kNoStateId))))) 456 ++front_; 457 if ((*queue_)[front_]) 458 return (*queue_)[front_]->Head(); 459 else 460 return trivial_queue_[front_]; 461 } 462 463 void Enqueue(StateId s) { 464 if (front_ > back_) front_ = back_ = scc_[s]; 465 else if (scc_[s] > back_) back_ = scc_[s]; 466 else if (scc_[s] < front_) front_ = scc_[s]; 467 if ((*queue_)[scc_[s]]) { 468 (*queue_)[scc_[s]]->Enqueue(s); 469 } else { 470 while (trivial_queue_.size() <= scc_[s]) 471 trivial_queue_.push_back(kNoStateId); 472 trivial_queue_[scc_[s]] = s; 473 } 474 } 475 476 void Dequeue() { 477 if ((*queue_)[front_]) 478 (*queue_)[front_]->Dequeue(); 479 else if (front_ < trivial_queue_.size()) 480 trivial_queue_[front_] = kNoStateId; 481 } 482 483 void Update(StateId s) { 484 if ((*queue_)[scc_[s]]) 485 (*queue_)[scc_[s]]->Update(s); 486 } 487 488 bool Empty() const { 489 if (front_ < back_) // Queue scc # back_ not empty unless back_==front_ 490 return false; 491 else if (front_ > back_) 492 return true; 493 else if ((*queue_)[front_]) 494 return (*queue_)[front_]->Empty(); 495 else 496 return (front_ > trivial_queue_.size()) 497 || (trivial_queue_[front_] == kNoStateId); 498 } 499 500 void Clear() { 501 for (StateId i = front_; i <= back_; ++i) 502 if ((*queue_)[i]) 503 (*queue_)[i]->Clear(); 504 else if (i < trivial_queue_.size()) 505 trivial_queue_[i] = kNoStateId; 506 front_ = 0; 507 back_ = kNoStateId; 508 } 509 510private: 511 vector<Queue*> *queue_; 512 const vector<StateId> &scc_; 513 mutable StateId front_; 514 StateId back_; 515 vector<StateId> trivial_queue_; 516 517 // This allows base-class virtual access to non-virtual derived- 518 // class members of the same name. It makes the derived class more 519 // efficient to use but unsafe to further derive. 520 virtual StateId Head_() const { return Head(); } 521 virtual void Enqueue_(StateId s) { Enqueue(s); } 522 virtual void Dequeue_() { Dequeue(); } 523 virtual void Update_(StateId s) { Update(s); } 524 virtual bool Empty_() const { return Empty(); } 525 virtual void Clear_() { return Clear(); } 526 527 DISALLOW_COPY_AND_ASSIGN(SccQueue); 528}; 529 530 531// Automatic queue discipline, templated on the StateId. It selects a 532// queue discipline for a given FST based on its properties. 533template <class S> 534class AutoQueue : public QueueBase<S> { 535public: 536 typedef S StateId; 537 538 // This constructor takes a state distance vector that, if non-null and if 539 // the Weight type has the path property, will entertain the 540 // shortest-first queue using the natural order w.r.t to the distance. 541 template <class Arc, class ArcFilter> 542 AutoQueue(const Fst<Arc> &fst, const vector<typename Arc::Weight> *distance, 543 ArcFilter filter) : QueueBase<S>(AUTO_QUEUE) { 544 typedef typename Arc::Weight Weight; 545 typedef StateWeightCompare< StateId, NaturalLess<Weight> > Compare; 546 547 // First check if the FST is known to have these properties. 548 uint64 props = fst.Properties(kAcyclic | kCyclic | 549 kTopSorted | kUnweighted, false); 550 if ((props & kTopSorted) || fst.Start() == kNoStateId) { 551 queue_ = new StateOrderQueue<StateId>(); 552 VLOG(2) << "AutoQueue: using state-order discipline"; 553 } else if (props & kAcyclic) { 554 queue_ = new TopOrderQueue<StateId>(fst, filter); 555 VLOG(2) << "AutoQueue: using top-order discipline"; 556 } else if ((props & kUnweighted) && (Weight::Properties() & kIdempotent)) { 557 queue_ = new LifoQueue<StateId>(); 558 VLOG(2) << "AutoQueue: using LIFO discipline"; 559 } else { 560 uint64 properties; 561 // Decompose into strongly-connected components. 562 SccVisitor<Arc> scc_visitor(&scc_, 0, 0, &properties); 563 DfsVisit(fst, &scc_visitor, filter); 564 StateId nscc = *max_element(scc_.begin(), scc_.end()) + 1; 565 vector<QueueType> queue_types(nscc); 566 NaturalLess<Weight> *less = 0; 567 Compare *comp = 0; 568 if (distance && (Weight::Properties() & kPath)) { 569 less = new NaturalLess<Weight>; 570 comp = new Compare(*distance, *less); 571 } 572 // Find the queue type to use per SCC. 573 bool unweighted; 574 bool all_trivial; 575 SccQueueType(fst, scc_, &queue_types, filter, less, &all_trivial, 576 &unweighted); 577 // If unweighted and semiring is idempotent, use lifo queue. 578 if (unweighted) { 579 queue_ = new LifoQueue<StateId>(); 580 VLOG(2) << "AutoQueue: using LIFO discipline"; 581 delete comp; 582 delete less; 583 return; 584 } 585 // If all the scc are trivial, FST is acyclic and the scc# gives 586 // the topological order. 587 if (all_trivial) { 588 queue_ = new TopOrderQueue<StateId>(scc_); 589 VLOG(2) << "AutoQueue: using top-order discipline"; 590 delete comp; 591 delete less; 592 return; 593 } 594 VLOG(2) << "AutoQueue: using SCC meta-discipline"; 595 queues_.resize(nscc); 596 for (StateId i = 0; i < nscc; ++i) { 597 switch(queue_types[i]) { 598 case TRIVIAL_QUEUE: 599 queues_[i] = 0; 600 VLOG(3) << "AutoQueue: SCC #" << i 601 << ": using trivial discipline"; 602 break; 603 case SHORTEST_FIRST_QUEUE: 604 queues_[i] = new ShortestFirstQueue<StateId, Compare, false>(*comp); 605 VLOG(3) << "AutoQueue: SCC #" << i << 606 ": using shortest-first discipline"; 607 break; 608 case LIFO_QUEUE: 609 queues_[i] = new LifoQueue<StateId>(); 610 VLOG(3) << "AutoQueue: SCC #" << i 611 << ": using LIFO disciplle"; 612 break; 613 case FIFO_QUEUE: 614 default: 615 queues_[i] = new FifoQueue<StateId>(); 616 VLOG(3) << "AutoQueue: SCC #" << i 617 << ": using FIFO disciplle"; 618 break; 619 } 620 } 621 queue_ = new SccQueue< StateId, QueueBase<StateId> >(scc_, &queues_); 622 delete comp; 623 delete less; 624 } 625 } 626 627 ~AutoQueue() { 628 for (StateId i = 0; i < queues_.size(); ++i) 629 delete queues_[i]; 630 delete queue_; 631 } 632 633 StateId Head() const { return queue_->Head(); } 634 635 void Enqueue(StateId s) { queue_->Enqueue(s); } 636 637 void Dequeue() { queue_->Dequeue(); } 638 639 void Update(StateId s) { queue_->Update(s); } 640 641 bool Empty() const { return queue_->Empty(); } 642 643 void Clear() { queue_->Clear(); } 644 645 646 private: 647 QueueBase<StateId> *queue_; 648 vector< QueueBase<StateId>* > queues_; 649 vector<StateId> scc_; 650 651 template <class Arc, class ArcFilter, class Less> 652 static void SccQueueType(const Fst<Arc> &fst, 653 const vector<StateId> &scc, 654 vector<QueueType> *queue_types, 655 ArcFilter filter, Less *less, 656 bool *all_trivial, bool *unweighted); 657 658 // This allows base-class virtual access to non-virtual derived- 659 // class members of the same name. It makes the derived class more 660 // efficient to use but unsafe to further derive. 661 virtual StateId Head_() const { return Head(); } 662 663 virtual void Enqueue_(StateId s) { Enqueue(s); } 664 665 virtual void Dequeue_() { Dequeue(); } 666 667 virtual void Update_(StateId s) { Update(s); } 668 669 virtual bool Empty_() const { return Empty(); } 670 671 virtual void Clear_() { return Clear(); } 672 673 DISALLOW_COPY_AND_ASSIGN(AutoQueue); 674}; 675 676 677// Examines the states in an Fst's strongly connected components and 678// determines which type of queue to use per SCC. Stores result in 679// vector QUEUE_TYPES, which is assumed to have length equal to the 680// number of SCCs. An arc filter is used to limit the transitions 681// considered (e.g., only the epsilon graph). ALL_TRIVIAL is set 682// to true if every queue is the trivial queue. UNWEIGHTED is set to 683// true if the semiring is idempotent and all the arc weights are equal to 684// Zero() or One(). 685template <class StateId> 686template <class A, class ArcFilter, class Less> 687void AutoQueue<StateId>::SccQueueType(const Fst<A> &fst, 688 const vector<StateId> &scc, 689 vector<QueueType> *queue_type, 690 ArcFilter filter, Less *less, 691 bool *all_trivial, bool *unweighted) { 692 typedef A Arc; 693 typedef typename A::StateId StateId; 694 typedef typename A::Weight Weight; 695 696 *all_trivial = true; 697 *unweighted = true; 698 699 for (StateId i = 0; i < queue_type->size(); ++i) 700 (*queue_type)[i] = TRIVIAL_QUEUE; 701 702 for (StateIterator< Fst<Arc> > sit(fst); !sit.Done(); sit.Next()) { 703 StateId state = sit.Value(); 704 for (ArcIterator< Fst<Arc> > ait(fst, state); 705 !ait.Done(); 706 ait.Next()) { 707 const Arc &arc = ait.Value(); 708 if (!filter(arc)) continue; 709 if (scc[state] == scc[arc.nextstate]) { 710 QueueType &type = (*queue_type)[scc[state]]; 711 if (!less || ((*less)(arc.weight, Weight::One()))) 712 type = FIFO_QUEUE; 713 else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) { 714 if (!(Weight::Properties() & kIdempotent) || 715 (arc.weight != Weight::Zero() && arc.weight != Weight::One())) 716 type = SHORTEST_FIRST_QUEUE; 717 else 718 type = LIFO_QUEUE; 719 } 720 if (type != TRIVIAL_QUEUE) *all_trivial = false; 721 } 722 if (!(Weight::Properties() & kIdempotent) || 723 (arc.weight != Weight::Zero() && arc.weight != Weight::One())) 724 *unweighted = false; 725 } 726 } 727} 728 729 730// An A* estimate is a function object that maps from a state ID to a 731// an estimate of the shortest distance to the final states. 732// The trivial A* estimate is always One(). 733template <typename S, typename W> 734struct TrivialAStarEstimate { 735 W operator()(S s) const { return W::One(); } 736}; 737 738 739// Given a vector that maps from states to weights representing the 740// shortest distance from the initial state, a Less comparison 741// function object between weights, and an estimate E of the 742// shortest distance to the final states, this class defines a 743// comparison function object between states. 744template <typename S, typename L, typename E> 745class AStarWeightCompare { 746 public: 747 typedef L Less; 748 typedef typename L::Weight Weight; 749 typedef S StateId; 750 751 AStarWeightCompare(const vector<Weight>& weights, const L &less, 752 const E &estimate) 753 : weights_(weights), less_(less), estimate_(estimate) {} 754 755 bool operator()(const S x, const S y) const { 756 Weight wx = Times(weights_[x], estimate_(x)); 757 Weight wy = Times(weights_[y], estimate_(y)); 758 return less_(wx, wy); 759 } 760 761 private: 762 const vector<Weight>& weights_; 763 L less_; 764 const E &estimate_; 765}; 766 767 768// A* queue discipline, templated on the StateId, Weight and an 769// estimate E of the shortest distance to the final states, is specialized 770// to use the weight's natural order for the comparison function. 771template <typename S, typename W, typename E> 772class NaturalAStarQueue : 773 public ShortestFirstQueue<S, AStarWeightCompare<S, NaturalLess<W>, E> > { 774 public: 775 typedef AStarWeightCompare<S, NaturalLess<W>, E> C; 776 777 NaturalAStarQueue(const vector<W> &distance, const E &estimate) : 778 ShortestFirstQueue<S, C>(C(distance, less_, estimate)) {} 779 780 private: 781 NaturalLess<W> less_; 782}; 783 784 785// A state equivalence class is a function object that 786// maps from a state ID to an equivalence class (state) ID. 787// The trivial equivalence class maps a state to itself. 788template <typename S> 789struct TrivialStateEquivClass { 790 S operator()(S s) const { return s; } 791}; 792 793 794// Pruning queue discipline: Enqueues a state 's' only when its 795// shortest distance (so far), as specified by 'distance', is less 796// than (as specified by 'comp') the shortest distance Times() the 797// 'threshold' to any state in the same equivalence class, as 798// specified by the function object 'class_func'. The underlying 799// queue discipline is specified by 'queue'. The ownership of 'queue' 800// is given to this class. 801template <typename Q, typename L, typename C> 802class PruneQueue : public QueueBase<typename Q::StateId> { 803 public: 804 typedef typename Q::StateId StateId; 805 typedef typename L::Weight Weight; 806 807 PruneQueue(const vector<Weight> &distance, Q *queue, L comp, 808 const C &class_func, Weight threshold) 809 : QueueBase<StateId>(OTHER_QUEUE), 810 distance_(distance), 811 queue_(queue), 812 less_(comp), 813 class_func_(class_func), 814 threshold_(threshold) {} 815 816 ~PruneQueue() { delete queue_; } 817 818 StateId Head() const { return queue_->Head(); } 819 820 void Enqueue(StateId s) { 821 StateId c = class_func_(s); 822 if (c >= class_distance_.size()) 823 class_distance_.resize(c + 1, Weight::Zero()); 824 if (less_(distance_[s], class_distance_[c])) 825 class_distance_[c] = distance_[s]; 826 827 // Enqueue only if below threshold limit 828 Weight limit = Times(class_distance_[c], threshold_); 829 if (less_(distance_[s], limit)) 830 queue_->Enqueue(s); 831 } 832 833 void Dequeue() { queue_->Dequeue(); } 834 835 void Update(StateId s) { 836 StateId c = class_func_(s); 837 if (less_(distance_[s], class_distance_[c])) 838 class_distance_[c] = distance_[s]; 839 queue_->Update(s); 840 } 841 842 bool Empty() const { return queue_->Empty(); } 843 void Clear() { queue_->Clear(); } 844 845 private: 846 // This allows base-class virtual access to non-virtual derived- 847 // class members of the same name. It makes the derived class more 848 // efficient to use but unsafe to further derive. 849 virtual StateId Head_() const { return Head(); } 850 virtual void Enqueue_(StateId s) { Enqueue(s); } 851 virtual void Dequeue_() { Dequeue(); } 852 virtual void Update_(StateId s) { Update(s); } 853 virtual bool Empty_() const { return Empty(); } 854 virtual void Clear_() { return Clear(); } 855 856 const vector<Weight> &distance_; // shortest distance to state 857 Q *queue_; 858 L less_; 859 const C &class_func_; // eqv. class function object 860 Weight threshold_; // pruning weight threshold 861 vector<Weight> class_distance_; // shortest distance to class 862 863 DISALLOW_COPY_AND_ASSIGN(PruneQueue); 864}; 865 866 867// Pruning queue discipline (see above) using the weight's natural 868// order for the comparison function. The ownership of 'queue' is 869// given to this class. 870template <typename Q, typename W, typename C> 871class NaturalPruneQueue : 872 public PruneQueue<Q, NaturalLess<W>, C> { 873 public: 874 typedef typename Q::StateId StateId; 875 typedef W Weight; 876 877 NaturalPruneQueue(const vector<W> &distance, Q *queue, 878 const C &class_func_, Weight threshold) : 879 PruneQueue<Q, NaturalLess<W>, C>(distance, queue, less_, 880 class_func_, threshold) {} 881 882 private: 883 NaturalLess<W> less_; 884}; 885 886 887} // namespace fst 888 889#endif 890