1// product-weight.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// Product weight set and associated semiring operation definitions.
18
19#ifndef FST_LIB_PRODUCT_WEIGHT_H__
20#define FST_LIB_PRODUCT_WEIGHT_H__
21
22#include "fst/lib/weight.h"
23
24DECLARE_string(fst_product_separator);
25
26namespace fst {
27
28// Product semiring: W1 * W2
29template<class W1, class W2>
30class ProductWeight {
31 public:
32  typedef ProductWeight<typename W1::ReverseWeight, typename W2::ReverseWeight>
33  ReverseWeight;
34
35  ProductWeight() {}
36
37  ProductWeight(W1 w1, W2 w2) : value1_(w1), value2_(w2) {}
38
39  static const ProductWeight<W1, W2> &Zero() {
40    static const ProductWeight<W1, W2> zero(W1::Zero(), W2::Zero());
41    return zero;
42  }
43
44  static const ProductWeight<W1, W2> &One() {
45    static const ProductWeight<W1, W2> one(W1::One(), W2::One());
46    return one;
47  }
48
49  static const string &Type() {
50    static const string type = W1::Type() + "_X_" + W2::Type();
51    return type;
52  }
53
54  istream &Read(istream &strm) {
55    value1_.Read(strm);
56    return value2_.Read(strm);
57  }
58
59  ostream &Write(ostream &strm) const {
60    value1_.Write(strm);
61    return value2_.Write(strm);
62  }
63
64  ProductWeight<W1, W2> &operator=(const ProductWeight<W1, W2> &w) {
65    value1_ = w.Value1();
66    value2_ = w.Value2();
67    return *this;
68  }
69
70  bool Member() const { return value1_.Member() && value2_.Member(); }
71
72  ssize_t Hash() const {
73    ssize_t h1 = value1_.Hash();
74    ssize_t h2 = value2_.Hash();
75    int lshift = 5;
76    int rshift = sizeof(ssize_t) - 5;
77    return h1 << lshift ^ h1 >> rshift ^ h2;
78  }
79
80  ProductWeight<W1, W2> Quantize(float delta = kDelta) const {
81    return ProductWeight<W1, W2>(value1_.Quantize(), value2_.Quantize());
82  }
83
84  ReverseWeight Reverse() const {
85    return ReverseWeight(value1_.Reverse(), value2_.Reverse());
86  }
87
88  static uint64 Properties() {
89    uint64 props1 = W1::Properties();
90    uint64 props2 = W2::Properties();
91    return props1 & props2 & (kLeftSemiring | kRightSemiring |
92                              kCommutative | kIdempotent);
93  }
94
95  W1 Value1() const { return value1_; }
96
97  W2 Value2() const { return value2_; }
98
99 private:
100  W1 value1_;
101  W2 value2_;
102};
103
104template <class W1, class W2>
105inline bool operator==(const ProductWeight<W1, W2> &w,
106                       const ProductWeight<W1, W2> &v) {
107  return w.Value1() == v.Value1() && w.Value2() == v.Value2();
108}
109
110template <class W1, class W2>
111inline bool operator!=(const ProductWeight<W1, W2> &w1,
112                       const ProductWeight<W1, W2> &w2) {
113  return w1.Value1() != w2.Value1() || w1.Value2() != w2.Value2();
114}
115
116
117template <class W1, class W2>
118inline bool ApproxEqual(const ProductWeight<W1, W2> &w1,
119                        const ProductWeight<W1, W2> &w2,
120                        float delta = kDelta) {
121  return w1 == w2;
122}
123
124template <class W1, class W2>
125inline ostream &operator<<(ostream &strm, const ProductWeight<W1, W2> &w) {
126  CHECK(FLAGS_fst_product_separator.size() == 1);
127  char separator = FLAGS_fst_product_separator[0];
128  return strm << w.Value1() << separator << w.Value2();
129}
130
131template <class W1, class W2>
132inline istream &operator>>(istream &strm, ProductWeight<W1, W2> &w) {
133  CHECK(FLAGS_fst_product_separator.size() == 1);
134  char separator = FLAGS_fst_product_separator[0];
135  int c;
136
137  // read any initial whitespapce
138  while (true) {
139    c = strm.get();
140    if (c == EOF || c == separator) {
141      strm.clear(std::ios::badbit);
142      return strm;
143    }
144    if (!isspace(c))
145      break;
146  }
147
148  // read first element
149  string s1;
150  do {
151    s1 += c;
152    c = strm.get();
153    if (c == EOF || isspace(c)) {
154      strm.clear(std::ios::badbit);
155      return strm;
156    }
157  } while (c != separator);
158  istringstream strm1(s1);
159  W1 w1 = W1::Zero();
160  strm1 >> w1;
161
162  // read second element
163  W2 w2 = W2::Zero();
164  strm >> w2;
165
166  w = ProductWeight<W1, W2>(w1, w2);
167  return strm;
168}
169
170template <class W1, class W2>
171inline ProductWeight<W1, W2> Plus(const ProductWeight<W1, W2> &w,
172                                  const ProductWeight<W1, W2> &v) {
173  return ProductWeight<W1, W2>(Plus(w.Value1(), v.Value1()),
174                               Plus(w.Value2(), v.Value2()));
175}
176
177template <class W1, class W2>
178inline ProductWeight<W1, W2> Times(const ProductWeight<W1, W2> &w,
179                                   const ProductWeight<W1, W2> &v) {
180  return ProductWeight<W1, W2>(Times(w.Value1(), v.Value1()),
181                               Times(w.Value2(), v.Value2()));
182}
183
184template <class W1, class W2>
185inline ProductWeight<W1, W2> Divide(const ProductWeight<W1, W2> &w,
186                                    const ProductWeight<W1, W2> &v,
187                                    DivideType typ = DIVIDE_ANY) {
188  return ProductWeight<W1, W2>(Divide(w.Value1(), v.Value1(), typ),
189                               Divide(w.Value2(), v.Value2(), typ));
190}
191
192}  // namespace fst;
193
194#endif  // FST_LIB_PRODUCT_WEIGHT_H__
195