state-map.h revision 3da1eb108d36da35333b2d655202791af854996b
1// map.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Class to map over/transform states e.g., sort transitions
20// Consider using when operation does not change the number of states.
21
22#ifndef FST_LIB_STATE_MAP_H__
23#define FST_LIB_STATE_MAP_H__
24
25#include <algorithm>
26#include <tr1/unordered_map>
27using std::tr1::unordered_map;
28using std::tr1::unordered_multimap;
29#include <string>
30#include <utility>
31using std::pair; using std::make_pair;
32
33#include <fst/cache.h>
34#include <fst/arc-map.h>
35#include <fst/mutable-fst.h>
36
37
38namespace fst {
39
40// StateMapper Interface - class determinies how states are mapped.
41// Useful for implementing operations that do not change the number of states.
42//
43// class StateMapper {
44//  public:
45//   typedef A FromArc;
46//   typedef B ToArc;
47//
48//   // Typical constructor
49//   StateMapper(const Fst<A> &fst);
50//   // Required copy constructor that allows updating Fst argument;
51//   // pass only if relevant and changed.
52//   StateMapper(const StateMapper &mapper, const Fst<A> *fst = 0);
53//
54//   // Specifies initial state of result
55//   B::StateId Start() const;
56//   // Specifies state's final weight in result
57//   B::Weight Final(B::StateId s) const;
58//
59//   // These methods iterate through a state's arcs in result
60//   // Specifies state to iterate over
61//   void SetState(B::StateId s);
62//   // End of arcs?
63//   bool Done() const;
64//   // Current arc
65
66//   const B &Value() const;
67//   // Advance to next arc (when !Done)
68//   void Next();
69//
70//   // Specifies input symbol table action the mapper requires (see above).
71//   MapSymbolsAction InputSymbolsAction() const;
72//   // Specifies output symbol table action the mapper requires (see above).
73//   MapSymbolsAction OutputSymbolsAction() const;
74//   // This specifies the known properties of an Fst mapped by this
75//   // mapper. It takes as argument the input Fst's known properties.
76//   uint64 Properties(uint64 props) const;
77// };
78//
79// We include a various state map versions below. One dimension of
80// variation is whether the mapping mutates its input, writes to a
81// new result Fst, or is an on-the-fly Fst. Another dimension is how
82// we pass the mapper. We allow passing the mapper by pointer
83// for cases that we need to change the state of the user's mapper.
84// We also include map versions that pass the mapper
85// by value or const reference when this suffices.
86
87// Maps an arc type A using a mapper function object C, passed
88// by pointer.  This version modifies its Fst input.
89template<class A, class C>
90void StateMap(MutableFst<A> *fst, C* mapper) {
91  typedef typename A::StateId StateId;
92  typedef typename A::Weight Weight;
93
94  if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS)
95    fst->SetInputSymbols(0);
96
97  if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS)
98    fst->SetOutputSymbols(0);
99
100  if (fst->Start() == kNoStateId)
101    return;
102
103  uint64 props = fst->Properties(kFstProperties, false);
104
105  fst->SetStart(mapper->Start());
106
107  for (StateId s = 0; s < fst->NumStates(); ++s) {
108    mapper->SetState(s);
109    fst->DeleteArcs(s);
110    for (; !mapper->Done(); mapper->Next())
111      fst->AddArc(s, mapper->Value());
112    fst->SetFinal(s, mapper->Final(s));
113  }
114
115  fst->SetProperties(mapper->Properties(props), kFstProperties);
116}
117
118// Maps an arc type A using a mapper function object C, passed
119// by value.  This version modifies its Fst input.
120template<class A, class C>
121void StateMap(MutableFst<A> *fst, C mapper) {
122  StateMap(fst, &mapper);
123}
124
125
126// Maps an arc type A to an arc type B using mapper function
127// object C, passed by pointer. This version writes the mapped
128// input Fst to an output MutableFst.
129template<class A, class B, class C>
130void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C* mapper) {
131  typedef typename A::StateId StateId;
132  typedef typename A::Weight Weight;
133
134  ofst->DeleteStates();
135
136  if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS)
137    ofst->SetInputSymbols(ifst.InputSymbols());
138  else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS)
139    ofst->SetInputSymbols(0);
140
141  if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS)
142    ofst->SetOutputSymbols(ifst.OutputSymbols());
143  else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS)
144    ofst->SetOutputSymbols(0);
145
146  uint64 iprops = ifst.Properties(kCopyProperties, false);
147
148  if (ifst.Start() == kNoStateId) {
149    if (iprops & kError) ofst->SetProperties(kError, kError);
150    return;
151  }
152
153  // Add all states.
154  if (ifst.Properties(kExpanded, false))
155    ofst->ReserveStates(CountStates(ifst));
156  for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next())
157    ofst->AddState();
158
159  ofst->SetStart(mapper->Start());
160
161  for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next()) {
162    StateId s = siter.Value();
163    mapper->SetState(s);
164    for (; !mapper->Done(); mapper->Next())
165      ofst->AddArc(s, mapper->Value());
166    ofst->SetFinal(s, mapper->Final(s));
167  }
168
169  uint64 oprops = ofst->Properties(kFstProperties, false);
170  ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties);
171}
172
173// Maps an arc type A to an arc type B using mapper function
174// object C, passed by value. This version writes the mapped input
175// Fst to an output MutableFst.
176template<class A, class B, class C>
177void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) {
178  StateMap(ifst, ofst, &mapper);
179}
180
181typedef CacheOptions StateMapFstOptions;
182
183template <class A, class B, class C> class StateMapFst;
184
185// Implementation of delayed StateMapFst.
186template <class A, class B, class C>
187class StateMapFstImpl : public CacheImpl<B> {
188 public:
189  using FstImpl<B>::SetType;
190  using FstImpl<B>::SetProperties;
191  using FstImpl<B>::SetInputSymbols;
192  using FstImpl<B>::SetOutputSymbols;
193
194  using VectorFstBaseImpl<typename CacheImpl<B>::State>::NumStates;
195
196  using CacheImpl<B>::PushArc;
197  using CacheImpl<B>::HasArcs;
198  using CacheImpl<B>::HasFinal;
199  using CacheImpl<B>::HasStart;
200  using CacheImpl<B>::SetArcs;
201  using CacheImpl<B>::SetFinal;
202  using CacheImpl<B>::SetStart;
203
204  friend class StateIterator< StateMapFst<A, B, C> >;
205
206  typedef B Arc;
207  typedef typename B::Weight Weight;
208  typedef typename B::StateId StateId;
209
210  StateMapFstImpl(const Fst<A> &fst, const C &mapper,
211                 const StateMapFstOptions& opts)
212      : CacheImpl<B>(opts),
213        fst_(fst.Copy()),
214        mapper_(new C(mapper, fst_)),
215        own_mapper_(true) {
216    Init();
217  }
218
219  StateMapFstImpl(const Fst<A> &fst, C *mapper,
220                 const StateMapFstOptions& opts)
221      : CacheImpl<B>(opts),
222        fst_(fst.Copy()),
223        mapper_(mapper),
224        own_mapper_(false) {
225    Init();
226  }
227
228  StateMapFstImpl(const StateMapFstImpl<A, B, C> &impl)
229      : CacheImpl<B>(impl),
230        fst_(impl.fst_->Copy(true)),
231        mapper_(new C(*impl.mapper_, fst_)),
232        own_mapper_(true) {
233    Init();
234  }
235
236  ~StateMapFstImpl() {
237    delete fst_;
238    if (own_mapper_) delete mapper_;
239  }
240
241  StateId Start() {
242    if (!HasStart())
243      SetStart(mapper_->Start());
244    return CacheImpl<B>::Start();
245  }
246
247  Weight Final(StateId s) {
248    if (!HasFinal(s))
249      SetFinal(s, mapper_->Final(s));
250    return CacheImpl<B>::Final(s);
251  }
252
253  size_t NumArcs(StateId s) {
254    if (!HasArcs(s))
255      Expand(s);
256    return CacheImpl<B>::NumArcs(s);
257  }
258
259  size_t NumInputEpsilons(StateId s) {
260    if (!HasArcs(s))
261      Expand(s);
262    return CacheImpl<B>::NumInputEpsilons(s);
263  }
264
265  size_t NumOutputEpsilons(StateId s) {
266    if (!HasArcs(s))
267      Expand(s);
268    return CacheImpl<B>::NumOutputEpsilons(s);
269  }
270
271  void InitStateIterator(StateIteratorData<A> *data) const {
272    fst_->InitStateIterator(data);
273  }
274
275  void InitArcIterator(StateId s, ArcIteratorData<B> *data) {
276    if (!HasArcs(s))
277      Expand(s);
278    CacheImpl<B>::InitArcIterator(s, data);
279  }
280
281  uint64 Properties() const { return Properties(kFstProperties); }
282
283  // Set error if found; return FST impl properties.
284  uint64 Properties(uint64 mask) const {
285    if ((mask & kError) && (fst_->Properties(kError, false) ||
286                            (mapper_->Properties(0) & kError)))
287      SetProperties(kError, kError);
288    return FstImpl<Arc>::Properties(mask);
289  }
290
291  void Expand(StateId s) {
292    // Add exiting arcs.
293    for (mapper_->SetState(s); !mapper_->Done(); mapper_->Next())
294      PushArc(s, mapper_->Value());
295    SetArcs(s);
296  }
297
298  const Fst<A> &GetFst() const {
299    return *fst_;
300  }
301
302 private:
303  void Init() {
304    SetType("statemap");
305
306    if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS)
307      SetInputSymbols(fst_->InputSymbols());
308    else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS)
309      SetInputSymbols(0);
310
311    if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS)
312      SetOutputSymbols(fst_->OutputSymbols());
313    else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS)
314      SetOutputSymbols(0);
315
316    uint64 props = fst_->Properties(kCopyProperties, false);
317    SetProperties(mapper_->Properties(props));
318  }
319
320  const Fst<A> *fst_;
321  C*  mapper_;
322  bool own_mapper_;
323
324  void operator=(const StateMapFstImpl<A, B, C> &);  // disallow
325};
326
327
328// Maps an arc type A to an arc type B using Mapper function object
329// C. This version is a delayed Fst.
330template <class A, class B, class C>
331class StateMapFst : public ImplToFst< StateMapFstImpl<A, B, C> > {
332 public:
333  friend class ArcIterator< StateMapFst<A, B, C> >;
334
335  typedef B Arc;
336  typedef typename B::Weight Weight;
337  typedef typename B::StateId StateId;
338  typedef CacheState<B> State;
339  typedef StateMapFstImpl<A, B, C> Impl;
340
341  StateMapFst(const Fst<A> &fst, const C &mapper,
342              const StateMapFstOptions& opts)
343      : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {}
344
345  StateMapFst(const Fst<A> &fst, C* mapper, const StateMapFstOptions& opts)
346      : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {}
347
348  StateMapFst(const Fst<A> &fst, const C &mapper)
349      : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {}
350
351  StateMapFst(const Fst<A> &fst, C* mapper)
352      : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {}
353
354  // See Fst<>::Copy() for doc.
355  StateMapFst(const StateMapFst<A, B, C> &fst, bool safe = false)
356    : ImplToFst<Impl>(fst, safe) {}
357
358  // Get a copy of this StateMapFst. See Fst<>::Copy() for further doc.
359  virtual StateMapFst<A, B, C> *Copy(bool safe = false) const {
360    return new StateMapFst<A, B, C>(*this, safe);
361  }
362
363  virtual void InitStateIterator(StateIteratorData<A> *data) const {
364    GetImpl()->InitStateIterator(data);
365  }
366
367  virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const {
368    GetImpl()->InitArcIterator(s, data);
369  }
370
371 protected:
372  Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
373
374 private:
375  void operator=(const StateMapFst<A, B, C> &fst);  // disallow
376};
377
378
379// Specialization for StateMapFst.
380template <class A, class B, class C>
381class ArcIterator< StateMapFst<A, B, C> >
382    : public CacheArcIterator< StateMapFst<A, B, C> > {
383 public:
384  typedef typename A::StateId StateId;
385
386  ArcIterator(const StateMapFst<A, B, C> &fst, StateId s)
387      : CacheArcIterator< StateMapFst<A, B, C> >(fst.GetImpl(), s) {
388    if (!fst.GetImpl()->HasArcs(s))
389      fst.GetImpl()->Expand(s);
390  }
391
392 private:
393  DISALLOW_COPY_AND_ASSIGN(ArcIterator);
394};
395
396//
397// Utility Mappers
398//
399
400// Mapper that returns its input.
401template <class A>
402class IdentityStateMapper {
403 public:
404  typedef A FromArc;
405  typedef A ToArc;
406
407  typedef typename A::StateId StateId;
408  typedef typename A::Weight Weight;
409
410  explicit IdentityStateMapper(const Fst<A> &fst) : fst_(fst), aiter_(0) {}
411
412  // Allows updating Fst argument; pass only if changed.
413  IdentityStateMapper(const IdentityStateMapper<A> &mapper,
414                      const Fst<A> *fst = 0)
415      : fst_(fst ? *fst : mapper.fst_), aiter_(0) {}
416
417  ~IdentityStateMapper() { delete aiter_; }
418
419  StateId Start() const { return fst_.Start(); }
420
421  Weight Final(StateId s) const { return fst_.Final(s); }
422
423  void SetState(StateId s) {
424    if (aiter_) delete aiter_;
425    aiter_ = new ArcIterator< Fst<A> >(fst_, s);
426  }
427
428  bool Done() const { return aiter_->Done(); }
429  const A &Value() const { return aiter_->Value(); }
430  void Next() { aiter_->Next(); }
431
432  MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
433  MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;}
434
435  uint64 Properties(uint64 props) const { return props; }
436
437 private:
438  const Fst<A> &fst_;
439  ArcIterator< Fst<A> > *aiter_;
440};
441
442template <class A>
443class ArcSumMapper {
444 public:
445  typedef A FromArc;
446  typedef A ToArc;
447
448  typedef typename A::StateId StateId;
449  typedef typename A::Weight Weight;
450
451  explicit ArcSumMapper(const Fst<A> &fst) : fst_(fst), i_(0) {}
452
453  // Allows updating Fst argument; pass only if changed.
454  ArcSumMapper(const ArcSumMapper<A> &mapper,
455               const Fst<A> *fst = 0)
456      : fst_(fst ? *fst : mapper.fst_), i_(0) {}
457
458  StateId Start() const { return fst_.Start(); }
459  Weight Final(StateId s) const { return fst_.Final(s); }
460
461  void SetState(StateId s) {
462    i_ = 0;
463    arcs_.clear();
464    arcs_.reserve(fst_.NumArcs(s));
465    for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next())
466      arcs_.push_back(aiter.Value());
467
468    // First sorts the exiting arcs by input label, output label
469    // and destination state and then sums weights of arcs with
470    // the same input label, output label, and destination state.
471    sort(arcs_.begin(), arcs_.end(), comp_);
472    size_t narcs = 0;
473    for (size_t i = 0; i < arcs_.size(); ++i) {
474      if (narcs > 0 && equal_(arcs_[i], arcs_[narcs - 1])) {
475        arcs_[narcs - 1].weight = Plus(arcs_[narcs - 1].weight,
476                                       arcs_[i].weight);
477      } else {
478        arcs_[narcs++] = arcs_[i];
479      }
480    }
481    arcs_.resize(narcs);
482  }
483
484  bool Done() const { return i_ >= arcs_.size(); }
485  const A &Value() const { return arcs_[i_]; }
486  void Next() { ++i_; }
487
488  MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
489  MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
490
491  uint64 Properties(uint64 props) const {
492    return props & kArcSortProperties &
493        kDeleteArcsProperties & kWeightInvariantProperties;
494  }
495
496 private:
497  struct Compare {
498    bool operator()(const A& x, const A& y) {
499      if (x.ilabel < y.ilabel) return true;
500      if (x.ilabel > y.ilabel) return false;
501      if (x.olabel < y.olabel) return true;
502      if (x.olabel > y.olabel) return false;
503      if (x.nextstate < y.nextstate) return true;
504      if (x.nextstate > y.nextstate) return false;
505      return false;
506    }
507  };
508
509  struct Equal {
510    bool operator()(const A& x, const A& y) {
511      return (x.ilabel == y.ilabel &&
512              x.olabel == y.olabel &&
513              x.nextstate == y.nextstate);
514    }
515  };
516
517  const Fst<A> &fst_;
518  Compare comp_;
519  Equal equal_;
520  vector<A> arcs_;
521  ssize_t i_;               // current arc position
522
523  void operator=(const ArcSumMapper<A> &);  // disallow
524};
525
526template <class A>
527class ArcUniqueMapper {
528 public:
529  typedef A FromArc;
530  typedef A ToArc;
531
532  typedef typename A::StateId StateId;
533  typedef typename A::Weight Weight;
534
535  explicit ArcUniqueMapper(const Fst<A> &fst) : fst_(fst), i_(0) {}
536
537  // Allows updating Fst argument; pass only if changed.
538  ArcUniqueMapper(const ArcSumMapper<A> &mapper,
539                  const Fst<A> *fst = 0)
540      : fst_(fst ? *fst : mapper.fst_), i_(0) {}
541
542  StateId Start() const { return fst_.Start(); }
543  Weight Final(StateId s) const { return fst_.Final(s); }
544
545  void SetState(StateId s) {
546    i_ = 0;
547    arcs_.clear();
548    arcs_.reserve(fst_.NumArcs(s));
549    for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next())
550      arcs_.push_back(aiter.Value());
551
552    // First sorts the exiting arcs by input label, output label
553    // and destination state and then uniques identical arcs
554    sort(arcs_.begin(), arcs_.end(), comp_);
555    typename vector<A>::iterator unique_end =
556        unique(arcs_.begin(), arcs_.end(), equal_);
557    arcs_.resize(unique_end - arcs_.begin());
558  }
559
560  bool Done() const { return i_ >= arcs_.size(); }
561  const A &Value() const { return arcs_[i_]; }
562  void Next() { ++i_; }
563
564  MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
565  MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; }
566
567  uint64 Properties(uint64 props) const {
568    return props & kArcSortProperties & kDeleteArcsProperties;
569  }
570
571 private:
572  struct Compare {
573    bool operator()(const A& x, const A& y) {
574      if (x.ilabel < y.ilabel) return true;
575      if (x.ilabel > y.ilabel) return false;
576      if (x.olabel < y.olabel) return true;
577      if (x.olabel > y.olabel) return false;
578      if (x.nextstate < y.nextstate) return true;
579      if (x.nextstate > y.nextstate) return false;
580      return false;
581    }
582  };
583
584  struct Equal {
585    bool operator()(const A& x, const A& y) {
586      return (x.ilabel == y.ilabel &&
587              x.olabel == y.olabel &&
588              x.nextstate == y.nextstate &&
589              x.weight == y.weight);
590    }
591  };
592
593  const Fst<A> &fst_;
594  Compare comp_;
595  Equal equal_;
596  vector<A> arcs_;
597  ssize_t i_;               // current arc position
598
599  void operator=(const ArcUniqueMapper<A> &);  // disallow
600};
601
602
603}  // namespace fst
604
605#endif  // FST_LIB_STATE_MAP_H__
606