1// sparse-tuple-weight.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: krr@google.com (Kasturi Rangan Raghavan)
17// Inspiration: allauzen@google.com (Cyril Allauzen)
18// \file
19// Sparse version of tuple-weight, based on tuple-weight.h
20//   Internally stores sparse key, value pairs in linked list
21//   Default value elemnt is the assumed value of unset keys
22//   Internal singleton implementation that stores first key,
23//   value pair as a initialized member variable to avoide
24//   unnecessary allocation on heap.
25// Use SparseTupleWeightIterator to iterate through the key,value pairs
26// Note: this does NOT iterate through the default value.
27//
28// Sparse tuple weight set operation definitions.
29
30#ifndef FST_LIB_SPARSE_TUPLE_WEIGHT_H__
31#define FST_LIB_SPARSE_TUPLE_WEIGHT_H__
32
33#include<string>
34#include<list>
35#include<stack>
36#include<tr1/unordered_map>
37using std::tr1::unordered_map;
38using std::tr1::unordered_multimap;
39
40#include <fst/weight.h>
41
42
43DECLARE_string(fst_weight_parentheses);
44DECLARE_string(fst_weight_separator);
45
46namespace fst {
47
48template <class W, class K> class SparseTupleWeight;
49
50template<class W, class K>
51class SparseTupleWeightIterator;
52
53template <class W, class K>
54istream &operator>>(istream &strm, SparseTupleWeight<W, K> &w);
55
56// Arbitrary dimension tuple weight, stored as a sorted linked-list
57// W is any weight class,
58// K is the key value type. kNoKey(-1) is reserved for internal use
59template <class W, class K = int>
60class SparseTupleWeight {
61 public:
62  typedef pair<K, W> Pair;
63  typedef SparseTupleWeight<typename W::ReverseWeight, K> ReverseWeight;
64
65  const static K kNoKey = -1;
66  SparseTupleWeight() {
67    Init();
68  }
69
70  template <class Iterator>
71  SparseTupleWeight(Iterator begin, Iterator end) {
72    Init();
73    // Assumes input iterator is sorted
74    for (Iterator it = begin; it != end; ++it)
75      Push(*it);
76  }
77
78
79  SparseTupleWeight(const K& key, const W &w) {
80    Init();
81    Push(key, w);
82  }
83
84  SparseTupleWeight(const W &w) {
85    Init(w);
86  }
87
88  SparseTupleWeight(const SparseTupleWeight<W, K> &w) {
89    Init(w.DefaultValue());
90    SetDefaultValue(w.DefaultValue());
91    for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) {
92      Push(it.Value());
93    }
94  }
95
96  static const SparseTupleWeight<W, K> &Zero() {
97    static SparseTupleWeight<W, K> zero;
98    return zero;
99  }
100
101  static const SparseTupleWeight<W, K> &One() {
102    static SparseTupleWeight<W, K> one(W::One());
103    return one;
104  }
105
106  static const SparseTupleWeight<W, K> &NoWeight() {
107    static SparseTupleWeight<W, K> no_weight(W::NoWeight());
108    return no_weight;
109  }
110
111  istream &Read(istream &strm) {
112    ReadType(strm, &default_);
113    ReadType(strm, &first_);
114    return ReadType(strm, &rest_);
115  }
116
117  ostream &Write(ostream &strm) const {
118    WriteType(strm, default_);
119    WriteType(strm, first_);
120    return WriteType(strm, rest_);
121  }
122
123  SparseTupleWeight<W, K> &operator=(const SparseTupleWeight<W, K> &w) {
124    if (this == &w) return *this; // check for w = w
125    Init(w.DefaultValue());
126    for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) {
127      Push(it.Value());
128    }
129    return *this;
130  }
131
132  bool Member() const {
133    if (!DefaultValue().Member()) return false;
134    for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
135      if (!it.Value().second.Member()) return false;
136    }
137    return true;
138  }
139
140  // Assumes H() function exists for the hash of the key value
141  size_t Hash() const {
142    uint64 h = 0;
143    std::tr1::hash<K> H;
144    for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
145      h = 5 * h + H(it.Value().first);
146      h = 13 * h + it.Value().second.Hash();
147    }
148    return size_t(h);
149  }
150
151  SparseTupleWeight<W, K> Quantize(float delta = kDelta) const {
152    SparseTupleWeight<W, K> w;
153    for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
154      w.Push(it.Value().first, it.Value().second.Quantize(delta));
155    }
156    return w;
157  }
158
159  ReverseWeight Reverse() const {
160    SparseTupleWeight<W, K> w;
161    for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
162      w.Push(it.Value().first, it.Value().second.Reverse());
163    }
164    return w;
165  }
166
167  // Common initializer among constructors.
168  void Init() {
169    Init(W::Zero());
170  }
171
172  void Init(const W& default_value) {
173    first_.first = kNoKey;
174    /* initialized to the reserved key value */
175    default_ = default_value;
176    rest_.clear();
177  }
178
179  size_t Size() const {
180    if (first_.first == kNoKey)
181      return 0;
182    else
183      return  rest_.size() + 1;
184  }
185
186  inline void Push(const K &k, const W &w, bool default_value_check = true) {
187    Push(make_pair(k, w), default_value_check);
188  }
189
190  inline void Push(const Pair &p, bool default_value_check = true) {
191    if (default_value_check && p.second == default_) return;
192    if (first_.first == kNoKey) {
193      first_ = p;
194    } else {
195      rest_.push_back(p);
196    }
197  }
198
199  void SetDefaultValue(const W& val) { default_ = val; }
200
201  const W& DefaultValue() const { return default_; }
202
203 protected:
204  static istream& ReadNoParen(
205    istream&, SparseTupleWeight<W, K>&, char separator);
206
207  static istream& ReadWithParen(
208    istream&, SparseTupleWeight<W, K>&,
209    char separator, char open_paren, char close_paren);
210
211 private:
212  // Assumed default value of uninitialized keys, by default W::Zero()
213  W default_;
214
215  // Key values pairs are first stored in first_, then fill rest_
216  // this way we can avoid dynamic allocation in the common case
217  // where the weight is a single key,val pair.
218  Pair first_;
219  list<Pair> rest_;
220
221  friend istream &operator>><W, K>(istream&, SparseTupleWeight<W, K>&);
222  friend class SparseTupleWeightIterator<W, K>;
223};
224
225template<class W, class K>
226class SparseTupleWeightIterator {
227 public:
228  typedef typename SparseTupleWeight<W, K>::Pair Pair;
229  typedef typename list<Pair>::const_iterator const_iterator;
230  typedef typename list<Pair>::iterator iterator;
231
232  explicit SparseTupleWeightIterator(const SparseTupleWeight<W, K>& w)
233    : first_(w.first_), rest_(w.rest_), init_(true),
234      iter_(rest_.begin()) {}
235
236  bool Done() const {
237    if (init_)
238      return first_.first == SparseTupleWeight<W, K>::kNoKey;
239    else
240      return iter_ == rest_.end();
241  }
242
243  const Pair& Value() const { return init_ ? first_ : *iter_; }
244
245  void Next() {
246    if (init_)
247      init_ = false;
248    else
249      ++iter_;
250  }
251
252  void Reset() {
253    init_ = true;
254    iter_ = rest_.begin();
255  }
256
257 private:
258  const Pair &first_;
259  const list<Pair> & rest_;
260  bool init_;  // in the initialized state?
261  typename list<Pair>::const_iterator iter_;
262
263  DISALLOW_COPY_AND_ASSIGN(SparseTupleWeightIterator);
264};
265
266template<class W, class K, class M>
267inline void SparseTupleWeightMap(
268  SparseTupleWeight<W, K>* ret,
269  const SparseTupleWeight<W, K>& w1,
270  const SparseTupleWeight<W, K>& w2,
271  const M& operator_mapper) {
272  SparseTupleWeightIterator<W, K> w1_it(w1);
273  SparseTupleWeightIterator<W, K> w2_it(w2);
274  const W& v1_def = w1.DefaultValue();
275  const W& v2_def = w2.DefaultValue();
276  ret->SetDefaultValue(operator_mapper.Map(0, v1_def, v2_def));
277  while (!w1_it.Done() || !w2_it.Done()) {
278    const K& k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first;
279    const K& k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first;
280    const W& v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second;
281    const W& v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second;
282    if (k1 == k2) {
283      ret->Push(k1, operator_mapper.Map(k1, v1, v2));
284      if (!w1_it.Done()) w1_it.Next();
285      if (!w2_it.Done()) w2_it.Next();
286    } else if (k1 < k2) {
287      ret->Push(k1, operator_mapper.Map(k1, v1, v2_def));
288      w1_it.Next();
289    } else {
290      ret->Push(k2, operator_mapper.Map(k2, v1_def, v2));
291      w2_it.Next();
292    }
293  }
294}
295
296template <class W, class K>
297inline bool operator==(const SparseTupleWeight<W, K> &w1,
298                       const SparseTupleWeight<W, K> &w2) {
299  const W& v1_def = w1.DefaultValue();
300  const W& v2_def = w2.DefaultValue();
301  if (v1_def != v2_def) return false;
302
303  SparseTupleWeightIterator<W, K> w1_it(w1);
304  SparseTupleWeightIterator<W, K> w2_it(w2);
305  while (!w1_it.Done() || !w2_it.Done()) {
306    const K& k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first;
307    const K& k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first;
308    const W& v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second;
309    const W& v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second;
310    if (k1 == k2) {
311      if (v1 != v2) return false;
312      if (!w1_it.Done()) w1_it.Next();
313      if (!w2_it.Done()) w2_it.Next();
314    } else if (k1 < k2) {
315      if (v1 != v2_def) return false;
316      w1_it.Next();
317    } else {
318      if (v1_def != v2) return false;
319      w2_it.Next();
320    }
321  }
322  return true;
323}
324
325template <class W, class K>
326inline bool operator!=(const SparseTupleWeight<W, K> &w1,
327                       const SparseTupleWeight<W, K> &w2) {
328  return !(w1 == w2);
329}
330
331template <class W, class K>
332inline ostream &operator<<(ostream &strm, const SparseTupleWeight<W, K> &w) {
333  if(FLAGS_fst_weight_separator.size() != 1) {
334    FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1";
335    strm.clear(std::ios::badbit);
336    return strm;
337  }
338  char separator = FLAGS_fst_weight_separator[0];
339  bool write_parens = false;
340  if (!FLAGS_fst_weight_parentheses.empty()) {
341    if (FLAGS_fst_weight_parentheses.size() != 2) {
342      FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2";
343      strm.clear(std::ios::badbit);
344      return strm;
345    }
346    write_parens = true;
347  }
348
349  if (write_parens)
350    strm << FLAGS_fst_weight_parentheses[0];
351
352  strm << w.DefaultValue();
353  strm << separator;
354
355  size_t n = w.Size();
356  strm << n;
357  strm << separator;
358
359  for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) {
360      strm << it.Value().first;
361      strm << separator;
362      strm << it.Value().second;
363      strm << separator;
364  }
365
366  if (write_parens)
367    strm << FLAGS_fst_weight_parentheses[1];
368
369  return strm;
370}
371
372template <class W, class K>
373inline istream &operator>>(istream &strm, SparseTupleWeight<W, K> &w) {
374  if(FLAGS_fst_weight_separator.size() != 1) {
375    FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1";
376    strm.clear(std::ios::badbit);
377    return strm;
378  }
379  char separator = FLAGS_fst_weight_separator[0];
380
381  if (!FLAGS_fst_weight_parentheses.empty()) {
382    if (FLAGS_fst_weight_parentheses.size() != 2) {
383      FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2";
384      strm.clear(std::ios::badbit);
385      return strm;
386    }
387    return SparseTupleWeight<W, K>::ReadWithParen(
388        strm, w, separator, FLAGS_fst_weight_parentheses[0],
389        FLAGS_fst_weight_parentheses[1]);
390  } else {
391    return SparseTupleWeight<W, K>::ReadNoParen(strm, w, separator);
392  }
393}
394
395// Reads SparseTupleWeight when there are no parentheses around tuple terms
396template <class W, class K>
397inline istream& SparseTupleWeight<W, K>::ReadNoParen(
398    istream &strm,
399    SparseTupleWeight<W, K> &w,
400    char separator) {
401  int c;
402  size_t n;
403
404  do {
405    c = strm.get();
406  } while (isspace(c));
407
408
409  { // Read default weight
410    W default_value;
411    string s;
412    while (c != separator) {
413      if (c == EOF) {
414        strm.clear(std::ios::badbit);
415        return strm;
416      }
417      s += c;
418      c = strm.get();
419    }
420    istringstream sstrm(s);
421    sstrm >> default_value;
422    w.SetDefaultValue(default_value);
423  }
424
425  c = strm.get();
426
427  { // Read n
428    string s;
429    while (c != separator) {
430      if (c == EOF) {
431        strm.clear(std::ios::badbit);
432        return strm;
433      }
434      s += c;
435      c = strm.get();
436    }
437    istringstream sstrm(s);
438    sstrm >> n;
439  }
440
441  // Read n elements
442  for (size_t i = 0; i < n; ++i) {
443    // discard separator
444    c = strm.get();
445    K p;
446    W r;
447
448    { // read key
449      string s;
450      while (c != separator) {
451        if (c == EOF) {
452          strm.clear(std::ios::badbit);
453          return strm;
454        }
455        s += c;
456        c = strm.get();
457      }
458      istringstream sstrm(s);
459      sstrm >> p;
460    }
461
462    c = strm.get();
463
464    { // read weight
465      string s;
466      while (c != separator) {
467        if (c == EOF) {
468          strm.clear(std::ios::badbit);
469          return strm;
470        }
471        s += c;
472        c = strm.get();
473      }
474      istringstream sstrm(s);
475      sstrm >> r;
476    }
477
478    w.Push(p, r);
479  }
480
481  c = strm.get();
482  if (c != separator) {
483    strm.clear(std::ios::badbit);
484  }
485
486  return strm;
487}
488
489// Reads SparseTupleWeight when there are parentheses around tuple terms
490template <class W, class K>
491inline istream& SparseTupleWeight<W, K>::ReadWithParen(
492    istream &strm,
493    SparseTupleWeight<W, K> &w,
494    char separator,
495    char open_paren,
496    char close_paren) {
497  int c;
498  size_t n;
499
500  do {
501    c = strm.get();
502  } while (isspace(c));
503
504  if (c != open_paren) {
505    FSTERROR() << "is fst_weight_parentheses flag set correcty? ";
506    strm.clear(std::ios::badbit);
507    return strm;
508  }
509
510  c = strm.get();
511
512  { // Read weight
513    W default_value;
514    stack<int> parens;
515    string s;
516    while (c != separator || !parens.empty()) {
517      if (c == EOF) {
518        strm.clear(std::ios::badbit);
519        return strm;
520      }
521      s += c;
522      // If parens encountered before separator, they must be matched
523      if (c == open_paren) {
524        parens.push(1);
525      } else if (c == close_paren) {
526        // Fail for mismatched parens
527        if (parens.empty()) {
528          strm.clear(std::ios::failbit);
529          return strm;
530        }
531        parens.pop();
532      }
533      c = strm.get();
534    }
535    istringstream sstrm(s);
536    sstrm >> default_value;
537    w.SetDefaultValue(default_value);
538  }
539
540  c = strm.get();
541
542  { // Read n
543    string s;
544    while (c != separator) {
545      if (c == EOF) {
546        strm.clear(std::ios::badbit);
547        return strm;
548      }
549      s += c;
550      c = strm.get();
551    }
552    istringstream sstrm(s);
553    sstrm >> n;
554  }
555
556  // Read n elements
557  for (size_t i = 0; i < n; ++i) {
558    // discard separator
559    c = strm.get();
560    K p;
561    W r;
562
563    { // Read key
564      stack<int> parens;
565      string s;
566      while (c != separator || !parens.empty()) {
567        if (c == EOF) {
568          strm.clear(std::ios::badbit);
569          return strm;
570        }
571        s += c;
572        // If parens encountered before separator, they must be matched
573        if (c == open_paren) {
574          parens.push(1);
575        } else if (c == close_paren) {
576          // Fail for mismatched parens
577          if (parens.empty()) {
578            strm.clear(std::ios::failbit);
579            return strm;
580          }
581          parens.pop();
582        }
583        c = strm.get();
584      }
585      istringstream sstrm(s);
586      sstrm >> p;
587    }
588
589    c = strm.get();
590
591    { // Read weight
592      stack<int> parens;
593      string s;
594      while (c != separator || !parens.empty()) {
595        if (c == EOF) {
596          strm.clear(std::ios::badbit);
597          return strm;
598        }
599        s += c;
600        // If parens encountered before separator, they must be matched
601        if (c == open_paren) {
602          parens.push(1);
603        } else if (c == close_paren) {
604          // Fail for mismatched parens
605          if (parens.empty()) {
606            strm.clear(std::ios::failbit);
607            return strm;
608          }
609          parens.pop();
610        }
611        c = strm.get();
612      }
613      istringstream sstrm(s);
614      sstrm >> r;
615    }
616
617    w.Push(p, r);
618  }
619
620  if (c != separator) {
621    FSTERROR() << " separator expected, not found! ";
622    strm.clear(std::ios::badbit);
623    return strm;
624  }
625
626  c = strm.get();
627  if (c != close_paren) {
628    FSTERROR() << " is fst_weight_parentheses flag set correcty? ";
629    strm.clear(std::ios::badbit);
630    return strm;
631  }
632
633  return strm;
634}
635
636
637
638}  // namespace fst
639
640#endif  // FST_LIB_SPARSE_TUPLE_WEIGHT_H__
641