1// relabel.h 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// 16// \file 17// Functions and classes to relabel an Fst (either on input or output) 18// 19#ifndef FST_LIB_RELABEL_H__ 20#define FST_LIB_RELABEL_H__ 21 22#include <ext/hash_map> 23using __gnu_cxx::hash_map; 24 25#include "fst/lib/cache.h" 26#include "fst/lib/test-properties.h" 27 28 29namespace fst { 30 31// 32// Relabels either the input labels or output labels. The old to 33// new labels are specified using a vector of pair<Label,Label>. 34// Any label associations not specified are assumed to be identity 35// mapping. 36// 37// \param fst input fst, must be mutable 38// \param relabel_pairs vector of pairs indicating old to new mapping 39// \param relabel_flags whether to relabel input or output 40// 41template <class A> 42void Relabel( 43 MutableFst<A> *fst, 44 const vector<pair<typename A::Label, typename A::Label> >& ipairs, 45 const vector<pair<typename A::Label, typename A::Label> >& opairs) { 46 typedef typename A::StateId StateId; 47 typedef typename A::Label Label; 48 49 uint64 props = fst->Properties(kFstProperties, false); 50 51 // construct label to label hash. Could 52 hash_map<Label, Label> input_map; 53 for (size_t i = 0; i < ipairs.size(); ++i) { 54 input_map[ipairs[i].first] = ipairs[i].second; 55 } 56 57 hash_map<Label, Label> output_map; 58 for (size_t i = 0; i < opairs.size(); ++i) { 59 output_map[opairs[i].first] = opairs[i].second; 60 } 61 62 for (StateIterator<MutableFst<A> > siter(*fst); 63 !siter.Done(); siter.Next()) { 64 StateId s = siter.Value(); 65 for (MutableArcIterator<MutableFst<A> > aiter(fst, s); 66 !aiter.Done(); aiter.Next()) { 67 A arc = aiter.Value(); 68 69 // relabel input 70 // only relabel if relabel pair defined 71 typename hash_map<Label, Label>::iterator it = 72 input_map.find(arc.ilabel); 73 if (it != input_map.end()) {arc.ilabel = it->second; } 74 75 // relabel output 76 it = output_map.find(arc.olabel); 77 if (it != output_map.end()) { arc.olabel = it->second; } 78 79 aiter.SetValue(arc); 80 } 81 } 82 83 fst->SetProperties(RelabelProperties(props), kFstProperties); 84} 85 86 87 88// 89// Relabels either the input labels or output labels. The old to 90// new labels mappings are specified using an input Symbol set. 91// Any label associations not specified are assumed to be identity 92// mapping. 93// 94// \param fst input fst, must be mutable 95// \param new_symbols symbol set indicating new mapping 96// \param relabel_flags whether to relabel input or output 97// 98template<class A> 99void Relabel(MutableFst<A> *fst, 100 const SymbolTable* new_isymbols, 101 const SymbolTable* new_osymbols) { 102 typedef typename A::StateId StateId; 103 typedef typename A::Label Label; 104 105 const SymbolTable* old_isymbols = fst->InputSymbols(); 106 const SymbolTable* old_osymbols = fst->OutputSymbols(); 107 108 vector<pair<Label, Label> > ipairs; 109 if (old_isymbols && new_isymbols) { 110 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); 111 syms_iter.Next()) { 112 ipairs.push_back(make_pair(syms_iter.Value(), 113 new_isymbols->Find(syms_iter.Symbol()))); 114 } 115 fst->SetInputSymbols(new_isymbols); 116 } 117 118 vector<pair<Label, Label> > opairs; 119 if (old_osymbols && new_osymbols) { 120 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); 121 syms_iter.Next()) { 122 opairs.push_back(make_pair(syms_iter.Value(), 123 new_osymbols->Find(syms_iter.Symbol()))); 124 } 125 fst->SetOutputSymbols(new_osymbols); 126 } 127 128 // call relabel using vector of relabel pairs. 129 Relabel(fst, ipairs, opairs); 130} 131 132 133typedef CacheOptions RelabelFstOptions; 134 135template <class A> class RelabelFst; 136 137// 138// \class RelabelFstImpl 139// \brief Implementation for delayed relabeling 140// 141// Relabels an FST from one symbol set to another. Relabeling 142// can either be on input or output space. RelabelFst implements 143// a delayed version of the relabel. Arcs are relabeled on the fly 144// and not cached. I.e each request is recomputed. 145// 146template<class A> 147class RelabelFstImpl : public CacheImpl<A> { 148 friend class StateIterator< RelabelFst<A> >; 149 public: 150 using FstImpl<A>::SetType; 151 using FstImpl<A>::SetProperties; 152 using FstImpl<A>::Properties; 153 using FstImpl<A>::SetInputSymbols; 154 using FstImpl<A>::SetOutputSymbols; 155 156 using CacheImpl<A>::HasStart; 157 using CacheImpl<A>::HasArcs; 158 159 typedef typename A::Label Label; 160 typedef typename A::Weight Weight; 161 typedef typename A::StateId StateId; 162 typedef CacheState<A> State; 163 164 RelabelFstImpl(const Fst<A>& fst, 165 const vector<pair<Label, Label> >& ipairs, 166 const vector<pair<Label, Label> >& opairs, 167 const RelabelFstOptions &opts) 168 : CacheImpl<A>(opts), fst_(fst.Copy()), 169 relabel_input_(false), relabel_output_(false) { 170 uint64 props = fst.Properties(kCopyProperties, false); 171 SetProperties(RelabelProperties(props)); 172 SetType("relabel"); 173 174 // create input label map 175 if (ipairs.size() > 0) { 176 for (size_t i = 0; i < ipairs.size(); ++i) { 177 input_map_[ipairs[i].first] = ipairs[i].second; 178 } 179 relabel_input_ = true; 180 } 181 182 // create output label map 183 if (opairs.size() > 0) { 184 for (size_t i = 0; i < opairs.size(); ++i) { 185 output_map_[opairs[i].first] = opairs[i].second; 186 } 187 relabel_output_ = true; 188 } 189 } 190 191 RelabelFstImpl(const Fst<A>& fst, 192 const SymbolTable* new_isymbols, 193 const SymbolTable* new_osymbols, 194 const RelabelFstOptions &opts) 195 : CacheImpl<A>(opts), fst_(fst.Copy()), 196 relabel_input_(false), relabel_output_(false) { 197 SetType("relabel"); 198 199 uint64 props = fst.Properties(kCopyProperties, false); 200 SetProperties(RelabelProperties(props)); 201 SetInputSymbols(fst.InputSymbols()); 202 SetOutputSymbols(fst.OutputSymbols()); 203 204 const SymbolTable* old_isymbols = fst.InputSymbols(); 205 const SymbolTable* old_osymbols = fst.OutputSymbols(); 206 207 if (old_isymbols && new_isymbols && 208 old_isymbols->CheckSum() != new_isymbols->CheckSum()) { 209 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); 210 syms_iter.Next()) { 211 input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol()); 212 } 213 SetInputSymbols(new_isymbols); 214 relabel_input_ = true; 215 } 216 217 if (old_osymbols && new_osymbols && 218 old_osymbols->CheckSum() != new_osymbols->CheckSum()) { 219 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); 220 syms_iter.Next()) { 221 output_map_[syms_iter.Value()] = 222 new_osymbols->Find(syms_iter.Symbol()); 223 } 224 SetOutputSymbols(new_osymbols); 225 relabel_output_ = true; 226 } 227 } 228 229 ~RelabelFstImpl() { delete fst_; } 230 231 StateId Start() { 232 if (!HasStart()) { 233 StateId s = fst_->Start(); 234 SetStart(s); 235 } 236 return CacheImpl<A>::Start(); 237 } 238 239 Weight Final(StateId s) { 240 if (!HasFinal(s)) { 241 SetFinal(s, fst_->Final(s)); 242 } 243 return CacheImpl<A>::Final(s); 244 } 245 246 size_t NumArcs(StateId s) { 247 if (!HasArcs(s)) { 248 Expand(s); 249 } 250 return CacheImpl<A>::NumArcs(s); 251 } 252 253 size_t NumInputEpsilons(StateId s) { 254 if (!HasArcs(s)) { 255 Expand(s); 256 } 257 return CacheImpl<A>::NumInputEpsilons(s); 258 } 259 260 size_t NumOutputEpsilons(StateId s) { 261 if (!HasArcs(s)) { 262 Expand(s); 263 } 264 return CacheImpl<A>::NumOutputEpsilons(s); 265 } 266 267 void InitArcIterator(StateId s, ArcIteratorData<A>* data) { 268 if (!HasArcs(s)) { 269 Expand(s); 270 } 271 CacheImpl<A>::InitArcIterator(s, data); 272 } 273 274 void Expand(StateId s) { 275 for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) { 276 A arc = aiter.Value(); 277 278 // relabel input 279 if (relabel_input_) { 280 typename hash_map<Label, Label>::iterator it = 281 input_map_.find(arc.ilabel); 282 if (it != input_map_.end()) { arc.ilabel = it->second; } 283 } 284 285 // relabel output 286 if (relabel_output_) { 287 typename hash_map<Label, Label>::iterator it = 288 output_map_.find(arc.olabel); 289 if (it != output_map_.end()) { arc.olabel = it->second; } 290 } 291 292 AddArc(s, arc); 293 } 294 SetArcs(s); 295 } 296 297 298 private: 299 const Fst<A> *fst_; 300 301 hash_map<Label, Label> input_map_; 302 hash_map<Label, Label> output_map_; 303 bool relabel_input_; 304 bool relabel_output_; 305 306 DISALLOW_EVIL_CONSTRUCTORS(RelabelFstImpl); 307}; 308 309 310// 311// \class RelabelFst 312// \brief Delayed implementation of arc relabeling 313// 314// This class attaches interface to implementation and handles 315// reference counting. 316template <class A> 317class RelabelFst : public Fst<A> { 318 public: 319 friend class ArcIterator< RelabelFst<A> >; 320 friend class StateIterator< RelabelFst<A> >; 321 friend class CacheArcIterator< RelabelFst<A> >; 322 323 typedef A Arc; 324 typedef typename A::Label Label; 325 typedef typename A::Weight Weight; 326 typedef typename A::StateId StateId; 327 typedef CacheState<A> State; 328 329 RelabelFst(const Fst<A>& fst, 330 const vector<pair<Label, Label> >& ipairs, 331 const vector<pair<Label, Label> >& opairs) : 332 impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, RelabelFstOptions())) {} 333 334 RelabelFst(const Fst<A>& fst, 335 const vector<pair<Label, Label> >& ipairs, 336 const vector<pair<Label, Label> >& opairs, 337 const RelabelFstOptions &opts) 338 : impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, opts)) {} 339 340 RelabelFst(const Fst<A>& fst, 341 const SymbolTable* new_isymbols, 342 const SymbolTable* new_osymbols) : 343 impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols, 344 RelabelFstOptions())) {} 345 346 RelabelFst(const Fst<A>& fst, 347 const SymbolTable* new_isymbols, 348 const SymbolTable* new_osymbols, 349 const RelabelFstOptions &opts) 350 : impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols, opts)) {} 351 352 RelabelFst(const RelabelFst<A> &fst) : impl_(fst.impl_) { 353 impl_->IncrRefCount(); 354 } 355 356 virtual ~RelabelFst() { if (!impl_->DecrRefCount()) delete impl_; } 357 358 virtual StateId Start() const { return impl_->Start(); } 359 360 virtual Weight Final(StateId s) const { return impl_->Final(s); } 361 362 virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); } 363 364 virtual size_t NumInputEpsilons(StateId s) const { 365 return impl_->NumInputEpsilons(s); 366 } 367 368 virtual size_t NumOutputEpsilons(StateId s) const { 369 return impl_->NumOutputEpsilons(s); 370 } 371 372 virtual uint64 Properties(uint64 mask, bool test) const { 373 if (test) { 374 uint64 known, test = TestProperties(*this, mask, &known); 375 impl_->SetProperties(test, known); 376 return test & mask; 377 } else { 378 return impl_->Properties(mask); 379 } 380 } 381 382 virtual const string& Type() const { return impl_->Type(); } 383 384 virtual RelabelFst<A> *Copy() const { 385 return new RelabelFst<A>(*this); 386 } 387 388 virtual const SymbolTable* InputSymbols() const { 389 return impl_->InputSymbols(); 390 } 391 392 virtual const SymbolTable* OutputSymbols() const { 393 return impl_->OutputSymbols(); 394 } 395 396 virtual void InitStateIterator(StateIteratorData<A> *data) const; 397 398 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { 399 return impl_->InitArcIterator(s, data); 400 } 401 402 private: 403 RelabelFstImpl<A> *impl_; 404 405 void operator=(const RelabelFst<A> &fst); // disallow 406}; 407 408// Specialization for RelabelFst. 409template<class A> 410class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> { 411 public: 412 typedef typename A::StateId StateId; 413 414 explicit StateIterator(const RelabelFst<A> &fst) 415 : impl_(fst.impl_), siter_(*impl_->fst_), s_(0) {} 416 417 bool Done() const { return siter_.Done(); } 418 419 StateId Value() const { return s_; } 420 421 void Next() { 422 if (!siter_.Done()) { 423 ++s_; 424 siter_.Next(); 425 } 426 } 427 428 void Reset() { 429 s_ = 0; 430 siter_.Reset(); 431 } 432 433 private: 434 const RelabelFstImpl<A> *impl_; 435 StateIterator< Fst<A> > siter_; 436 StateId s_; 437 438 DISALLOW_EVIL_CONSTRUCTORS(StateIterator); 439}; 440 441 442// Specialization for RelabelFst. 443template <class A> 444class ArcIterator< RelabelFst<A> > 445 : public CacheArcIterator< RelabelFst<A> > { 446 public: 447 typedef typename A::StateId StateId; 448 449 ArcIterator(const RelabelFst<A> &fst, StateId s) 450 : CacheArcIterator< RelabelFst<A> >(fst, s) { 451 if (!fst.impl_->HasArcs(s)) 452 fst.impl_->Expand(s); 453 } 454 455 private: 456 DISALLOW_EVIL_CONSTRUCTORS(ArcIterator); 457}; 458 459template <class A> inline 460void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const { 461 data->base = new StateIterator< RelabelFst<A> >(*this); 462} 463 464// Useful alias when using StdArc. 465typedef RelabelFst<StdArc> StdRelabelFst; 466 467} // namespace fst 468 469#endif // FST_LIB_RELABEL_H__ 470