1// randequivalent.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: allauzen@google.com (Cyril Allauzen)
17//
18// \file
19// Tests if two FSTS are equivalent by checking if random
20// strings from one FST are transduced the same by both FSTs.
21
22#ifndef FST_RANDEQUIVALENT_H__
23#define FST_RANDEQUIVALENT_H__
24
25#include <fst/arcsort.h>
26#include <fst/compose.h>
27#include <fst/project.h>
28#include <fst/randgen.h>
29#include <fst/shortest-distance.h>
30#include <fst/vector-fst.h>
31
32
33namespace fst {
34
35// Test if two FSTs are equivalent by randomly generating 'num_paths'
36// paths (as specified by the RandGenOptions 'opts') in these FSTs.
37//
38// For each randomly generated path, the algorithm computes for each
39// of the two FSTs the sum of the weights of all the successful paths
40// sharing the same input and output labels as the considered randomly
41// generated path and checks that these two values are within
42// 'delta'. Returns optional error value (when FLAGS_error_fatal = false).
43template<class Arc, class ArcSelector>
44bool RandEquivalent(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
45                    ssize_t num_paths, float delta,
46                    const RandGenOptions<ArcSelector> &opts,
47                    bool *error = 0) {
48  typedef typename Arc::Weight Weight;
49  if (error) *error = false;
50
51  // Check that the symbol table are compatible
52  if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols()) ||
53      !CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols())) {
54    FSTERROR() << "RandEquivalent: input/output symbol tables of 1st "
55               << "argument do not match input/output symbol tables of 2nd "
56               << "argument";
57    if (error) *error = true;
58    return false;
59  }
60
61  ILabelCompare<Arc> icomp;
62  OLabelCompare<Arc> ocomp;
63  VectorFst<Arc> sfst1(fst1);
64  VectorFst<Arc> sfst2(fst2);
65  Connect(&sfst1);
66  Connect(&sfst2);
67  ArcSort(&sfst1, icomp);
68  ArcSort(&sfst2, icomp);
69
70  bool ret = true;
71  for (ssize_t n = 0; n < num_paths; ++n) {
72    VectorFst<Arc> path;
73    const Fst<Arc> &fst = rand() % 2 ? sfst1 : sfst2;
74    RandGen(fst, &path, opts);
75
76    VectorFst<Arc> ipath(path);
77    VectorFst<Arc> opath(path);
78    Project(&ipath, PROJECT_INPUT);
79    Project(&opath, PROJECT_OUTPUT);
80
81    VectorFst<Arc> cfst1, pfst1;
82    Compose(ipath, sfst1, &cfst1);
83    ArcSort(&cfst1, ocomp);
84    Compose(cfst1, opath, &pfst1);
85    // Give up if there are epsilon cycles in a non-idempotent semiring
86    if (!(Weight::Properties() & kIdempotent) &&
87        pfst1.Properties(kCyclic, true))
88      continue;
89    Weight sum1 = ShortestDistance(pfst1);
90
91    VectorFst<Arc> cfst2, pfst2;
92    Compose(ipath, sfst2, &cfst2);
93    ArcSort(&cfst2, ocomp);
94    Compose(cfst2, opath, &pfst2);
95    // Give up if there are epsilon cycles in a non-idempotent semiring
96    if (!(Weight::Properties() & kIdempotent) &&
97        pfst2.Properties(kCyclic, true))
98      continue;
99    Weight sum2 = ShortestDistance(pfst2);
100
101    if (!ApproxEqual(sum1, sum2, delta)) {
102        VLOG(1) << "Sum1 = " << sum1;
103        VLOG(1) << "Sum2 = " << sum2;
104        ret = false;
105        break;
106    }
107  }
108
109  if (fst1.Properties(kError, false) || fst2.Properties(kError, false)) {
110    if (error) *error = true;
111    return false;
112  }
113
114  return ret;
115}
116
117
118// Test if two FSTs are equivalent by randomly generating 'num_paths' paths
119// of length no more than 'path_length' using the seed 'seed' in these FSTs.
120// Returns optional error value (when FLAGS_error_fatal = false).
121template <class Arc>
122bool RandEquivalent(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
123                    ssize_t num_paths, float delta = kDelta,
124                    int seed = time(0), int path_length = INT_MAX,
125                    bool *error = 0) {
126  UniformArcSelector<Arc> uniform_selector(seed);
127  RandGenOptions< UniformArcSelector<Arc> >
128      opts(uniform_selector, path_length);
129  return RandEquivalent(fst1, fst2, num_paths, delta, opts, error);
130}
131
132
133}  // namespace fst
134
135#endif  // FST_LIB_RANDEQUIVALENT_H__
136