mutable-fst.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
1// mutable-fst.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Expanded FST augmented with mutators - interface class definition
20// and mutable arc iterator interface.
21//
22
23#ifndef FST_LIB_MUTABLE_FST_H__
24#define FST_LIB_MUTABLE_FST_H__
25
26#include <stddef.h>
27#include <sys/types.h>
28#include <string>
29#include <vector>
30using std::vector;
31
32#include <fst/expanded-fst.h>
33
34
35namespace fst {
36
37template <class A> class MutableArcIteratorData;
38
39// An expanded FST plus mutators (use MutableArcIterator to modify arcs).
40template <class A>
41class MutableFst : public ExpandedFst<A> {
42 public:
43  typedef A Arc;
44  typedef typename A::Weight Weight;
45  typedef typename A::StateId StateId;
46
47  virtual MutableFst<A> &operator=(const Fst<A> &fst) = 0;
48
49  MutableFst<A> &operator=(const MutableFst<A> &fst) {
50    return operator=(static_cast<const Fst<A> &>(fst));
51  }
52
53  virtual void SetStart(StateId) = 0;           // Set the initial state
54  virtual void SetFinal(StateId, Weight) = 0;   // Set a state's final weight
55  virtual void SetProperties(uint64 props,
56                             uint64 mask) = 0;  // Set property bits wrt mask
57
58  virtual StateId AddState() = 0;               // Add a state, return its ID
59  virtual void AddArc(StateId, const A &arc) = 0;   // Add an arc to state
60
61  virtual void DeleteStates(const vector<StateId>&) = 0;  // Delete some states
62  virtual void DeleteStates() = 0;              // Delete all states
63  virtual void DeleteArcs(StateId, size_t n) = 0;  // Delete some arcs at state
64  virtual void DeleteArcs(StateId) = 0;         // Delete all arcs at state
65
66  virtual void ReserveStates(StateId n) { }  // Optional, best effort only.
67  virtual void ReserveArcs(StateId s, size_t n) { }  // Optional, Best effort.
68
69  // Return input label symbol table; return NULL if not specified
70  virtual const SymbolTable* InputSymbols() const = 0;
71  // Return output label symbol table; return NULL if not specified
72  virtual const SymbolTable* OutputSymbols() const = 0;
73
74  // Return input label symbol table; return NULL if not specified
75  virtual SymbolTable* MutableInputSymbols() = 0;
76  // Return output label symbol table; return NULL if not specified
77  virtual SymbolTable* MutableOutputSymbols() = 0;
78
79  // Set input label symbol table; NULL signifies not unspecified
80  virtual void SetInputSymbols(const SymbolTable* isyms) = 0;
81  // Set output label symbol table; NULL signifies not unspecified
82  virtual void SetOutputSymbols(const SymbolTable* osyms) = 0;
83
84  // Get a copy of this MutableFst. See Fst<>::Copy() for further doc.
85  virtual MutableFst<A> *Copy(bool safe = false) const = 0;
86
87  // Read an MutableFst from an input stream; return NULL on error.
88  static MutableFst<A> *Read(istream &strm, const FstReadOptions &opts) {
89    FstReadOptions ropts(opts);
90    FstHeader hdr;
91    if (ropts.header)
92      hdr = *opts.header;
93    else {
94      if (!hdr.Read(strm, opts.source))
95        return 0;
96      ropts.header = &hdr;
97    }
98    if (!(hdr.Properties() & kMutable)) {
99      LOG(ERROR) << "MutableFst::Read: Not an MutableFst: " << ropts.source;
100      return 0;
101    }
102    FstRegister<A> *registr = FstRegister<A>::GetRegister();
103    const typename FstRegister<A>::Reader reader =
104      registr->GetReader(hdr.FstType());
105    if (!reader) {
106      LOG(ERROR) << "MutableFst::Read: Unknown FST type \"" << hdr.FstType()
107                 << "\" (arc type = \"" << A::Type()
108                 << "\"): " << ropts.source;
109      return 0;
110    }
111    Fst<A> *fst = reader(strm, ropts);
112    if (!fst) return 0;
113    return static_cast<MutableFst<A> *>(fst);
114  }
115
116  // Read a MutableFst from a file; return NULL on error.
117  // Empty filename reads from standard input. If 'convert' is true,
118  // convert to a mutable FST of type 'convert_type' if file is
119  // a non-mutable FST.
120  static MutableFst<A> *Read(const string &filename, bool convert = false,
121                             const string &convert_type = "vector") {
122    if (convert == false) {
123      if (!filename.empty()) {
124        ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
125        if (!strm) {
126          LOG(ERROR) << "MutableFst::Read: Can't open file: " << filename;
127          return 0;
128        }
129        return Read(strm, FstReadOptions(filename));
130      } else {
131        return Read(std::cin, FstReadOptions("standard input"));
132      }
133    } else {  // Converts to 'convert_type' if not mutable.
134      Fst<A> *ifst = Fst<A>::Read(filename);
135      if (!ifst) return 0;
136      if (ifst->Properties(kMutable, false)) {
137        return static_cast<MutableFst *>(ifst);
138      } else {
139        Fst<A> *ofst = Convert(*ifst, convert_type);
140        delete ifst;
141        if (!ofst) return 0;
142        if (!ofst->Properties(kMutable, false))
143          LOG(ERROR) << "MutableFst: bad convert type: " << convert_type;
144        return static_cast<MutableFst *>(ofst);
145      }
146    }
147  }
148
149  // For generic mutuble arc iterator construction; not normally called
150  // directly by users.
151  virtual void InitMutableArcIterator(StateId s,
152                                      MutableArcIteratorData<A> *) = 0;
153};
154
155// Mutable arc iterator interface, templated on the Arc definition; used
156// for mutable Arc iterator specializations that are returned by
157// the InitMutableArcIterator MutableFst method.
158template <class A>
159class MutableArcIteratorBase : public ArcIteratorBase<A> {
160 public:
161  typedef A Arc;
162
163  void SetValue(const A &arc) { SetValue_(arc); }  // Set current arc's content
164
165 private:
166  virtual void SetValue_(const A &arc) = 0;
167};
168
169template <class A>
170struct MutableArcIteratorData {
171  MutableArcIteratorBase<A> *base;  // Specific iterator
172};
173
174// Generic mutable arc iterator, templated on the FST definition
175// - a wrapper around pointer to specific one.
176// Here is a typical use: \code
177//   for (MutableArcIterator<StdFst> aiter(&fst, s));
178//        !aiter.Done();
179//         aiter.Next()) {
180//     StdArc arc = aiter.Value();
181//     arc.ilabel = 7;
182//     aiter.SetValue(arc);
183//     ...
184//   } \endcode
185// This version requires function calls.
186template <class F>
187class MutableArcIterator {
188 public:
189  typedef F FST;
190  typedef typename F::Arc Arc;
191  typedef typename Arc::StateId StateId;
192
193  MutableArcIterator(F *fst, StateId s) {
194    fst->InitMutableArcIterator(s, &data_);
195  }
196  ~MutableArcIterator() { delete data_.base; }
197
198  bool Done() const { return data_.base->Done(); }
199  const Arc& Value() const { return data_.base->Value(); }
200  void Next() { data_.base->Next(); }
201  size_t Position() const { return data_.base->Position(); }
202  void Reset() { data_.base->Reset(); }
203  void Seek(size_t a) { data_.base->Seek(a); }
204  void SetValue(const Arc &a) { data_.base->SetValue(a); }
205  uint32 Flags() const { return data_.base->Flags(); }
206  void SetFlags(uint32 f, uint32 m) {
207    return data_.base->SetFlags(f, m);
208  }
209
210 private:
211  MutableArcIteratorData<Arc> data_;
212  DISALLOW_COPY_AND_ASSIGN(MutableArcIterator);
213};
214
215
216namespace internal {
217
218//  MutableFst<A> case - abstract methods.
219template <class A> inline
220typename A::Weight Final(const MutableFst<A> &fst, typename A::StateId s) {
221  return fst.Final(s);
222}
223
224template <class A> inline
225ssize_t NumArcs(const MutableFst<A> &fst, typename A::StateId s) {
226  return fst.NumArcs(s);
227}
228
229template <class A> inline
230ssize_t NumInputEpsilons(const MutableFst<A> &fst, typename A::StateId s) {
231  return fst.NumInputEpsilons(s);
232}
233
234template <class A> inline
235ssize_t NumOutputEpsilons(const MutableFst<A> &fst, typename A::StateId s) {
236  return fst.NumOutputEpsilons(s);
237}
238
239}  // namespace internal
240
241
242// A useful alias when using StdArc.
243typedef MutableFst<StdArc> StdMutableFst;
244
245
246// This is a helper class template useful for attaching a MutableFst
247// interface to its implementation, handling reference counting and
248// copy-on-write.
249template <class I, class F = MutableFst<typename I::Arc> >
250class ImplToMutableFst : public ImplToExpandedFst<I, F> {
251 public:
252  typedef typename I::Arc Arc;
253  typedef typename Arc::Weight Weight;
254  typedef typename Arc::StateId StateId;
255
256  using ImplToFst<I, F>::GetImpl;
257  using ImplToFst<I, F>::SetImpl;
258
259  virtual void SetStart(StateId s) {
260    MutateCheck();
261    GetImpl()->SetStart(s);
262  }
263
264  virtual void SetFinal(StateId s, Weight w) {
265    MutateCheck();
266    GetImpl()->SetFinal(s, w);
267  }
268
269  virtual void SetProperties(uint64 props, uint64 mask) {
270    // Can skip mutate check if extrinsic properties don't change,
271    // since it is then safe to update all (shallow) copies
272    uint64 exprops = kExtrinsicProperties & mask;
273    if (GetImpl()->Properties(exprops) != (props & exprops))
274      MutateCheck();
275    GetImpl()->SetProperties(props, mask);
276  }
277
278  virtual StateId AddState() {
279    MutateCheck();
280    return GetImpl()->AddState();
281  }
282
283  virtual void AddArc(StateId s, const Arc &arc) {
284    MutateCheck();
285    GetImpl()->AddArc(s, arc);
286  }
287
288  virtual void DeleteStates(const vector<StateId> &dstates) {
289    MutateCheck();
290    GetImpl()->DeleteStates(dstates);
291  }
292
293  virtual void DeleteStates() {
294    MutateCheck();
295    GetImpl()->DeleteStates();
296  }
297
298  virtual void DeleteArcs(StateId s, size_t n) {
299    MutateCheck();
300    GetImpl()->DeleteArcs(s, n);
301  }
302
303  virtual void DeleteArcs(StateId s) {
304    MutateCheck();
305    GetImpl()->DeleteArcs(s);
306  }
307
308  virtual void ReserveStates(StateId s) {
309    MutateCheck();
310    GetImpl()->ReserveStates(s);
311  }
312
313  virtual void ReserveArcs(StateId s, size_t n) {
314    MutateCheck();
315    GetImpl()->ReserveArcs(s, n);
316  }
317
318  virtual const SymbolTable* InputSymbols() const {
319    return GetImpl()->InputSymbols();
320  }
321
322  virtual const SymbolTable* OutputSymbols() const {
323    return GetImpl()->OutputSymbols();
324  }
325
326  virtual SymbolTable* MutableInputSymbols() {
327    MutateCheck();
328    return GetImpl()->InputSymbols();
329  }
330
331  virtual SymbolTable* MutableOutputSymbols() {
332    MutateCheck();
333    return GetImpl()->OutputSymbols();
334  }
335
336  virtual void SetInputSymbols(const SymbolTable* isyms) {
337    MutateCheck();
338    GetImpl()->SetInputSymbols(isyms);
339  }
340
341  virtual void SetOutputSymbols(const SymbolTable* osyms) {
342    MutateCheck();
343    GetImpl()->SetOutputSymbols(osyms);
344  }
345
346 protected:
347  ImplToMutableFst() : ImplToExpandedFst<I, F>() {}
348
349  ImplToMutableFst(I *impl) : ImplToExpandedFst<I, F>(impl) {}
350
351
352  ImplToMutableFst(const ImplToMutableFst<I, F> &fst)
353      : ImplToExpandedFst<I, F>(fst) {}
354
355  ImplToMutableFst(const ImplToMutableFst<I, F> &fst, bool safe)
356      : ImplToExpandedFst<I, F>(fst, safe) {}
357
358  void MutateCheck() {
359    // Copy on write
360    if (GetImpl()->RefCount() > 1)
361      SetImpl(new I(*this));
362  }
363
364 private:
365  // Disallow
366  ImplToMutableFst<I, F>  &operator=(const ImplToMutableFst<I, F> &fst);
367
368  ImplToMutableFst<I, F> &operator=(const Fst<Arc> &fst) {
369    FSTERROR() << "ImplToMutableFst: Assignment operator disallowed";
370    GetImpl()->SetProperties(kError, kError);
371    return *this;
372  }
373};
374
375
376}  // namespace fst
377
378#endif  // FST_LIB_MUTABLE_FST_H__
379