1// relabel.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: johans@google.com (Johan Schalkwyk)
17//
18// \file
19// Functions and classes to relabel an Fst (either on input or output)
20//
21#ifndef FST_LIB_RELABEL_H__
22#define FST_LIB_RELABEL_H__
23
24#include <tr1/unordered_map>
25using std::tr1::unordered_map;
26using std::tr1::unordered_multimap;
27#include <string>
28#include <utility>
29using std::pair; using std::make_pair;
30#include <vector>
31using std::vector;
32
33#include <fst/cache.h>
34#include <fst/test-properties.h>
35
36
37#include <tr1/unordered_map>
38using std::tr1::unordered_map;
39using std::tr1::unordered_multimap;
40
41namespace fst {
42
43//
44// Relabels either the input labels or output labels. The old to
45// new labels are specified using a vector of pair<Label,Label>.
46// Any label associations not specified are assumed to be identity
47// mapping.
48//
49// \param fst input fst, must be mutable
50// \param ipairs vector of input label pairs indicating old to new mapping
51// \param opairs vector of output label pairs indicating old to new mapping
52//
53template <class A>
54void Relabel(
55    MutableFst<A> *fst,
56    const vector<pair<typename A::Label, typename A::Label> >& ipairs,
57    const vector<pair<typename A::Label, typename A::Label> >& opairs) {
58  typedef typename A::StateId StateId;
59  typedef typename A::Label   Label;
60
61  uint64 props = fst->Properties(kFstProperties, false);
62
63  // construct label to label hash.
64  unordered_map<Label, Label> input_map;
65  for (size_t i = 0; i < ipairs.size(); ++i) {
66    input_map[ipairs[i].first] = ipairs[i].second;
67  }
68
69  unordered_map<Label, Label> output_map;
70  for (size_t i = 0; i < opairs.size(); ++i) {
71    output_map[opairs[i].first] = opairs[i].second;
72  }
73
74  for (StateIterator<MutableFst<A> > siter(*fst);
75       !siter.Done(); siter.Next()) {
76    StateId s = siter.Value();
77    for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
78         !aiter.Done(); aiter.Next()) {
79      A arc = aiter.Value();
80
81      // relabel input
82      // only relabel if relabel pair defined
83      typename unordered_map<Label, Label>::iterator it =
84        input_map.find(arc.ilabel);
85      if (it != input_map.end()) {
86        if (it->second == kNoLabel) {
87          FSTERROR() << "Input symbol id " << arc.ilabel
88                     << " missing from target vocabulary";
89          fst->SetProperties(kError, kError);
90          return;
91        }
92        arc.ilabel = it->second;
93      }
94
95      // relabel output
96      it = output_map.find(arc.olabel);
97      if (it != output_map.end()) {
98        if (it->second == kNoLabel) {
99          FSTERROR() << "Output symbol id " << arc.olabel
100                     << " missing from target vocabulary";
101          fst->SetProperties(kError, kError);
102          return;
103        }
104        arc.olabel = it->second;
105      }
106
107      aiter.SetValue(arc);
108    }
109  }
110
111  fst->SetProperties(RelabelProperties(props), kFstProperties);
112}
113
114//
115// Relabels either the input labels or output labels. The old to
116// new labels mappings are specified using an input Symbol set.
117// Any label associations not specified are assumed to be identity
118// mapping.
119//
120// \param fst input fst, must be mutable
121// \param new_isymbols symbol set indicating new mapping of input symbols
122// \param new_osymbols symbol set indicating new mapping of output symbols
123//
124template<class A>
125void Relabel(MutableFst<A> *fst,
126             const SymbolTable* new_isymbols,
127             const SymbolTable* new_osymbols) {
128  Relabel(fst,
129          fst->InputSymbols(), new_isymbols, true,
130          fst->OutputSymbols(), new_osymbols, true);
131}
132
133template<class A>
134void Relabel(MutableFst<A> *fst,
135             const SymbolTable* old_isymbols,
136             const SymbolTable* new_isymbols,
137             bool attach_new_isymbols,
138             const SymbolTable* old_osymbols,
139             const SymbolTable* new_osymbols,
140             bool attach_new_osymbols) {
141  typedef typename A::StateId StateId;
142  typedef typename A::Label   Label;
143
144  vector<pair<Label, Label> > ipairs;
145  if (old_isymbols && new_isymbols) {
146    for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
147         syms_iter.Next()) {
148      string isymbol = syms_iter.Symbol();
149      int isymbol_val = syms_iter.Value();
150      int new_isymbol_val = new_isymbols->Find(isymbol);
151      ipairs.push_back(make_pair(isymbol_val, new_isymbol_val));
152    }
153    if (attach_new_isymbols)
154      fst->SetInputSymbols(new_isymbols);
155  }
156
157  vector<pair<Label, Label> > opairs;
158  if (old_osymbols && new_osymbols) {
159    for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
160         syms_iter.Next()) {
161      string osymbol = syms_iter.Symbol();
162      int osymbol_val = syms_iter.Value();
163      int new_osymbol_val = new_osymbols->Find(osymbol);
164      opairs.push_back(make_pair(osymbol_val, new_osymbol_val));
165    }
166    if (attach_new_osymbols)
167      fst->SetOutputSymbols(new_osymbols);
168  }
169
170  // call relabel using vector of relabel pairs.
171  Relabel(fst, ipairs, opairs);
172}
173
174
175typedef CacheOptions RelabelFstOptions;
176
177template <class A> class RelabelFst;
178
179//
180// \class RelabelFstImpl
181// \brief Implementation for delayed relabeling
182//
183// Relabels an FST from one symbol set to another. Relabeling
184// can either be on input or output space. RelabelFst implements
185// a delayed version of the relabel. Arcs are relabeled on the fly
186// and not cached. I.e each request is recomputed.
187//
188template<class A>
189class RelabelFstImpl : public CacheImpl<A> {
190  friend class StateIterator< RelabelFst<A> >;
191 public:
192  using FstImpl<A>::SetType;
193  using FstImpl<A>::SetProperties;
194  using FstImpl<A>::WriteHeader;
195  using FstImpl<A>::SetInputSymbols;
196  using FstImpl<A>::SetOutputSymbols;
197
198  using CacheImpl<A>::PushArc;
199  using CacheImpl<A>::HasArcs;
200  using CacheImpl<A>::HasFinal;
201  using CacheImpl<A>::HasStart;
202  using CacheImpl<A>::SetArcs;
203  using CacheImpl<A>::SetFinal;
204  using CacheImpl<A>::SetStart;
205
206  typedef A Arc;
207  typedef typename A::Label   Label;
208  typedef typename A::Weight  Weight;
209  typedef typename A::StateId StateId;
210  typedef CacheState<A> State;
211
212  RelabelFstImpl(const Fst<A>& fst,
213                 const vector<pair<Label, Label> >& ipairs,
214                 const vector<pair<Label, Label> >& opairs,
215                 const RelabelFstOptions &opts)
216      : CacheImpl<A>(opts), fst_(fst.Copy()),
217        relabel_input_(false), relabel_output_(false) {
218    uint64 props = fst.Properties(kCopyProperties, false);
219    SetProperties(RelabelProperties(props));
220    SetType("relabel");
221
222    // create input label map
223    if (ipairs.size() > 0) {
224      for (size_t i = 0; i < ipairs.size(); ++i) {
225        input_map_[ipairs[i].first] = ipairs[i].second;
226      }
227      relabel_input_ = true;
228    }
229
230    // create output label map
231    if (opairs.size() > 0) {
232      for (size_t i = 0; i < opairs.size(); ++i) {
233        output_map_[opairs[i].first] = opairs[i].second;
234      }
235      relabel_output_ = true;
236    }
237  }
238
239  RelabelFstImpl(const Fst<A>& fst,
240                 const SymbolTable* old_isymbols,
241                 const SymbolTable* new_isymbols,
242                 const SymbolTable* old_osymbols,
243                 const SymbolTable* new_osymbols,
244                 const RelabelFstOptions &opts)
245      : CacheImpl<A>(opts), fst_(fst.Copy()),
246        relabel_input_(false), relabel_output_(false) {
247    SetType("relabel");
248
249    uint64 props = fst.Properties(kCopyProperties, false);
250    SetProperties(RelabelProperties(props));
251    SetInputSymbols(old_isymbols);
252    SetOutputSymbols(old_osymbols);
253
254    if (old_isymbols && new_isymbols &&
255        old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
256      for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
257           syms_iter.Next()) {
258        input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
259      }
260      SetInputSymbols(new_isymbols);
261      relabel_input_ = true;
262    }
263
264    if (old_osymbols && new_osymbols &&
265        old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
266      for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
267           syms_iter.Next()) {
268        output_map_[syms_iter.Value()] =
269          new_osymbols->Find(syms_iter.Symbol());
270      }
271      SetOutputSymbols(new_osymbols);
272      relabel_output_ = true;
273    }
274  }
275
276  RelabelFstImpl(const RelabelFstImpl<A>& impl)
277      : CacheImpl<A>(impl),
278        fst_(impl.fst_->Copy(true)),
279        input_map_(impl.input_map_),
280        output_map_(impl.output_map_),
281        relabel_input_(impl.relabel_input_),
282        relabel_output_(impl.relabel_output_) {
283    SetType("relabel");
284    SetProperties(impl.Properties(), kCopyProperties);
285    SetInputSymbols(impl.InputSymbols());
286    SetOutputSymbols(impl.OutputSymbols());
287  }
288
289  ~RelabelFstImpl() { delete fst_; }
290
291  StateId Start() {
292    if (!HasStart()) {
293      StateId s = fst_->Start();
294      SetStart(s);
295    }
296    return CacheImpl<A>::Start();
297  }
298
299  Weight Final(StateId s) {
300    if (!HasFinal(s)) {
301      SetFinal(s, fst_->Final(s));
302    }
303    return CacheImpl<A>::Final(s);
304  }
305
306  size_t NumArcs(StateId s) {
307    if (!HasArcs(s)) {
308      Expand(s);
309    }
310    return CacheImpl<A>::NumArcs(s);
311  }
312
313  size_t NumInputEpsilons(StateId s) {
314    if (!HasArcs(s)) {
315      Expand(s);
316    }
317    return CacheImpl<A>::NumInputEpsilons(s);
318  }
319
320  size_t NumOutputEpsilons(StateId s) {
321    if (!HasArcs(s)) {
322      Expand(s);
323    }
324    return CacheImpl<A>::NumOutputEpsilons(s);
325  }
326
327  uint64 Properties() const { return Properties(kFstProperties); }
328
329  // Set error if found; return FST impl properties.
330  uint64 Properties(uint64 mask) const {
331    if ((mask & kError) && fst_->Properties(kError, false))
332      SetProperties(kError, kError);
333    return FstImpl<Arc>::Properties(mask);
334  }
335
336  void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
337    if (!HasArcs(s)) {
338      Expand(s);
339    }
340    CacheImpl<A>::InitArcIterator(s, data);
341  }
342
343  void Expand(StateId s) {
344    for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
345      A arc = aiter.Value();
346
347      // relabel input
348      if (relabel_input_) {
349        typename unordered_map<Label, Label>::iterator it =
350          input_map_.find(arc.ilabel);
351        if (it != input_map_.end()) { arc.ilabel = it->second; }
352      }
353
354      // relabel output
355      if (relabel_output_) {
356        typename unordered_map<Label, Label>::iterator it =
357          output_map_.find(arc.olabel);
358        if (it != output_map_.end()) { arc.olabel = it->second; }
359      }
360
361      PushArc(s, arc);
362    }
363    SetArcs(s);
364  }
365
366
367 private:
368  const Fst<A> *fst_;
369
370  unordered_map<Label, Label> input_map_;
371  unordered_map<Label, Label> output_map_;
372  bool relabel_input_;
373  bool relabel_output_;
374
375  void operator=(const RelabelFstImpl<A> &);  // disallow
376};
377
378
379//
380// \class RelabelFst
381// \brief Delayed implementation of arc relabeling
382//
383// This class attaches interface to implementation and handles
384// reference counting, delegating most methods to ImplToFst.
385template <class A>
386class RelabelFst : public ImplToFst< RelabelFstImpl<A> > {
387 public:
388  friend class ArcIterator< RelabelFst<A> >;
389  friend class StateIterator< RelabelFst<A> >;
390
391  typedef A Arc;
392  typedef typename A::Label   Label;
393  typedef typename A::Weight  Weight;
394  typedef typename A::StateId StateId;
395  typedef CacheState<A> State;
396  typedef RelabelFstImpl<A> Impl;
397
398  RelabelFst(const Fst<A>& fst,
399             const vector<pair<Label, Label> >& ipairs,
400             const vector<pair<Label, Label> >& opairs)
401      : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {}
402
403  RelabelFst(const Fst<A>& fst,
404             const vector<pair<Label, Label> >& ipairs,
405             const vector<pair<Label, Label> >& opairs,
406             const RelabelFstOptions &opts)
407      : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {}
408
409  RelabelFst(const Fst<A>& fst,
410             const SymbolTable* new_isymbols,
411             const SymbolTable* new_osymbols)
412      : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
413                                 fst.OutputSymbols(), new_osymbols,
414                                 RelabelFstOptions())) {}
415
416  RelabelFst(const Fst<A>& fst,
417             const SymbolTable* new_isymbols,
418             const SymbolTable* new_osymbols,
419             const RelabelFstOptions &opts)
420      : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
421                                 fst.OutputSymbols(), new_osymbols, opts)) {}
422
423  RelabelFst(const Fst<A>& fst,
424             const SymbolTable* old_isymbols,
425             const SymbolTable* new_isymbols,
426             const SymbolTable* old_osymbols,
427             const SymbolTable* new_osymbols)
428    : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
429                               new_osymbols, RelabelFstOptions())) {}
430
431  RelabelFst(const Fst<A>& fst,
432             const SymbolTable* old_isymbols,
433             const SymbolTable* new_isymbols,
434             const SymbolTable* old_osymbols,
435             const SymbolTable* new_osymbols,
436             const RelabelFstOptions &opts)
437    : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
438                               new_osymbols, opts)) {}
439
440  // See Fst<>::Copy() for doc.
441  RelabelFst(const RelabelFst<A> &fst, bool safe = false)
442    : ImplToFst<Impl>(fst, safe) {}
443
444  // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc.
445  virtual RelabelFst<A> *Copy(bool safe = false) const {
446    return new RelabelFst<A>(*this, safe);
447  }
448
449  virtual void InitStateIterator(StateIteratorData<A> *data) const;
450
451  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
452    return GetImpl()->InitArcIterator(s, data);
453  }
454
455 private:
456  // Makes visible to friends.
457  Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
458
459  void operator=(const RelabelFst<A> &fst);  // disallow
460};
461
462// Specialization for RelabelFst.
463template<class A>
464class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
465 public:
466  typedef typename A::StateId StateId;
467
468  explicit StateIterator(const RelabelFst<A> &fst)
469      : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
470
471  bool Done() const { return siter_.Done(); }
472
473  StateId Value() const { return s_; }
474
475  void Next() {
476    if (!siter_.Done()) {
477      ++s_;
478      siter_.Next();
479    }
480  }
481
482  void Reset() {
483    s_ = 0;
484    siter_.Reset();
485  }
486
487 private:
488  bool Done_() const { return Done(); }
489  StateId Value_() const { return Value(); }
490  void Next_() { Next(); }
491  void Reset_() { Reset(); }
492
493  const RelabelFstImpl<A> *impl_;
494  StateIterator< Fst<A> > siter_;
495  StateId s_;
496
497  DISALLOW_COPY_AND_ASSIGN(StateIterator);
498};
499
500
501// Specialization for RelabelFst.
502template <class A>
503class ArcIterator< RelabelFst<A> >
504    : public CacheArcIterator< RelabelFst<A> > {
505 public:
506  typedef typename A::StateId StateId;
507
508  ArcIterator(const RelabelFst<A> &fst, StateId s)
509      : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) {
510    if (!fst.GetImpl()->HasArcs(s))
511      fst.GetImpl()->Expand(s);
512  }
513
514 private:
515  DISALLOW_COPY_AND_ASSIGN(ArcIterator);
516};
517
518template <class A> inline
519void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
520  data->base = new StateIterator< RelabelFst<A> >(*this);
521}
522
523// Useful alias when using StdArc.
524typedef RelabelFst<StdArc> StdRelabelFst;
525
526}  // namespace fst
527
528#endif  // FST_LIB_RELABEL_H__
529