1// string-weight.h
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15//
16// \file
17// String weight set and associated semiring operation definitions.
18
19#ifndef FST_LIB_STRING_WEIGHT_H__
20#define FST_LIB_STRING_WEIGHT_H__
21
22#include <list>
23
24#include "fst/lib/product-weight.h"
25#include "fst/lib/weight.h"
26
27namespace fst {
28
29const int kStringInfinity = -1;      // Label for the infinite string
30const int kStringBad = -2;           // Label for a non-string
31const char kStringSeparator = '_';   // Label separator in strings
32
33// Determines whether to use left or right string semiring.  Includes
34// restricted versions that signal an error if proper prefixes
35// (suffixes) would otherwise be returned by Plus, useful with various
36// algorithms that require functional transducer input with the
37// string semirings.
38enum StringType { STRING_LEFT = 0, STRING_RIGHT = 1 ,
39                  STRING_LEFT_RESTRICT = 2, STRING_RIGHT_RESTRICT };
40
41#define REVERSE_STRING_TYPE(S)                                  \
42   ((S) == STRING_LEFT ? STRING_RIGHT :                         \
43    ((S) == STRING_RIGHT ? STRING_LEFT :                        \
44     ((S) == STRING_LEFT_RESTRICT ? STRING_RIGHT_RESTRICT :     \
45      STRING_LEFT_RESTRICT)))
46
47template <typename L, StringType S = STRING_LEFT>
48class StringWeight;
49
50template <typename L, StringType S = STRING_LEFT>
51class StringWeightIterator;
52
53template <typename L, StringType S = STRING_LEFT>
54class StringWeightReverseIterator;
55
56template <typename L, StringType S>
57bool operator==(const StringWeight<L, S> &,  const StringWeight<L, S> &);
58
59
60// String semiring: (longest_common_prefix/suffix, ., Infinity, Epsilon)
61template <typename L, StringType S>
62class StringWeight {
63 public:
64  typedef L Label;
65  typedef StringWeight<L, REVERSE_STRING_TYPE(S)> ReverseWeight;
66
67  friend class StringWeightIterator<L, S>;
68  friend class StringWeightReverseIterator<L, S>;
69  friend bool operator==<>(const StringWeight<L, S> &,
70                           const StringWeight<L, S> &);
71
72  StringWeight() { Init(); }
73
74  template <typename Iter>
75  StringWeight(const Iter &begin, const Iter &end) {
76    Init();
77    for (Iter iter = begin; iter != end; ++iter)
78      PushBack(*iter);
79  }
80
81  explicit StringWeight(L l) { Init(); PushBack(l); }
82
83  static const StringWeight<L, S> &Zero() {
84    static const StringWeight<L, S> zero(kStringInfinity);
85    return zero;
86  }
87
88  static const StringWeight<L, S> &One() {
89    static const StringWeight<L, S> one;
90    return one;
91  }
92
93  static const string &Type() {
94    static const string type =
95        S == STRING_LEFT ? "string" :
96        (S == STRING_RIGHT ? "right_string" :
97         (S == STRING_LEFT_RESTRICT ? "restricted_string" :
98          "right_restricted_string"));
99    return type;
100  }
101
102  bool Member() const;
103
104  istream &Read(istream &strm);
105
106  ostream &Write(ostream &strm) const;
107
108  ssize_t Hash() const;
109
110  StringWeight<L, S> Quantize(float delta = kDelta) const {
111    return *this;
112  }
113
114  ReverseWeight Reverse() const;
115
116  static uint64 Properties() {
117    return (S == STRING_LEFT || S == STRING_LEFT_RESTRICT ?
118            kLeftSemiring : kRightSemiring) | kIdempotent;
119  }
120
121  // NB: This needs to be uncommented only if default fails for this impl.
122  // StringWeight<L, S> &operator=(const StringWeight<L, S> &w);
123
124  // These operations combined with the StringWeightIterator and
125  // StringWeightReverseIterator provide the access and mutation of
126  // the string internal elements.
127
128  // Common initializer among constructors.
129  void Init() { first_ = 0; }
130
131  // Clear existing StringWeight.
132  void Clear() { first_ = 0; rest_.clear(); }
133
134  Label Size() const { return first_ ? rest_.size() + 1 : 0; }
135
136  void PushFront(L l) {
137    if (first_)
138      rest_.push_front(first_);
139    first_ = l;
140  }
141
142  void PushBack(L l) {
143    if (!first_)
144      first_ = l;
145    else
146      rest_.push_back(l);
147  }
148
149 private:
150  L first_;         // first label in string (0 if empty)
151  list<L> rest_;    // remaining labels in string
152};
153
154
155// Traverses string in forward direction.
156template <typename L, StringType S>
157class StringWeightIterator {
158 public:
159  explicit StringWeightIterator(const StringWeight<L, S>& w)
160      : first_(w.first_), rest_(w.rest_), init_(true),
161        iter_(rest_.begin()) {}
162
163  bool Done() const {
164    if (init_) return first_ == 0;
165    else return iter_ == rest_.end();
166  }
167
168  const L& Value() const { return init_ ? first_ : *iter_; }
169
170  void Next() {
171    if (init_) init_ = false;
172    else  ++iter_;
173  }
174
175  void Reset() {
176    init_ = true;
177    iter_ = rest_.begin();
178  }
179
180 private:
181  const L &first_;
182  const list<L> &rest_;
183  bool init_;   // in the initialized state?
184  typename list<L>::const_iterator iter_;
185
186  DISALLOW_EVIL_CONSTRUCTORS(StringWeightIterator);
187};
188
189
190// Traverses string in forward direction.
191template <typename L, StringType S>
192class StringWeightReverseIterator {
193 public:
194  explicit StringWeightReverseIterator(const StringWeight<L, S>& w)
195      : first_(w.first_), rest_(w.rest_), fin_(first_ == 0),
196        iter_(rest_.rbegin()) {}
197
198  bool Done() const { return fin_; }
199
200  const L& Value() const { return iter_ == rest_.rend() ? first_ : *iter_; }
201
202  void Next() {
203    if (iter_ == rest_.rend()) fin_ = true;
204    else  ++iter_;
205  }
206
207  void Reset() {
208    fin_ = false;
209    iter_ = rest_.rbegin();
210  }
211
212 private:
213  const L &first_;
214  const list<L> &rest_;
215  bool fin_;   // in the final state?
216  typename list<L>::const_reverse_iterator iter_;
217
218  DISALLOW_EVIL_CONSTRUCTORS(StringWeightReverseIterator);
219};
220
221
222// StringWeight member functions follow that require
223// StringWeightIterator or StringWeightReverseIterator.
224
225template <typename L, StringType S>
226inline istream &StringWeight<L, S>::Read(istream &strm) {
227  Clear();
228  int32 size = 0;
229  ReadType(strm, &size);
230  for (int i = 0; i < size; ++i) {
231    L label;
232    ReadType(strm, &label);
233    PushBack(label);
234  }
235  return strm;
236}
237
238template <typename L, StringType S>
239inline ostream &StringWeight<L, S>::Write(ostream &strm) const {
240  int32 size =  Size();
241  WriteType(strm, size);
242  for (StringWeightIterator<L, S> iter(*this); !iter.Done(); iter.Next()) {
243    L label = iter.Value();
244    WriteType(strm, label);
245  }
246  return strm;
247}
248
249template <typename L, StringType S>
250inline bool StringWeight<L, S>::Member() const {
251  if (Size() != 1)
252    return true;
253  StringWeightIterator<L, S> iter(*this);
254  return iter.Value() != kStringBad;
255}
256
257template <typename L, StringType S>
258inline typename StringWeight<L, S>::ReverseWeight
259StringWeight<L, S>::Reverse() const {
260  ReverseWeight rw;
261  for (StringWeightIterator<L, S> iter(*this); !iter.Done(); iter.Next())
262    rw.PushFront(iter.Value());
263  return rw;
264}
265
266template <typename L, StringType S>
267inline ssize_t StringWeight<L, S>::Hash() const {
268  size_t h = 0;
269  for (StringWeightIterator<L, S> iter(*this); !iter.Done(); iter.Next())
270    h ^= h<<1 ^ iter.Value();
271  return static_cast<ssize_t>(h);
272}
273
274// NB: This needs to be uncommented only if default fails for this the impl.
275//
276// template <typename L, StringType S>
277// inline StringWeight<L, S>
278// &StringWeight<L, S>::operator=(const StringWeight<L, S> &w) {
279//   if (this != &w) {
280//     Clear();
281//     for (StringWeightIterator<L, S> iter(w); !iter.Done(); iter.Next())
282//       PushBack(iter.Value());
283//   }
284//   return *this;
285// }
286
287template <typename L, StringType S>
288inline bool operator==(const StringWeight<L, S> &w1,
289                       const StringWeight<L, S> &w2) {
290  if (w1.Size() != w2.Size())
291    return false;
292
293  StringWeightIterator<L, S> iter1(w1);
294  StringWeightIterator<L, S> iter2(w2);
295
296  for (; !iter1.Done() ; iter1.Next(), iter2.Next())
297    if (iter1.Value() != iter2.Value())
298      return false;
299
300  return true;
301}
302
303template <typename L, StringType S>
304inline bool operator!=(const StringWeight<L, S> &w1,
305                       const StringWeight<L, S> &w2) {
306  return !(w1 == w2);
307}
308
309template <typename L, StringType S>
310inline bool ApproxEqual(const StringWeight<L, S> &w1,
311                        const StringWeight<L, S> &w2,
312                        float delta = kDelta) {
313  return w1 == w2;
314}
315
316template <typename L, StringType S>
317inline ostream &operator<<(ostream &strm, const StringWeight<L, S> &w) {
318  StringWeightIterator<L, S> iter(w);
319  if (iter.Done())
320    return strm << "Epsilon";
321  else if (iter.Value() == kStringInfinity)
322    return strm << "Infinity";
323  else if (iter.Value() == kStringBad)
324    return strm << "BadString";
325  else
326    for (size_t i = 0; !iter.Done(); ++i, iter.Next()) {
327      if (i > 0)
328        strm << kStringSeparator;
329      strm << iter.Value();
330    }
331  return strm;
332}
333
334template <typename L, StringType S>
335inline istream &operator>>(istream &strm, StringWeight<L, S> &w) {
336  string s;
337  strm >> s;
338  if (s == "Infinity") {
339    w = StringWeight<L, S>::Zero();
340  } else if (s == "Epsilon") {
341    w = StringWeight<L, S>::One();
342  } else {
343    w.Clear();
344    char *p = 0;
345    for (const char *cs = s.c_str(); !p || *p != '\0'; cs = p + 1) {
346      int l = strtoll(cs, &p, 10);
347      if (p == cs || *p != 0 && *p != kStringSeparator) {
348        strm.clear(std::ios::badbit);
349        break;
350      }
351      w.PushBack(l);
352    }
353  }
354  return strm;
355}
356
357
358// Default is for the restricted left and right semirings.  String
359// equality is required (for non-Zero() input. This restriction
360// is used in e.g. Determinize to ensure functional input.
361template <typename L, StringType S>  inline StringWeight<L, S>
362Plus(const StringWeight<L, S> &w1,
363     const StringWeight<L, S> &w2) {
364  if (w1 == StringWeight<L, S>::Zero())
365    return w2;
366  if (w2 == StringWeight<L, S>::Zero())
367    return w1;
368
369  if (w1 != w2)
370    LOG(FATAL) << "StringWeight::Plus: unequal arguments "
371               << "(non-functional FST?)";
372
373  return w1;
374}
375
376
377// Longest common prefix for left string semiring.
378template <typename L>  inline StringWeight<L, STRING_LEFT>
379Plus(const StringWeight<L, STRING_LEFT> &w1,
380     const StringWeight<L, STRING_LEFT> &w2) {
381  if (w1 == StringWeight<L, STRING_LEFT>::Zero())
382    return w2;
383  if (w2 == StringWeight<L, STRING_LEFT>::Zero())
384    return w1;
385
386  StringWeight<L, STRING_LEFT> sum;
387  StringWeightIterator<L, STRING_LEFT> iter1(w1);
388  StringWeightIterator<L, STRING_LEFT> iter2(w2);
389  for (; !iter1.Done() && !iter2.Done() && iter1.Value() == iter2.Value();
390       iter1.Next(), iter2.Next())
391    sum.PushBack(iter1.Value());
392  return sum;
393}
394
395
396// Longest common suffix for right string semiring.
397template <typename L>  inline StringWeight<L, STRING_RIGHT>
398Plus(const StringWeight<L, STRING_RIGHT> &w1,
399     const StringWeight<L, STRING_RIGHT> &w2) {
400  if (w1 == StringWeight<L, STRING_RIGHT>::Zero())
401    return w2;
402  if (w2 == StringWeight<L, STRING_RIGHT>::Zero())
403    return w1;
404
405  StringWeight<L, STRING_RIGHT> sum;
406  StringWeightReverseIterator<L, STRING_RIGHT> iter1(w1);
407  StringWeightReverseIterator<L, STRING_RIGHT> iter2(w2);
408  for (; !iter1.Done() && !iter2.Done() && iter1.Value() == iter2.Value();
409       iter1.Next(), iter2.Next())
410    sum.PushFront(iter1.Value());
411  return sum;
412}
413
414
415template <typename L, StringType S>
416inline StringWeight<L, S> Times(const StringWeight<L, S> &w1,
417                             const StringWeight<L, S> &w2) {
418  if (w1 == StringWeight<L, S>::Zero() || w2 == StringWeight<L, S>::Zero())
419    return StringWeight<L, S>::Zero();
420
421  StringWeight<L, S> prod(w1);
422  for (StringWeightIterator<L, S> iter(w2); !iter.Done(); iter.Next())
423    prod.PushBack(iter.Value());
424
425  return prod;
426}
427
428
429// Default is for left division in the left string and the
430// left restricted string semirings.
431template <typename L, StringType S> inline StringWeight<L, S>
432Divide(const StringWeight<L, S> &w1,
433       const StringWeight<L, S> &w2,
434       DivideType typ) {
435
436  if (typ != DIVIDE_LEFT)
437    LOG(FATAL) << "StringWeight::Divide: only left division is defined "
438               << "for the " << StringWeight<L, S>::Type() << " semiring";
439
440  if (w2 == StringWeight<L, S>::Zero())
441    return StringWeight<L, S>(kStringBad);
442  else if (w1 == StringWeight<L, S>::Zero())
443    return StringWeight<L, S>::Zero();
444
445  StringWeight<L, S> div;
446  StringWeightIterator<L, S> iter(w1);
447  for (int i = 0; !iter.Done(); iter.Next(), ++i) {
448    if (i >= w2.Size())
449      div.PushBack(iter.Value());
450  }
451  return div;
452}
453
454
455// Right division in the right string semiring.
456template <typename L> inline StringWeight<L, STRING_RIGHT>
457Divide(const StringWeight<L, STRING_RIGHT> &w1,
458       const StringWeight<L, STRING_RIGHT> &w2,
459       DivideType typ) {
460
461  if (typ != DIVIDE_RIGHT)
462    LOG(FATAL) << "StringWeight::Divide: only right division is defined "
463               << "for the right string semiring";
464
465  if (w2 == StringWeight<L, STRING_RIGHT>::Zero())
466    return StringWeight<L, STRING_RIGHT>(kStringBad);
467  else if (w1 == StringWeight<L, STRING_RIGHT>::Zero())
468    return StringWeight<L, STRING_RIGHT>::Zero();
469
470  StringWeight<L, STRING_RIGHT> div;
471  StringWeightReverseIterator<L, STRING_RIGHT> iter(w1);
472  for (int i = 0; !iter.Done(); iter.Next(), ++i) {
473    if (i >= w2.Size())
474      div.PushFront(iter.Value());
475  }
476  return div;
477}
478
479
480// Right division in the right restricted string semiring.
481template <typename L> inline StringWeight<L, STRING_RIGHT_RESTRICT>
482Divide(const StringWeight<L, STRING_RIGHT_RESTRICT> &w1,
483       const StringWeight<L, STRING_RIGHT_RESTRICT> &w2,
484       DivideType typ) {
485
486  if (typ != DIVIDE_RIGHT)
487    LOG(FATAL) << "StringWeight::Divide: only right division is defined "
488               << "for the right restricted string semiring";
489
490  if (w2 == StringWeight<L, STRING_RIGHT_RESTRICT>::Zero())
491    return StringWeight<L, STRING_RIGHT_RESTRICT>(kStringBad);
492  else if (w1 == StringWeight<L, STRING_RIGHT_RESTRICT>::Zero())
493    return StringWeight<L, STRING_RIGHT_RESTRICT>::Zero();
494
495  StringWeight<L, STRING_RIGHT_RESTRICT> div;
496  StringWeightReverseIterator<L, STRING_RIGHT_RESTRICT> iter(w1);
497  for (int i = 0; !iter.Done(); iter.Next(), ++i) {
498    if (i >= w2.Size())
499      div.PushFront(iter.Value());
500  }
501  return div;
502}
503
504
505// Product of string weight and an arbitray weight.
506template <class L, class W, StringType S = STRING_LEFT>
507struct GallicWeight : public ProductWeight<StringWeight<L, S>, W> {
508  typedef GallicWeight<L, typename W::ReverseWeight, REVERSE_STRING_TYPE(S)>
509  ReverseWeight;
510
511  GallicWeight() {}
512
513  GallicWeight(StringWeight<L, S> w1, W w2)
514      : ProductWeight<StringWeight<L, S>, W>(w1, w2) {}
515
516  explicit GallicWeight(const string &s, int *nread = 0)
517      : ProductWeight<StringWeight<L, S>, W>(s, nread) {}
518
519  GallicWeight(const ProductWeight<StringWeight<L, S>, W> &w)
520      : ProductWeight<StringWeight<L, S>, W>(w) {}
521};
522
523}  // namespace fst;
524
525#endif  // FST_LIB_STRING_WEIGHT_H__
526