tuple-weight.h revision f4c12fce1ee58e670f9c3fce46c40296ba9ee8a2
1// tuple-weight.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: allauzen@google (Cyril Allauzen)
17//
18// \file
19// Tuple weight set operation definitions.
20
21#ifndef FST_LIB_TUPLE_WEIGHT_H__
22#define FST_LIB_TUPLE_WEIGHT_H__
23
24#include <string>
25#include <vector>
26using std::vector;
27
28#include <fst/weight.h>
29
30
31DECLARE_string(fst_weight_parentheses);
32DECLARE_string(fst_weight_separator);
33
34namespace fst {
35
36template<class W, unsigned int n> class TupleWeight;
37template <class W, unsigned int n>
38istream &operator>>(istream &strm, TupleWeight<W, n> &w);
39
40// n-tuple weight, element of the n-th catersian power of W
41template <class W, unsigned int n>
42class TupleWeight {
43 public:
44  typedef TupleWeight<typename W::ReverseWeight, n> ReverseWeight;
45
46  TupleWeight() {}
47
48  TupleWeight(const TupleWeight &w) {
49    for (size_t i = 0; i < n; ++i)
50      values_[i] = w.values_[i];
51  }
52
53  template <class Iterator>
54  TupleWeight(Iterator begin, Iterator end) {
55    for (Iterator iter = begin; iter != end; ++iter)
56      values_[iter - begin] = *iter;
57  }
58
59  TupleWeight(const W &w) {
60    for (size_t i = 0; i < n; ++i)
61      values_[i] = w;
62  }
63
64  static const TupleWeight<W, n> &Zero() {
65    static const TupleWeight<W, n> zero(W::Zero());
66    return zero;
67  }
68
69  static const TupleWeight<W, n> &One() {
70    static const TupleWeight<W, n> one(W::One());
71    return one;
72  }
73
74  static const TupleWeight<W, n> &NoWeight() {
75    static const TupleWeight<W, n> no_weight(W::NoWeight());
76    return no_weight;
77  }
78
79  static unsigned int Length() {
80    return n;
81  }
82
83  istream &Read(istream &strm) {
84    for (size_t i = 0; i < n; ++i)
85      values_[i].Read(strm);
86    return strm;
87  }
88
89  ostream &Write(ostream &strm) const {
90    for (size_t i = 0; i < n; ++i)
91      values_[i].Write(strm);
92    return strm;
93  }
94
95  TupleWeight<W, n> &operator=(const TupleWeight<W, n> &w) {
96    for (size_t i = 0; i < n; ++i)
97      values_[i] = w.values_[i];
98    return *this;
99  }
100
101  bool Member() const {
102    bool member = true;
103    for (size_t i = 0; i < n; ++i)
104      member = member && values_[i].Member();
105    return member;
106  }
107
108  size_t Hash() const {
109    uint64 hash = 0;
110    for (size_t i = 0; i < n; ++i)
111      hash = 5 * hash + values_[i].Hash();
112    return size_t(hash);
113  }
114
115  TupleWeight<W, n> Quantize(float delta = kDelta) const {
116    TupleWeight<W, n> w;
117    for (size_t i = 0; i < n; ++i)
118      w.values_[i] = values_[i].Quantize(delta);
119    return w;
120  }
121
122  ReverseWeight Reverse() const {
123    TupleWeight<W, n> w;
124    for (size_t i = 0; i < n; ++i)
125      w.values_[i] = values_[i].Reverse();
126    return w;
127  }
128
129  const W& Value(size_t i) const { return values_[i]; }
130
131  void SetValue(size_t i, const W &w) { values_[i] = w; }
132
133 protected:
134  // Reads TupleWeight when there are no parentheses around tuple terms
135  inline static istream &ReadNoParen(istream &strm,
136                                     TupleWeight<W, n> &w,
137                                     char separator) {
138    int c;
139    do {
140      c = strm.get();
141    } while (isspace(c));
142
143    for (size_t i = 0; i < n - 1; ++i) {
144      string s;
145      if (i)
146        c = strm.get();
147      while (c != separator) {
148        if (c == EOF) {
149          strm.clear(std::ios::badbit);
150          return strm;
151        }
152        s += c;
153        c = strm.get();
154      }
155      // read (i+1)-th element
156      istringstream sstrm(s);
157      W r = W::Zero();
158      sstrm >> r;
159      w.SetValue(i, r);
160    }
161
162    // read n-th element
163    W r = W::Zero();
164    strm >> r;
165    w.SetValue(n - 1, r);
166
167    return strm;
168  }
169
170  // Reads TupleWeight when there are parentheses around tuple terms
171  inline static istream &ReadWithParen(istream &strm,
172                                       TupleWeight<W, n> &w,
173                                       char separator,
174                                       char open_paren,
175                                       char close_paren) {
176    int c;
177    do {
178      c = strm.get();
179    } while (isspace(c));
180
181    if (c != open_paren) {
182      FSTERROR() << " is fst_weight_parentheses flag set correcty? ";
183      strm.clear(std::ios::badbit);
184      return strm;
185    }
186
187    for (size_t i = 0; i < n - 1; ++i) {
188      // read (i+1)-th element
189      stack<int> parens;
190      string s;
191      c = strm.get();
192      while (c != separator || !parens.empty()) {
193        if (c == EOF) {
194          strm.clear(std::ios::badbit);
195          return strm;
196        }
197        s += c;
198        // if parens encountered before separator, they must be matched
199        if (c == open_paren) {
200          parens.push(1);
201        } else if (c == close_paren) {
202          // Fail for mismatched parens
203          if (parens.empty()) {
204            strm.clear(std::ios::failbit);
205            return strm;
206          }
207          parens.pop();
208        }
209        c = strm.get();
210      }
211      istringstream sstrm(s);
212      W r = W::Zero();
213      sstrm >> r;
214      w.SetValue(i, r);
215    }
216
217    // read n-th element
218    string s;
219    c = strm.get();
220    while (c != EOF) {
221      s += c;
222      c = strm.get();
223    }
224    if (s.empty() || *s.rbegin() != close_paren) {
225      FSTERROR() << " is fst_weight_parentheses flag set correcty? ";
226      strm.clear(std::ios::failbit);
227      return strm;
228    }
229    s.erase(s.size() - 1, 1);
230    istringstream sstrm(s);
231    W r = W::Zero();
232    sstrm >> r;
233    w.SetValue(n - 1, r);
234
235    return strm;
236  }
237
238
239 private:
240  W values_[n];
241
242  friend istream &operator>><W, n>(istream&, TupleWeight<W, n>&);
243};
244
245template <class W, unsigned int n>
246inline bool operator==(const TupleWeight<W, n> &w1,
247                       const TupleWeight<W, n> &w2) {
248  bool equal = true;
249  for (size_t i = 0; i < n; ++i)
250    equal = equal && (w1.Value(i) == w2.Value(i));
251  return equal;
252}
253
254template <class W, unsigned int n>
255inline bool operator!=(const TupleWeight<W, n> &w1,
256                       const TupleWeight<W, n> &w2) {
257  bool not_equal = false;
258  for (size_t i = 0; (i < n) && !not_equal; ++i)
259    not_equal = not_equal || (w1.Value(i) != w2.Value(i));
260  return not_equal;
261}
262
263template <class W, unsigned int n>
264inline bool ApproxEqual(const TupleWeight<W, n> &w1,
265                        const TupleWeight<W, n> &w2,
266                        float delta = kDelta) {
267  bool approx_equal = true;
268  for (size_t i = 0; i < n; ++i)
269    approx_equal = approx_equal &&
270        ApproxEqual(w1.Value(i), w2.Value(i), delta);
271  return approx_equal;
272}
273
274template <class W, unsigned int n>
275inline ostream &operator<<(ostream &strm, const TupleWeight<W, n> &w) {
276  if(FLAGS_fst_weight_separator.size() != 1) {
277    FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1";
278    strm.clear(std::ios::badbit);
279    return strm;
280  }
281  char separator = FLAGS_fst_weight_separator[0];
282  bool write_parens = false;
283  if (!FLAGS_fst_weight_parentheses.empty()) {
284    if (FLAGS_fst_weight_parentheses.size() != 2) {
285      FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2";
286      strm.clear(std::ios::badbit);
287      return strm;
288    }
289    write_parens = true;
290  }
291
292  if (write_parens)
293    strm << FLAGS_fst_weight_parentheses[0];
294  for (size_t i  = 0; i < n; ++i) {
295    if(i)
296      strm << separator;
297    strm << w.Value(i);
298  }
299  if (write_parens)
300    strm << FLAGS_fst_weight_parentheses[1];
301
302  return strm;
303}
304
305template <class W, unsigned int n>
306inline istream &operator>>(istream &strm, TupleWeight<W, n> &w) {
307  if(FLAGS_fst_weight_separator.size() != 1) {
308    FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1";
309    strm.clear(std::ios::badbit);
310    return strm;
311  }
312  char separator = FLAGS_fst_weight_separator[0];
313
314  if (!FLAGS_fst_weight_parentheses.empty()) {
315    if (FLAGS_fst_weight_parentheses.size() != 2) {
316      FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2";
317      strm.clear(std::ios::badbit);
318      return strm;
319    }
320    return TupleWeight<W, n>::ReadWithParen(
321        strm, w, separator, FLAGS_fst_weight_parentheses[0],
322        FLAGS_fst_weight_parentheses[1]);
323  } else {
324    return TupleWeight<W, n>::ReadNoParen(strm, w, separator);
325  }
326}
327
328
329
330}  // namespace fst
331
332#endif  // FST_LIB_TUPLE_WEIGHT_H__
333