synchronize.h revision 4a68b3365c8c50aa93505e99ead2565ab73dcdb0
1// synchronize.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Author: allauzen@cs.nyu.edu (Cyril Allauzen)
16//
17// \file
18// Synchronize an FST with bounded delay.
19
20#ifndef FST_LIB_SYNCHRONIZE_H__
21#define FST_LIB_SYNCHRONIZE_H__
22
23#include <algorithm>
24
25#include <ext/hash_map>
26using __gnu_cxx::hash_map;
27
28#include "fst/lib/cache.h"
29#include "fst/lib/test-properties.h"
30
31namespace fst {
32
33typedef CacheOptions SynchronizeFstOptions;
34
35
36// Implementation class for SynchronizeFst
37template <class A>
38class SynchronizeFstImpl
39    : public CacheImpl<A> {
40 public:
41  using FstImpl<A>::SetType;
42  using FstImpl<A>::SetProperties;
43  using FstImpl<A>::Properties;
44  using FstImpl<A>::SetInputSymbols;
45  using FstImpl<A>::SetOutputSymbols;
46
47  using CacheBaseImpl< CacheState<A> >::HasStart;
48  using CacheBaseImpl< CacheState<A> >::HasFinal;
49  using CacheBaseImpl< CacheState<A> >::HasArcs;
50
51  typedef A Arc;
52  typedef typename A::Label Label;
53  typedef typename A::Weight Weight;
54  typedef typename A::StateId StateId;
55
56  typedef basic_string<Label> String;
57
58  struct Element {
59    Element() {}
60
61    Element(StateId s, const String *i, const String *o)
62        : state(s), istring(i), ostring(o) {}
63
64    StateId state;     // Input state Id
65    const String *istring;     // Residual input labels
66    const String *ostring;     // Residual output labels
67    // Residual strings are represented by const pointers to
68    // basic_string<Label> and are stored in a hash_set. The pointed
69    // memory is owned by the hash_set string_set_.
70  };
71
72  SynchronizeFstImpl(const Fst<A> &fst, const SynchronizeFstOptions &opts)
73      : CacheImpl<A>(opts), fst_(fst.Copy()) {
74    SetType("synchronize");
75    uint64 props = fst.Properties(kFstProperties, false);
76    SetProperties(SynchronizeProperties(props), kCopyProperties);
77
78    SetInputSymbols(fst.InputSymbols());
79    SetOutputSymbols(fst.OutputSymbols());
80  }
81
82  ~SynchronizeFstImpl() {
83    delete fst_;
84    // Extract pointers from the hash set
85    vector<const String*> strings;
86    typename StringSet::iterator it = string_set_.begin();
87    for (; it != string_set_.end(); ++it)
88      strings.push_back(*it);
89    // Free the extracted pointers
90    for (size_t i = 0; i < strings.size(); ++i)
91      delete strings[i];
92  }
93
94  StateId Start() {
95    if (!HasStart()) {
96      StateId s = fst_->Start();
97      if (s == kNoStateId)
98        return kNoStateId;
99      const String *empty = FindString(new String());
100      StateId start = FindState(Element(fst_->Start(), empty, empty));
101      SetStart(start);
102    }
103    return CacheImpl<A>::Start();
104  }
105
106  Weight Final(StateId s) {
107    if (!HasFinal(s)) {
108      const Element &e = elements_[s];
109      Weight w = e.state == kNoStateId ? Weight::One() : fst_->Final(e.state);
110      if ((w != Weight::Zero()) && (e.istring)->empty() && (e.ostring)->empty())
111        SetFinal(s, w);
112      else
113        SetFinal(s, Weight::Zero());
114    }
115    return CacheImpl<A>::Final(s);
116  }
117
118  size_t NumArcs(StateId s) {
119    if (!HasArcs(s))
120      Expand(s);
121    return CacheImpl<A>::NumArcs(s);
122  }
123
124  size_t NumInputEpsilons(StateId s) {
125    if (!HasArcs(s))
126      Expand(s);
127    return CacheImpl<A>::NumInputEpsilons(s);
128  }
129
130  size_t NumOutputEpsilons(StateId s) {
131    if (!HasArcs(s))
132      Expand(s);
133    return CacheImpl<A>::NumOutputEpsilons(s);
134  }
135
136  void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
137    if (!HasArcs(s))
138      Expand(s);
139    CacheImpl<A>::InitArcIterator(s, data);
140  }
141
142  // Returns the first character of the string obtained by
143  // concatenating s and l.
144  Label Car(const String *s, Label l = 0) const {
145    if (!s->empty())
146      return (*s)[0];
147    else
148      return l;
149  }
150
151  // Computes the residual string obtained by removing the first
152  // character in the concatenation of s and l.
153  const String *Cdr(const String *s, Label l = 0) {
154    String *r = new String();
155    for (int i = 1; i < s->size(); ++i)
156      r->push_back((*s)[i]);
157    if (l && !(s->empty())) r->push_back(l);
158    return FindString(r);
159  }
160
161  // Computes the concatenation of s and l.
162  const String *Concat(const String *s, Label l = 0) {
163    String *r = new String();
164    for (int i = 0; i < s->size(); ++i)
165      r->push_back((*s)[i]);
166    if (l) r->push_back(l);
167    return FindString(r);
168  }
169
170  // Tests if the concatenation of s and l is empty
171  bool Empty(const String *s, Label l = 0) const {
172    if (s->empty())
173      return l == 0;
174    else
175      return false;
176  }
177
178  // Finds the string pointed by s in the hash set. Transfers the
179  // pointer ownership to the hash set.
180  const String *FindString(const String *s) {
181    typename StringSet::iterator it = string_set_.find(s);
182    if (it != string_set_.end()) {
183      delete s;
184      return (*it);
185    } else {
186      string_set_.insert(s);
187      return s;
188    }
189  }
190
191  // Finds state corresponding to an element. Creates new state
192  // if element not found.
193  StateId FindState(const Element &e) {
194    typename ElementMap::iterator eit = element_map_.find(e);
195    if (eit != element_map_.end()) {
196      return (*eit).second;
197    } else {
198      StateId s = elements_.size();
199      elements_.push_back(e);
200      element_map_.insert(pair<const Element, StateId>(e, s));
201      return s;
202    }
203  }
204
205
206  // Computes the outgoing transitions from a state, creating new destination
207  // states as needed.
208  void Expand(StateId s) {
209    Element e = elements_[s];
210
211    if (e.state != kNoStateId)
212      for (ArcIterator< Fst<A> > ait(*fst_, e.state);
213           !ait.Done();
214           ait.Next()) {
215        const A &arc = ait.Value();
216        if (!Empty(e.istring, arc.ilabel)  && !Empty(e.ostring, arc.olabel)) {
217          const String *istring = Cdr(e.istring, arc.ilabel);
218          const String *ostring = Cdr(e.ostring, arc.olabel);
219          StateId d = FindState(Element(arc.nextstate, istring, ostring));
220          AddArc(s, Arc(Car(e.istring, arc.ilabel),
221                        Car(e.ostring, arc.olabel), arc.weight, d));
222        } else {
223          const String *istring = Concat(e.istring, arc.ilabel);
224          const String *ostring = Concat(e.ostring, arc.olabel);
225          StateId d = FindState(Element(arc.nextstate, istring, ostring));
226          AddArc(s, Arc(0 , 0, arc.weight, d));
227        }
228      }
229
230    Weight w = e.state == kNoStateId ? Weight::One() : fst_->Final(e.state);
231    if ((w != Weight::Zero()) &&
232        ((e.istring)->size() + (e.ostring)->size() > 0)) {
233      const String *istring = Cdr(e.istring);
234      const String *ostring = Cdr(e.ostring);
235      StateId d = FindState(Element(kNoStateId, istring, ostring));
236      AddArc(s, Arc(Car(e.istring), Car(e.ostring), w, d));
237    }
238    SetArcs(s);
239  }
240
241 private:
242  // Equality function for Elements, assume strings have been hashed.
243  class ElementEqual {
244   public:
245    bool operator()(const Element &x, const Element &y) const {
246      return x.state == y.state &&
247              x.istring == y.istring &&
248              x.ostring == y.ostring;
249    }
250  };
251
252  // Hash function for Elements to Fst states.
253  class ElementKey {
254   public:
255    size_t operator()(const Element &x) const {
256      size_t key = x.state;
257      key = (key << 1) ^ (x.istring)->size();
258      for (size_t i = 0; i < (x.istring)->size(); ++i)
259        key = (key << 1) ^ (*x.istring)[i];
260      key = (key << 1) ^ (x.ostring)->size();
261      for (size_t i = 0; i < (x.ostring)->size(); ++i)
262        key = (key << 1) ^ (*x.ostring)[i];
263      return key;
264    }
265  };
266
267  // Equality function for strings
268  class StringEqual {
269   public:
270    bool operator()(const String * const &x, const String * const &y) const {
271      if (x->size() != y->size()) return false;
272      for (size_t i = 0; i < x->size(); ++i)
273        if ((*x)[i] != (*y)[i]) return false;
274      return true;
275    }
276  };
277
278  // Hash function for set of strings
279  class StringKey{
280   public:
281    size_t operator()(const String * const & x) const {
282      size_t key = x->size();
283      for (size_t i = 0; i < x->size(); ++i)
284        key = (key << 1) ^ (*x)[i];
285      return key;
286    }
287  };
288
289
290  typedef hash_map<Element, StateId, ElementKey, ElementEqual> ElementMap;
291  typedef hash_set<const String*, StringKey, StringEqual> StringSet;
292
293  const Fst<A> *fst_;
294  vector<Element> elements_;  // mapping Fst state to Elements
295  ElementMap element_map_;    // mapping Elements to Fst state
296  StringSet string_set_;
297
298  DISALLOW_EVIL_CONSTRUCTORS(SynchronizeFstImpl);
299};
300
301
302// Synchronizes a transducer. This version is a delayed Fst.  The
303// result will be an equivalent FST that has the property that during
304// the traversal of a path, the delay is either zero or strictly
305// increasing, where the delay is the difference between the number of
306// non-epsilon output labels and input labels along the path.
307//
308// For the algorithm to terminate, the input transducer must have
309// bounded delay, i.e., the delay of every cycle must be zero.
310//
311// Complexity:
312// - A has bounded delay: exponential
313// - A does not have bounded delay: does not terminate
314//
315// References:
316// - Mehryar Mohri. Edit-Distance of Weighted Automata: General
317//   Definitions and Algorithms, International Journal of Computer
318//   Science, 14(6): 957-982 (2003).
319template <class A>
320class SynchronizeFst : public Fst<A> {
321 public:
322  friend class ArcIterator< SynchronizeFst<A> >;
323  friend class CacheStateIterator< SynchronizeFst<A> >;
324  friend class CacheArcIterator< SynchronizeFst<A> >;
325
326  typedef A Arc;
327  typedef typename A::Weight Weight;
328  typedef typename A::StateId StateId;
329  typedef CacheState<A> State;
330
331  SynchronizeFst(const Fst<A> &fst)
332      : impl_(new SynchronizeFstImpl<A>(fst, SynchronizeFstOptions())) {}
333
334  SynchronizeFst(const Fst<A> &fst,  const SynchronizeFstOptions &opts)
335      : impl_(new SynchronizeFstImpl<A>(fst, opts)) {}
336
337  SynchronizeFst(const SynchronizeFst<A> &fst) : impl_(fst.impl_) {
338    impl_->IncrRefCount();
339  }
340
341  virtual ~SynchronizeFst() { if (!impl_->DecrRefCount()) delete impl_; }
342
343  virtual StateId Start() const { return impl_->Start(); }
344
345  virtual Weight Final(StateId s) const { return impl_->Final(s); }
346
347  virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
348
349  virtual size_t NumInputEpsilons(StateId s) const {
350    return impl_->NumInputEpsilons(s);
351  }
352
353  virtual size_t NumOutputEpsilons(StateId s) const {
354    return impl_->NumOutputEpsilons(s);
355  }
356
357  virtual uint64 Properties(uint64 mask, bool test) const {
358    if (test) {
359      uint64 known, test = TestProperties(*this, mask, &known);
360      impl_->SetProperties(test, known);
361      return test & mask;
362    } else {
363      return impl_->Properties(mask);
364    }
365  }
366
367  virtual const string& Type() const { return impl_->Type(); }
368
369  virtual SynchronizeFst<A> *Copy() const {
370    return new SynchronizeFst<A>(*this);
371  }
372
373  virtual const SymbolTable* InputSymbols() const {
374    return impl_->InputSymbols();
375  }
376
377  virtual const SymbolTable* OutputSymbols() const {
378    return impl_->OutputSymbols();
379  }
380
381  virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
382
383  virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
384    impl_->InitArcIterator(s, data);
385  }
386
387 private:
388  SynchronizeFstImpl<A> *Impl() { return impl_; }
389
390  SynchronizeFstImpl<A> *impl_;
391
392  void operator=(const SynchronizeFst<A> &fst);  // Disallow
393};
394
395
396// Specialization for SynchronizeFst.
397template<class A>
398class StateIterator< SynchronizeFst<A> >
399    : public CacheStateIterator< SynchronizeFst<A> > {
400 public:
401  explicit StateIterator(const SynchronizeFst<A> &fst)
402      : CacheStateIterator< SynchronizeFst<A> >(fst) {}
403};
404
405
406// Specialization for SynchronizeFst.
407template <class A>
408class ArcIterator< SynchronizeFst<A> >
409    : public CacheArcIterator< SynchronizeFst<A> > {
410 public:
411  typedef typename A::StateId StateId;
412
413  ArcIterator(const SynchronizeFst<A> &fst, StateId s)
414      : CacheArcIterator< SynchronizeFst<A> >(fst, s) {
415    if (!fst.impl_->HasArcs(s))
416      fst.impl_->Expand(s);
417  }
418
419 private:
420  DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
421};
422
423
424template <class A> inline
425void SynchronizeFst<A>::InitStateIterator(StateIteratorData<A> *data) const
426{
427  data->base = new StateIterator< SynchronizeFst<A> >(*this);
428}
429
430
431
432// Synchronizes a transducer. This version writes the synchronized
433// result to a MutableFst.  The result will be an equivalent FST that
434// has the property that during the traversal of a path, the delay is
435// either zero or strictly increasing, where the delay is the
436// difference between the number of non-epsilon output labels and
437// input labels along the path.
438//
439// For the algorithm to terminate, the input transducer must have
440// bounded delay, i.e., the delay of every cycle must be zero.
441//
442// Complexity:
443// - A has bounded delay: exponential
444// - A does not have bounded delay: does not terminate
445//
446// References:
447// - Mehryar Mohri. Edit-Distance of Weighted Automata: General
448//   Definitions and Algorithms, International Journal of Computer
449//   Science, 14(6): 957-982 (2003).
450template<class Arc>
451void Synchronize(const Fst<Arc> &ifst, MutableFst<Arc> *ofst) {
452  SynchronizeFstOptions opts;
453  opts.gc_limit = 0;  // Cache only the last state for fastest copy.
454  *ofst = SynchronizeFst<Arc>(ifst, opts);
455}
456
457}  // namespace fst
458
459#endif // FST_LIB_SYNCHRONIZE_H__
460