relabel.h revision 8fc5a7f51e62cb4ae44a27bdf4176d04adc80ede
1// relabel.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// Functions and classes to relabel an Fst (either on input or output)
18//
19#ifndef FST_LIB_RELABEL_H__
20#define FST_LIB_RELABEL_H__
21
22#include <ext/hash_map>
23using __gnu_cxx::hash_map;
24
25#include "fst/lib/cache.h"
26#include "fst/lib/test-properties.h"
27
28
29namespace fst {
30
31//
32// Relabels either the input labels or output labels. The old to
33// new labels are specified using a vector of pair<Label,Label>.
34// Any label associations not specified are assumed to be identity
35// mapping.
36//
37// \param fst input fst, must be mutable
38// \param relabel_pairs vector of pairs indicating old to new mapping
39// \param relabel_flags whether to relabel input or output
40//
41template <class A>
42void Relabel(
43    MutableFst<A> *fst,
44    const vector<pair<typename A::Label, typename A::Label> >& ipairs,
45    const vector<pair<typename A::Label, typename A::Label> >& opairs) {
46  typedef typename A::StateId StateId;
47  typedef typename A::Label   Label;
48
49  uint64 props = fst->Properties(kFstProperties, false);
50
51  // construct label to label hash. Could
52  hash_map<Label, Label> input_map;
53  for (size_t i = 0; i < ipairs.size(); ++i) {
54    input_map[ipairs[i].first] = ipairs[i].second;
55  }
56
57  hash_map<Label, Label> output_map;
58  for (size_t i = 0; i < opairs.size(); ++i) {
59    output_map[opairs[i].first] = opairs[i].second;
60  }
61
62  for (StateIterator<MutableFst<A> > siter(*fst);
63       !siter.Done(); siter.Next()) {
64    StateId s = siter.Value();
65    for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
66         !aiter.Done(); aiter.Next()) {
67      A arc = aiter.Value();
68
69      // relabel input
70      // only relabel if relabel pair defined
71      typename hash_map<Label, Label>::iterator it =
72        input_map.find(arc.ilabel);
73      if (it != input_map.end()) {arc.ilabel = it->second; }
74
75      // relabel output
76      it = output_map.find(arc.olabel);
77      if (it != output_map.end()) { arc.olabel = it->second; }
78
79      aiter.SetValue(arc);
80    }
81  }
82
83  fst->SetProperties(RelabelProperties(props), kFstProperties);
84}
85
86
87
88//
89// Relabels either the input labels or output labels. The old to
90// new labels mappings are specified using an input Symbol set.
91// Any label associations not specified are assumed to be identity
92// mapping.
93//
94// \param fst input fst, must be mutable
95// \param new_symbols symbol set indicating new mapping
96// \param relabel_flags whether to relabel input or output
97//
98template<class A>
99void Relabel(MutableFst<A> *fst,
100             const SymbolTable* new_isymbols,
101             const SymbolTable* new_osymbols) {
102  typedef typename A::StateId StateId;
103  typedef typename A::Label   Label;
104
105  const SymbolTable* old_isymbols = fst->InputSymbols();
106  const SymbolTable* old_osymbols = fst->OutputSymbols();
107
108  vector<pair<Label, Label> > ipairs;
109  if (old_isymbols && new_isymbols) {
110    for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
111         syms_iter.Next()) {
112      ipairs.push_back(make_pair(syms_iter.Value(),
113                                 new_isymbols->Find(syms_iter.Symbol())));
114    }
115    fst->SetInputSymbols(new_isymbols);
116  }
117
118  vector<pair<Label, Label> > opairs;
119  if (old_osymbols && new_osymbols) {
120    for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
121         syms_iter.Next()) {
122      opairs.push_back(make_pair(syms_iter.Value(),
123                                 new_osymbols->Find(syms_iter.Symbol())));
124    }
125    fst->SetOutputSymbols(new_osymbols);
126  }
127
128  // call relabel using vector of relabel pairs.
129  Relabel(fst, ipairs, opairs);
130}
131
132
133typedef CacheOptions RelabelFstOptions;
134
135template <class A> class RelabelFst;
136
137//
138// \class RelabelFstImpl
139// \brief Implementation for delayed relabeling
140//
141// Relabels an FST from one symbol set to another. Relabeling
142// can either be on input or output space. RelabelFst implements
143// a delayed version of the relabel. Arcs are relabeled on the fly
144// and not cached. I.e each request is recomputed.
145//
146template<class A>
147class RelabelFstImpl : public CacheImpl<A> {
148  friend class StateIterator< RelabelFst<A> >;
149 public:
150  using FstImpl<A>::SetType;
151  using FstImpl<A>::SetProperties;
152  using FstImpl<A>::Properties;
153  using FstImpl<A>::SetInputSymbols;
154  using FstImpl<A>::SetOutputSymbols;
155
156  using CacheImpl<A>::HasStart;
157  using CacheImpl<A>::HasArcs;
158
159  typedef typename A::Label   Label;
160  typedef typename A::Weight  Weight;
161  typedef typename A::StateId StateId;
162  typedef CacheState<A> State;
163
164  RelabelFstImpl(const Fst<A>& fst,
165                 const vector<pair<Label, Label> >& ipairs,
166                 const vector<pair<Label, Label> >& opairs,
167                 const RelabelFstOptions &opts)
168      : CacheImpl<A>(opts), fst_(fst.Copy()),
169        relabel_input_(false), relabel_output_(false) {
170    uint64 props = fst.Properties(kCopyProperties, false);
171    SetProperties(RelabelProperties(props));
172    SetType("relabel");
173
174    // create input label map
175    if (ipairs.size() > 0) {
176      for (size_t i = 0; i < ipairs.size(); ++i) {
177        input_map_[ipairs[i].first] = ipairs[i].second;
178      }
179      relabel_input_ = true;
180    }
181
182    // create output label map
183    if (opairs.size() > 0) {
184      for (size_t i = 0; i < opairs.size(); ++i) {
185        output_map_[opairs[i].first] = opairs[i].second;
186      }
187      relabel_output_ = true;
188    }
189  }
190
191  RelabelFstImpl(const Fst<A>& fst,
192                 const SymbolTable* new_isymbols,
193                 const SymbolTable* new_osymbols,
194                 const RelabelFstOptions &opts)
195      : CacheImpl<A>(opts), fst_(fst.Copy()),
196        relabel_input_(false), relabel_output_(false) {
197    SetType("relabel");
198
199    uint64 props = fst.Properties(kCopyProperties, false);
200    SetProperties(RelabelProperties(props));
201    SetInputSymbols(fst.InputSymbols());
202    SetOutputSymbols(fst.OutputSymbols());
203
204    const SymbolTable* old_isymbols = fst.InputSymbols();
205    const SymbolTable* old_osymbols = fst.OutputSymbols();
206
207    if (old_isymbols && new_isymbols &&
208        old_isymbols->CheckSum() != new_isymbols->CheckSum()) {
209      for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
210           syms_iter.Next()) {
211        input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
212      }
213      SetInputSymbols(new_isymbols);
214      relabel_input_ = true;
215    }
216
217    if (old_osymbols && new_osymbols &&
218        old_osymbols->CheckSum() != new_osymbols->CheckSum()) {
219      for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
220           syms_iter.Next()) {
221        output_map_[syms_iter.Value()] =
222          new_osymbols->Find(syms_iter.Symbol());
223      }
224      SetOutputSymbols(new_osymbols);
225      relabel_output_ = true;
226    }
227  }
228
229  ~RelabelFstImpl() { delete fst_; }
230
231  StateId Start() {
232    if (!HasStart()) {
233      StateId s = fst_->Start();
234      SetStart(s);
235    }
236    return CacheImpl<A>::Start();
237  }
238
239  Weight Final(StateId s) {
240    if (!HasFinal(s)) {
241      SetFinal(s, fst_->Final(s));
242    }
243    return CacheImpl<A>::Final(s);
244  }
245
246  size_t NumArcs(StateId s) {
247    if (!HasArcs(s)) {
248      Expand(s);
249    }
250    return CacheImpl<A>::NumArcs(s);
251  }
252
253  size_t NumInputEpsilons(StateId s) {
254    if (!HasArcs(s)) {
255      Expand(s);
256    }
257    return CacheImpl<A>::NumInputEpsilons(s);
258  }
259
260  size_t NumOutputEpsilons(StateId s) {
261    if (!HasArcs(s)) {
262      Expand(s);
263    }
264    return CacheImpl<A>::NumOutputEpsilons(s);
265  }
266
267  void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
268    if (!HasArcs(s)) {
269      Expand(s);
270    }
271    CacheImpl<A>::InitArcIterator(s, data);
272  }
273
274  void Expand(StateId s) {
275    for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
276      A arc = aiter.Value();
277
278      // relabel input
279      if (relabel_input_) {
280        typename hash_map<Label, Label>::iterator it =
281          input_map_.find(arc.ilabel);
282        if (it != input_map_.end()) { arc.ilabel = it->second; }
283      }
284
285      // relabel output
286      if (relabel_output_) {
287        typename hash_map<Label, Label>::iterator it =
288          output_map_.find(arc.olabel);
289        if (it != output_map_.end()) { arc.olabel = it->second; }
290      }
291
292      AddArc(s, arc);
293    }
294    SetArcs(s);
295  }
296
297
298 private:
299  const Fst<A> *fst_;
300
301  hash_map<Label, Label> input_map_;
302  hash_map<Label, Label> output_map_;
303  bool relabel_input_;
304  bool relabel_output_;
305
306  DISALLOW_EVIL_CONSTRUCTORS(RelabelFstImpl);
307};
308
309
310//
311// \class RelabelFst
312// \brief Delayed implementation of arc relabeling
313//
314// This class attaches interface to implementation and handles
315// reference counting.
316template <class A>
317class RelabelFst : public Fst<A> {
318 public:
319  friend class ArcIterator< RelabelFst<A> >;
320  friend class StateIterator< RelabelFst<A> >;
321  friend class CacheArcIterator< RelabelFst<A> >;
322
323  typedef A Arc;
324  typedef typename A::Label   Label;
325  typedef typename A::Weight  Weight;
326  typedef typename A::StateId StateId;
327  typedef CacheState<A> State;
328
329  RelabelFst(const Fst<A>& fst,
330             const vector<pair<Label, Label> >& ipairs,
331             const vector<pair<Label, Label> >& opairs) :
332      impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, RelabelFstOptions())) {}
333
334  RelabelFst(const Fst<A>& fst,
335             const vector<pair<Label, Label> >& ipairs,
336             const vector<pair<Label, Label> >& opairs,
337             const RelabelFstOptions &opts)
338      : impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, opts)) {}
339
340  RelabelFst(const Fst<A>& fst,
341             const SymbolTable* new_isymbols,
342             const SymbolTable* new_osymbols) :
343      impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols,
344                                  RelabelFstOptions())) {}
345
346  RelabelFst(const Fst<A>& fst,
347             const SymbolTable* new_isymbols,
348             const SymbolTable* new_osymbols,
349             const RelabelFstOptions &opts)
350    : impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols, opts)) {}
351
352  RelabelFst(const RelabelFst<A> &fst) : impl_(fst.impl_) {
353    impl_->IncrRefCount();
354  }
355
356  virtual ~RelabelFst() { if (!impl_->DecrRefCount()) delete impl_;  }
357
358  virtual StateId Start() const { return impl_->Start(); }
359
360  virtual Weight Final(StateId s) const { return impl_->Final(s); }
361
362  virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
363
364  virtual size_t NumInputEpsilons(StateId s) const {
365    return impl_->NumInputEpsilons(s);
366  }
367
368  virtual size_t NumOutputEpsilons(StateId s) const {
369    return impl_->NumOutputEpsilons(s);
370  }
371
372  virtual uint64 Properties(uint64 mask, bool test) const {
373    if (test) {
374      uint64 known, test = TestProperties(*this, mask, &known);
375      impl_->SetProperties(test, known);
376      return test & mask;
377    } else {
378      return impl_->Properties(mask);
379    }
380  }
381
382  virtual const string& Type() const { return impl_->Type(); }
383
384  virtual RelabelFst<A> *Copy() const {
385    return new RelabelFst<A>(*this);
386  }
387
388  virtual const SymbolTable* InputSymbols() const {
389    return impl_->InputSymbols();
390  }
391
392  virtual const SymbolTable* OutputSymbols() const {
393    return impl_->OutputSymbols();
394  }
395
396  virtual void InitStateIterator(StateIteratorData<A> *data) const;
397
398  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
399    return impl_->InitArcIterator(s, data);
400  }
401
402 private:
403  RelabelFstImpl<A> *impl_;
404
405  void operator=(const RelabelFst<A> &fst);  // disallow
406};
407
408// Specialization for RelabelFst.
409template<class A>
410class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
411 public:
412  typedef typename A::StateId StateId;
413
414  explicit StateIterator(const RelabelFst<A> &fst)
415      : impl_(fst.impl_), siter_(*impl_->fst_), s_(0) {}
416
417  bool Done() const { return siter_.Done(); }
418
419  StateId Value() const { return s_; }
420
421  void Next() {
422    if (!siter_.Done()) {
423      ++s_;
424      siter_.Next();
425    }
426  }
427
428  void Reset() {
429    s_ = 0;
430    siter_.Reset();
431  }
432
433 private:
434  const RelabelFstImpl<A> *impl_;
435  StateIterator< Fst<A> > siter_;
436  StateId s_;
437
438  DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
439};
440
441
442// Specialization for RelabelFst.
443template <class A>
444class ArcIterator< RelabelFst<A> >
445    : public CacheArcIterator< RelabelFst<A> > {
446 public:
447  typedef typename A::StateId StateId;
448
449  ArcIterator(const RelabelFst<A> &fst, StateId s)
450      : CacheArcIterator< RelabelFst<A> >(fst, s) {
451    if (!fst.impl_->HasArcs(s))
452      fst.impl_->Expand(s);
453  }
454
455 private:
456  DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
457};
458
459template <class A> inline
460void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
461  data->base = new StateIterator< RelabelFst<A> >(*this);
462}
463
464// Useful alias when using StdArc.
465typedef RelabelFst<StdArc> StdRelabelFst;
466
467}  // namespace fst
468
469#endif  // FST_LIB_RELABEL_H__
470