1// weight_test.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// Regression test for Fst weights.
20
21#include <cstdlib>
22#include <ctime>
23
24#include <fst/expectation-weight.h>
25#include <fst/float-weight.h>
26#include <fst/random-weight.h>
27#include "./weight-tester.h"
28
29DEFINE_int32(seed, -1, "random seed");
30DEFINE_int32(repeat, 100000, "number of test repetitions");
31
32using fst::TropicalWeight;
33using fst::TropicalWeightGenerator;
34using fst::TropicalWeightTpl;
35using fst::TropicalWeightGenerator_;
36
37using fst::LogWeight;
38using fst::LogWeightGenerator;
39using fst::LogWeightTpl;
40using fst::LogWeightGenerator_;
41
42using fst::MinMaxWeight;
43using fst::MinMaxWeightGenerator;
44using fst::MinMaxWeightTpl;
45using fst::MinMaxWeightGenerator_;
46
47using fst::StringWeight;
48using fst::StringWeightGenerator;
49
50using fst::GallicWeight;
51using fst::GallicWeightGenerator;
52
53using fst::LexicographicWeight;
54using fst::LexicographicWeightGenerator;
55
56using fst::ProductWeight;
57using fst::ProductWeightGenerator;
58
59using fst::PowerWeight;
60using fst::PowerWeightGenerator;
61
62using fst::SignedLogWeightTpl;
63using fst::SignedLogWeightGenerator_;
64
65using fst::ExpectationWeight;
66
67using fst::SparsePowerWeight;
68using fst::SparsePowerWeightGenerator;
69
70using fst::STRING_LEFT;
71using fst::STRING_RIGHT;
72
73using fst::WeightTester;
74
75template <class T>
76void TestTemplatedWeights(int repeat, int seed) {
77  TropicalWeightGenerator_<T> tropical_generator(seed);
78  WeightTester<TropicalWeightTpl<T>, TropicalWeightGenerator_<T> >
79      tropical_tester(tropical_generator);
80  tropical_tester.Test(repeat);
81
82  LogWeightGenerator_<T> log_generator(seed);
83  WeightTester<LogWeightTpl<T>, LogWeightGenerator_<T> >
84      log_tester(log_generator);
85  log_tester.Test(repeat);
86
87  MinMaxWeightGenerator_<T> minmax_generator(seed);
88  WeightTester<MinMaxWeightTpl<T>, MinMaxWeightGenerator_<T> >
89      minmax_tester(minmax_generator);
90  minmax_tester.Test(repeat);
91
92  SignedLogWeightGenerator_<T> signedlog_generator(seed);
93  WeightTester<SignedLogWeightTpl<T>, SignedLogWeightGenerator_<T> >
94      signedlog_tester(signedlog_generator);
95  signedlog_tester.Test(repeat);
96}
97
98int main(int argc, char **argv) {
99  std::set_new_handler(FailedNewHandler);
100  SET_FLAGS(argv[0], &argc, &argv, true);
101
102  int seed = FLAGS_seed >= 0 ? FLAGS_seed : time(0);
103  LOG(INFO) << "Seed = " << seed;
104
105  TestTemplatedWeights<float>(FLAGS_repeat, seed);
106  TestTemplatedWeights<double>(FLAGS_repeat, seed);
107  FLAGS_fst_weight_parentheses = "()";
108  TestTemplatedWeights<float>(FLAGS_repeat, seed);
109  TestTemplatedWeights<double>(FLAGS_repeat, seed);
110  FLAGS_fst_weight_parentheses = "";
111
112  // Make sure type names for templated weights are consistent
113  CHECK(TropicalWeight::Type() == "tropical");
114  CHECK(TropicalWeightTpl<double>::Type() != TropicalWeightTpl<float>::Type());
115  CHECK(LogWeight::Type() == "log");
116  CHECK(LogWeightTpl<double>::Type() != LogWeightTpl<float>::Type());
117  TropicalWeightTpl<double> w(15.0);
118  TropicalWeight tw(15.0);
119
120  StringWeightGenerator<int> left_string_generator(seed);
121  WeightTester<StringWeight<int>, StringWeightGenerator<int> >
122    left_string_tester(left_string_generator);
123  left_string_tester.Test(FLAGS_repeat);
124
125  StringWeightGenerator<int, STRING_RIGHT> right_string_generator(seed);
126  WeightTester<StringWeight<int, STRING_RIGHT>,
127    StringWeightGenerator<int, STRING_RIGHT> >
128    right_string_tester(right_string_generator);
129  right_string_tester.Test(FLAGS_repeat);
130
131  typedef GallicWeight<int, TropicalWeight> TropicalGallicWeight;
132  typedef GallicWeightGenerator<int, TropicalWeightGenerator>
133    TropicalGallicWeightGenerator;
134
135  TropicalGallicWeightGenerator tropical_gallic_generator(seed);
136  WeightTester<TropicalGallicWeight, TropicalGallicWeightGenerator>
137    tropical_gallic_tester(tropical_gallic_generator);
138  tropical_gallic_tester.Test(FLAGS_repeat);
139
140  typedef ProductWeight<TropicalWeight, TropicalWeight> TropicalProductWeight;
141  typedef ProductWeightGenerator<TropicalWeightGenerator,
142      TropicalWeightGenerator> TropicalProductWeightGenerator;
143
144  TropicalProductWeightGenerator tropical_product_generator(seed);
145  WeightTester<TropicalProductWeight, TropicalProductWeightGenerator>
146      tropical_product_weight_tester(tropical_product_generator);
147  tropical_product_weight_tester.Test(FLAGS_repeat);
148
149  typedef PowerWeight<TropicalWeight, 3> TropicalCubeWeight;
150  typedef PowerWeightGenerator<TropicalWeightGenerator, 3>
151      TropicalCubeWeightGenerator;
152
153  TropicalCubeWeightGenerator tropical_cube_generator(seed);
154  WeightTester<TropicalCubeWeight, TropicalCubeWeightGenerator>
155      tropical_cube_weight_tester(tropical_cube_generator);
156  tropical_cube_weight_tester.Test(FLAGS_repeat);
157
158  typedef ProductWeight<TropicalWeight, TropicalProductWeight>
159      SecondNestedProductWeight;
160  typedef ProductWeightGenerator<TropicalWeightGenerator,
161      TropicalProductWeightGenerator> SecondNestedProductWeightGenerator;
162
163  SecondNestedProductWeightGenerator second_nested_product_generator(seed);
164  WeightTester<SecondNestedProductWeight, SecondNestedProductWeightGenerator>
165      second_nested_product_weight_tester(second_nested_product_generator);
166  second_nested_product_weight_tester.Test(FLAGS_repeat);
167
168  // This only works with fst_weight_parentheses = "()"
169  typedef ProductWeight<TropicalProductWeight, TropicalWeight>
170      FirstNestedProductWeight;
171  typedef ProductWeightGenerator<TropicalProductWeightGenerator,
172      TropicalWeightGenerator> FirstNestedProductWeightGenerator;
173
174  FirstNestedProductWeightGenerator first_nested_product_generator(seed);
175  WeightTester<FirstNestedProductWeight, FirstNestedProductWeightGenerator>
176      first_nested_product_weight_tester(first_nested_product_generator);
177
178  typedef PowerWeight<FirstNestedProductWeight, 3> NestedProductCubeWeight;
179  typedef PowerWeightGenerator<FirstNestedProductWeightGenerator, 3>
180      NestedProductCubeWeightGenerator;
181
182  NestedProductCubeWeightGenerator nested_product_cube_generator(seed);
183  WeightTester<NestedProductCubeWeight, NestedProductCubeWeightGenerator>
184      nested_product_cube_weight_tester(nested_product_cube_generator);
185
186  typedef SparsePowerWeight<NestedProductCubeWeight,
187      size_t > SparseNestedProductCubeWeight;
188  typedef SparsePowerWeightGenerator<NestedProductCubeWeightGenerator,
189      size_t, 3> SparseNestedProductCubeWeightGenerator;
190
191  SparseNestedProductCubeWeightGenerator
192      sparse_nested_product_cube_generator(seed);
193  WeightTester<SparseNestedProductCubeWeight,
194      SparseNestedProductCubeWeightGenerator>
195      sparse_nested_product_cube_weight_tester(
196          sparse_nested_product_cube_generator);
197
198  typedef SparsePowerWeight<LogWeight, size_t > LogSparsePowerWeight;
199  typedef SparsePowerWeightGenerator<LogWeightGenerator,
200      size_t, 3> LogSparsePowerWeightGenerator;
201
202  LogSparsePowerWeightGenerator
203      log_sparse_power_weight_generator(seed);
204  WeightTester<LogSparsePowerWeight,
205      LogSparsePowerWeightGenerator>
206      log_sparse_power_weight_tester(
207          log_sparse_power_weight_generator);
208
209  typedef ExpectationWeight<LogWeight, LogWeight>
210      LogLogExpectWeight;
211  typedef ProductWeightGenerator<LogWeightGenerator, LogWeightGenerator,
212    LogLogExpectWeight> LogLogExpectWeightGenerator;
213
214  LogLogExpectWeightGenerator log_log_expect_weight_generator(seed);
215  WeightTester<LogLogExpectWeight, LogLogExpectWeightGenerator>
216      log_log_expect_weight_tester(log_log_expect_weight_generator);
217
218  typedef ExpectationWeight<LogWeight, LogSparsePowerWeight>
219      LogLogSparseExpectWeight;
220  typedef ProductWeightGenerator<
221    LogWeightGenerator,
222    LogSparsePowerWeightGenerator,
223    LogLogSparseExpectWeight> LogLogSparseExpectWeightGenerator;
224
225  LogLogSparseExpectWeightGenerator log_logsparse_expect_weight_generator(seed);
226  WeightTester<LogLogSparseExpectWeight, LogLogSparseExpectWeightGenerator>
227      log_logsparse_expect_weight_tester(log_logsparse_expect_weight_generator);
228
229  // Test all product weight I/O with parentheses
230  FLAGS_fst_weight_parentheses = "()";
231  first_nested_product_weight_tester.Test(FLAGS_repeat);
232  nested_product_cube_weight_tester.Test(FLAGS_repeat);
233  log_sparse_power_weight_tester.Test(1);
234  sparse_nested_product_cube_weight_tester.Test(1);
235  tropical_product_weight_tester.Test(5);
236  second_nested_product_weight_tester.Test(5);
237  tropical_gallic_tester.Test(5);
238  tropical_cube_weight_tester.Test(5);
239  FLAGS_fst_weight_parentheses = "";
240  log_sparse_power_weight_tester.Test(1);
241  log_log_expect_weight_tester.Test(1, false); // disables division
242  log_logsparse_expect_weight_tester.Test(1, false);
243
244  typedef LexicographicWeight<TropicalWeight, TropicalWeight>
245      TropicalLexicographicWeight;
246  typedef LexicographicWeightGenerator<TropicalWeightGenerator,
247      TropicalWeightGenerator> TropicalLexicographicWeightGenerator;
248
249  TropicalLexicographicWeightGenerator tropical_lexicographic_generator(seed);
250  WeightTester<TropicalLexicographicWeight,
251      TropicalLexicographicWeightGenerator>
252    tropical_lexicographic_tester(tropical_lexicographic_generator);
253  tropical_lexicographic_tester.Test(FLAGS_repeat);
254
255  cout << "PASS" << endl;
256
257  return 0;
258}
259