1// string-weight.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17//
18// \file
19// String weight set and associated semiring operation definitions.
20
21#ifndef FST_LIB_STRING_WEIGHT_H__
22#define FST_LIB_STRING_WEIGHT_H__
23
24#include <list>
25#include <string>
26
27#include <fst/product-weight.h>
28#include <fst/weight.h>
29
30namespace fst {
31
32const int kStringInfinity = -1;      // Label for the infinite string
33const int kStringBad = -2;           // Label for a non-string
34const char kStringSeparator = '_';   // Label separator in strings
35
36// Determines whether to use left or right string semiring.  Includes
37// restricted versions that signal an error if proper prefixes
38// (suffixes) would otherwise be returned by Plus, useful with various
39// algorithms that require functional transducer input with the
40// string semirings.
41enum StringType { STRING_LEFT = 0, STRING_RIGHT = 1 ,
42                  STRING_LEFT_RESTRICT = 2, STRING_RIGHT_RESTRICT };
43
44#define REVERSE_STRING_TYPE(S)                                  \
45   ((S) == STRING_LEFT ? STRING_RIGHT :                         \
46    ((S) == STRING_RIGHT ? STRING_LEFT :                        \
47     ((S) == STRING_LEFT_RESTRICT ? STRING_RIGHT_RESTRICT :     \
48      STRING_LEFT_RESTRICT)))
49
50template <typename L, StringType S = STRING_LEFT>
51class StringWeight;
52
53template <typename L, StringType S = STRING_LEFT>
54class StringWeightIterator;
55
56template <typename L, StringType S = STRING_LEFT>
57class StringWeightReverseIterator;
58
59template <typename L, StringType S>
60bool operator==(const StringWeight<L, S> &,  const StringWeight<L, S> &);
61
62
63// String semiring: (longest_common_prefix/suffix, ., Infinity, Epsilon)
64template <typename L, StringType S>
65class StringWeight {
66 public:
67  typedef L Label;
68  typedef StringWeight<L, REVERSE_STRING_TYPE(S)> ReverseWeight;
69
70  friend class StringWeightIterator<L, S>;
71  friend class StringWeightReverseIterator<L, S>;
72  friend bool operator==<>(const StringWeight<L, S> &,
73                           const StringWeight<L, S> &);
74
75  StringWeight() { Init(); }
76
77  template <typename Iter>
78  StringWeight(const Iter &begin, const Iter &end) {
79    Init();
80    for (Iter iter = begin; iter != end; ++iter)
81      PushBack(*iter);
82  }
83
84  explicit StringWeight(L l) { Init(); PushBack(l); }
85
86  static const StringWeight<L, S> &Zero() {
87    static const StringWeight<L, S> zero(kStringInfinity);
88    return zero;
89  }
90
91  static const StringWeight<L, S> &One() {
92    static const StringWeight<L, S> one;
93    return one;
94  }
95
96  static const StringWeight<L, S> &NoWeight() {
97    static const StringWeight<L, S> no_weight(kStringBad);
98    return no_weight;
99  }
100
101  static const string &Type() {
102    static const string type =
103        S == STRING_LEFT ? "string" :
104        (S == STRING_RIGHT ? "right_string" :
105         (S == STRING_LEFT_RESTRICT ? "restricted_string" :
106          "right_restricted_string"));
107    return type;
108  }
109
110  bool Member() const;
111
112  istream &Read(istream &strm);
113
114  ostream &Write(ostream &strm) const;
115
116  size_t Hash() const;
117
118  StringWeight<L, S> Quantize(float delta = kDelta) const {
119    return *this;
120  }
121
122  ReverseWeight Reverse() const;
123
124  static uint64 Properties() {
125    return (S == STRING_LEFT || S == STRING_LEFT_RESTRICT ?
126            kLeftSemiring : kRightSemiring) | kIdempotent;
127  }
128
129  // NB: This needs to be uncommented only if default fails for this impl.
130  // StringWeight<L, S> &operator=(const StringWeight<L, S> &w);
131
132  // These operations combined with the StringWeightIterator and
133  // StringWeightReverseIterator provide the access and mutation of
134  // the string internal elements.
135
136  // Common initializer among constructors.
137  void Init() { first_ = 0; }
138
139  // Clear existing StringWeight.
140  void Clear() { first_ = 0; rest_.clear(); }
141
142  size_t Size() const { return first_ ? rest_.size() + 1 : 0; }
143
144  void PushFront(L l) {
145    if (first_)
146      rest_.push_front(first_);
147    first_ = l;
148  }
149
150  void PushBack(L l) {
151    if (!first_)
152      first_ = l;
153    else
154      rest_.push_back(l);
155  }
156
157 private:
158  L first_;         // first label in string (0 if empty)
159  list<L> rest_;    // remaining labels in string
160};
161
162
163// Traverses string in forward direction.
164template <typename L, StringType S>
165class StringWeightIterator {
166 public:
167  explicit StringWeightIterator(const StringWeight<L, S>& w)
168      : first_(w.first_), rest_(w.rest_), init_(true),
169        iter_(rest_.begin()) {}
170
171  bool Done() const {
172    if (init_) return first_ == 0;
173    else return iter_ == rest_.end();
174  }
175
176  const L& Value() const { return init_ ? first_ : *iter_; }
177
178  void Next() {
179    if (init_) init_ = false;
180    else  ++iter_;
181  }
182
183  void Reset() {
184    init_ = true;
185    iter_ = rest_.begin();
186  }
187
188 private:
189  const L &first_;
190  const list<L> &rest_;
191  bool init_;   // in the initialized state?
192  typename list<L>::const_iterator iter_;
193
194  DISALLOW_COPY_AND_ASSIGN(StringWeightIterator);
195};
196
197
198// Traverses string in backward direction.
199template <typename L, StringType S>
200class StringWeightReverseIterator {
201 public:
202  explicit StringWeightReverseIterator(const StringWeight<L, S>& w)
203      : first_(w.first_), rest_(w.rest_), fin_(first_ == 0),
204        iter_(rest_.rbegin()) {}
205
206  bool Done() const { return fin_; }
207
208  const L& Value() const { return iter_ == rest_.rend() ? first_ : *iter_; }
209
210  void Next() {
211    if (iter_ == rest_.rend()) fin_ = true;
212    else  ++iter_;
213  }
214
215  void Reset() {
216    fin_ = false;
217    iter_ = rest_.rbegin();
218  }
219
220 private:
221  const L &first_;
222  const list<L> &rest_;
223  bool fin_;   // in the final state?
224  typename list<L>::const_reverse_iterator iter_;
225
226  DISALLOW_COPY_AND_ASSIGN(StringWeightReverseIterator);
227};
228
229
230// StringWeight member functions follow that require
231// StringWeightIterator or StringWeightReverseIterator.
232
233template <typename L, StringType S>
234inline istream &StringWeight<L, S>::Read(istream &strm) {
235  Clear();
236  int32 size;
237  ReadType(strm, &size);
238  for (int i = 0; i < size; ++i) {
239    L label;
240    ReadType(strm, &label);
241    PushBack(label);
242  }
243  return strm;
244}
245
246template <typename L, StringType S>
247inline ostream &StringWeight<L, S>::Write(ostream &strm) const {
248  int32 size =  Size();
249  WriteType(strm, size);
250  for (StringWeightIterator<L, S> iter(*this); !iter.Done(); iter.Next()) {
251    L label = iter.Value();
252    WriteType(strm, label);
253  }
254  return strm;
255}
256
257template <typename L, StringType S>
258inline bool StringWeight<L, S>::Member() const {
259  if (Size() != 1)
260    return true;
261  StringWeightIterator<L, S> iter(*this);
262  return iter.Value() != kStringBad;
263}
264
265template <typename L, StringType S>
266inline typename StringWeight<L, S>::ReverseWeight
267StringWeight<L, S>::Reverse() const {
268  ReverseWeight rw;
269  for (StringWeightIterator<L, S> iter(*this); !iter.Done(); iter.Next())
270    rw.PushFront(iter.Value());
271  return rw;
272}
273
274template <typename L, StringType S>
275inline size_t StringWeight<L, S>::Hash() const {
276  size_t h = 0;
277  for (StringWeightIterator<L, S> iter(*this); !iter.Done(); iter.Next())
278    h ^= h<<1 ^ iter.Value();
279  return h;
280}
281
282// NB: This needs to be uncommented only if default fails for this the impl.
283//
284// template <typename L, StringType S>
285// inline StringWeight<L, S>
286// &StringWeight<L, S>::operator=(const StringWeight<L, S> &w) {
287//   if (this != &w) {
288//     Clear();
289//     for (StringWeightIterator<L, S> iter(w); !iter.Done(); iter.Next())
290//       PushBack(iter.Value());
291//   }
292//   return *this;
293// }
294
295template <typename L, StringType S>
296inline bool operator==(const StringWeight<L, S> &w1,
297                       const StringWeight<L, S> &w2) {
298  if (w1.Size() != w2.Size())
299    return false;
300
301  StringWeightIterator<L, S> iter1(w1);
302  StringWeightIterator<L, S> iter2(w2);
303
304  for (; !iter1.Done() ; iter1.Next(), iter2.Next())
305    if (iter1.Value() != iter2.Value())
306      return false;
307
308  return true;
309}
310
311template <typename L, StringType S>
312inline bool operator!=(const StringWeight<L, S> &w1,
313                       const StringWeight<L, S> &w2) {
314  return !(w1 == w2);
315}
316
317template <typename L, StringType S>
318inline bool ApproxEqual(const StringWeight<L, S> &w1,
319                        const StringWeight<L, S> &w2,
320                        float delta = kDelta) {
321  return w1 == w2;
322}
323
324template <typename L, StringType S>
325inline ostream &operator<<(ostream &strm, const StringWeight<L, S> &w) {
326  StringWeightIterator<L, S> iter(w);
327  if (iter.Done())
328    return strm << "Epsilon";
329  else if (iter.Value() == kStringInfinity)
330    return strm << "Infinity";
331  else if (iter.Value() == kStringBad)
332    return strm << "BadString";
333  else
334    for (size_t i = 0; !iter.Done(); ++i, iter.Next()) {
335      if (i > 0)
336        strm << kStringSeparator;
337      strm << iter.Value();
338    }
339  return strm;
340}
341
342template <typename L, StringType S>
343inline istream &operator>>(istream &strm, StringWeight<L, S> &w) {
344  string s;
345  strm >> s;
346  if (s == "Infinity") {
347    w = StringWeight<L, S>::Zero();
348  } else if (s == "Epsilon") {
349    w = StringWeight<L, S>::One();
350  } else {
351    w.Clear();
352    char *p = 0;
353    for (const char *cs = s.c_str(); !p || *p != '\0'; cs = p + 1) {
354      int l = strtoll(cs, &p, 10);
355      if (p == cs || (*p != 0 && *p != kStringSeparator)) {
356        strm.clear(std::ios::badbit);
357        break;
358      }
359      w.PushBack(l);
360    }
361  }
362  return strm;
363}
364
365
366// Default is for the restricted left and right semirings.  String
367// equality is required (for non-Zero() input. This restriction
368// is used in e.g. Determinize to ensure functional input.
369template <typename L, StringType S>  inline StringWeight<L, S>
370Plus(const StringWeight<L, S> &w1,
371     const StringWeight<L, S> &w2) {
372  if (!w1.Member() || !w2.Member())
373    return StringWeight<L, S>::NoWeight();
374  if (w1 == StringWeight<L, S>::Zero())
375    return w2;
376  if (w2 == StringWeight<L, S>::Zero())
377    return w1;
378
379  if (w1 != w2) {
380    FSTERROR() << "StringWeight::Plus: unequal arguments "
381               << "(non-functional FST?)"
382               << " w1 = " << w1
383               << " w2 = " << w2;
384    return StringWeight<L, S>::NoWeight();
385  }
386
387  return w1;
388}
389
390
391// Longest common prefix for left string semiring.
392template <typename L>  inline StringWeight<L, STRING_LEFT>
393Plus(const StringWeight<L, STRING_LEFT> &w1,
394     const StringWeight<L, STRING_LEFT> &w2) {
395  if (!w1.Member() || !w2.Member())
396    return StringWeight<L, STRING_LEFT>::NoWeight();
397  if (w1 == StringWeight<L, STRING_LEFT>::Zero())
398    return w2;
399  if (w2 == StringWeight<L, STRING_LEFT>::Zero())
400    return w1;
401
402  StringWeight<L, STRING_LEFT> sum;
403  StringWeightIterator<L, STRING_LEFT> iter1(w1);
404  StringWeightIterator<L, STRING_LEFT> iter2(w2);
405  for (; !iter1.Done() && !iter2.Done() && iter1.Value() == iter2.Value();
406       iter1.Next(), iter2.Next())
407    sum.PushBack(iter1.Value());
408  return sum;
409}
410
411
412// Longest common suffix for right string semiring.
413template <typename L>  inline StringWeight<L, STRING_RIGHT>
414Plus(const StringWeight<L, STRING_RIGHT> &w1,
415     const StringWeight<L, STRING_RIGHT> &w2) {
416  if (!w1.Member() || !w2.Member())
417    return StringWeight<L, STRING_RIGHT>::NoWeight();
418  if (w1 == StringWeight<L, STRING_RIGHT>::Zero())
419    return w2;
420  if (w2 == StringWeight<L, STRING_RIGHT>::Zero())
421    return w1;
422
423  StringWeight<L, STRING_RIGHT> sum;
424  StringWeightReverseIterator<L, STRING_RIGHT> iter1(w1);
425  StringWeightReverseIterator<L, STRING_RIGHT> iter2(w2);
426  for (; !iter1.Done() && !iter2.Done() && iter1.Value() == iter2.Value();
427       iter1.Next(), iter2.Next())
428    sum.PushFront(iter1.Value());
429  return sum;
430}
431
432
433template <typename L, StringType S>
434inline StringWeight<L, S> Times(const StringWeight<L, S> &w1,
435                             const StringWeight<L, S> &w2) {
436  if (!w1.Member() || !w2.Member())
437    return StringWeight<L, S>::NoWeight();
438  if (w1 == StringWeight<L, S>::Zero() || w2 == StringWeight<L, S>::Zero())
439    return StringWeight<L, S>::Zero();
440
441  StringWeight<L, S> prod(w1);
442  for (StringWeightIterator<L, S> iter(w2); !iter.Done(); iter.Next())
443    prod.PushBack(iter.Value());
444
445  return prod;
446}
447
448
449// Default is for left division in the left string and the
450// left restricted string semirings.
451template <typename L, StringType S> inline StringWeight<L, S>
452Divide(const StringWeight<L, S> &w1,
453       const StringWeight<L, S> &w2,
454       DivideType typ) {
455
456  if (typ != DIVIDE_LEFT) {
457    FSTERROR() << "StringWeight::Divide: only left division is defined "
458               << "for the " << StringWeight<L, S>::Type() << " semiring";
459    return StringWeight<L, S>::NoWeight();
460  }
461
462  if (!w1.Member() || !w2.Member())
463    return StringWeight<L, S>::NoWeight();
464
465  if (w2 == StringWeight<L, S>::Zero())
466    return StringWeight<L, S>(kStringBad);
467  else if (w1 == StringWeight<L, S>::Zero())
468    return StringWeight<L, S>::Zero();
469
470  StringWeight<L, S> div;
471  StringWeightIterator<L, S> iter(w1);
472  for (int i = 0; !iter.Done(); iter.Next(), ++i) {
473    if (i >= w2.Size())
474      div.PushBack(iter.Value());
475  }
476  return div;
477}
478
479
480// Right division in the right string semiring.
481template <typename L> inline StringWeight<L, STRING_RIGHT>
482Divide(const StringWeight<L, STRING_RIGHT> &w1,
483       const StringWeight<L, STRING_RIGHT> &w2,
484       DivideType typ) {
485
486  if (typ != DIVIDE_RIGHT) {
487    FSTERROR() << "StringWeight::Divide: only right division is defined "
488               << "for the right string semiring";
489    return StringWeight<L, STRING_RIGHT>::NoWeight();
490  }
491
492  if (!w1.Member() || !w2.Member())
493    return StringWeight<L, STRING_RIGHT>::NoWeight();
494
495  if (w2 == StringWeight<L, STRING_RIGHT>::Zero())
496    return StringWeight<L, STRING_RIGHT>(kStringBad);
497  else if (w1 == StringWeight<L, STRING_RIGHT>::Zero())
498    return StringWeight<L, STRING_RIGHT>::Zero();
499
500  StringWeight<L, STRING_RIGHT> div;
501  StringWeightReverseIterator<L, STRING_RIGHT> iter(w1);
502  for (int i = 0; !iter.Done(); iter.Next(), ++i) {
503    if (i >= w2.Size())
504      div.PushFront(iter.Value());
505  }
506  return div;
507}
508
509
510// Right division in the right restricted string semiring.
511template <typename L> inline StringWeight<L, STRING_RIGHT_RESTRICT>
512Divide(const StringWeight<L, STRING_RIGHT_RESTRICT> &w1,
513       const StringWeight<L, STRING_RIGHT_RESTRICT> &w2,
514       DivideType typ) {
515
516  if (typ != DIVIDE_RIGHT) {
517    FSTERROR() << "StringWeight::Divide: only right division is defined "
518               << "for the right restricted string semiring";
519    return StringWeight<L, STRING_RIGHT_RESTRICT>::NoWeight();
520  }
521
522  if (!w1.Member() || !w2.Member())
523    return StringWeight<L, STRING_RIGHT_RESTRICT>::NoWeight();
524
525  if (w2 == StringWeight<L, STRING_RIGHT_RESTRICT>::Zero())
526    return StringWeight<L, STRING_RIGHT_RESTRICT>(kStringBad);
527  else if (w1 == StringWeight<L, STRING_RIGHT_RESTRICT>::Zero())
528    return StringWeight<L, STRING_RIGHT_RESTRICT>::Zero();
529
530  StringWeight<L, STRING_RIGHT_RESTRICT> div;
531  StringWeightReverseIterator<L, STRING_RIGHT_RESTRICT> iter(w1);
532  for (int i = 0; !iter.Done(); iter.Next(), ++i) {
533    if (i >= w2.Size())
534      div.PushFront(iter.Value());
535  }
536  return div;
537}
538
539
540// Product of string weight and an arbitray weight.
541template <class L, class W, StringType S = STRING_LEFT>
542struct GallicWeight : public ProductWeight<StringWeight<L, S>, W> {
543  typedef GallicWeight<L, typename W::ReverseWeight, REVERSE_STRING_TYPE(S)>
544  ReverseWeight;
545
546  GallicWeight() {}
547
548  GallicWeight(StringWeight<L, S> w1, W w2)
549      : ProductWeight<StringWeight<L, S>, W>(w1, w2) {}
550
551  explicit GallicWeight(const string &s, int *nread = 0)
552      : ProductWeight<StringWeight<L, S>, W>(s, nread) {}
553
554  GallicWeight(const ProductWeight<StringWeight<L, S>, W> &w)
555      : ProductWeight<StringWeight<L, S>, W>(w) {}
556};
557
558}  // namespace fst
559
560#endif  // FST_LIB_STRING_WEIGHT_H__
561