relabel.h revision 3da1eb108d36da35333b2d655202791af854996b
1// relabel.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: johans@google.com (Johan Schalkwyk)
17//
18// \file
19// Functions and classes to relabel an Fst (either on input or output)
20//
21#ifndef FST_LIB_RELABEL_H__
22#define FST_LIB_RELABEL_H__
23
24#include <tr1/unordered_map>
25using std::tr1::unordered_map;
26using std::tr1::unordered_multimap;
27#include <string>
28#include <utility>
29using std::pair; using std::make_pair;
30#include <vector>
31using std::vector;
32
33#include <fst/cache.h>
34#include <fst/test-properties.h>
35
36
37namespace fst {
38
39//
40// Relabels either the input labels or output labels. The old to
41// new labels are specified using a vector of pair<Label,Label>.
42// Any label associations not specified are assumed to be identity
43// mapping.
44//
45// \param fst input fst, must be mutable
46// \param ipairs vector of input label pairs indicating old to new mapping
47// \param opairs vector of output label pairs indicating old to new mapping
48//
49template <class A>
50void Relabel(
51    MutableFst<A> *fst,
52    const vector<pair<typename A::Label, typename A::Label> >& ipairs,
53    const vector<pair<typename A::Label, typename A::Label> >& opairs) {
54  typedef typename A::StateId StateId;
55  typedef typename A::Label   Label;
56
57  uint64 props = fst->Properties(kFstProperties, false);
58
59  // construct label to label hash.
60  unordered_map<Label, Label> input_map;
61  for (size_t i = 0; i < ipairs.size(); ++i) {
62    input_map[ipairs[i].first] = ipairs[i].second;
63  }
64
65  unordered_map<Label, Label> output_map;
66  for (size_t i = 0; i < opairs.size(); ++i) {
67    output_map[opairs[i].first] = opairs[i].second;
68  }
69
70  for (StateIterator<MutableFst<A> > siter(*fst);
71       !siter.Done(); siter.Next()) {
72    StateId s = siter.Value();
73    for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
74         !aiter.Done(); aiter.Next()) {
75      A arc = aiter.Value();
76
77      // relabel input
78      // only relabel if relabel pair defined
79      typename unordered_map<Label, Label>::iterator it =
80        input_map.find(arc.ilabel);
81      if (it != input_map.end()) {
82        if (it->second == kNoLabel) {
83          FSTERROR() << "Input symbol id " << arc.ilabel
84                     << " missing from target vocabulary";
85          fst->SetProperties(kError, kError);
86          return;
87        }
88        arc.ilabel = it->second;
89      }
90
91      // relabel output
92      it = output_map.find(arc.olabel);
93      if (it != output_map.end()) {
94        if (it->second == kNoLabel) {
95          FSTERROR() << "Output symbol id " << arc.olabel
96                     << " missing from target vocabulary";
97          fst->SetProperties(kError, kError);
98          return;
99        }
100        arc.olabel = it->second;
101      }
102
103      aiter.SetValue(arc);
104    }
105  }
106
107  fst->SetProperties(RelabelProperties(props), kFstProperties);
108}
109
110//
111// Relabels either the input labels or output labels. The old to
112// new labels mappings are specified using an input Symbol set.
113// Any label associations not specified are assumed to be identity
114// mapping.
115//
116// \param fst input fst, must be mutable
117// \param new_isymbols symbol set indicating new mapping of input symbols
118// \param new_osymbols symbol set indicating new mapping of output symbols
119//
120template<class A>
121void Relabel(MutableFst<A> *fst,
122             const SymbolTable* new_isymbols,
123             const SymbolTable* new_osymbols) {
124  Relabel(fst,
125          fst->InputSymbols(), new_isymbols, true,
126          fst->OutputSymbols(), new_osymbols, true);
127}
128
129template<class A>
130void Relabel(MutableFst<A> *fst,
131             const SymbolTable* old_isymbols,
132             const SymbolTable* new_isymbols,
133             bool attach_new_isymbols,
134             const SymbolTable* old_osymbols,
135             const SymbolTable* new_osymbols,
136             bool attach_new_osymbols) {
137  typedef typename A::StateId StateId;
138  typedef typename A::Label   Label;
139
140  vector<pair<Label, Label> > ipairs;
141  if (old_isymbols && new_isymbols) {
142    for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
143         syms_iter.Next()) {
144      string isymbol = syms_iter.Symbol();
145      int isymbol_val = syms_iter.Value();
146      int new_isymbol_val = new_isymbols->Find(isymbol);
147      ipairs.push_back(make_pair(isymbol_val, new_isymbol_val));
148    }
149    if (attach_new_isymbols)
150      fst->SetInputSymbols(new_isymbols);
151  }
152
153  vector<pair<Label, Label> > opairs;
154  if (old_osymbols && new_osymbols) {
155    for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
156         syms_iter.Next()) {
157      string osymbol = syms_iter.Symbol();
158      int osymbol_val = syms_iter.Value();
159      int new_osymbol_val = new_osymbols->Find(osymbol);
160      opairs.push_back(make_pair(osymbol_val, new_osymbol_val));
161    }
162    if (attach_new_osymbols)
163      fst->SetOutputSymbols(new_osymbols);
164  }
165
166  // call relabel using vector of relabel pairs.
167  Relabel(fst, ipairs, opairs);
168}
169
170
171typedef CacheOptions RelabelFstOptions;
172
173template <class A> class RelabelFst;
174
175//
176// \class RelabelFstImpl
177// \brief Implementation for delayed relabeling
178//
179// Relabels an FST from one symbol set to another. Relabeling
180// can either be on input or output space. RelabelFst implements
181// a delayed version of the relabel. Arcs are relabeled on the fly
182// and not cached. I.e each request is recomputed.
183//
184template<class A>
185class RelabelFstImpl : public CacheImpl<A> {
186  friend class StateIterator< RelabelFst<A> >;
187 public:
188  using FstImpl<A>::SetType;
189  using FstImpl<A>::SetProperties;
190  using FstImpl<A>::WriteHeader;
191  using FstImpl<A>::SetInputSymbols;
192  using FstImpl<A>::SetOutputSymbols;
193
194  using CacheImpl<A>::PushArc;
195  using CacheImpl<A>::HasArcs;
196  using CacheImpl<A>::HasFinal;
197  using CacheImpl<A>::HasStart;
198  using CacheImpl<A>::SetArcs;
199  using CacheImpl<A>::SetFinal;
200  using CacheImpl<A>::SetStart;
201
202  typedef A Arc;
203  typedef typename A::Label   Label;
204  typedef typename A::Weight  Weight;
205  typedef typename A::StateId StateId;
206  typedef CacheState<A> State;
207
208  RelabelFstImpl(const Fst<A>& fst,
209                 const vector<pair<Label, Label> >& ipairs,
210                 const vector<pair<Label, Label> >& opairs,
211                 const RelabelFstOptions &opts)
212      : CacheImpl<A>(opts), fst_(fst.Copy()),
213        relabel_input_(false), relabel_output_(false) {
214    uint64 props = fst.Properties(kCopyProperties, false);
215    SetProperties(RelabelProperties(props));
216    SetType("relabel");
217
218    // create input label map
219    if (ipairs.size() > 0) {
220      for (size_t i = 0; i < ipairs.size(); ++i) {
221        input_map_[ipairs[i].first] = ipairs[i].second;
222      }
223      relabel_input_ = true;
224    }
225
226    // create output label map
227    if (opairs.size() > 0) {
228      for (size_t i = 0; i < opairs.size(); ++i) {
229        output_map_[opairs[i].first] = opairs[i].second;
230      }
231      relabel_output_ = true;
232    }
233  }
234
235  RelabelFstImpl(const Fst<A>& fst,
236                 const SymbolTable* old_isymbols,
237                 const SymbolTable* new_isymbols,
238                 const SymbolTable* old_osymbols,
239                 const SymbolTable* new_osymbols,
240                 const RelabelFstOptions &opts)
241      : CacheImpl<A>(opts), fst_(fst.Copy()),
242        relabel_input_(false), relabel_output_(false) {
243    SetType("relabel");
244
245    uint64 props = fst.Properties(kCopyProperties, false);
246    SetProperties(RelabelProperties(props));
247    SetInputSymbols(old_isymbols);
248    SetOutputSymbols(old_osymbols);
249
250    if (old_isymbols && new_isymbols &&
251        old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
252      for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
253           syms_iter.Next()) {
254        input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
255      }
256      SetInputSymbols(new_isymbols);
257      relabel_input_ = true;
258    }
259
260    if (old_osymbols && new_osymbols &&
261        old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
262      for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
263           syms_iter.Next()) {
264        output_map_[syms_iter.Value()] =
265          new_osymbols->Find(syms_iter.Symbol());
266      }
267      SetOutputSymbols(new_osymbols);
268      relabel_output_ = true;
269    }
270  }
271
272  RelabelFstImpl(const RelabelFstImpl<A>& impl)
273      : CacheImpl<A>(impl),
274        fst_(impl.fst_->Copy(true)),
275        input_map_(impl.input_map_),
276        output_map_(impl.output_map_),
277        relabel_input_(impl.relabel_input_),
278        relabel_output_(impl.relabel_output_) {
279    SetType("relabel");
280    SetProperties(impl.Properties(), kCopyProperties);
281    SetInputSymbols(impl.InputSymbols());
282    SetOutputSymbols(impl.OutputSymbols());
283  }
284
285  ~RelabelFstImpl() { delete fst_; }
286
287  StateId Start() {
288    if (!HasStart()) {
289      StateId s = fst_->Start();
290      SetStart(s);
291    }
292    return CacheImpl<A>::Start();
293  }
294
295  Weight Final(StateId s) {
296    if (!HasFinal(s)) {
297      SetFinal(s, fst_->Final(s));
298    }
299    return CacheImpl<A>::Final(s);
300  }
301
302  size_t NumArcs(StateId s) {
303    if (!HasArcs(s)) {
304      Expand(s);
305    }
306    return CacheImpl<A>::NumArcs(s);
307  }
308
309  size_t NumInputEpsilons(StateId s) {
310    if (!HasArcs(s)) {
311      Expand(s);
312    }
313    return CacheImpl<A>::NumInputEpsilons(s);
314  }
315
316  size_t NumOutputEpsilons(StateId s) {
317    if (!HasArcs(s)) {
318      Expand(s);
319    }
320    return CacheImpl<A>::NumOutputEpsilons(s);
321  }
322
323  uint64 Properties() const { return Properties(kFstProperties); }
324
325  // Set error if found; return FST impl properties.
326  uint64 Properties(uint64 mask) const {
327    if ((mask & kError) && fst_->Properties(kError, false))
328      SetProperties(kError, kError);
329    return FstImpl<Arc>::Properties(mask);
330  }
331
332  void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
333    if (!HasArcs(s)) {
334      Expand(s);
335    }
336    CacheImpl<A>::InitArcIterator(s, data);
337  }
338
339  void Expand(StateId s) {
340    for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
341      A arc = aiter.Value();
342
343      // relabel input
344      if (relabel_input_) {
345        typename unordered_map<Label, Label>::iterator it =
346          input_map_.find(arc.ilabel);
347        if (it != input_map_.end()) { arc.ilabel = it->second; }
348      }
349
350      // relabel output
351      if (relabel_output_) {
352        typename unordered_map<Label, Label>::iterator it =
353          output_map_.find(arc.olabel);
354        if (it != output_map_.end()) { arc.olabel = it->second; }
355      }
356
357      PushArc(s, arc);
358    }
359    SetArcs(s);
360  }
361
362
363 private:
364  const Fst<A> *fst_;
365
366  unordered_map<Label, Label> input_map_;
367  unordered_map<Label, Label> output_map_;
368  bool relabel_input_;
369  bool relabel_output_;
370
371  void operator=(const RelabelFstImpl<A> &);  // disallow
372};
373
374
375//
376// \class RelabelFst
377// \brief Delayed implementation of arc relabeling
378//
379// This class attaches interface to implementation and handles
380// reference counting, delegating most methods to ImplToFst.
381template <class A>
382class RelabelFst : public ImplToFst< RelabelFstImpl<A> > {
383 public:
384  friend class ArcIterator< RelabelFst<A> >;
385  friend class StateIterator< RelabelFst<A> >;
386
387  typedef A Arc;
388  typedef typename A::Label   Label;
389  typedef typename A::Weight  Weight;
390  typedef typename A::StateId StateId;
391  typedef CacheState<A> State;
392  typedef RelabelFstImpl<A> Impl;
393
394  RelabelFst(const Fst<A>& fst,
395             const vector<pair<Label, Label> >& ipairs,
396             const vector<pair<Label, Label> >& opairs)
397      : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {}
398
399  RelabelFst(const Fst<A>& fst,
400             const vector<pair<Label, Label> >& ipairs,
401             const vector<pair<Label, Label> >& opairs,
402             const RelabelFstOptions &opts)
403      : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {}
404
405  RelabelFst(const Fst<A>& fst,
406             const SymbolTable* new_isymbols,
407             const SymbolTable* new_osymbols)
408      : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
409                                 fst.OutputSymbols(), new_osymbols,
410                                 RelabelFstOptions())) {}
411
412  RelabelFst(const Fst<A>& fst,
413             const SymbolTable* new_isymbols,
414             const SymbolTable* new_osymbols,
415             const RelabelFstOptions &opts)
416      : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
417                                 fst.OutputSymbols(), new_osymbols, opts)) {}
418
419  RelabelFst(const Fst<A>& fst,
420             const SymbolTable* old_isymbols,
421             const SymbolTable* new_isymbols,
422             const SymbolTable* old_osymbols,
423             const SymbolTable* new_osymbols)
424    : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
425                               new_osymbols, RelabelFstOptions())) {}
426
427  RelabelFst(const Fst<A>& fst,
428             const SymbolTable* old_isymbols,
429             const SymbolTable* new_isymbols,
430             const SymbolTable* old_osymbols,
431             const SymbolTable* new_osymbols,
432             const RelabelFstOptions &opts)
433    : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
434                               new_osymbols, opts)) {}
435
436  // See Fst<>::Copy() for doc.
437  RelabelFst(const RelabelFst<A> &fst, bool safe = false)
438    : ImplToFst<Impl>(fst, safe) {}
439
440  // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc.
441  virtual RelabelFst<A> *Copy(bool safe = false) const {
442    return new RelabelFst<A>(*this, safe);
443  }
444
445  virtual void InitStateIterator(StateIteratorData<A> *data) const;
446
447  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
448    return GetImpl()->InitArcIterator(s, data);
449  }
450
451 private:
452  // Makes visible to friends.
453  Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
454
455  void operator=(const RelabelFst<A> &fst);  // disallow
456};
457
458// Specialization for RelabelFst.
459template<class A>
460class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
461 public:
462  typedef typename A::StateId StateId;
463
464  explicit StateIterator(const RelabelFst<A> &fst)
465      : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
466
467  bool Done() const { return siter_.Done(); }
468
469  StateId Value() const { return s_; }
470
471  void Next() {
472    if (!siter_.Done()) {
473      ++s_;
474      siter_.Next();
475    }
476  }
477
478  void Reset() {
479    s_ = 0;
480    siter_.Reset();
481  }
482
483 private:
484  bool Done_() const { return Done(); }
485  StateId Value_() const { return Value(); }
486  void Next_() { Next(); }
487  void Reset_() { Reset(); }
488
489  const RelabelFstImpl<A> *impl_;
490  StateIterator< Fst<A> > siter_;
491  StateId s_;
492
493  DISALLOW_COPY_AND_ASSIGN(StateIterator);
494};
495
496
497// Specialization for RelabelFst.
498template <class A>
499class ArcIterator< RelabelFst<A> >
500    : public CacheArcIterator< RelabelFst<A> > {
501 public:
502  typedef typename A::StateId StateId;
503
504  ArcIterator(const RelabelFst<A> &fst, StateId s)
505      : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) {
506    if (!fst.GetImpl()->HasArcs(s))
507      fst.GetImpl()->Expand(s);
508  }
509
510 private:
511  DISALLOW_COPY_AND_ASSIGN(ArcIterator);
512};
513
514template <class A> inline
515void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
516  data->base = new StateIterator< RelabelFst<A> >(*this);
517}
518
519// Useful alias when using StdArc.
520typedef RelabelFst<StdArc> StdRelabelFst;
521
522}  // namespace fst
523
524#endif  // FST_LIB_RELABEL_H__
525