1// relabel.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// Functions and classes to relabel an Fst (either on input or output)
18//
19#ifndef FST_LIB_RELABEL_H__
20#define FST_LIB_RELABEL_H__
21
22#include <unordered_map>
23
24#include "fst/lib/cache.h"
25#include "fst/lib/test-properties.h"
26
27
28namespace fst {
29
30//
31// Relabels either the input labels or output labels. The old to
32// new labels are specified using a vector of pair<Label,Label>.
33// Any label associations not specified are assumed to be identity
34// mapping.
35//
36// \param fst input fst, must be mutable
37// \param relabel_pairs vector of pairs indicating old to new mapping
38// \param relabel_flags whether to relabel input or output
39//
40template <class A>
41void Relabel(
42    MutableFst<A> *fst,
43    const vector<pair<typename A::Label, typename A::Label> >& ipairs,
44    const vector<pair<typename A::Label, typename A::Label> >& opairs) {
45  typedef typename A::StateId StateId;
46  typedef typename A::Label   Label;
47
48  uint64 props = fst->Properties(kFstProperties, false);
49
50  // construct label to label hash. Could
51  std::unordered_map<Label, Label> input_map;
52  for (size_t i = 0; i < ipairs.size(); ++i) {
53    input_map[ipairs[i].first] = ipairs[i].second;
54  }
55
56  std::unordered_map<Label, Label> output_map;
57  for (size_t i = 0; i < opairs.size(); ++i) {
58    output_map[opairs[i].first] = opairs[i].second;
59  }
60
61  for (StateIterator<MutableFst<A> > siter(*fst);
62       !siter.Done(); siter.Next()) {
63    StateId s = siter.Value();
64    for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
65         !aiter.Done(); aiter.Next()) {
66      A arc = aiter.Value();
67
68      // relabel input
69      // only relabel if relabel pair defined
70      typename std::unordered_map<Label, Label>::iterator it =
71        input_map.find(arc.ilabel);
72      if (it != input_map.end()) {arc.ilabel = it->second; }
73
74      // relabel output
75      it = output_map.find(arc.olabel);
76      if (it != output_map.end()) { arc.olabel = it->second; }
77
78      aiter.SetValue(arc);
79    }
80  }
81
82  fst->SetProperties(RelabelProperties(props), kFstProperties);
83}
84
85
86
87//
88// Relabels either the input labels or output labels. The old to
89// new labels mappings are specified using an input Symbol set.
90// Any label associations not specified are assumed to be identity
91// mapping.
92//
93// \param fst input fst, must be mutable
94// \param new_symbols symbol set indicating new mapping
95// \param relabel_flags whether to relabel input or output
96//
97template<class A>
98void Relabel(MutableFst<A> *fst,
99             const SymbolTable* new_isymbols,
100             const SymbolTable* new_osymbols) {
101  typedef typename A::StateId StateId;
102  typedef typename A::Label   Label;
103
104  const SymbolTable* old_isymbols = fst->InputSymbols();
105  const SymbolTable* old_osymbols = fst->OutputSymbols();
106
107  vector<pair<Label, Label> > ipairs;
108  if (old_isymbols && new_isymbols) {
109    for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
110         syms_iter.Next()) {
111      ipairs.push_back(make_pair(syms_iter.Value(),
112                                 new_isymbols->Find(syms_iter.Symbol())));
113    }
114    fst->SetInputSymbols(new_isymbols);
115  }
116
117  vector<pair<Label, Label> > opairs;
118  if (old_osymbols && new_osymbols) {
119    for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
120         syms_iter.Next()) {
121      opairs.push_back(make_pair(syms_iter.Value(),
122                                 new_osymbols->Find(syms_iter.Symbol())));
123    }
124    fst->SetOutputSymbols(new_osymbols);
125  }
126
127  // call relabel using vector of relabel pairs.
128  Relabel(fst, ipairs, opairs);
129}
130
131
132typedef CacheOptions RelabelFstOptions;
133
134template <class A> class RelabelFst;
135
136//
137// \class RelabelFstImpl
138// \brief Implementation for delayed relabeling
139//
140// Relabels an FST from one symbol set to another. Relabeling
141// can either be on input or output space. RelabelFst implements
142// a delayed version of the relabel. Arcs are relabeled on the fly
143// and not cached. I.e each request is recomputed.
144//
145template<class A>
146class RelabelFstImpl : public CacheImpl<A> {
147  friend class StateIterator< RelabelFst<A> >;
148 public:
149  using FstImpl<A>::SetType;
150  using FstImpl<A>::SetProperties;
151  using FstImpl<A>::Properties;
152  using FstImpl<A>::SetInputSymbols;
153  using FstImpl<A>::SetOutputSymbols;
154
155  using CacheImpl<A>::HasStart;
156  using CacheImpl<A>::HasArcs;
157
158  typedef typename A::Label   Label;
159  typedef typename A::Weight  Weight;
160  typedef typename A::StateId StateId;
161  typedef CacheState<A> State;
162
163  RelabelFstImpl(const Fst<A>& fst,
164                 const vector<pair<Label, Label> >& ipairs,
165                 const vector<pair<Label, Label> >& opairs,
166                 const RelabelFstOptions &opts)
167      : CacheImpl<A>(opts), fst_(fst.Copy()),
168        relabel_input_(false), relabel_output_(false) {
169    uint64 props = fst.Properties(kCopyProperties, false);
170    SetProperties(RelabelProperties(props));
171    SetType("relabel");
172
173    // create input label map
174    if (ipairs.size() > 0) {
175      for (size_t i = 0; i < ipairs.size(); ++i) {
176        input_map_[ipairs[i].first] = ipairs[i].second;
177      }
178      relabel_input_ = true;
179    }
180
181    // create output label map
182    if (opairs.size() > 0) {
183      for (size_t i = 0; i < opairs.size(); ++i) {
184        output_map_[opairs[i].first] = opairs[i].second;
185      }
186      relabel_output_ = true;
187    }
188  }
189
190  RelabelFstImpl(const Fst<A>& fst,
191                 const SymbolTable* new_isymbols,
192                 const SymbolTable* new_osymbols,
193                 const RelabelFstOptions &opts)
194      : CacheImpl<A>(opts), fst_(fst.Copy()),
195        relabel_input_(false), relabel_output_(false) {
196    SetType("relabel");
197
198    uint64 props = fst.Properties(kCopyProperties, false);
199    SetProperties(RelabelProperties(props));
200    SetInputSymbols(fst.InputSymbols());
201    SetOutputSymbols(fst.OutputSymbols());
202
203    const SymbolTable* old_isymbols = fst.InputSymbols();
204    const SymbolTable* old_osymbols = fst.OutputSymbols();
205
206    if (old_isymbols && new_isymbols &&
207        old_isymbols->CheckSum() != new_isymbols->CheckSum()) {
208      for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
209           syms_iter.Next()) {
210        input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
211      }
212      SetInputSymbols(new_isymbols);
213      relabel_input_ = true;
214    }
215
216    if (old_osymbols && new_osymbols &&
217        old_osymbols->CheckSum() != new_osymbols->CheckSum()) {
218      for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
219           syms_iter.Next()) {
220        output_map_[syms_iter.Value()] =
221          new_osymbols->Find(syms_iter.Symbol());
222      }
223      SetOutputSymbols(new_osymbols);
224      relabel_output_ = true;
225    }
226  }
227
228  ~RelabelFstImpl() { delete fst_; }
229
230  StateId Start() {
231    if (!HasStart()) {
232      StateId s = fst_->Start();
233      SetStart(s);
234    }
235    return CacheImpl<A>::Start();
236  }
237
238  Weight Final(StateId s) {
239    if (!HasFinal(s)) {
240      SetFinal(s, fst_->Final(s));
241    }
242    return CacheImpl<A>::Final(s);
243  }
244
245  size_t NumArcs(StateId s) {
246    if (!HasArcs(s)) {
247      Expand(s);
248    }
249    return CacheImpl<A>::NumArcs(s);
250  }
251
252  size_t NumInputEpsilons(StateId s) {
253    if (!HasArcs(s)) {
254      Expand(s);
255    }
256    return CacheImpl<A>::NumInputEpsilons(s);
257  }
258
259  size_t NumOutputEpsilons(StateId s) {
260    if (!HasArcs(s)) {
261      Expand(s);
262    }
263    return CacheImpl<A>::NumOutputEpsilons(s);
264  }
265
266  void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
267    if (!HasArcs(s)) {
268      Expand(s);
269    }
270    CacheImpl<A>::InitArcIterator(s, data);
271  }
272
273  void Expand(StateId s) {
274    for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
275      A arc = aiter.Value();
276
277      // relabel input
278      if (relabel_input_) {
279        typename std::unordered_map<Label, Label>::iterator it =
280          input_map_.find(arc.ilabel);
281        if (it != input_map_.end()) { arc.ilabel = it->second; }
282      }
283
284      // relabel output
285      if (relabel_output_) {
286        typename std::unordered_map<Label, Label>::iterator it =
287          output_map_.find(arc.olabel);
288        if (it != output_map_.end()) { arc.olabel = it->second; }
289      }
290
291      AddArc(s, arc);
292    }
293    SetArcs(s);
294  }
295
296
297 private:
298  const Fst<A> *fst_;
299
300  std::unordered_map<Label, Label> input_map_;
301  std::unordered_map<Label, Label> output_map_;
302  bool relabel_input_;
303  bool relabel_output_;
304
305  DISALLOW_EVIL_CONSTRUCTORS(RelabelFstImpl);
306};
307
308
309//
310// \class RelabelFst
311// \brief Delayed implementation of arc relabeling
312//
313// This class attaches interface to implementation and handles
314// reference counting.
315template <class A>
316class RelabelFst : public Fst<A> {
317 public:
318  friend class ArcIterator< RelabelFst<A> >;
319  friend class StateIterator< RelabelFst<A> >;
320  friend class CacheArcIterator< RelabelFst<A> >;
321
322  typedef A Arc;
323  typedef typename A::Label   Label;
324  typedef typename A::Weight  Weight;
325  typedef typename A::StateId StateId;
326  typedef CacheState<A> State;
327
328  RelabelFst(const Fst<A>& fst,
329             const vector<pair<Label, Label> >& ipairs,
330             const vector<pair<Label, Label> >& opairs) :
331      impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, RelabelFstOptions())) {}
332
333  RelabelFst(const Fst<A>& fst,
334             const vector<pair<Label, Label> >& ipairs,
335             const vector<pair<Label, Label> >& opairs,
336             const RelabelFstOptions &opts)
337      : impl_(new RelabelFstImpl<A>(fst, ipairs, opairs, opts)) {}
338
339  RelabelFst(const Fst<A>& fst,
340             const SymbolTable* new_isymbols,
341             const SymbolTable* new_osymbols) :
342      impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols,
343                                  RelabelFstOptions())) {}
344
345  RelabelFst(const Fst<A>& fst,
346             const SymbolTable* new_isymbols,
347             const SymbolTable* new_osymbols,
348             const RelabelFstOptions &opts)
349    : impl_(new RelabelFstImpl<A>(fst, new_isymbols, new_osymbols, opts)) {}
350
351  RelabelFst(const RelabelFst<A> &fst) : impl_(fst.impl_) {
352    impl_->IncrRefCount();
353  }
354
355  virtual ~RelabelFst() { if (!impl_->DecrRefCount()) delete impl_;  }
356
357  virtual StateId Start() const { return impl_->Start(); }
358
359  virtual Weight Final(StateId s) const { return impl_->Final(s); }
360
361  virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
362
363  virtual size_t NumInputEpsilons(StateId s) const {
364    return impl_->NumInputEpsilons(s);
365  }
366
367  virtual size_t NumOutputEpsilons(StateId s) const {
368    return impl_->NumOutputEpsilons(s);
369  }
370
371  virtual uint64 Properties(uint64 mask, bool test) const {
372    if (test) {
373      uint64 known, test = TestProperties(*this, mask, &known);
374      impl_->SetProperties(test, known);
375      return test & mask;
376    } else {
377      return impl_->Properties(mask);
378    }
379  }
380
381  virtual const string& Type() const { return impl_->Type(); }
382
383  virtual RelabelFst<A> *Copy() const {
384    return new RelabelFst<A>(*this);
385  }
386
387  virtual const SymbolTable* InputSymbols() const {
388    return impl_->InputSymbols();
389  }
390
391  virtual const SymbolTable* OutputSymbols() const {
392    return impl_->OutputSymbols();
393  }
394
395  virtual void InitStateIterator(StateIteratorData<A> *data) const;
396
397  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
398    return impl_->InitArcIterator(s, data);
399  }
400
401 private:
402  RelabelFstImpl<A> *impl_;
403
404  void operator=(const RelabelFst<A> &fst);  // disallow
405};
406
407// Specialization for RelabelFst.
408template<class A>
409class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
410 public:
411  typedef typename A::StateId StateId;
412
413  explicit StateIterator(const RelabelFst<A> &fst)
414      : impl_(fst.impl_), siter_(*impl_->fst_), s_(0) {}
415
416  bool Done() const { return siter_.Done(); }
417
418  StateId Value() const { return s_; }
419
420  void Next() {
421    if (!siter_.Done()) {
422      ++s_;
423      siter_.Next();
424    }
425  }
426
427  void Reset() {
428    s_ = 0;
429    siter_.Reset();
430  }
431
432 private:
433  const RelabelFstImpl<A> *impl_;
434  StateIterator< Fst<A> > siter_;
435  StateId s_;
436
437  DISALLOW_EVIL_CONSTRUCTORS(StateIterator);
438};
439
440
441// Specialization for RelabelFst.
442template <class A>
443class ArcIterator< RelabelFst<A> >
444    : public CacheArcIterator< RelabelFst<A> > {
445 public:
446  typedef typename A::StateId StateId;
447
448  ArcIterator(const RelabelFst<A> &fst, StateId s)
449      : CacheArcIterator< RelabelFst<A> >(fst, s) {
450    if (!fst.impl_->HasArcs(s))
451      fst.impl_->Expand(s);
452  }
453
454 private:
455  DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
456};
457
458template <class A> inline
459void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
460  data->base = new StateIterator< RelabelFst<A> >(*this);
461}
462
463// Useful alias when using StdArc.
464typedef RelabelFst<StdArc> StdRelabelFst;
465
466}  // namespace fst
467
468#endif  // FST_LIB_RELABEL_H__
469