equal.h revision 8fc5a7f51e62cb4ae44a27bdf4176d04adc80ede
1// test.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// Function to test equality of two Fsts.
18
19#ifndef FST_LIB_EQUAL_H__
20#define FST_LIB_EQUAL_H__
21
22#include "fst/lib/fst.h"
23
24namespace fst {
25
26// Tests if two Fsts have the same states and arcs in the same order.
27template<class Arc>
28bool Equal(const Fst<Arc> &fst1, const Fst<Arc> &fst2) {
29  typedef typename Arc::StateId StateId;
30  typedef typename Arc::Weight Weight;
31
32  if (fst1.Start() != fst2.Start()) {
33    VLOG(1) << "Equal: mismatched start states";
34    return false;
35  }
36
37  StateIterator< Fst<Arc> > siter1(fst1);
38  StateIterator< Fst<Arc> > siter2(fst2);
39
40  while (!siter1.Done() || !siter2.Done()) {
41    if (siter1.Done() || siter2.Done()) {
42      VLOG(1) << "Equal: mismatched # of states";
43      return false;
44    }
45    StateId s1 = siter1.Value();
46    StateId s2 = siter2.Value();
47    if (s1 != s2) {
48      VLOG(1) << "Equal: mismatched states:"
49              << ", state1 = " << s1
50              << ", state2 = " << s2;
51      return false;
52    }
53    Weight final1 = fst1.Final(s1);
54    Weight final2 = fst2.Final(s2);
55    if (!ApproxEqual(final1, final2)) {
56      VLOG(1) << "Equal: mismatched final weights:"
57              << " state = " << s1
58              << ", final1 = " << final1
59              << ", final2 = " << final2;
60      return false;
61     }
62    ArcIterator< Fst<Arc> > aiter1(fst1, s1);
63    ArcIterator< Fst<Arc> > aiter2(fst2, s2);
64    for (size_t a = 0; !aiter1.Done() || !aiter2.Done(); ++a) {
65      if (aiter1.Done() || aiter2.Done()) {
66        VLOG(1) << "Equal: mismatched # of arcs"
67                << " state = " << s1;
68        return false;
69      }
70      Arc arc1 = aiter1.Value();
71      Arc arc2 = aiter2.Value();
72      if (arc1.ilabel != arc2.ilabel) {
73        VLOG(1) << "Equal: mismatched arc input labels:"
74                << " state = " << s1
75                << ", arc = " << a
76                << ", ilabel1 = " << arc1.ilabel
77                << ", ilabel2 = " << arc2.ilabel;
78        return false;
79      } else  if (arc1.olabel != arc2.olabel) {
80        VLOG(1) << "Equal: mismatched arc output labels:"
81                << " state = " << s1
82                << ", arc = " << a
83                << ", olabel1 = " << arc1.olabel
84                << ", olabel2 = " << arc2.olabel;
85        return false;
86      } else  if (!ApproxEqual(arc1.weight, arc2.weight)) {
87        VLOG(1) << "Equal: mismatched arc weights:"
88                << " state = " << s1
89                << ", arc = " << a
90                << ", weight1 = " << arc1.weight
91                << ", weight2 = " << arc2.weight;
92        return false;
93      } else  if (arc1.nextstate != arc2.nextstate) {
94        VLOG(1) << "Equal: mismatched input label:"
95                << " state = " << s1
96                << ", arc = " << a
97                << ", nextstate1 = " << arc1.nextstate
98                << ", nextstate2 = " << arc2.nextstate;
99        return false;
100      }
101      aiter1.Next();
102      aiter2.Next();
103
104    }
105    // Sanity checks
106    CHECK_EQ(fst1.NumArcs(s1), fst2.NumArcs(s2));
107    CHECK_EQ(fst1.NumInputEpsilons(s1), fst2.NumInputEpsilons(s2));
108    CHECK_EQ(fst1.NumOutputEpsilons(s1), fst2.NumOutputEpsilons(s2));
109
110    siter1.Next();
111    siter2.Next();
112  }
113  return true;
114}
115
116}  // namespace fst
117
118
119#endif  // FST_LIB_EQUAL_H__
120