1// randgen.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Classes and functions to generate random paths through an FST.
20
21#ifndef FST_LIB_RANDGEN_H__
22#define FST_LIB_RANDGEN_H__
23
24#include <cmath>
25#include <cstdlib>
26#include <ctime>
27#include <map>
28
29#include <fst/accumulator.h>
30#include <fst/cache.h>
31#include <fst/dfs-visit.h>
32#include <fst/mutable-fst.h>
33
34namespace fst {
35
36//
37// ARC SELECTORS - these function objects are used to select a random
38// transition to take from an FST's state. They should return a number
39// N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th
40// transition is selected. If N == NumArcs(), then the final weight at
41// that state is selected (i.e., the 'super-final' transition is selected).
42// It can be assumed these will not be called unless either there
43// are transitions leaving the state and/or the state is final.
44//
45
46// Randomly selects a transition using the uniform distribution.
47template <class A>
48struct UniformArcSelector {
49  typedef typename A::StateId StateId;
50  typedef typename A::Weight Weight;
51
52  UniformArcSelector(int seed = time(0)) { srand(seed); }
53
54  size_t operator()(const Fst<A> &fst, StateId s) const {
55    double r = rand()/(RAND_MAX + 1.0);
56    size_t n = fst.NumArcs(s);
57    if (fst.Final(s) != Weight::Zero())
58      ++n;
59    return static_cast<size_t>(r * n);
60  }
61};
62
63
64// Randomly selects a transition w.r.t. the weights treated as negative
65// log probabilities after normalizing for the total weight leaving
66// the state. Weight::zero transitions are disregarded.
67// Assumes Weight::Value() accesses the floating point
68// representation of the weight.
69template <class A>
70class LogProbArcSelector {
71 public:
72  typedef typename A::StateId StateId;
73  typedef typename A::Weight Weight;
74
75  LogProbArcSelector(int seed = time(0)) { srand(seed); }
76
77  size_t operator()(const Fst<A> &fst, StateId s) const {
78    // Find total weight leaving state
79    double sum = 0.0;
80    for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
81         aiter.Next()) {
82      const A &arc = aiter.Value();
83      sum += exp(-to_log_weight_(arc.weight).Value());
84    }
85    sum += exp(-to_log_weight_(fst.Final(s)).Value());
86
87    double r = rand()/(RAND_MAX + 1.0);
88    double p = 0.0;
89    int n = 0;
90    for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
91         aiter.Next(), ++n) {
92      const A &arc = aiter.Value();
93      p += exp(-to_log_weight_(arc.weight).Value());
94      if (p > r * sum) return n;
95    }
96    return n;
97  }
98
99 private:
100  WeightConvert<Weight, Log64Weight> to_log_weight_;
101};
102
103// Convenience definitions
104typedef LogProbArcSelector<StdArc> StdArcSelector;
105typedef LogProbArcSelector<LogArc> LogArcSelector;
106
107
108// Same as LogProbArcSelector but use CacheLogAccumulator to cache
109// the cummulative weight computations.
110template <class A>
111class FastLogProbArcSelector : public LogProbArcSelector<A> {
112 public:
113  typedef typename A::StateId StateId;
114  typedef typename A::Weight Weight;
115  using LogProbArcSelector<A>::operator();
116
117  FastLogProbArcSelector(int seed = time(0))
118      : LogProbArcSelector<A>(seed),
119        seed_(seed) {}
120
121  size_t operator()(const Fst<A> &fst, StateId s,
122                    CacheLogAccumulator<A> *accumulator) const {
123    accumulator->SetState(s);
124    ArcIterator< Fst<A> > aiter(fst, s);
125    // Find total weight leaving state
126    double sum = to_log_weight_(accumulator->Sum(fst.Final(s), &aiter, 0,
127                                                 fst.NumArcs(s))).Value();
128    double r = -log(rand()/(RAND_MAX + 1.0));
129    return accumulator->LowerBound(r + sum, &aiter);
130  }
131
132  int Seed() const { return seed_; }
133 private:
134  int seed_;
135  WeightConvert<Weight, Log64Weight> to_log_weight_;
136};
137
138// Random path state info maintained by RandGenFst and passed to samplers.
139template <typename A>
140struct RandState {
141  typedef typename A::StateId StateId;
142
143  StateId state_id;              // current input FST state
144  size_t nsamples;               // # of samples to be sampled at this state
145  size_t length;                 // length of path to this random state
146  size_t select;                 // previous sample arc selection
147  const RandState<A> *parent;    // previous random state on this path
148
149  RandState(StateId s, size_t n, size_t l, size_t k, const RandState<A> *p)
150      : state_id(s), nsamples(n), length(l), select(k), parent(p) {}
151
152  RandState()
153      : state_id(kNoStateId), nsamples(0), length(0), select(0), parent(0) {}
154};
155
156// This class, given an arc selector, samples, with raplacement,
157// multiple random transitions from an FST's state. This is a generic
158// version with a straight-forward use of the arc selector.
159// Specializations may be defined for arc selectors for greater
160// efficiency or special behavior.
161template <class A, class S>
162class ArcSampler {
163 public:
164  typedef typename A::StateId StateId;
165  typedef typename A::Weight Weight;
166
167  // The 'max_length' may be interpreted (including ignored) by a
168  // sampler as it chooses. This generic version interprets this literally.
169  ArcSampler(const Fst<A> &fst, const S &arc_selector,
170             int max_length = INT_MAX)
171      : fst_(fst),
172        arc_selector_(arc_selector),
173        max_length_(max_length) {}
174
175  // Allow updating Fst argument; pass only if changed.
176  ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0)
177      : fst_(fst ? *fst : sampler.fst_),
178        arc_selector_(sampler.arc_selector_),
179        max_length_(sampler.max_length_) {
180    Reset();
181  }
182
183  // Samples 'rstate.nsamples' from state 'state_id'. The 'rstate.length' is
184  // the length of the path to 'rstate'. Returns true if samples were
185  // collected.  No samples may be collected if either there are no (including
186  // 'super-final') transitions leaving that state or if the
187  // 'max_length' has been deemed reached. Use the iterator members to
188  // read the samples. The samples will be in their original order.
189  bool Sample(const RandState<A> &rstate) {
190    sample_map_.clear();
191    if ((fst_.NumArcs(rstate.state_id) == 0 &&
192         fst_.Final(rstate.state_id) == Weight::Zero()) ||
193        rstate.length == max_length_) {
194      Reset();
195      return false;
196    }
197
198    for (size_t i = 0; i < rstate.nsamples; ++i)
199      ++sample_map_[arc_selector_(fst_, rstate.state_id)];
200    Reset();
201    return true;
202  }
203
204  // More samples?
205  bool Done() const { return sample_iter_ == sample_map_.end(); }
206
207  // Gets the next sample.
208  void Next() { ++sample_iter_; }
209
210  // Returns a pair (N, K) where 0 <= N <= NumArcs(s) and 0 < K <= nsamples.
211  // If N < NumArcs(s), then the N-th transition is specified.
212  // If N == NumArcs(s), then the final weight at that state is
213  // specified (i.e., the 'super-final' transition is specified).
214  // For the specified transition, K repetitions have been sampled.
215  pair<size_t, size_t> Value() const { return *sample_iter_; }
216
217  void Reset() { sample_iter_ = sample_map_.begin(); }
218
219  bool Error() const { return false; }
220
221 private:
222  const Fst<A> &fst_;
223  const S &arc_selector_;
224  int max_length_;
225
226  // Stores (N, K) as described for Value().
227  map<size_t, size_t> sample_map_;
228  map<size_t, size_t>::const_iterator sample_iter_;
229
230  // disallow
231  ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s);
232};
233
234
235// Specialization for FastLogProbArcSelector.
236template <class A>
237class ArcSampler<A, FastLogProbArcSelector<A> > {
238 public:
239  typedef FastLogProbArcSelector<A> S;
240  typedef typename A::StateId StateId;
241  typedef typename A::Weight Weight;
242  typedef CacheLogAccumulator<A> C;
243
244  ArcSampler(const Fst<A> &fst, const S &arc_selector, int max_length = INT_MAX)
245      : fst_(fst),
246        arc_selector_(arc_selector),
247        max_length_(max_length),
248        accumulator_(new C()) {
249    accumulator_->Init(fst);
250  }
251
252  ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0)
253      : fst_(fst ? *fst : sampler.fst_),
254        arc_selector_(sampler.arc_selector_),
255        max_length_(sampler.max_length_) {
256    if (fst) {
257      accumulator_ = new C();
258      accumulator_->Init(*fst);
259    } else {  // shallow copy
260      accumulator_ = new C(*sampler.accumulator_);
261    }
262  }
263
264  ~ArcSampler() {
265    delete accumulator_;
266  }
267
268  bool Sample(const RandState<A> &rstate) {
269    sample_map_.clear();
270    if ((fst_.NumArcs(rstate.state_id) == 0 &&
271         fst_.Final(rstate.state_id) == Weight::Zero()) ||
272        rstate.length == max_length_) {
273      Reset();
274      return false;
275    }
276
277    for (size_t i = 0; i < rstate.nsamples; ++i)
278      ++sample_map_[arc_selector_(fst_, rstate.state_id, accumulator_)];
279    Reset();
280    return true;
281  }
282
283  bool Done() const { return sample_iter_ == sample_map_.end(); }
284  void Next() { ++sample_iter_; }
285  pair<size_t, size_t> Value() const { return *sample_iter_; }
286  void Reset() { sample_iter_ = sample_map_.begin(); }
287
288  bool Error() const { return accumulator_->Error(); }
289
290 private:
291  const Fst<A> &fst_;
292  const S &arc_selector_;
293  int max_length_;
294
295  // Stores (N, K) as described for Value().
296  map<size_t, size_t> sample_map_;
297  map<size_t, size_t>::const_iterator sample_iter_;
298  C *accumulator_;
299
300  // disallow
301  ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s);
302};
303
304
305// Options for random path generation with RandGenFst. The template argument
306// is an arc sampler, typically class 'ArcSampler' above.  Ownership of
307// the sampler is taken by RandGenFst.
308template <class S>
309struct RandGenFstOptions : public CacheOptions {
310  S *arc_sampler;            // How to sample transitions at a state
311  size_t npath;              // # of paths to generate
312  bool weighted;             // Output tree weighted by path count; o.w.
313                             // output unweighted DAG
314  bool remove_total_weight;  // Remove total weight when output is weighted.
315
316  RandGenFstOptions(const CacheOptions &copts, S *samp,
317                    size_t n = 1, bool w = true, bool rw = false)
318      : CacheOptions(copts),
319        arc_sampler(samp),
320        npath(n),
321        weighted(w),
322        remove_total_weight(rw) {}
323};
324
325
326// Implementation of RandGenFst.
327template <class A, class B, class S>
328class RandGenFstImpl : public CacheImpl<B> {
329 public:
330  using FstImpl<B>::SetType;
331  using FstImpl<B>::SetProperties;
332  using FstImpl<B>::SetInputSymbols;
333  using FstImpl<B>::SetOutputSymbols;
334
335  using CacheBaseImpl< CacheState<B> >::AddArc;
336  using CacheBaseImpl< CacheState<B> >::HasArcs;
337  using CacheBaseImpl< CacheState<B> >::HasFinal;
338  using CacheBaseImpl< CacheState<B> >::HasStart;
339  using CacheBaseImpl< CacheState<B> >::SetArcs;
340  using CacheBaseImpl< CacheState<B> >::SetFinal;
341  using CacheBaseImpl< CacheState<B> >::SetStart;
342
343  typedef B Arc;
344  typedef typename A::Label Label;
345  typedef typename A::Weight Weight;
346  typedef typename A::StateId StateId;
347
348  RandGenFstImpl(const Fst<A> &fst, const RandGenFstOptions<S> &opts)
349      : CacheImpl<B>(opts),
350        fst_(fst.Copy()),
351        arc_sampler_(opts.arc_sampler),
352        npath_(opts.npath),
353        weighted_(opts.weighted),
354        remove_total_weight_(opts.remove_total_weight),
355        superfinal_(kNoLabel) {
356    SetType("randgen");
357
358    uint64 props = fst.Properties(kFstProperties, false);
359    SetProperties(RandGenProperties(props, weighted_), kCopyProperties);
360
361    SetInputSymbols(fst.InputSymbols());
362    SetOutputSymbols(fst.OutputSymbols());
363  }
364
365  RandGenFstImpl(const RandGenFstImpl &impl)
366    : CacheImpl<B>(impl),
367      fst_(impl.fst_->Copy(true)),
368      arc_sampler_(new S(*impl.arc_sampler_, fst_)),
369      npath_(impl.npath_),
370      weighted_(impl.weighted_),
371      superfinal_(kNoLabel) {
372    SetType("randgen");
373    SetProperties(impl.Properties(), kCopyProperties);
374    SetInputSymbols(impl.InputSymbols());
375    SetOutputSymbols(impl.OutputSymbols());
376  }
377
378  ~RandGenFstImpl() {
379    for (int i = 0; i < state_table_.size(); ++i)
380      delete state_table_[i];
381    delete fst_;
382    delete arc_sampler_;
383  }
384
385  StateId Start() {
386    if (!HasStart()) {
387      StateId s = fst_->Start();
388      if (s == kNoStateId)
389        return kNoStateId;
390      StateId start = state_table_.size();
391      SetStart(start);
392      RandState<A> *rstate = new RandState<A>(s, npath_, 0, 0, 0);
393      state_table_.push_back(rstate);
394    }
395    return CacheImpl<B>::Start();
396  }
397
398  Weight Final(StateId s) {
399    if (!HasFinal(s)) {
400      Expand(s);
401    }
402    return CacheImpl<B>::Final(s);
403  }
404
405  size_t NumArcs(StateId s) {
406    if (!HasArcs(s)) {
407      Expand(s);
408    }
409    return CacheImpl<B>::NumArcs(s);
410  }
411
412  size_t NumInputEpsilons(StateId s) {
413    if (!HasArcs(s))
414      Expand(s);
415    return CacheImpl<B>::NumInputEpsilons(s);
416  }
417
418  size_t NumOutputEpsilons(StateId s) {
419    if (!HasArcs(s))
420      Expand(s);
421    return CacheImpl<B>::NumOutputEpsilons(s);
422  }
423
424  uint64 Properties() const { return Properties(kFstProperties); }
425
426  // Set error if found; return FST impl properties.
427  uint64 Properties(uint64 mask) const {
428    if ((mask & kError) &&
429        (fst_->Properties(kError, false) || arc_sampler_->Error())) {
430      SetProperties(kError, kError);
431    }
432    return FstImpl<Arc>::Properties(mask);
433  }
434
435  void InitArcIterator(StateId s, ArcIteratorData<B> *data) {
436    if (!HasArcs(s))
437      Expand(s);
438    CacheImpl<B>::InitArcIterator(s, data);
439  }
440
441  // Computes the outgoing transitions from a state, creating new destination
442  // states as needed.
443  void Expand(StateId s) {
444    if (s == superfinal_) {
445      SetFinal(s, Weight::One());
446      SetArcs(s);
447      return;
448    }
449
450    SetFinal(s, Weight::Zero());
451    const RandState<A> &rstate = *state_table_[s];
452    arc_sampler_->Sample(rstate);
453    ArcIterator< Fst<A> > aiter(*fst_, rstate.state_id);
454    size_t narcs = fst_->NumArcs(rstate.state_id);
455    for (;!arc_sampler_->Done(); arc_sampler_->Next()) {
456      const pair<size_t, size_t> &sample_pair = arc_sampler_->Value();
457      size_t pos = sample_pair.first;
458      size_t count = sample_pair.second;
459      double prob = static_cast<double>(count)/rstate.nsamples;
460      if (pos < narcs) {  // regular transition
461        aiter.Seek(sample_pair.first);
462        const A &aarc = aiter.Value();
463        Weight weight = weighted_ ? to_weight_(-log(prob)) : Weight::One();
464        B barc(aarc.ilabel, aarc.olabel, weight, state_table_.size());
465        AddArc(s, barc);
466        RandState<A> *nrstate =
467            new RandState<A>(aarc.nextstate, count, rstate.length + 1,
468                             pos, &rstate);
469        state_table_.push_back(nrstate);
470      } else {            // super-final transition
471        if (weighted_) {
472          Weight weight = remove_total_weight_ ?
473              to_weight_(-log(prob)) : to_weight_(-log(prob * npath_));
474          SetFinal(s, weight);
475        } else {
476          if (superfinal_ == kNoLabel) {
477            superfinal_ = state_table_.size();
478            RandState<A> *nrstate = new RandState<A>(kNoStateId, 0, 0, 0, 0);
479            state_table_.push_back(nrstate);
480          }
481          for (size_t n = 0; n < count; ++n) {
482            B barc(0, 0, Weight::One(), superfinal_);
483            AddArc(s, barc);
484          }
485        }
486      }
487    }
488    SetArcs(s);
489  }
490
491 private:
492  Fst<A> *fst_;
493  S *arc_sampler_;
494  size_t npath_;
495  vector<RandState<A> *> state_table_;
496  bool weighted_;
497  bool remove_total_weight_;
498  StateId superfinal_;
499  WeightConvert<Log64Weight, Weight> to_weight_;
500
501  void operator=(const RandGenFstImpl<A, B, S> &);  // disallow
502};
503
504
505// Fst class to randomly generate paths through an FST; details controlled
506// by RandGenOptionsFst. Output format is a tree weighted by the
507// path count.
508template <class A, class B, class S>
509class RandGenFst : public ImplToFst< RandGenFstImpl<A, B, S> > {
510 public:
511  friend class ArcIterator< RandGenFst<A, B, S> >;
512  friend class StateIterator< RandGenFst<A, B, S> >;
513  typedef B Arc;
514  typedef S Sampler;
515  typedef typename A::Label Label;
516  typedef typename A::Weight Weight;
517  typedef typename A::StateId StateId;
518  typedef CacheState<B> State;
519  typedef RandGenFstImpl<A, B, S> Impl;
520
521  RandGenFst(const Fst<A> &fst, const RandGenFstOptions<S> &opts)
522    : ImplToFst<Impl>(new Impl(fst, opts)) {}
523
524  // See Fst<>::Copy() for doc.
525 RandGenFst(const RandGenFst<A, B, S> &fst, bool safe = false)
526    : ImplToFst<Impl>(fst, safe) {}
527
528  // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc.
529  virtual RandGenFst<A, B, S> *Copy(bool safe = false) const {
530    return new RandGenFst<A, B, S>(*this, safe);
531  }
532
533  virtual inline void InitStateIterator(StateIteratorData<B> *data) const;
534
535  virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const {
536    GetImpl()->InitArcIterator(s, data);
537  }
538
539 private:
540  // Makes visible to friends.
541  Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
542
543  void operator=(const RandGenFst<A, B, S> &fst);  // Disallow
544};
545
546
547
548// Specialization for RandGenFst.
549template <class A, class B, class S>
550class StateIterator< RandGenFst<A, B, S> >
551    : public CacheStateIterator< RandGenFst<A, B, S> > {
552 public:
553  explicit StateIterator(const RandGenFst<A, B, S> &fst)
554    : CacheStateIterator< RandGenFst<A, B, S> >(fst, fst.GetImpl()) {}
555
556 private:
557  DISALLOW_COPY_AND_ASSIGN(StateIterator);
558};
559
560
561// Specialization for RandGenFst.
562template <class A, class B, class S>
563class ArcIterator< RandGenFst<A, B, S> >
564    : public CacheArcIterator< RandGenFst<A, B, S> > {
565 public:
566  typedef typename A::StateId StateId;
567
568  ArcIterator(const RandGenFst<A, B, S> &fst, StateId s)
569      : CacheArcIterator< RandGenFst<A, B, S> >(fst.GetImpl(), s) {
570    if (!fst.GetImpl()->HasArcs(s))
571      fst.GetImpl()->Expand(s);
572  }
573
574 private:
575  DISALLOW_COPY_AND_ASSIGN(ArcIterator);
576};
577
578
579template <class A, class B, class S> inline
580void RandGenFst<A, B, S>::InitStateIterator(StateIteratorData<B> *data) const
581{
582  data->base = new StateIterator< RandGenFst<A, B, S> >(*this);
583}
584
585// Options for random path generation.
586template <class S>
587struct RandGenOptions {
588  const S &arc_selector;     // How an arc is selected at a state
589  int max_length;            // Maximum path length
590  size_t npath;              // # of paths to generate
591  bool weighted;             // Output is tree weighted by path count; o.w.
592                             // output unweighted union of paths.
593  bool remove_total_weight;  // Remove total weight when output is weighted.
594
595  RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1,
596                 bool w = false, bool rw = false)
597      : arc_selector(sel),
598        max_length(len),
599        npath(n),
600        weighted(w),
601        remove_total_weight(rw) {}
602};
603
604
605template <class IArc, class OArc>
606class RandGenVisitor {
607 public:
608  typedef typename IArc::Weight Weight;
609  typedef typename IArc::StateId StateId;
610
611  RandGenVisitor(MutableFst<OArc> *ofst) : ofst_(ofst) {}
612
613  void InitVisit(const Fst<IArc> &ifst) {
614    ifst_ = &ifst;
615
616    ofst_->DeleteStates();
617    ofst_->SetInputSymbols(ifst.InputSymbols());
618    ofst_->SetOutputSymbols(ifst.OutputSymbols());
619    if (ifst.Properties(kError, false))
620      ofst_->SetProperties(kError, kError);
621    path_.clear();
622  }
623
624  bool InitState(StateId s, StateId root) { return true; }
625
626  bool TreeArc(StateId s, const IArc &arc) {
627    if (ifst_->Final(arc.nextstate) == Weight::Zero()) {
628      path_.push_back(arc);
629    } else {
630      OutputPath();
631    }
632    return true;
633  }
634
635  bool BackArc(StateId s, const IArc &arc) {
636    FSTERROR() << "RandGenVisitor: cyclic input";
637    ofst_->SetProperties(kError, kError);
638    return false;
639  }
640
641  bool ForwardOrCrossArc(StateId s, const IArc &arc) {
642    OutputPath();
643    return true;
644  }
645
646  void FinishState(StateId s, StateId p, const IArc *) {
647    if (p != kNoStateId && ifst_->Final(s) == Weight::Zero())
648      path_.pop_back();
649  }
650
651  void FinishVisit() {}
652
653 private:
654  void OutputPath() {
655    if (ofst_->Start() == kNoStateId) {
656      StateId start = ofst_->AddState();
657      ofst_->SetStart(start);
658    }
659
660    StateId src = ofst_->Start();
661    for (size_t i = 0; i < path_.size(); ++i) {
662      StateId dest = ofst_->AddState();
663      OArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest);
664      ofst_->AddArc(src, arc);
665      src = dest;
666    }
667    ofst_->SetFinal(src, Weight::One());
668  }
669
670  const Fst<IArc> *ifst_;
671  MutableFst<OArc> *ofst_;
672  vector<OArc> path_;
673
674  DISALLOW_COPY_AND_ASSIGN(RandGenVisitor);
675};
676
677
678// Randomly generate paths through an FST; details controlled by
679// RandGenOptions.
680template<class IArc, class OArc, class Selector>
681void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst,
682             const RandGenOptions<Selector> &opts) {
683  typedef ArcSampler<IArc, Selector> Sampler;
684  typedef RandGenFst<IArc, OArc, Sampler> RandFst;
685  typedef typename OArc::StateId StateId;
686  typedef typename OArc::Weight Weight;
687
688  Sampler* arc_sampler = new Sampler(ifst, opts.arc_selector, opts.max_length);
689  RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), arc_sampler,
690                                   opts.npath, opts.weighted,
691                                   opts.remove_total_weight);
692  RandFst rfst(ifst, fopts);
693  if (opts.weighted) {
694    *ofst = rfst;
695  } else {
696    RandGenVisitor<IArc, OArc> rand_visitor(ofst);
697    DfsVisit(rfst, &rand_visitor);
698  }
699}
700
701// Randomly generate a path through an FST with the uniform distribution
702// over the transitions.
703template<class IArc, class OArc>
704void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst) {
705  UniformArcSelector<IArc> uniform_selector;
706  RandGenOptions< UniformArcSelector<IArc> > opts(uniform_selector);
707  RandGen(ifst, ofst, opts);
708}
709
710}  // namespace fst
711
712#endif  // FST_LIB_RANDGEN_H__
713