1// randgen.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// Function to generate random paths through an FST.
18
19#ifndef FST_LIB_RANDGEN_H__
20#define FST_LIB_RANDGEN_H__
21
22#include <cmath>
23#include <cstdlib>
24#include <ctime>
25
26#include "fst/lib/mutable-fst.h"
27
28namespace fst {
29
30//
31// ARC SELECTORS - these function objects are used to select a random
32// transition to take from an FST's state. They should return a number
33// N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th
34// transition is selected. If N == NumArcs(), then the final weight at
35// that state is selected (i.e., the 'super-final' transition is selected).
36// It can be assumed these will not be called unless either there
37// are transitions leaving the state and/or the state is final.
38//
39
40// Randomly selects a transition using the uniform distribution.
41template <class A>
42struct UniformArcSelector {
43  typedef typename A::StateId StateId;
44  typedef typename A::Weight Weight;
45
46  UniformArcSelector(int seed = time(0)) { srand(seed); }
47
48  size_t operator()(const Fst<A> &fst, StateId s) const {
49    double r = rand()/(RAND_MAX + 1.0);
50    size_t n = fst.NumArcs(s);
51    if (fst.Final(s) != Weight::Zero())
52      ++n;
53    return static_cast<size_t>(r * n);
54  }
55};
56
57// Randomly selects a transition w.r.t. the weights treated as negative
58// log probabilities after normalizing for the total weight leaving
59// the state). Weight::zero transitions are disregarded.
60// Assumes Weight::Value() accesses the floating point
61// representation of the weight.
62template <class A>
63struct LogProbArcSelector {
64  typedef typename A::StateId StateId;
65  typedef typename A::Weight Weight;
66
67  LogProbArcSelector(int seed = time(0)) { srand(seed); }
68
69  size_t operator()(const Fst<A> &fst, StateId s) const {
70    // Find total weight leaving state
71    double sum = 0.0;
72    for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
73         aiter.Next()) {
74      const A &arc = aiter.Value();
75      sum += exp(-arc.weight.Value());
76    }
77    sum += exp(-fst.Final(s).Value());
78
79    double r = rand()/(RAND_MAX + 1.0);
80    double p = 0.0;
81    int n = 0;
82    for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
83         aiter.Next(), ++n) {
84      const A &arc = aiter.Value();
85      p += exp(-arc.weight.Value());
86      if (p > r * sum) return n;
87    }
88    return n;
89  }
90};
91
92// Convenience definitions
93typedef LogProbArcSelector<StdArc> StdArcSelector;
94typedef LogProbArcSelector<LogArc> LogArcSelector;
95
96
97// Options for random path generation.
98template <class S>
99struct RandGenOptions {
100  const S &arc_selector;  // How an arc is selected at a state
101  int max_length;         // Maximum path length
102  size_t npath;           // # of paths to generate
103
104  // These are used internally by RandGen
105  int64 source;           // 'ifst' state to expand
106  int64 dest;             // 'ofst' state to append
107
108  RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1)
109    : arc_selector(sel), max_length(len), npath(n),
110       source(kNoStateId), dest(kNoStateId) {}
111};
112
113
114// Randomly generate paths through an FST; details controlled by
115// RandGenOptions.
116template<class Arc, class ArcSelector>
117void RandGen(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
118	     const RandGenOptions<ArcSelector> &opts) {
119  typedef typename Arc::Weight Weight;
120
121  if (opts.npath == 0 || opts.max_length == 0 || ifst.Start() == kNoStateId)
122    return;
123
124  if (opts.source == kNoStateId) {   // first call
125    ofst->DeleteStates();
126    ofst->SetInputSymbols(ifst.InputSymbols());
127    ofst->SetOutputSymbols(ifst.OutputSymbols());
128    ofst->SetStart(ofst->AddState());
129    RandGenOptions<ArcSelector> nopts(opts);
130    nopts.source = ifst.Start();
131    nopts.dest = ofst->Start();
132    for (; nopts.npath > 0; --nopts.npath)
133      RandGen(ifst, ofst, nopts);
134  } else {
135    if (ifst.NumArcs(opts.source) == 0 &&
136	ifst.Final(opts.source) == Weight::Zero())  // Non-coaccessible
137      return;
138    // Pick a random transition from the source state
139    size_t n = opts.arc_selector(ifst, opts.source);
140    if (n == ifst.NumArcs(opts.source)) {  // Take 'super-final' transition
141      ofst->SetFinal(opts.dest, Weight::One());
142    } else {
143      ArcIterator< Fst<Arc> > aiter(ifst, opts.source);
144      aiter.Seek(n);
145      const Arc &iarc = aiter.Value();
146      Arc oarc(iarc.ilabel, iarc.olabel, Weight::One(), ofst->AddState());
147      ofst->AddArc(opts.dest, oarc);
148
149      RandGenOptions<ArcSelector> nopts(opts);
150      nopts.source = iarc.nextstate;
151      nopts.dest = oarc.nextstate;
152      --nopts.max_length;
153      RandGen(ifst, ofst, nopts);
154    }
155  }
156}
157
158// Randomly generate a path through an FST with the uniform distribution
159// over the transitions.
160template<class Arc>
161void RandGen(const Fst<Arc> &ifst, MutableFst<Arc> *ofst) {
162  UniformArcSelector<Arc> uniform_selector;
163  RandGenOptions< UniformArcSelector<Arc> > opts(uniform_selector);
164  RandGen(ifst, ofst, opts);
165}
166
167}  // namespace fst
168
169#endif  // FST_LIB_RANDGEN_H__
170