1// weight-tester.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Utility class for regression testing of Fst weights.
20
21#ifndef FST_TEST_WEIGHT_TESTER_H_
22#define FST_TEST_WEIGHT_TESTER_H_
23
24#include <iostream>
25#include <sstream>
26
27#include <fst/random-weight.h>
28
29namespace fst {
30
31// This class tests a variety of identities and properties that must
32// hold for the Weight class to be well-defined. It calls function object
33// WEIGHT_GENERATOR to select weights that are used in the tests.
34template<class Weight, class WeightGenerator>
35class WeightTester {
36 public:
37  WeightTester(WeightGenerator generator) : weight_generator_(generator) {}
38
39  void Test(int iterations, bool test_division = true) {
40    for (int i = 0; i < iterations; ++i) {
41      // Selects the test weights.
42      Weight w1 = weight_generator_();
43      Weight w2 = weight_generator_();
44      Weight w3 = weight_generator_();
45
46      VLOG(1) << "weight type = " << Weight::Type();
47      VLOG(1) << "w1 = " << w1;
48      VLOG(1) << "w2 = " << w2;
49      VLOG(1) << "w3 = " << w3;
50
51      TestSemiring(w1, w2, w3);
52      if (test_division)
53        TestDivision(w1, w2);
54      TestReverse(w1, w2);
55      TestEquality(w1, w2, w3);
56      TestIO(w1);
57      TestCopy(w1);
58    }
59  }
60
61 private:
62  // Note in the tests below we use ApproxEqual rather than == and add
63  // kDelta to inequalities where the weights might be inexact.
64
65  // Tests (Plus, Times, Zero, One) defines a commutative semiring.
66  void TestSemiring(Weight w1, Weight w2, Weight w3) {
67    // Checks that the operations are closed.
68    CHECK(Plus(w1, w2).Member());
69    CHECK(Times(w1, w2).Member());
70
71    // Checks that the operations are associative.
72    CHECK(ApproxEqual(Plus(w1, Plus(w2, w3)), Plus(Plus(w1, w2), w3)));
73    CHECK(ApproxEqual(Times(w1, Times(w2, w3)), Times(Times(w1, w2), w3)));
74
75    // Checks the identity elements.
76    CHECK(Plus(w1, Weight::Zero()) == w1);
77    CHECK(Plus(Weight::Zero(), w1) == w1);
78    CHECK(Times(w1, Weight::One()) == w1);
79    CHECK(Times(Weight::One(), w1) == w1);
80
81    // Check the no weight element.
82    CHECK(!Weight::NoWeight().Member());
83    CHECK(!Plus(w1, Weight::NoWeight()).Member());
84    CHECK(!Plus(Weight::NoWeight(), w1).Member());
85    CHECK(!Times(w1, Weight::NoWeight()).Member());
86    CHECK(!Times(Weight::NoWeight(), w1).Member());
87
88    // Checks that the operations commute.
89    CHECK(ApproxEqual(Plus(w1, w2), Plus(w2, w1)));
90    if (Weight::Properties() & kCommutative)
91      CHECK(ApproxEqual(Times(w1, w2), Times(w2, w1)));
92
93    // Checks Zero() is the annihilator.
94    CHECK(Times(w1, Weight::Zero()) == Weight::Zero());
95    CHECK(Times(Weight::Zero(), w1) == Weight::Zero());
96
97    // Check Power(w, 0) is Weight::One()
98    CHECK(Power(w1, 0) == Weight::One());
99
100    // Check Power(w, 1) is w
101    CHECK(Power(w1, 1) == w1);
102
103    // Check Power(w, 3) is Times(w, Times(w, w))
104    CHECK(Power(w1, 3) == Times(w1, Times(w1, w1)));
105
106    // Checks distributivity.
107    if (Weight::Properties() & kLeftSemiring)
108      CHECK(ApproxEqual(Times(w1, Plus(w2, w3)),
109                        Plus(Times(w1, w2), Times(w1, w3))));
110    if (Weight::Properties() & kRightSemiring)
111      CHECK(ApproxEqual(Times(Plus(w1, w2), w3),
112                        Plus(Times(w1, w3), Times(w2, w3))));
113
114    if (Weight::Properties() & kIdempotent)
115      CHECK(Plus(w1, w1) == w1);
116
117    if (Weight::Properties() & kPath)
118      CHECK(Plus(w1, w2) == w1 || Plus(w1, w2) == w2);
119
120    // Ensure weights form a left or right semiring.
121    CHECK(Weight::Properties() & (kLeftSemiring | kRightSemiring));
122
123    // Check when Times() is commutative that it is marked as a semiring.
124    if (Weight::Properties() & kCommutative)
125      CHECK(Weight::Properties() & kSemiring);
126  }
127
128  // Tests division operation.
129  void TestDivision(Weight w1, Weight w2) {
130    Weight p = Times(w1, w2);
131
132    if (Weight::Properties() & kLeftSemiring) {
133      Weight d = Divide(p, w1, DIVIDE_LEFT);
134      if (d.Member())
135        CHECK(ApproxEqual(p, Times(w1, d)));
136      CHECK(!Divide(w1, Weight::NoWeight(), DIVIDE_LEFT).Member());
137      CHECK(!Divide(Weight::NoWeight(), w1, DIVIDE_LEFT).Member());
138    }
139
140    if (Weight::Properties() & kRightSemiring) {
141      Weight d = Divide(p, w2, DIVIDE_RIGHT);
142      if (d.Member())
143        CHECK(ApproxEqual(p, Times(d, w2)));
144      CHECK(!Divide(w1, Weight::NoWeight(), DIVIDE_RIGHT).Member());
145      CHECK(!Divide(Weight::NoWeight(), w1, DIVIDE_RIGHT).Member());
146    }
147
148    if (Weight::Properties() & kCommutative) {
149      Weight d = Divide(p, w1, DIVIDE_RIGHT);
150      if (d.Member())
151        CHECK(ApproxEqual(p, Times(d, w1)));
152    }
153  }
154
155  // Tests reverse operation.
156  void TestReverse(Weight w1, Weight w2) {
157    typedef typename Weight::ReverseWeight ReverseWeight;
158
159    ReverseWeight rw1 = w1.Reverse();
160    ReverseWeight rw2 = w2.Reverse();
161
162    CHECK(rw1.Reverse() == w1);
163    CHECK(Plus(w1, w2).Reverse() == Plus(rw1, rw2));
164    CHECK(Times(w1, w2).Reverse() == Times(rw2, rw1));
165  }
166
167  // Tests == is an equivalence relation.
168  void TestEquality(Weight w1, Weight w2, Weight w3) {
169    // Checks reflexivity.
170    CHECK(w1 == w1);
171
172    // Checks symmetry.
173    CHECK((w1 == w2) == (w2 == w1));
174
175    // Checks transitivity.
176    if (w1 == w2 && w2 == w3)
177      CHECK(w1 == w3);
178  }
179
180  // Tests binary serialization and textual I/O.
181  void TestIO(Weight w) {
182    // Tests binary I/O
183    {
184    ostringstream os;
185    w.Write(os);
186    os.flush();
187    istringstream is(os.str());
188    Weight v;
189    v.Read(is);
190    CHECK_EQ(w, v);
191    }
192
193    // Tests textual I/O.
194    {
195      ostringstream os;
196      os << w;
197      istringstream is(os.str());
198      Weight v(Weight::One());
199      is >> v;
200      CHECK(ApproxEqual(w, v));
201    }
202  }
203
204  // Tests copy constructor and assignment operator
205  void TestCopy(Weight w) {
206    Weight x = w;
207    CHECK(w == x);
208
209    x = Weight(w);
210    CHECK(w == x);
211
212    x.operator=(x);
213    CHECK(w == x);
214
215  }
216
217  // Generates weights used in testing.
218  WeightGenerator weight_generator_;
219
220  DISALLOW_COPY_AND_ASSIGN(WeightTester);
221};
222
223}  // namespace fst
224
225#endif  // FST_TEST_WEIGHT_TESTER_H_
226