relabel.h revision 3da1eb108d36da35333b2d655202791af854996b
1// relabel.h 2 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// Copyright 2005-2010 Google, Inc. 16// Author: johans@google.com (Johan Schalkwyk) 17// 18// \file 19// Functions and classes to relabel an Fst (either on input or output) 20// 21#ifndef FST_LIB_RELABEL_H__ 22#define FST_LIB_RELABEL_H__ 23 24#include <tr1/unordered_map> 25using std::tr1::unordered_map; 26using std::tr1::unordered_multimap; 27#include <string> 28#include <utility> 29using std::pair; using std::make_pair; 30#include <vector> 31using std::vector; 32 33#include <fst/cache.h> 34#include <fst/test-properties.h> 35 36 37namespace fst { 38 39// 40// Relabels either the input labels or output labels. The old to 41// new labels are specified using a vector of pair<Label,Label>. 42// Any label associations not specified are assumed to be identity 43// mapping. 44// 45// \param fst input fst, must be mutable 46// \param ipairs vector of input label pairs indicating old to new mapping 47// \param opairs vector of output label pairs indicating old to new mapping 48// 49template <class A> 50void Relabel( 51 MutableFst<A> *fst, 52 const vector<pair<typename A::Label, typename A::Label> >& ipairs, 53 const vector<pair<typename A::Label, typename A::Label> >& opairs) { 54 typedef typename A::StateId StateId; 55 typedef typename A::Label Label; 56 57 uint64 props = fst->Properties(kFstProperties, false); 58 59 // construct label to label hash. 60 unordered_map<Label, Label> input_map; 61 for (size_t i = 0; i < ipairs.size(); ++i) { 62 input_map[ipairs[i].first] = ipairs[i].second; 63 } 64 65 unordered_map<Label, Label> output_map; 66 for (size_t i = 0; i < opairs.size(); ++i) { 67 output_map[opairs[i].first] = opairs[i].second; 68 } 69 70 for (StateIterator<MutableFst<A> > siter(*fst); 71 !siter.Done(); siter.Next()) { 72 StateId s = siter.Value(); 73 for (MutableArcIterator<MutableFst<A> > aiter(fst, s); 74 !aiter.Done(); aiter.Next()) { 75 A arc = aiter.Value(); 76 77 // relabel input 78 // only relabel if relabel pair defined 79 typename unordered_map<Label, Label>::iterator it = 80 input_map.find(arc.ilabel); 81 if (it != input_map.end()) { 82 if (it->second == kNoLabel) { 83 FSTERROR() << "Input symbol id " << arc.ilabel 84 << " missing from target vocabulary"; 85 fst->SetProperties(kError, kError); 86 return; 87 } 88 arc.ilabel = it->second; 89 } 90 91 // relabel output 92 it = output_map.find(arc.olabel); 93 if (it != output_map.end()) { 94 if (it->second == kNoLabel) { 95 FSTERROR() << "Output symbol id " << arc.olabel 96 << " missing from target vocabulary"; 97 fst->SetProperties(kError, kError); 98 return; 99 } 100 arc.olabel = it->second; 101 } 102 103 aiter.SetValue(arc); 104 } 105 } 106 107 fst->SetProperties(RelabelProperties(props), kFstProperties); 108} 109 110// 111// Relabels either the input labels or output labels. The old to 112// new labels mappings are specified using an input Symbol set. 113// Any label associations not specified are assumed to be identity 114// mapping. 115// 116// \param fst input fst, must be mutable 117// \param new_isymbols symbol set indicating new mapping of input symbols 118// \param new_osymbols symbol set indicating new mapping of output symbols 119// 120template<class A> 121void Relabel(MutableFst<A> *fst, 122 const SymbolTable* new_isymbols, 123 const SymbolTable* new_osymbols) { 124 Relabel(fst, 125 fst->InputSymbols(), new_isymbols, true, 126 fst->OutputSymbols(), new_osymbols, true); 127} 128 129template<class A> 130void Relabel(MutableFst<A> *fst, 131 const SymbolTable* old_isymbols, 132 const SymbolTable* new_isymbols, 133 bool attach_new_isymbols, 134 const SymbolTable* old_osymbols, 135 const SymbolTable* new_osymbols, 136 bool attach_new_osymbols) { 137 typedef typename A::StateId StateId; 138 typedef typename A::Label Label; 139 140 vector<pair<Label, Label> > ipairs; 141 if (old_isymbols && new_isymbols) { 142 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); 143 syms_iter.Next()) { 144 string isymbol = syms_iter.Symbol(); 145 int isymbol_val = syms_iter.Value(); 146 int new_isymbol_val = new_isymbols->Find(isymbol); 147 ipairs.push_back(make_pair(isymbol_val, new_isymbol_val)); 148 } 149 if (attach_new_isymbols) 150 fst->SetInputSymbols(new_isymbols); 151 } 152 153 vector<pair<Label, Label> > opairs; 154 if (old_osymbols && new_osymbols) { 155 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); 156 syms_iter.Next()) { 157 string osymbol = syms_iter.Symbol(); 158 int osymbol_val = syms_iter.Value(); 159 int new_osymbol_val = new_osymbols->Find(osymbol); 160 opairs.push_back(make_pair(osymbol_val, new_osymbol_val)); 161 } 162 if (attach_new_osymbols) 163 fst->SetOutputSymbols(new_osymbols); 164 } 165 166 // call relabel using vector of relabel pairs. 167 Relabel(fst, ipairs, opairs); 168} 169 170 171typedef CacheOptions RelabelFstOptions; 172 173template <class A> class RelabelFst; 174 175// 176// \class RelabelFstImpl 177// \brief Implementation for delayed relabeling 178// 179// Relabels an FST from one symbol set to another. Relabeling 180// can either be on input or output space. RelabelFst implements 181// a delayed version of the relabel. Arcs are relabeled on the fly 182// and not cached. I.e each request is recomputed. 183// 184template<class A> 185class RelabelFstImpl : public CacheImpl<A> { 186 friend class StateIterator< RelabelFst<A> >; 187 public: 188 using FstImpl<A>::SetType; 189 using FstImpl<A>::SetProperties; 190 using FstImpl<A>::WriteHeader; 191 using FstImpl<A>::SetInputSymbols; 192 using FstImpl<A>::SetOutputSymbols; 193 194 using CacheImpl<A>::PushArc; 195 using CacheImpl<A>::HasArcs; 196 using CacheImpl<A>::HasFinal; 197 using CacheImpl<A>::HasStart; 198 using CacheImpl<A>::SetArcs; 199 using CacheImpl<A>::SetFinal; 200 using CacheImpl<A>::SetStart; 201 202 typedef A Arc; 203 typedef typename A::Label Label; 204 typedef typename A::Weight Weight; 205 typedef typename A::StateId StateId; 206 typedef CacheState<A> State; 207 208 RelabelFstImpl(const Fst<A>& fst, 209 const vector<pair<Label, Label> >& ipairs, 210 const vector<pair<Label, Label> >& opairs, 211 const RelabelFstOptions &opts) 212 : CacheImpl<A>(opts), fst_(fst.Copy()), 213 relabel_input_(false), relabel_output_(false) { 214 uint64 props = fst.Properties(kCopyProperties, false); 215 SetProperties(RelabelProperties(props)); 216 SetType("relabel"); 217 218 // create input label map 219 if (ipairs.size() > 0) { 220 for (size_t i = 0; i < ipairs.size(); ++i) { 221 input_map_[ipairs[i].first] = ipairs[i].second; 222 } 223 relabel_input_ = true; 224 } 225 226 // create output label map 227 if (opairs.size() > 0) { 228 for (size_t i = 0; i < opairs.size(); ++i) { 229 output_map_[opairs[i].first] = opairs[i].second; 230 } 231 relabel_output_ = true; 232 } 233 } 234 235 RelabelFstImpl(const Fst<A>& fst, 236 const SymbolTable* old_isymbols, 237 const SymbolTable* new_isymbols, 238 const SymbolTable* old_osymbols, 239 const SymbolTable* new_osymbols, 240 const RelabelFstOptions &opts) 241 : CacheImpl<A>(opts), fst_(fst.Copy()), 242 relabel_input_(false), relabel_output_(false) { 243 SetType("relabel"); 244 245 uint64 props = fst.Properties(kCopyProperties, false); 246 SetProperties(RelabelProperties(props)); 247 SetInputSymbols(old_isymbols); 248 SetOutputSymbols(old_osymbols); 249 250 if (old_isymbols && new_isymbols && 251 old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) { 252 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done(); 253 syms_iter.Next()) { 254 input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol()); 255 } 256 SetInputSymbols(new_isymbols); 257 relabel_input_ = true; 258 } 259 260 if (old_osymbols && new_osymbols && 261 old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) { 262 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done(); 263 syms_iter.Next()) { 264 output_map_[syms_iter.Value()] = 265 new_osymbols->Find(syms_iter.Symbol()); 266 } 267 SetOutputSymbols(new_osymbols); 268 relabel_output_ = true; 269 } 270 } 271 272 RelabelFstImpl(const RelabelFstImpl<A>& impl) 273 : CacheImpl<A>(impl), 274 fst_(impl.fst_->Copy(true)), 275 input_map_(impl.input_map_), 276 output_map_(impl.output_map_), 277 relabel_input_(impl.relabel_input_), 278 relabel_output_(impl.relabel_output_) { 279 SetType("relabel"); 280 SetProperties(impl.Properties(), kCopyProperties); 281 SetInputSymbols(impl.InputSymbols()); 282 SetOutputSymbols(impl.OutputSymbols()); 283 } 284 285 ~RelabelFstImpl() { delete fst_; } 286 287 StateId Start() { 288 if (!HasStart()) { 289 StateId s = fst_->Start(); 290 SetStart(s); 291 } 292 return CacheImpl<A>::Start(); 293 } 294 295 Weight Final(StateId s) { 296 if (!HasFinal(s)) { 297 SetFinal(s, fst_->Final(s)); 298 } 299 return CacheImpl<A>::Final(s); 300 } 301 302 size_t NumArcs(StateId s) { 303 if (!HasArcs(s)) { 304 Expand(s); 305 } 306 return CacheImpl<A>::NumArcs(s); 307 } 308 309 size_t NumInputEpsilons(StateId s) { 310 if (!HasArcs(s)) { 311 Expand(s); 312 } 313 return CacheImpl<A>::NumInputEpsilons(s); 314 } 315 316 size_t NumOutputEpsilons(StateId s) { 317 if (!HasArcs(s)) { 318 Expand(s); 319 } 320 return CacheImpl<A>::NumOutputEpsilons(s); 321 } 322 323 uint64 Properties() const { return Properties(kFstProperties); } 324 325 // Set error if found; return FST impl properties. 326 uint64 Properties(uint64 mask) const { 327 if ((mask & kError) && fst_->Properties(kError, false)) 328 SetProperties(kError, kError); 329 return FstImpl<Arc>::Properties(mask); 330 } 331 332 void InitArcIterator(StateId s, ArcIteratorData<A>* data) { 333 if (!HasArcs(s)) { 334 Expand(s); 335 } 336 CacheImpl<A>::InitArcIterator(s, data); 337 } 338 339 void Expand(StateId s) { 340 for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) { 341 A arc = aiter.Value(); 342 343 // relabel input 344 if (relabel_input_) { 345 typename unordered_map<Label, Label>::iterator it = 346 input_map_.find(arc.ilabel); 347 if (it != input_map_.end()) { arc.ilabel = it->second; } 348 } 349 350 // relabel output 351 if (relabel_output_) { 352 typename unordered_map<Label, Label>::iterator it = 353 output_map_.find(arc.olabel); 354 if (it != output_map_.end()) { arc.olabel = it->second; } 355 } 356 357 PushArc(s, arc); 358 } 359 SetArcs(s); 360 } 361 362 363 private: 364 const Fst<A> *fst_; 365 366 unordered_map<Label, Label> input_map_; 367 unordered_map<Label, Label> output_map_; 368 bool relabel_input_; 369 bool relabel_output_; 370 371 void operator=(const RelabelFstImpl<A> &); // disallow 372}; 373 374 375// 376// \class RelabelFst 377// \brief Delayed implementation of arc relabeling 378// 379// This class attaches interface to implementation and handles 380// reference counting, delegating most methods to ImplToFst. 381template <class A> 382class RelabelFst : public ImplToFst< RelabelFstImpl<A> > { 383 public: 384 friend class ArcIterator< RelabelFst<A> >; 385 friend class StateIterator< RelabelFst<A> >; 386 387 typedef A Arc; 388 typedef typename A::Label Label; 389 typedef typename A::Weight Weight; 390 typedef typename A::StateId StateId; 391 typedef CacheState<A> State; 392 typedef RelabelFstImpl<A> Impl; 393 394 RelabelFst(const Fst<A>& fst, 395 const vector<pair<Label, Label> >& ipairs, 396 const vector<pair<Label, Label> >& opairs) 397 : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {} 398 399 RelabelFst(const Fst<A>& fst, 400 const vector<pair<Label, Label> >& ipairs, 401 const vector<pair<Label, Label> >& opairs, 402 const RelabelFstOptions &opts) 403 : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {} 404 405 RelabelFst(const Fst<A>& fst, 406 const SymbolTable* new_isymbols, 407 const SymbolTable* new_osymbols) 408 : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols, 409 fst.OutputSymbols(), new_osymbols, 410 RelabelFstOptions())) {} 411 412 RelabelFst(const Fst<A>& fst, 413 const SymbolTable* new_isymbols, 414 const SymbolTable* new_osymbols, 415 const RelabelFstOptions &opts) 416 : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols, 417 fst.OutputSymbols(), new_osymbols, opts)) {} 418 419 RelabelFst(const Fst<A>& fst, 420 const SymbolTable* old_isymbols, 421 const SymbolTable* new_isymbols, 422 const SymbolTable* old_osymbols, 423 const SymbolTable* new_osymbols) 424 : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols, 425 new_osymbols, RelabelFstOptions())) {} 426 427 RelabelFst(const Fst<A>& fst, 428 const SymbolTable* old_isymbols, 429 const SymbolTable* new_isymbols, 430 const SymbolTable* old_osymbols, 431 const SymbolTable* new_osymbols, 432 const RelabelFstOptions &opts) 433 : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols, 434 new_osymbols, opts)) {} 435 436 // See Fst<>::Copy() for doc. 437 RelabelFst(const RelabelFst<A> &fst, bool safe = false) 438 : ImplToFst<Impl>(fst, safe) {} 439 440 // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc. 441 virtual RelabelFst<A> *Copy(bool safe = false) const { 442 return new RelabelFst<A>(*this, safe); 443 } 444 445 virtual void InitStateIterator(StateIteratorData<A> *data) const; 446 447 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { 448 return GetImpl()->InitArcIterator(s, data); 449 } 450 451 private: 452 // Makes visible to friends. 453 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } 454 455 void operator=(const RelabelFst<A> &fst); // disallow 456}; 457 458// Specialization for RelabelFst. 459template<class A> 460class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> { 461 public: 462 typedef typename A::StateId StateId; 463 464 explicit StateIterator(const RelabelFst<A> &fst) 465 : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {} 466 467 bool Done() const { return siter_.Done(); } 468 469 StateId Value() const { return s_; } 470 471 void Next() { 472 if (!siter_.Done()) { 473 ++s_; 474 siter_.Next(); 475 } 476 } 477 478 void Reset() { 479 s_ = 0; 480 siter_.Reset(); 481 } 482 483 private: 484 bool Done_() const { return Done(); } 485 StateId Value_() const { return Value(); } 486 void Next_() { Next(); } 487 void Reset_() { Reset(); } 488 489 const RelabelFstImpl<A> *impl_; 490 StateIterator< Fst<A> > siter_; 491 StateId s_; 492 493 DISALLOW_COPY_AND_ASSIGN(StateIterator); 494}; 495 496 497// Specialization for RelabelFst. 498template <class A> 499class ArcIterator< RelabelFst<A> > 500 : public CacheArcIterator< RelabelFst<A> > { 501 public: 502 typedef typename A::StateId StateId; 503 504 ArcIterator(const RelabelFst<A> &fst, StateId s) 505 : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) { 506 if (!fst.GetImpl()->HasArcs(s)) 507 fst.GetImpl()->Expand(s); 508 } 509 510 private: 511 DISALLOW_COPY_AND_ASSIGN(ArcIterator); 512}; 513 514template <class A> inline 515void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const { 516 data->base = new StateIterator< RelabelFst<A> >(*this); 517} 518 519// Useful alias when using StdArc. 520typedef RelabelFst<StdArc> StdRelabelFst; 521 522} // namespace fst 523 524#endif // FST_LIB_RELABEL_H__ 525