1// intersect.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Class to compute the intersection of two FSAs
20
21#ifndef FST_LIB_INTERSECT_H__
22#define FST_LIB_INTERSECT_H__
23
24#include <algorithm>
25#include <vector>
26using std::vector;
27
28#include <fst/cache.h>
29#include <fst/compose.h>
30
31
32namespace fst {
33
34template <class A,
35          class M = Matcher<Fst<A> >,
36          class F = SequenceComposeFilter<M>,
37          class T = GenericComposeStateTable<A, typename F::FilterState> >
38struct IntersectFstOptions : public ComposeFstOptions<A, M, F, T> {
39  explicit IntersectFstOptions(const CacheOptions &opts,
40                               M *mat1 = 0, M *mat2 = 0,
41                               F *filt = 0, T *sttable= 0)
42      : ComposeFstOptions<A, M, F, T>(opts, mat1, mat2, filt, sttable) { }
43
44  IntersectFstOptions() {}
45};
46
47// Computes the intersection (Hadamard product) of two FSAs. This
48// version is a delayed Fst.  Only strings that are in both automata
49// are retained in the result.
50//
51// The two arguments must be acceptors. One of the arguments must be
52// label-sorted.
53//
54// Complexity: same as ComposeFst.
55//
56// Caveats:  same as ComposeFst.
57template <class A>
58class IntersectFst : public ComposeFst<A> {
59 public:
60  using ComposeFst<A>::CreateBase;
61  using ComposeFst<A>::CreateBase1;
62  using ComposeFst<A>::Properties;
63  using ImplToFst< ComposeFstImplBase<A> >::GetImpl;
64  using ImplToFst< ComposeFstImplBase<A> >::SetImpl;
65
66  typedef A Arc;
67  typedef typename A::Weight Weight;
68  typedef typename A::StateId StateId;
69
70  IntersectFst(const Fst<A> &fst1, const Fst<A> &fst2,
71               const CacheOptions opts = CacheOptions()) {
72    bool acceptors = fst1.Properties(kAcceptor, true) &&
73        fst2.Properties(kAcceptor, true);
74    SetImpl(CreateBase(fst1, fst2, opts));
75    if (!acceptors) {
76      FSTERROR() << "IntersectFst: input FSTs are not acceptors";
77      GetImpl()->SetProperties(kError);
78    }
79  }
80
81  template <class M, class F, class T>
82  IntersectFst(const Fst<A> &fst1, const Fst<A> &fst2,
83               const IntersectFstOptions<A, M, F, T> &opts) {
84    bool acceptors = fst1.Properties(kAcceptor, true) &&
85        fst2.Properties(kAcceptor, true);
86    SetImpl(CreateBase1(fst1, fst2, opts));
87    if (!acceptors) {
88      FSTERROR() << "IntersectFst: input FSTs are not acceptors";
89      GetImpl()->SetProperties(kError);
90    }
91  }
92
93  // See Fst<>::Copy() for doc.
94  IntersectFst(const IntersectFst<A> &fst, bool safe = false) :
95      ComposeFst<A>(fst, safe) {}
96
97  // Get a copy of this IntersectFst. See Fst<>::Copy() for further doc.
98  virtual IntersectFst<A> *Copy(bool safe = false) const {
99    return new IntersectFst<A>(*this, safe);
100  }
101};
102
103
104// Specialization for IntersectFst.
105template <class A>
106class StateIterator< IntersectFst<A> >
107    : public StateIterator< ComposeFst<A> > {
108 public:
109  explicit StateIterator(const IntersectFst<A> &fst)
110      : StateIterator< ComposeFst<A> >(fst) {}
111};
112
113
114// Specialization for IntersectFst.
115template <class A>
116class ArcIterator< IntersectFst<A> >
117    : public ArcIterator< ComposeFst<A> > {
118 public:
119  typedef typename A::StateId StateId;
120
121  ArcIterator(const IntersectFst<A> &fst, StateId s)
122      : ArcIterator< ComposeFst<A> >(fst, s) {}
123};
124
125// Useful alias when using StdArc.
126typedef IntersectFst<StdArc> StdIntersectFst;
127
128
129typedef ComposeOptions IntersectOptions;
130
131
132// Computes the intersection (Hadamard product) of two FSAs. This
133// version writes the intersection to an output MurableFst. Only
134// strings that are in both automata are retained in the result.
135//
136// The two arguments must be acceptors. One of the arguments must be
137// label-sorted.
138//
139// Complexity: same as Compose.
140//
141// Caveats:  same as Compose.
142template<class Arc>
143void Intersect(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
144             MutableFst<Arc> *ofst,
145             const IntersectOptions &opts = IntersectOptions()) {
146  typedef Matcher< Fst<Arc> > M;
147
148  if (opts.filter_type == AUTO_FILTER) {
149    CacheOptions nopts;
150    nopts.gc_limit = 0;  // Cache only the last state for fastest copy.
151    *ofst = IntersectFst<Arc>(ifst1, ifst2, nopts);
152  } else if (opts.filter_type == SEQUENCE_FILTER) {
153    IntersectFstOptions<Arc> iopts;
154    iopts.gc_limit = 0;  // Cache only the last state for fastest copy.
155    *ofst = IntersectFst<Arc>(ifst1, ifst2, iopts);
156  } else if (opts.filter_type == ALT_SEQUENCE_FILTER) {
157    IntersectFstOptions<Arc, M, AltSequenceComposeFilter<M> > iopts;
158    iopts.gc_limit = 0;  // Cache only the last state for fastest copy.
159    *ofst = IntersectFst<Arc>(ifst1, ifst2, iopts);
160  } else if (opts.filter_type == MATCH_FILTER) {
161    IntersectFstOptions<Arc, M, MatchComposeFilter<M> > iopts;
162    iopts.gc_limit = 0;  // Cache only the last state for fastest copy.
163    *ofst = IntersectFst<Arc>(ifst1, ifst2, iopts);
164  }
165
166  if (opts.connect)
167    Connect(ofst);
168}
169
170}  // namespace fst
171
172#endif  // FST_LIB_INTERSECT_H__
173