1// factor-weight.h 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// Author: allauzen@cs.nyu.edu (Cyril Allauzen) 16// 17// \file 18// Classes to factor weights in an FST. 19 20#ifndef FST_LIB_FACTOR_WEIGHT_H__ 21#define FST_LIB_FACTOR_WEIGHT_H__ 22 23#include <algorithm> 24 25#include <unordered_map> 26#include <forward_list> 27 28#include "fst/lib/cache.h" 29#include "fst/lib/test-properties.h" 30 31namespace fst { 32 33struct FactorWeightOptions : CacheOptions { 34 float delta; 35 bool final_only; // only factor final weights when true 36 37 FactorWeightOptions(const CacheOptions &opts, float d, bool of) 38 : CacheOptions(opts), delta(d), final_only(of) {} 39 40 explicit FactorWeightOptions(float d, bool of = false) 41 : delta(d), final_only(of) {} 42 43 FactorWeightOptions(bool of = false) 44 : delta(kDelta), final_only(of) {} 45}; 46 47 48// A factor iterator takes as argument a weight w and returns a 49// sequence of pairs of weights (xi,yi) such that the sum of the 50// products xi times yi is equal to w. If w is fully factored, 51// the iterator should return nothing. 52// 53// template <class W> 54// class FactorIterator { 55// public: 56// FactorIterator(W w); 57// bool Done() const; 58// void Next(); 59// pair<W, W> Value() const; 60// void Reset(); 61// } 62 63 64// Factor trivially. 65template <class W> 66class IdentityFactor { 67 public: 68 IdentityFactor(const W &w) {} 69 bool Done() const { return true; } 70 void Next() {} 71 pair<W, W> Value() const { return make_pair(W::One(), W::One()); } // unused 72 void Reset() {} 73}; 74 75 76// Factor a StringWeight w as 'ab' where 'a' is a label. 77template <typename L, StringType S = STRING_LEFT> 78class StringFactor { 79 public: 80 StringFactor(const StringWeight<L, S> &w) 81 : weight_(w), done_(w.Size() <= 1) {} 82 83 bool Done() const { return done_; } 84 85 void Next() { done_ = true; } 86 87 pair< StringWeight<L, S>, StringWeight<L, S> > Value() const { 88 StringWeightIterator<L, S> iter(weight_); 89 StringWeight<L, S> w1(iter.Value()); 90 StringWeight<L, S> w2; 91 for (iter.Next(); !iter.Done(); iter.Next()) 92 w2.PushBack(iter.Value()); 93 return make_pair(w1, w2); 94 } 95 96 void Reset() { done_ = weight_.Size() <= 1; } 97 98 private: 99 StringWeight<L, S> weight_; 100 bool done_; 101}; 102 103 104// Factor a GallicWeight using StringFactor. 105template <class L, class W, StringType S = STRING_LEFT> 106class GallicFactor { 107 public: 108 GallicFactor(const GallicWeight<L, W, S> &w) 109 : weight_(w), done_(w.Value1().Size() <= 1) {} 110 111 bool Done() const { return done_; } 112 113 void Next() { done_ = true; } 114 115 pair< GallicWeight<L, W, S>, GallicWeight<L, W, S> > Value() const { 116 StringFactor<L, S> iter(weight_.Value1()); 117 GallicWeight<L, W, S> w1(iter.Value().first, weight_.Value2()); 118 GallicWeight<L, W, S> w2(iter.Value().second, W::One()); 119 return make_pair(w1, w2); 120 } 121 122 void Reset() { done_ = weight_.Value1().Size() <= 1; } 123 124 private: 125 GallicWeight<L, W, S> weight_; 126 bool done_; 127}; 128 129 130// Implementation class for FactorWeight 131template <class A, class F> 132class FactorWeightFstImpl 133 : public CacheImpl<A> { 134 public: 135 using FstImpl<A>::SetType; 136 using FstImpl<A>::SetProperties; 137 using FstImpl<A>::Properties; 138 using FstImpl<A>::SetInputSymbols; 139 using FstImpl<A>::SetOutputSymbols; 140 141 using CacheBaseImpl< CacheState<A> >::HasStart; 142 using CacheBaseImpl< CacheState<A> >::HasFinal; 143 using CacheBaseImpl< CacheState<A> >::HasArcs; 144 145 typedef A Arc; 146 typedef typename A::Label Label; 147 typedef typename A::Weight Weight; 148 typedef typename A::StateId StateId; 149 typedef F FactorIterator; 150 151 struct Element { 152 Element() {} 153 154 Element(StateId s, Weight w) : state(s), weight(w) {} 155 156 StateId state; // Input state Id 157 Weight weight; // Residual weight 158 }; 159 160 FactorWeightFstImpl(const Fst<A> &fst, const FactorWeightOptions &opts) 161 : CacheImpl<A>(opts), fst_(fst.Copy()), delta_(opts.delta), 162 final_only_(opts.final_only) { 163 SetType("factor-weight"); 164 uint64 props = fst.Properties(kFstProperties, false); 165 SetProperties(FactorWeightProperties(props), kCopyProperties); 166 167 SetInputSymbols(fst.InputSymbols()); 168 SetOutputSymbols(fst.OutputSymbols()); 169 } 170 171 ~FactorWeightFstImpl() { 172 delete fst_; 173 } 174 175 StateId Start() { 176 if (!HasStart()) { 177 StateId s = fst_->Start(); 178 if (s == kNoStateId) 179 return kNoStateId; 180 StateId start = FindState(Element(fst_->Start(), Weight::One())); 181 this->SetStart(start); 182 } 183 return CacheImpl<A>::Start(); 184 } 185 186 Weight Final(StateId s) { 187 if (!HasFinal(s)) { 188 const Element &e = elements_[s]; 189 // TODO: fix so cast is unnecessary 190 Weight w = e.state == kNoStateId 191 ? e.weight 192 : (Weight) Times(e.weight, fst_->Final(e.state)); 193 FactorIterator f(w); 194 if (w != Weight::Zero() && f.Done()) 195 this->SetFinal(s, w); 196 else 197 this->SetFinal(s, Weight::Zero()); 198 } 199 return CacheImpl<A>::Final(s); 200 } 201 202 size_t NumArcs(StateId s) { 203 if (!HasArcs(s)) 204 Expand(s); 205 return CacheImpl<A>::NumArcs(s); 206 } 207 208 size_t NumInputEpsilons(StateId s) { 209 if (!HasArcs(s)) 210 Expand(s); 211 return CacheImpl<A>::NumInputEpsilons(s); 212 } 213 214 size_t NumOutputEpsilons(StateId s) { 215 if (!HasArcs(s)) 216 Expand(s); 217 return CacheImpl<A>::NumOutputEpsilons(s); 218 } 219 220 void InitArcIterator(StateId s, ArcIteratorData<A> *data) { 221 if (!HasArcs(s)) 222 Expand(s); 223 CacheImpl<A>::InitArcIterator(s, data); 224 } 225 226 227 // Find state corresponding to an element. Create new state 228 // if element not found. 229 StateId FindState(const Element &e) { 230 if (final_only_ && e.weight == Weight::One()) { 231 while (unfactored_.size() <= (unsigned int)e.state) 232 unfactored_.push_back(kNoStateId); 233 if (unfactored_[e.state] == kNoStateId) { 234 unfactored_[e.state] = elements_.size(); 235 elements_.push_back(e); 236 } 237 return unfactored_[e.state]; 238 } else { 239 typename ElementMap::iterator eit = element_map_.find(e); 240 if (eit != element_map_.end()) { 241 return (*eit).second; 242 } else { 243 StateId s = elements_.size(); 244 elements_.push_back(e); 245 element_map_.insert(pair<const Element, StateId>(e, s)); 246 return s; 247 } 248 } 249 } 250 251 // Computes the outgoing transitions from a state, creating new destination 252 // states as needed. 253 void Expand(StateId s) { 254 Element e = elements_[s]; 255 if (e.state != kNoStateId) { 256 for (ArcIterator< Fst<A> > ait(*fst_, e.state); 257 !ait.Done(); 258 ait.Next()) { 259 const A &arc = ait.Value(); 260 Weight w = Times(e.weight, arc.weight); 261 FactorIterator fit(w); 262 if (final_only_ || fit.Done()) { 263 StateId d = FindState(Element(arc.nextstate, Weight::One())); 264 this->AddArc(s, Arc(arc.ilabel, arc.olabel, w, d)); 265 } else { 266 for (; !fit.Done(); fit.Next()) { 267 const pair<Weight, Weight> &p = fit.Value(); 268 StateId d = FindState(Element(arc.nextstate, 269 p.second.Quantize(delta_))); 270 this->AddArc(s, Arc(arc.ilabel, arc.olabel, p.first, d)); 271 } 272 } 273 } 274 } 275 if ((e.state == kNoStateId) || 276 (fst_->Final(e.state) != Weight::Zero())) { 277 Weight w = e.state == kNoStateId 278 ? e.weight 279 : Times(e.weight, fst_->Final(e.state)); 280 for (FactorIterator fit(w); 281 !fit.Done(); 282 fit.Next()) { 283 const pair<Weight, Weight> &p = fit.Value(); 284 StateId d = FindState(Element(kNoStateId, 285 p.second.Quantize(delta_))); 286 this->AddArc(s, Arc(0, 0, p.first, d)); 287 } 288 } 289 this->SetArcs(s); 290 } 291 292 private: 293 // Equality function for Elements, assume weights have been quantized. 294 class ElementEqual { 295 public: 296 bool operator()(const Element &x, const Element &y) const { 297 return x.state == y.state && x.weight == y.weight; 298 } 299 }; 300 301 // Hash function for Elements to Fst states. 302 class ElementKey { 303 public: 304 size_t operator()(const Element &x) const { 305 return static_cast<size_t>(x.state * kPrime + x.weight.Hash()); 306 } 307 private: 308 static const int kPrime = 7853; 309 }; 310 311 typedef std::unordered_map<Element, StateId, ElementKey, ElementEqual> ElementMap; 312 313 const Fst<A> *fst_; 314 float delta_; 315 bool final_only_; 316 vector<Element> elements_; // mapping Fst state to Elements 317 ElementMap element_map_; // mapping Elements to Fst state 318 // mapping between old/new 'StateId' for states that do not need to 319 // be factored when 'final_only_' is true 320 vector<StateId> unfactored_; 321 322 DISALLOW_EVIL_CONSTRUCTORS(FactorWeightFstImpl); 323}; 324 325 326// FactorWeightFst takes as template parameter a FactorIterator as 327// defined above. The result of weight factoring is a transducer 328// equivalent to the input whose path weights have been factored 329// according to the FactorIterator. States and transitions will be 330// added as necessary. The algorithm is a generalization to arbitrary 331// weights of the second step of the input epsilon-normalization 332// algorithm due to Mohri, "Generic epsilon-removal and input 333// epsilon-normalization algorithms for weighted transducers", 334// International Journal of Computer Science 13(1): 129-143 (2002). 335template <class A, class F> 336class FactorWeightFst : public Fst<A> { 337 public: 338 friend class ArcIterator< FactorWeightFst<A, F> >; 339 friend class CacheStateIterator< FactorWeightFst<A, F> >; 340 friend class CacheArcIterator< FactorWeightFst<A, F> >; 341 342 typedef A Arc; 343 typedef typename A::Weight Weight; 344 typedef typename A::StateId StateId; 345 typedef CacheState<A> State; 346 347 FactorWeightFst(const Fst<A> &fst) 348 : impl_(new FactorWeightFstImpl<A, F>(fst, FactorWeightOptions())) {} 349 350 FactorWeightFst(const Fst<A> &fst, const FactorWeightOptions &opts) 351 : impl_(new FactorWeightFstImpl<A, F>(fst, opts)) {} 352 FactorWeightFst(const FactorWeightFst<A, F> &fst) : Fst<A>(fst), impl_(fst.impl_) { 353 impl_->IncrRefCount(); 354 } 355 356 virtual ~FactorWeightFst() { if (!impl_->DecrRefCount()) delete impl_; } 357 358 virtual StateId Start() const { return impl_->Start(); } 359 360 virtual Weight Final(StateId s) const { return impl_->Final(s); } 361 362 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); } 363 364 virtual size_t NumInputEpsilons(StateId s) const { 365 return impl_->NumInputEpsilons(s); 366 } 367 368 virtual size_t NumOutputEpsilons(StateId s) const { 369 return impl_->NumOutputEpsilons(s); 370 } 371 372 virtual uint64 Properties(uint64 mask, bool test) const { 373 if (test) { 374 uint64 known, test = TestProperties(*this, mask, &known); 375 impl_->SetProperties(test, known); 376 return test & mask; 377 } else { 378 return impl_->Properties(mask); 379 } 380 } 381 382 virtual const string& Type() const { return impl_->Type(); } 383 384 virtual FactorWeightFst<A, F> *Copy() const { 385 return new FactorWeightFst<A, F>(*this); 386 } 387 388 virtual const SymbolTable* InputSymbols() const { 389 return impl_->InputSymbols(); 390 } 391 392 virtual const SymbolTable* OutputSymbols() const { 393 return impl_->OutputSymbols(); 394 } 395 396 virtual inline void InitStateIterator(StateIteratorData<A> *data) const; 397 398 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { 399 impl_->InitArcIterator(s, data); 400 } 401 402 private: 403 FactorWeightFstImpl<A, F> *Impl() { return impl_; } 404 405 FactorWeightFstImpl<A, F> *impl_; 406 407 void operator=(const FactorWeightFst<A, F> &fst); // Disallow 408}; 409 410 411// Specialization for FactorWeightFst. 412template<class A, class F> 413class StateIterator< FactorWeightFst<A, F> > 414 : public CacheStateIterator< FactorWeightFst<A, F> > { 415 public: 416 explicit StateIterator(const FactorWeightFst<A, F> &fst) 417 : CacheStateIterator< FactorWeightFst<A, F> >(fst) {} 418}; 419 420 421// Specialization for FactorWeightFst. 422template <class A, class F> 423class ArcIterator< FactorWeightFst<A, F> > 424 : public CacheArcIterator< FactorWeightFst<A, F> > { 425 public: 426 typedef typename A::StateId StateId; 427 428 ArcIterator(const FactorWeightFst<A, F> &fst, StateId s) 429 : CacheArcIterator< FactorWeightFst<A, F> >(fst, s) { 430 if (!fst.impl_->HasArcs(s)) 431 fst.impl_->Expand(s); 432 } 433 434 private: 435 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator); 436}; 437 438template <class A, class F> inline 439void FactorWeightFst<A, F>::InitStateIterator(StateIteratorData<A> *data) const 440{ 441 data->base = new StateIterator< FactorWeightFst<A, F> >(*this); 442} 443 444 445} // namespace fst 446 447#endif // FST_LIB_FACTOR_WEIGHT_H__ 448