1// randgen.h 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// 16// \file 17// Function to generate random paths through an FST. 18 19#ifndef FST_LIB_RANDGEN_H__ 20#define FST_LIB_RANDGEN_H__ 21 22#include <cmath> 23#include <cstdlib> 24#include <ctime> 25 26#include "fst/lib/mutable-fst.h" 27 28namespace fst { 29 30// 31// ARC SELECTORS - these function objects are used to select a random 32// transition to take from an FST's state. They should return a number 33// N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th 34// transition is selected. If N == NumArcs(), then the final weight at 35// that state is selected (i.e., the 'super-final' transition is selected). 36// It can be assumed these will not be called unless either there 37// are transitions leaving the state and/or the state is final. 38// 39 40// Randomly selects a transition using the uniform distribution. 41template <class A> 42struct UniformArcSelector { 43 typedef typename A::StateId StateId; 44 typedef typename A::Weight Weight; 45 46 UniformArcSelector(int seed = time(0)) { srand(seed); } 47 48 size_t operator()(const Fst<A> &fst, StateId s) const { 49 double r = rand()/(RAND_MAX + 1.0); 50 size_t n = fst.NumArcs(s); 51 if (fst.Final(s) != Weight::Zero()) 52 ++n; 53 return static_cast<size_t>(r * n); 54 } 55}; 56 57// Randomly selects a transition w.r.t. the weights treated as negative 58// log probabilities after normalizing for the total weight leaving 59// the state). Weight::zero transitions are disregarded. 60// Assumes Weight::Value() accesses the floating point 61// representation of the weight. 62template <class A> 63struct LogProbArcSelector { 64 typedef typename A::StateId StateId; 65 typedef typename A::Weight Weight; 66 67 LogProbArcSelector(int seed = time(0)) { srand(seed); } 68 69 size_t operator()(const Fst<A> &fst, StateId s) const { 70 // Find total weight leaving state 71 double sum = 0.0; 72 for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done(); 73 aiter.Next()) { 74 const A &arc = aiter.Value(); 75 sum += exp(-arc.weight.Value()); 76 } 77 sum += exp(-fst.Final(s).Value()); 78 79 double r = rand()/(RAND_MAX + 1.0); 80 double p = 0.0; 81 int n = 0; 82 for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done(); 83 aiter.Next(), ++n) { 84 const A &arc = aiter.Value(); 85 p += exp(-arc.weight.Value()); 86 if (p > r * sum) return n; 87 } 88 return n; 89 } 90}; 91 92// Convenience definitions 93typedef LogProbArcSelector<StdArc> StdArcSelector; 94typedef LogProbArcSelector<LogArc> LogArcSelector; 95 96 97// Options for random path generation. 98template <class S> 99struct RandGenOptions { 100 const S &arc_selector; // How an arc is selected at a state 101 int max_length; // Maximum path length 102 size_t npath; // # of paths to generate 103 104 // These are used internally by RandGen 105 int64 source; // 'ifst' state to expand 106 int64 dest; // 'ofst' state to append 107 108 RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1) 109 : arc_selector(sel), max_length(len), npath(n), 110 source(kNoStateId), dest(kNoStateId) {} 111}; 112 113 114// Randomly generate paths through an FST; details controlled by 115// RandGenOptions. 116template<class Arc, class ArcSelector> 117void RandGen(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, 118 const RandGenOptions<ArcSelector> &opts) { 119 typedef typename Arc::Weight Weight; 120 121 if (opts.npath == 0 || opts.max_length == 0 || ifst.Start() == kNoStateId) 122 return; 123 124 if (opts.source == kNoStateId) { // first call 125 ofst->DeleteStates(); 126 ofst->SetInputSymbols(ifst.InputSymbols()); 127 ofst->SetOutputSymbols(ifst.OutputSymbols()); 128 ofst->SetStart(ofst->AddState()); 129 RandGenOptions<ArcSelector> nopts(opts); 130 nopts.source = ifst.Start(); 131 nopts.dest = ofst->Start(); 132 for (; nopts.npath > 0; --nopts.npath) 133 RandGen(ifst, ofst, nopts); 134 } else { 135 if (ifst.NumArcs(opts.source) == 0 && 136 ifst.Final(opts.source) == Weight::Zero()) // Non-coaccessible 137 return; 138 // Pick a random transition from the source state 139 size_t n = opts.arc_selector(ifst, opts.source); 140 if (n == ifst.NumArcs(opts.source)) { // Take 'super-final' transition 141 ofst->SetFinal(opts.dest, Weight::One()); 142 } else { 143 ArcIterator< Fst<Arc> > aiter(ifst, opts.source); 144 aiter.Seek(n); 145 const Arc &iarc = aiter.Value(); 146 Arc oarc(iarc.ilabel, iarc.olabel, Weight::One(), ofst->AddState()); 147 ofst->AddArc(opts.dest, oarc); 148 149 RandGenOptions<ArcSelector> nopts(opts); 150 nopts.source = iarc.nextstate; 151 nopts.dest = oarc.nextstate; 152 --nopts.max_length; 153 RandGen(ifst, ofst, nopts); 154 } 155 } 156} 157 158// Randomly generate a path through an FST with the uniform distribution 159// over the transitions. 160template<class Arc> 161void RandGen(const Fst<Arc> &ifst, MutableFst<Arc> *ofst) { 162 UniformArcSelector<Arc> uniform_selector; 163 RandGenOptions< UniformArcSelector<Arc> > opts(uniform_selector); 164 RandGen(ifst, ofst, opts); 165} 166 167} // namespace fst 168 169#endif // FST_LIB_RANDGEN_H__ 170