1// union.h 2 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// 15// Copyright 2005-2010 Google, Inc. 16// Author: riley@google.com (Michael Riley) 17// 18// \file 19// Functions and classes to compute the union of two FSTs. 20 21#ifndef FST_LIB_UNION_H__ 22#define FST_LIB_UNION_H__ 23 24#include <vector> 25using std::vector; 26#include <algorithm> 27 28#include <fst/mutable-fst.h> 29#include <fst/rational.h> 30 31 32namespace fst { 33 34// Computes the union (sum) of two FSTs. This version writes the 35// union to an output MurableFst. If A transduces string x to y with 36// weight a and B transduces string w to v with weight b, then their 37// union transduces x to y with weight a and w to v with weight b. 38// 39// Complexity: 40// - Time: (V2 + E2) 41// - Space: O(V2 + E2) 42// where Vi = # of states and Ei = # of arcs of the ith FST. 43template <class Arc> 44void Union(MutableFst<Arc> *fst1, const Fst<Arc> &fst2) { 45 typedef typename Arc::StateId StateId; 46 typedef typename Arc::Label Label; 47 typedef typename Arc::Weight Weight; 48 49 // TODO(riley): restore when voice actions issues fixed 50 // Check that the symbol table are compatible 51 if (!CompatSymbols(fst1->InputSymbols(), fst2.InputSymbols()) || 52 !CompatSymbols(fst1->OutputSymbols(), fst2.OutputSymbols())) { 53 LOG(ERROR) << "Union: input/output symbol tables of 1st argument " 54 << "do not match input/output symbol tables of 2nd argument"; 55 // fst1->SetProperties(kError, kError); 56 // return; 57 } 58 59 StateId numstates1 = fst1->NumStates(); 60 bool initial_acyclic1 = fst1->Properties(kInitialAcyclic, true); 61 uint64 props1 = fst1->Properties(kFstProperties, false); 62 uint64 props2 = fst2.Properties(kFstProperties, false); 63 64 StateId start2 = fst2.Start(); 65 if (start2 == kNoStateId) { 66 if (props2 & kError) fst1->SetProperties(kError, kError); 67 return; 68 } 69 70 if (fst2.Properties(kExpanded, false)) { 71 fst1->ReserveStates( 72 numstates1 + CountStates(fst2) + (initial_acyclic1 ? 0 : 1)); 73 } 74 75 for (StateIterator< Fst<Arc> > siter(fst2); 76 !siter.Done(); 77 siter.Next()) { 78 StateId s1 = fst1->AddState(); 79 StateId s2 = siter.Value(); 80 fst1->SetFinal(s1, fst2.Final(s2)); 81 fst1->ReserveArcs(s1, fst2.NumArcs(s2)); 82 for (ArcIterator< Fst<Arc> > aiter(fst2, s2); 83 !aiter.Done(); 84 aiter.Next()) { 85 Arc arc = aiter.Value(); 86 arc.nextstate += numstates1; 87 fst1->AddArc(s1, arc); 88 } 89 } 90 StateId start1 = fst1->Start(); 91 if (start1 == kNoStateId) { 92 fst1->SetStart(start2); 93 fst1->SetProperties(props2, kCopyProperties); 94 return; 95 } 96 97 if (initial_acyclic1) { 98 fst1->AddArc(start1, Arc(0, 0, Weight::One(), start2 + numstates1)); 99 } else { 100 StateId nstart1 = fst1->AddState(); 101 fst1->SetStart(nstart1); 102 fst1->AddArc(nstart1, Arc(0, 0, Weight::One(), start1)); 103 fst1->AddArc(nstart1, Arc(0, 0, Weight::One(), start2 + numstates1)); 104 } 105 fst1->SetProperties(UnionProperties(props1, props2), kFstProperties); 106} 107 108 109// Computes the union of two FSTs; this version modifies its 110// RationalFst argument. 111template<class Arc> 112void Union(RationalFst<Arc> *fst1, const Fst<Arc> &fst2) { 113 fst1->GetImpl()->AddUnion(fst2); 114} 115 116 117typedef RationalFstOptions UnionFstOptions; 118 119 120// Computes the union (sum) of two FSTs. This version is a delayed 121// Fst. If A transduces string x to y with weight a and B transduces 122// string w to v with weight b, then their union transduces x to y 123// with weight a and w to v with weight b. 124// 125// Complexity: 126// - Time: O(v1 + e1 + v2 + e2) 127// - Sapce: O(v1 + v2) 128// where vi = # of states visited and ei = # of arcs visited of the 129// ith FST. Constant time and space to visit an input state or arc 130// is assumed and exclusive of caching. 131template <class A> 132class UnionFst : public RationalFst<A> { 133 public: 134 using ImplToFst< RationalFstImpl<A> >::GetImpl; 135 136 typedef A Arc; 137 typedef typename A::Weight Weight; 138 typedef typename A::StateId StateId; 139 140 UnionFst(const Fst<A> &fst1, const Fst<A> &fst2) { 141 GetImpl()->InitUnion(fst1, fst2); 142 } 143 144 UnionFst(const Fst<A> &fst1, const Fst<A> &fst2, const UnionFstOptions &opts) 145 : RationalFst<A>(opts) { 146 GetImpl()->InitUnion(fst1, fst2); 147 } 148 149 // See Fst<>::Copy() for doc. 150 UnionFst(const UnionFst<A> &fst, bool safe = false) 151 : RationalFst<A>(fst, safe) {} 152 153 // Get a copy of this UnionFst. See Fst<>::Copy() for further doc. 154 virtual UnionFst<A> *Copy(bool safe = false) const { 155 return new UnionFst<A>(*this, safe); 156 } 157}; 158 159 160// Specialization for UnionFst. 161template <class A> 162class StateIterator< UnionFst<A> > : public StateIterator< RationalFst<A> > { 163 public: 164 explicit StateIterator(const UnionFst<A> &fst) 165 : StateIterator< RationalFst<A> >(fst) {} 166}; 167 168 169// Specialization for UnionFst. 170template <class A> 171class ArcIterator< UnionFst<A> > : public ArcIterator< RationalFst<A> > { 172 public: 173 typedef typename A::StateId StateId; 174 175 ArcIterator(const UnionFst<A> &fst, StateId s) 176 : ArcIterator< RationalFst<A> >(fst, s) {} 177}; 178 179 180// Useful alias when using StdArc. 181typedef UnionFst<StdArc> StdUnionFst; 182 183} // namespace fst 184 185#endif // FST_LIB_UNION_H__ 186