1// concat.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Functions and classes to compute the concat of two FSTs.
20
21#ifndef FST_LIB_CONCAT_H__
22#define FST_LIB_CONCAT_H__
23
24#include <vector>
25using std::vector;
26#include <algorithm>
27
28#include <fst/mutable-fst.h>
29#include <fst/rational.h>
30
31
32namespace fst {
33
34// Computes the concatenation (product) of two FSTs. If FST1
35// transduces string x to y with weight a and FST2 transduces string w
36// to v with weight b, then their concatenation transduces string xw
37// to yv with Times(a, b).
38//
39// This version modifies its MutableFst argument (in first position).
40//
41// Complexity:
42// - Time: O(V1 + V2 + E2)
43// - Space: O(V1 + V2 + E2)
44// where Vi = # of states and Ei = # of arcs of the ith FST.
45//
46template<class Arc>
47void Concat(MutableFst<Arc> *fst1, const Fst<Arc> &fst2) {
48  typedef typename Arc::StateId StateId;
49  typedef typename Arc::Label Label;
50  typedef typename Arc::Weight Weight;
51
52  // TODO(riley): restore when voice actions issues fixed
53  // Check that the symbol table are compatible
54  if (!CompatSymbols(fst1->InputSymbols(), fst2.InputSymbols()) ||
55      !CompatSymbols(fst1->OutputSymbols(), fst2.OutputSymbols())) {
56    LOG(ERROR) << "Concat: input/output symbol tables of 1st argument "
57               << "do not match input/output symbol tables of 2nd argument";
58    // fst1->SetProperties(kError, kError);
59    // return;
60  }
61
62  uint64 props1 = fst1->Properties(kFstProperties, false);
63  uint64 props2 = fst2.Properties(kFstProperties, false);
64
65  StateId start1 = fst1->Start();
66  if (start1 == kNoStateId) {
67    if (props2 & kError) fst1->SetProperties(kError, kError);
68    return;
69  }
70
71  StateId numstates1 = fst1->NumStates();
72  if (fst2.Properties(kExpanded, false))
73    fst1->ReserveStates(numstates1 + CountStates(fst2));
74
75  for (StateIterator< Fst<Arc> > siter2(fst2);
76       !siter2.Done();
77       siter2.Next()) {
78    StateId s1 = fst1->AddState();
79    StateId s2 = siter2.Value();
80    fst1->SetFinal(s1, fst2.Final(s2));
81    fst1->ReserveArcs(s1, fst2.NumArcs(s2));
82    for (ArcIterator< Fst<Arc> > aiter(fst2, s2);
83         !aiter.Done();
84         aiter.Next()) {
85      Arc arc = aiter.Value();
86      arc.nextstate += numstates1;
87      fst1->AddArc(s1, arc);
88    }
89  }
90
91  StateId start2 = fst2.Start();
92  for (StateId s1 = 0; s1 < numstates1; ++s1) {
93    Weight final = fst1->Final(s1);
94    if (final != Weight::Zero()) {
95      fst1->SetFinal(s1, Weight::Zero());
96      if (start2 != kNoStateId)
97        fst1->AddArc(s1, Arc(0, 0, final, start2 + numstates1));
98    }
99  }
100  if (start2 != kNoStateId)
101    fst1->SetProperties(ConcatProperties(props1, props2), kFstProperties);
102}
103
104// Computes the concatentation of two FSTs.  This version modifies its
105// MutableFst argument (in second position).
106//
107// Complexity:
108// - Time: O(V1 + E1)
109// - Space: O(V1 + E1)
110// where Vi = # of states and Ei = # of arcs of the ith FST.
111//
112template<class Arc>
113void Concat(const Fst<Arc> &fst1, MutableFst<Arc> *fst2) {
114  typedef typename Arc::StateId StateId;
115  typedef typename Arc::Label Label;
116  typedef typename Arc::Weight Weight;
117
118  // Check that the symbol table are compatible
119  if (!CompatSymbols(fst1.InputSymbols(), fst2->InputSymbols()) ||
120      !CompatSymbols(fst1.OutputSymbols(), fst2->OutputSymbols())) {
121    LOG(ERROR) << "Concat: input/output symbol tables of 1st argument "
122               << "do not match input/output symbol tables of 2nd argument";
123    // fst2->SetProperties(kError, kError);
124    // return;
125  }
126
127  uint64 props1 = fst1.Properties(kFstProperties, false);
128  uint64 props2 = fst2->Properties(kFstProperties, false);
129
130  StateId start2 = fst2->Start();
131  if (start2 == kNoStateId) {
132    if (props1 & kError) fst2->SetProperties(kError, kError);
133    return;
134  }
135
136  StateId numstates2 = fst2->NumStates();
137  if (fst1.Properties(kExpanded, false))
138    fst2->ReserveStates(numstates2 + CountStates(fst1));
139
140  for (StateIterator< Fst<Arc> > siter(fst1);
141       !siter.Done();
142       siter.Next()) {
143    StateId s1 = siter.Value();
144    StateId s2 = fst2->AddState();
145    Weight final = fst1.Final(s1);
146    fst2->ReserveArcs(s2, fst1.NumArcs(s1) + (final != Weight::Zero() ? 1 : 0));
147    if (final != Weight::Zero())
148      fst2->AddArc(s2, Arc(0, 0, final, start2));
149    for (ArcIterator< Fst<Arc> > aiter(fst1, s1);
150         !aiter.Done();
151         aiter.Next()) {
152      Arc arc = aiter.Value();
153      arc.nextstate += numstates2;
154      fst2->AddArc(s2, arc);
155    }
156  }
157  StateId start1 = fst1.Start();
158  fst2->SetStart(start1 == kNoStateId ? fst2->AddState() : start1 + numstates2);
159  if (start1 != kNoStateId)
160    fst2->SetProperties(ConcatProperties(props1, props2), kFstProperties);
161}
162
163
164// Computes the concatentation of two FSTs. This version modifies its
165// RationalFst input (in first position).
166template<class Arc>
167void Concat(RationalFst<Arc> *fst1, const Fst<Arc> &fst2) {
168  fst1->GetImpl()->AddConcat(fst2, true);
169}
170
171// Computes the concatentation of two FSTs. This version modifies its
172// RationalFst input (in second position).
173template<class Arc>
174void Concat(const Fst<Arc> &fst1, RationalFst<Arc> *fst2) {
175  fst2->GetImpl()->AddConcat(fst1, false);
176}
177
178typedef RationalFstOptions ConcatFstOptions;
179
180
181// Computes the concatenation (product) of two FSTs; this version is a
182// delayed Fst. If FST1 transduces string x to y with weight a and FST2
183// transduces string w to v with weight b, then their concatenation
184// transduces string xw to yv with Times(a, b).
185//
186// Complexity:
187// - Time: O(v1 + e1 + v2 + e2),
188// - Space: O(v1 + v2)
189// where vi = # of states visited and ei = # of arcs visited of the
190// ith FST. Constant time and space to visit an input state or arc is
191// assumed and exclusive of caching.
192template <class A>
193class ConcatFst : public RationalFst<A> {
194 public:
195  using ImplToFst< RationalFstImpl<A> >::GetImpl;
196
197  typedef A Arc;
198  typedef typename A::Weight Weight;
199  typedef typename A::StateId StateId;
200
201  ConcatFst(const Fst<A> &fst1, const Fst<A> &fst2) {
202    GetImpl()->InitConcat(fst1, fst2);
203  }
204
205  ConcatFst(const Fst<A> &fst1, const Fst<A> &fst2,
206            const ConcatFstOptions &opts) : RationalFst<A>(opts) {
207    GetImpl()->InitConcat(fst1, fst2);
208  }
209
210  // See Fst<>::Copy() for doc.
211  ConcatFst(const ConcatFst<A> &fst, bool safe = false)
212      : RationalFst<A>(fst, safe) {}
213
214  // Get a copy of this ConcatFst. See Fst<>::Copy() for further doc.
215  virtual ConcatFst<A> *Copy(bool safe = false) const {
216    return new ConcatFst<A>(*this, safe);
217  }
218};
219
220
221// Specialization for ConcatFst.
222template <class A>
223class StateIterator< ConcatFst<A> > : public StateIterator< RationalFst<A> > {
224 public:
225  explicit StateIterator(const ConcatFst<A> &fst)
226      : StateIterator< RationalFst<A> >(fst) {}
227};
228
229
230// Specialization for ConcatFst.
231template <class A>
232class ArcIterator< ConcatFst<A> > : public ArcIterator< RationalFst<A> > {
233 public:
234  typedef typename A::StateId StateId;
235
236  ArcIterator(const ConcatFst<A> &fst, StateId s)
237      : ArcIterator< RationalFst<A> >(fst, s) {}
238};
239
240
241// Useful alias when using StdArc.
242typedef ConcatFst<StdArc> StdConcatFst;
243
244}  // namespace fst
245
246#endif  // FST_LIB_CONCAT_H__
247