factor-weight.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
1// factor-weight.h 2 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// Copyright 2005-2010 Google, Inc. 16// Author: allauzen@google.com (Cyril Allauzen) 17// 18// \file 19// Classes to factor weights in an FST. 20 21#ifndef FST_LIB_FACTOR_WEIGHT_H__ 22#define FST_LIB_FACTOR_WEIGHT_H__ 23 24#include <algorithm> 25#include <unordered_map> 26using std::tr1::unordered_map; 27using std::tr1::unordered_multimap; 28#include <fst/slist.h> 29#include <string> 30#include <utility> 31using std::pair; using std::make_pair; 32#include <vector> 33using std::vector; 34 35#include <fst/cache.h> 36#include <fst/test-properties.h> 37 38 39namespace fst { 40 41const uint32 kFactorFinalWeights = 0x00000001; 42const uint32 kFactorArcWeights = 0x00000002; 43 44template <class Arc> 45struct FactorWeightOptions : CacheOptions { 46 typedef typename Arc::Label Label; 47 float delta; 48 uint32 mode; // factor arc weights and/or final weights 49 Label final_ilabel; // input label of arc created when factoring final w's 50 Label final_olabel; // output label of arc created when factoring final w's 51 52 FactorWeightOptions(const CacheOptions &opts, float d, 53 uint32 m = kFactorArcWeights | kFactorFinalWeights, 54 Label il = 0, Label ol = 0) 55 : CacheOptions(opts), delta(d), mode(m), final_ilabel(il), 56 final_olabel(ol) {} 57 58 explicit FactorWeightOptions( 59 float d, uint32 m = kFactorArcWeights | kFactorFinalWeights, 60 Label il = 0, Label ol = 0) 61 : delta(d), mode(m), final_ilabel(il), final_olabel(ol) {} 62 63 FactorWeightOptions(uint32 m = kFactorArcWeights | kFactorFinalWeights, 64 Label il = 0, Label ol = 0) 65 : delta(kDelta), mode(m), final_ilabel(il), final_olabel(ol) {} 66}; 67 68 69// A factor iterator takes as argument a weight w and returns a 70// sequence of pairs of weights (xi,yi) such that the sum of the 71// products xi times yi is equal to w. If w is fully factored, 72// the iterator should return nothing. 73// 74// template <class W> 75// class FactorIterator { 76// public: 77// FactorIterator(W w); 78// bool Done() const; 79// void Next(); 80// pair<W, W> Value() const; 81// void Reset(); 82// } 83 84 85// Factor trivially. 86template <class W> 87class IdentityFactor { 88 public: 89 IdentityFactor(const W &w) {} 90 bool Done() const { return true; } 91 void Next() {} 92 pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused 93 void Reset() {} 94}; 95 96 97// Factor a StringWeight w as 'ab' where 'a' is a label. 98template <typename L, StringType S = STRING_LEFT> 99class StringFactor { 100 public: 101 StringFactor(const StringWeight<L, S> &w) 102 : weight_(w), done_(w.Size() <= 1) {} 103 104 bool Done() const { return done_; } 105 106 void Next() { done_ = true; } 107 108 pair< StringWeight<L, S>, StringWeight<L, S> > Value() const { 109 StringWeightIterator<L, S> iter(weight_); 110 StringWeight<L, S> w1(iter.Value()); 111 StringWeight<L, S> w2; 112 for (iter.Next(); !iter.Done(); iter.Next()) 113 w2.PushBack(iter.Value()); 114 return make_pair(w1, w2); 115 } 116 117 void Reset() { done_ = weight_.Size() <= 1; } 118 119 private: 120 StringWeight<L, S> weight_; 121 bool done_; 122}; 123 124 125// Factor a GallicWeight using StringFactor. 126template <class L, class W, StringType S = STRING_LEFT> 127class GallicFactor { 128 public: 129 GallicFactor(const GallicWeight<L, W, S> &w) 130 : weight_(w), done_(w.Value1().Size() <= 1) {} 131 132 bool Done() const { return done_; } 133 134 void Next() { done_ = true; } 135 136 pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const { 137 StringFactor<L, S> iter(weight_.Value1()); 138 GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2()); 139 GallicWeight<L, W, S> w2(iter.Value().second, W::One()); 140 return make_pair(w1, w2); 141 } 142 143 void Reset() { done_ = weight_.Value1().Size() <= 1; } 144 145 private: 146 GallicWeight<L, W, S> weight_; 147 bool done_; 148}; 149 150 151// Implementation class for FactorWeight 152template <class A, class F> 153class FactorWeightFstImpl 154 : public CacheImpl<A> { 155 public: 156 using FstImpl<A>::SetType; 157 using FstImpl<A>::SetProperties; 158 using FstImpl<A>::SetInputSymbols; 159 using FstImpl<A>::SetOutputSymbols; 160 161 using CacheBaseImpl< CacheState<A> >::PushArc; 162 using CacheBaseImpl< CacheState<A> >::HasStart; 163 using CacheBaseImpl< CacheState<A> >::HasFinal; 164 using CacheBaseImpl< CacheState<A> >::HasArcs; 165 using CacheBaseImpl< CacheState<A> >::SetArcs; 166 using CacheBaseImpl< CacheState<A> >::SetFinal; 167 using CacheBaseImpl< CacheState<A> >::SetStart; 168 169 typedef A Arc; 170 typedef typename A::Label Label; 171 typedef typename A::Weight Weight; 172 typedef typename A::StateId StateId; 173 typedef F FactorIterator; 174 175 struct Element { 176 Element() {} 177 178 Element(StateId s, Weight w) : state(s), weight(w) {} 179 180 StateId state; // Input state Id 181 Weight weight; // Residual weight 182 }; 183 184 FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions<A> &opts) 185 : CacheImpl<A>(opts), 186 fst_(fst.Copy()), 187 delta_(opts.delta), 188 mode_(opts.mode), 189 final_ilabel_(opts.final_ilabel), 190 final_olabel_(opts.final_olabel) { 191 SetType("factor_weight"); 192 uint64 props = fst.Properties(kFstProperties, false); 193 SetProperties(FactorWeightProperties(props), kCopyProperties); 194 195 SetInputSymbols(fst.InputSymbols()); 196 SetOutputSymbols(fst.OutputSymbols()); 197 198 if (mode_ == 0) 199 LOG(WARNING) << "FactorWeightFst: factor mode is set to 0: " 200 << "factoring neither arc weights nor final weights."; 201 } 202 203 FactorWeightFstImpl(const FactorWeightFstImpl<A, F> &impl) 204 : CacheImpl<A>(impl), 205 fst_(impl.fst_->Copy(true)), 206 delta_(impl.delta_), 207 mode_(impl.mode_), 208 final_ilabel_(impl.final_ilabel_), 209 final_olabel_(impl.final_olabel_) { 210 SetType("factor_weight"); 211 SetProperties(impl.Properties(), kCopyProperties); 212 SetInputSymbols(impl.InputSymbols()); 213 SetOutputSymbols(impl.OutputSymbols()); 214 } 215 216 ~FactorWeightFstImpl() { 217 delete fst_; 218 } 219 220 StateId Start() { 221 if (!HasStart()) { 222 StateId s = fst_->Start(); 223 if (s == kNoStateId) 224 return kNoStateId; 225 StateId start = FindState(Element(fst_->Start(), Weight::One())); 226 SetStart(start); 227 } 228 return CacheImpl<A>::Start(); 229 } 230 231 Weight Final(StateId s) { 232 if (!HasFinal(s)) { 233 const Element &e = elements_[s]; 234 // TODO: fix so cast is unnecessary 235 Weight w = e.state == kNoStateId 236 ? e.weight 237 : (Weight) Times(e.weight, fst_->Final(e.state)); 238 FactorIterator f(w); 239 if (!(mode_ & kFactorFinalWeights) || f.Done()) 240 SetFinal(s, w); 241 else 242 SetFinal(s, Weight::Zero()); 243 } 244 return CacheImpl<A>::Final(s); 245 } 246 247 size_t NumArcs(StateId s) { 248 if (!HasArcs(s)) 249 Expand(s); 250 return CacheImpl<A>::NumArcs(s); 251 } 252 253 size_t NumInputEpsilons(StateId s) { 254 if (!HasArcs(s)) 255 Expand(s); 256 return CacheImpl<A>::NumInputEpsilons(s); 257 } 258 259 size_t NumOutputEpsilons(StateId s) { 260 if (!HasArcs(s)) 261 Expand(s); 262 return CacheImpl<A>::NumOutputEpsilons(s); 263 } 264 265 uint64 Properties() const { return Properties(kFstProperties); } 266 267 // Set error if found; return FST impl properties. 268 uint64 Properties(uint64 mask) const { 269 if ((mask & kError) && fst_->Properties(kError, false)) 270 SetProperties(kError, kError); 271 return FstImpl<Arc>::Properties(mask); 272 } 273 274 void InitArcIterator(StateId s, ArcIteratorData<A> *data) { 275 if (!HasArcs(s)) 276 Expand(s); 277 CacheImpl<A>::InitArcIterator(s, data); 278 } 279 280 281 // Find state corresponding to an element. Create new state 282 // if element not found. 283 StateId FindState(const Element &e) { 284 if (!(mode_ & kFactorArcWeights) && e.weight == Weight::One()) { 285 while (unfactored_.size() <= e.state) 286 unfactored_.push_back(kNoStateId); 287 if (unfactored_[e.state] == kNoStateId) { 288 unfactored_[e.state] = elements_.size(); 289 elements_.push_back(e); 290 } 291 return unfactored_[e.state]; 292 } else { 293 typename ElementMap::iterator eit = element_map_.find(e); 294 if (eit != element_map_.end()) { 295 return (*eit).second; 296 } else { 297 StateId s = elements_.size(); 298 elements_.push_back(e); 299 element_map_.insert(pair<const Element, StateId>(e, s)); 300 return s; 301 } 302 } 303 } 304 305 // Computes the outgoing transitions from a state, creating new destination 306 // states as needed. 307 void Expand(StateId s) { 308 Element e = elements_[s]; 309 if (e.state != kNoStateId) { 310 for (ArcIterator< Fst<A> > ait(*fst_, e.state); 311 !ait.Done(); 312 ait.Next()) { 313 const A &arc = ait.Value(); 314 Weight w = Times(e.weight, arc.weight); 315 FactorIterator fit(w); 316 if (!(mode_ & kFactorArcWeights) || fit.Done()) { 317 StateId d = FindState(Element(arc.nextstate, Weight::One())); 318 PushArc(s, Arc(arc.ilabel, arc.olabel, w, d)); 319 } else { 320 for (; !fit.Done(); fit.Next()) { 321 const pair<Weight, Weight> &p = fit.Value(); 322 StateId d = FindState(Element(arc.nextstate, 323 p.second.Quantize(delta_))); 324 PushArc(s, Arc(arc.ilabel, arc.olabel, p.first, d)); 325 } 326 } 327 } 328 } 329 330 if ((mode_ & kFactorFinalWeights) && 331 ((e.state == kNoStateId) || 332 (fst_->Final(e.state) != Weight::Zero()))) { 333 Weight w = e.state == kNoStateId 334 ? e.weight 335 : Times(e.weight, fst_->Final(e.state)); 336 for (FactorIterator fit(w); 337 !fit.Done(); 338 fit.Next()) { 339 const pair<Weight, Weight> &p = fit.Value(); 340 StateId d = FindState(Element(kNoStateId, 341 p.second.Quantize(delta_))); 342 PushArc(s, Arc(final_ilabel_, final_olabel_, p.first, d)); 343 } 344 } 345 SetArcs(s); 346 } 347 348 private: 349 static const size_t kPrime = 7853; 350 351 // Equality function for Elements, assume weights have been quantized. 352 class ElementEqual { 353 public: 354 bool operator()(const Element &x, const Element &y) const { 355 return x.state == y.state && x.weight == y.weight; 356 } 357 }; 358 359 // Hash function for Elements to Fst states. 360 class ElementKey { 361 public: 362 size_t operator()(const Element &x) const { 363 return static_cast<size_t>(x.state * kPrime + x.weight.Hash()); 364 } 365 private: 366 }; 367 368 typedef unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap; 369 370 const Fst<A> *fst_; 371 float delta_; 372 uint32 mode_; // factoring arc and/or final weights 373 Label final_ilabel_; // ilabel of arc created when factoring final w's 374 Label final_olabel_; // olabel of arc created when factoring final w's 375 vector<Element> elements_; // mapping Fst state to Elements 376 ElementMap element_map_; // mapping Elements to Fst state 377 // mapping between old/new 'StateId' for states that do not need to 378 // be factored when 'mode_' is '0' or 'kFactorFinalWeights' 379 vector<StateId> unfactored_; 380 381 void operator=(const FactorWeightFstImpl<A, F> &); // disallow 382}; 383 384template <class A, class F> const size_t FactorWeightFstImpl<A, F>::kPrime; 385 386 387// FactorWeightFst takes as template parameter a FactorIterator as 388// defined above. The result of weight factoring is a transducer 389// equivalent to the input whose path weights have been factored 390// according to the FactorIterator. States and transitions will be 391// added as necessary. The algorithm is a generalization to arbitrary 392// weights of the second step of the input epsilon-normalization 393// algorithm due to Mohri, "Generic epsilon-removal and input 394// epsilon-normalization algorithms for weighted transducers", 395// International Journal of Computer Science 13(1): 129-143 (2002). 396// 397// This class attaches interface to implementation and handles 398// reference counting, delegating most methods to ImplToFst. 399template <class A, class F> 400class FactorWeightFst : public ImplToFst< FactorWeightFstImpl<A, F> > { 401 public: 402 friend class ArcIterator< FactorWeightFst<A, F> >; 403 friend class StateIterator< FactorWeightFst<A, F> >; 404 405 typedef A Arc; 406 typedef typename A::Weight Weight; 407 typedef typename A::StateId StateId; 408 typedef CacheState<A> State; 409 typedef FactorWeightFstImpl<A, F> Impl; 410 411 FactorWeightFst(const Fst<A> &fst) 412 : ImplToFst<Impl>(new Impl(fst, FactorWeightOptions<A>())) {} 413 414 FactorWeightFst(const Fst<A> &fst, const FactorWeightOptions<A> &opts) 415 : ImplToFst<Impl>(new Impl(fst, opts)) {} 416 417 // See Fst<>::Copy() for doc. 418 FactorWeightFst(const FactorWeightFst<A, F> &fst, bool copy) 419 : ImplToFst<Impl>(fst, copy) {} 420 421 // Get a copy of this FactorWeightFst. See Fst<>::Copy() for further doc. 422 virtual FactorWeightFst<A, F> *Copy(bool copy = false) const { 423 return new FactorWeightFst<A, F>(*this, copy); 424 } 425 426 virtual inline void InitStateIterator(StateIteratorData<A> *data) const; 427 428 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { 429 GetImpl()->InitArcIterator(s, data); 430 } 431 432 private: 433 // Makes visible to friends. 434 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } 435 436 void operator=(const FactorWeightFst<A, F> &fst); // Disallow 437}; 438 439 440// Specialization for FactorWeightFst. 441template<class A, class F> 442class StateIterator< FactorWeightFst<A, F> > 443 : public CacheStateIterator< FactorWeightFst<A, F> > { 444 public: 445 explicit StateIterator(const FactorWeightFst<A, F> &fst) 446 : CacheStateIterator< FactorWeightFst<A, F> >(fst, fst.GetImpl()) {} 447}; 448 449 450// Specialization for FactorWeightFst. 451template <class A, class F> 452class ArcIterator< FactorWeightFst<A, F> > 453 : public CacheArcIterator< FactorWeightFst<A, F> > { 454 public: 455 typedef typename A::StateId StateId; 456 457 ArcIterator(const FactorWeightFst<A, F> &fst, StateId s) 458 : CacheArcIterator< FactorWeightFst<A, F> >(fst.GetImpl(), s) { 459 if (!fst.GetImpl()->HasArcs(s)) 460 fst.GetImpl()->Expand(s); 461 } 462 463 private: 464 DISALLOW_COPY_AND_ASSIGN(ArcIterator); 465}; 466 467template <class A, class F> inline 468void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const 469{ 470 data->base = new StateIterator< FactorWeightFst<A, F> >(*this); 471} 472 473 474} // namespace fst 475 476#endif // FST_LIB_FACTOR_WEIGHT_H__ 477