1
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13//
14// Copyright 2005-2010 Google, Inc.
15// Author: sorenj@google.com (Jeffrey Sorensen)
16//
17#ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
18#define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
19
20#include <stddef.h>
21#include <string.h>
22#include <algorithm>
23#include <string>
24#include <vector>
25using std::vector;
26
27#include <fst/compat.h>
28#include <fst/fstlib.h>
29#include <fst/mapped-file.h>
30#include <fst/extensions/ngram/bitmap-index.h>
31
32// NgramFst implements a n-gram language model based upon the LOUDS data
33// structure.  Please refer to "Unary Data Strucutres for Language Models"
34// http://research.google.com/pubs/archive/37218.pdf
35
36namespace fst {
37template <class A> class NGramFst;
38template <class A> class NGramFstMatcher;
39
40// Instance data containing mutable state for bookkeeping repeated access to
41// the same state.
42template <class A>
43struct NGramFstInst {
44  typedef typename A::Label Label;
45  typedef typename A::StateId StateId;
46  typedef typename A::Weight Weight;
47  StateId state_;
48  size_t num_futures_;
49  size_t offset_;
50  size_t node_;
51  StateId node_state_;
52  vector<Label> context_;
53  StateId context_state_;
54  NGramFstInst()
55      : state_(kNoStateId), node_state_(kNoStateId),
56        context_state_(kNoStateId) { }
57};
58
59// Implementation class for LOUDS based NgramFst interface
60template <class A>
61class NGramFstImpl : public FstImpl<A> {
62  using FstImpl<A>::SetInputSymbols;
63  using FstImpl<A>::SetOutputSymbols;
64  using FstImpl<A>::SetType;
65  using FstImpl<A>::WriteHeader;
66
67  friend class ArcIterator<NGramFst<A> >;
68  friend class NGramFstMatcher<A>;
69
70 public:
71  using FstImpl<A>::InputSymbols;
72  using FstImpl<A>::SetProperties;
73  using FstImpl<A>::Properties;
74
75  typedef A Arc;
76  typedef typename A::Label Label;
77  typedef typename A::StateId StateId;
78  typedef typename A::Weight Weight;
79
80  NGramFstImpl() : data_region_(0), data_(0), owned_(false) {
81    SetType("ngram");
82    SetInputSymbols(NULL);
83    SetOutputSymbols(NULL);
84    SetProperties(kStaticProperties);
85  }
86
87  NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out);
88
89  ~NGramFstImpl() {
90    if (owned_) {
91      delete [] data_;
92    }
93    delete data_region_;
94  }
95
96  static NGramFstImpl<A>* Read(istream &strm,  // NOLINT
97                               const FstReadOptions &opts) {
98    NGramFstImpl<A>* impl = new NGramFstImpl();
99    FstHeader hdr;
100    if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0;
101    uint64 num_states, num_futures, num_final;
102    const size_t offset = sizeof(num_states) + sizeof(num_futures) +
103        sizeof(num_final);
104    // Peek at num_states and num_futures to see how much more needs to be read.
105    strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states));
106    strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures));
107    strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final));
108    size_t size = Storage(num_states, num_futures, num_final);
109    MappedFile *data_region = MappedFile::Allocate(size);
110    char *data = reinterpret_cast<char *>(data_region->mutable_data());
111    // Copy num_states, num_futures and num_final back into data.
112    memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states));
113    memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures),
114           sizeof(num_futures));
115    memcpy(data + sizeof(num_states) + sizeof(num_futures),
116           reinterpret_cast<char *>(&num_final), sizeof(num_final));
117    strm.read(data + offset, size - offset);
118    if (!strm) {
119      delete impl;
120      return NULL;
121    }
122    impl->Init(data, false, data_region);
123    return impl;
124  }
125
126  bool Write(ostream &strm,   // NOLINT
127             const FstWriteOptions &opts) const {
128    FstHeader hdr;
129    hdr.SetStart(Start());
130    hdr.SetNumStates(num_states_);
131    WriteHeader(strm, opts, kFileVersion, &hdr);
132    strm.write(data_, StorageSize());
133    return strm;
134  }
135
136  StateId Start() const {
137    return 1;
138  }
139
140  Weight Final(StateId state) const {
141    if (final_index_.Get(state)) {
142      return final_probs_[final_index_.Rank1(state)];
143    } else {
144      return Weight::Zero();
145    }
146  }
147
148  size_t NumArcs(StateId state, NGramFstInst<A> *inst = NULL) const {
149    if (inst == NULL) {
150      const size_t next_zero = future_index_.Select0(state + 1);
151      const size_t this_zero = future_index_.Select0(state);
152      return next_zero - this_zero - 1;
153    }
154    SetInstFuture(state, inst);
155    return inst->num_futures_ + ((state == 0) ? 0 : 1);
156  }
157
158  size_t NumInputEpsilons(StateId state) const {
159    // State 0 has no parent, thus no backoff.
160    if (state == 0) return 0;
161    return 1;
162  }
163
164  size_t NumOutputEpsilons(StateId state) const {
165    return NumInputEpsilons(state);
166  }
167
168  StateId NumStates() const {
169    return num_states_;
170  }
171
172  void InitStateIterator(StateIteratorData<A>* data) const {
173    data->base = 0;
174    data->nstates = num_states_;
175  }
176
177  static size_t Storage(uint64 num_states, uint64 num_futures,
178                        uint64 num_final) {
179    uint64 b64;
180    Weight weight;
181    Label label;
182    size_t offset = sizeof(num_states) + sizeof(num_futures) +
183        sizeof(num_final);
184    offset += sizeof(b64) * (
185        BitmapIndex::StorageSize(num_states * 2 + 1) +
186        BitmapIndex::StorageSize(num_futures + num_states + 1) +
187        BitmapIndex::StorageSize(num_states));
188    offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label);
189    // Pad for alignemnt, see
190    // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding
191    offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
192    offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) +
193        (num_futures + 1) * sizeof(weight);
194    return offset;
195  }
196
197  void SetInstFuture(StateId state, NGramFstInst<A> *inst) const {
198    if (inst->state_ != state) {
199      inst->state_ = state;
200      const size_t next_zero = future_index_.Select0(state + 1);
201      const size_t this_zero = future_index_.Select0(state);
202      inst->num_futures_ = next_zero - this_zero - 1;
203      inst->offset_ = future_index_.Rank1(future_index_.Select0(state) + 1);
204    }
205  }
206
207  void SetInstNode(NGramFstInst<A> *inst) const {
208    if (inst->node_state_ != inst->state_) {
209      inst->node_state_ = inst->state_;
210      inst->node_ = context_index_.Select1(inst->state_);
211    }
212  }
213
214  void SetInstContext(NGramFstInst<A> *inst) const {
215    SetInstNode(inst);
216    if (inst->context_state_ != inst->state_) {
217      inst->context_state_ = inst->state_;
218      inst->context_.clear();
219      size_t node = inst->node_;
220      while (node != 0) {
221        inst->context_.push_back(context_words_[context_index_.Rank1(node)]);
222        node = context_index_.Select1(context_index_.Rank0(node) - 1);
223      }
224    }
225  }
226
227  // Access to the underlying representation
228  const char* GetData(size_t* data_size) const {
229    *data_size = StorageSize();
230    return data_;
231  }
232
233  void Init(const char* data, bool owned, MappedFile *file = 0);
234
235  const vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const {
236    SetInstFuture(s, inst);
237    SetInstContext(inst);
238    return inst->context_;
239  }
240
241  size_t StorageSize() const {
242    return Storage(num_states_, num_futures_, num_final_);
243  }
244
245  void GetStates(const vector<Label>& context, vector<StateId> *states) const;
246
247 private:
248  StateId Transition(const vector<Label> &context, Label future) const;
249
250  // Properties always true for this Fst class.
251  static const uint64 kStaticProperties = kAcceptor | kIDeterministic |
252      kODeterministic | kEpsilons | kIEpsilons | kOEpsilons | kILabelSorted |
253      kOLabelSorted | kWeighted | kCyclic | kInitialAcyclic | kNotTopSorted |
254      kAccessible | kCoAccessible | kNotString | kExpanded;
255  // Current file format version.
256  static const int kFileVersion = 4;
257  // Minimum file format version supported.
258  static const int kMinFileVersion = 4;
259
260  MappedFile *data_region_;
261  const char* data_;
262  bool owned_;  // True if we own data_
263  uint64 num_states_, num_futures_, num_final_;
264  size_t root_num_children_;
265  const Label *root_children_;
266  size_t root_first_child_;
267  // borrowed references
268  const uint64 *context_, *future_, *final_;
269  const Label *context_words_, *future_words_;
270  const Weight *backoff_, *final_probs_, *future_probs_;
271  BitmapIndex context_index_;
272  BitmapIndex future_index_;
273  BitmapIndex final_index_;
274
275  void operator=(const NGramFstImpl<A> &);  // Disallow
276};
277
278template<typename A>
279NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out)
280    : data_region_(0), data_(0), owned_(false) {
281  typedef A Arc;
282  typedef typename Arc::Label Label;
283  typedef typename Arc::Weight Weight;
284  typedef typename Arc::StateId StateId;
285  SetType("ngram");
286  SetInputSymbols(fst.InputSymbols());
287  SetOutputSymbols(fst.OutputSymbols());
288  SetProperties(kStaticProperties);
289
290  // Check basic requirements for an OpenGRM language model Fst.
291  int64 props = kAcceptor | kIDeterministic | kIEpsilons | kILabelSorted;
292  if (fst.Properties(props, true) != props) {
293    FSTERROR() << "NGramFst only accepts OpenGRM langauge models as input";
294    SetProperties(kError, kError);
295    return;
296  }
297
298  int64 num_states = CountStates(fst);
299  Label* context = new Label[num_states];
300
301  // Find the unigram state by starting from the start state, following
302  // epsilons.
303  StateId unigram = fst.Start();
304  while (1) {
305    if (unigram == kNoStateId) {
306      FSTERROR() << "Could not identify unigram state.";
307      SetProperties(kError, kError);
308      return;
309    }
310    ArcIterator<Fst<A> > aiter(fst, unigram);
311    if (aiter.Done()) {
312      LOG(WARNING) << "Unigram state " << unigram << " has no arcs.";
313      break;
314    }
315    if (aiter.Value().ilabel != 0) break;
316    unigram = aiter.Value().nextstate;
317  }
318
319  // Each state's context is determined by the subtree it is under from the
320  // unigram state.
321  queue<pair<StateId, Label> > label_queue;
322  vector<bool> visited(num_states);
323  // Force an epsilon link to the start state.
324  label_queue.push(make_pair(fst.Start(), 0));
325  for (ArcIterator<Fst<A> > aiter(fst, unigram);
326       !aiter.Done(); aiter.Next()) {
327    label_queue.push(make_pair(aiter.Value().nextstate, aiter.Value().ilabel));
328  }
329  // investigate states in breadth first fashion to assign context words.
330  while (!label_queue.empty()) {
331    pair<StateId, Label> &now = label_queue.front();
332    if (!visited[now.first]) {
333      context[now.first] = now.second;
334      visited[now.first] = true;
335      for (ArcIterator<Fst<A> > aiter(fst, now.first);
336           !aiter.Done(); aiter.Next()) {
337        const Arc &arc = aiter.Value();
338        if (arc.ilabel != 0) {
339          label_queue.push(make_pair(arc.nextstate, now.second));
340        }
341      }
342    }
343    label_queue.pop();
344  }
345  visited.clear();
346
347  // The arc from the start state should be assigned an epsilon to put it
348  // in front of the all other labels (which makes Start state 1 after
349  // unigram which is state 0).
350  context[fst.Start()] = 0;
351
352  // Build the tree of contexts fst by reversing the epsilon arcs from fst.
353  VectorFst<Arc> context_fst;
354  uint64 num_final = 0;
355  for (int i = 0; i < num_states; ++i) {
356    if (fst.Final(i) != Weight::Zero()) {
357      ++num_final;
358    }
359    context_fst.SetFinal(context_fst.AddState(), fst.Final(i));
360  }
361  context_fst.SetStart(unigram);
362  context_fst.SetInputSymbols(fst.InputSymbols());
363  context_fst.SetOutputSymbols(fst.OutputSymbols());
364  int64 num_context_arcs = 0;
365  int64 num_futures = 0;
366  for (StateIterator<Fst<A> > siter(fst); !siter.Done(); siter.Next()) {
367    const StateId &state = siter.Value();
368    num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state);
369    ArcIterator<Fst<A> > aiter(fst, state);
370    if (!aiter.Done()) {
371      const Arc &arc = aiter.Value();
372      // this arc goes from state to arc.nextstate, so create an arc from
373      // arc.nextstate to state to reverse it.
374      if (arc.ilabel == 0) {
375        context_fst.AddArc(arc.nextstate, Arc(context[state], context[state],
376                                              arc.weight, state));
377        num_context_arcs++;
378      }
379    }
380  }
381  if (num_context_arcs != context_fst.NumStates() - 1) {
382    FSTERROR() << "Number of contexts arcs != number of states - 1";
383    SetProperties(kError, kError);
384    return;
385  }
386  if (context_fst.NumStates() != num_states) {
387    FSTERROR() << "Number of contexts != number of states";
388    SetProperties(kError, kError);
389    return;
390  }
391  int64 context_props = context_fst.Properties(kIDeterministic |
392                                               kILabelSorted, true);
393  if (!(context_props & kIDeterministic)) {
394    FSTERROR() << "Input fst is not structured properly";
395    SetProperties(kError, kError);
396    return;
397  }
398  if (!(context_props & kILabelSorted)) {
399     ArcSort(&context_fst, ILabelCompare<Arc>());
400  }
401
402  delete [] context;
403
404  uint64 b64;
405  Weight weight;
406  Label label = kNoLabel;
407  const size_t storage = Storage(num_states, num_futures, num_final);
408  MappedFile *data_region = MappedFile::Allocate(storage);
409  char *data = reinterpret_cast<char *>(data_region->mutable_data());
410  memset(data, 0, storage);
411  size_t offset = 0;
412  memcpy(data + offset, reinterpret_cast<char *>(&num_states),
413         sizeof(num_states));
414  offset += sizeof(num_states);
415  memcpy(data + offset, reinterpret_cast<char *>(&num_futures),
416         sizeof(num_futures));
417  offset += sizeof(num_futures);
418  memcpy(data + offset, reinterpret_cast<char *>(&num_final),
419         sizeof(num_final));
420  offset += sizeof(num_final);
421  uint64* context_bits = reinterpret_cast<uint64*>(data + offset);
422  offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64);
423  uint64* future_bits = reinterpret_cast<uint64*>(data + offset);
424  offset +=
425      BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64);
426  uint64* final_bits = reinterpret_cast<uint64*>(data + offset);
427  offset += BitmapIndex::StorageSize(num_states) * sizeof(b64);
428  Label* context_words = reinterpret_cast<Label*>(data + offset);
429  offset += (num_states + 1) * sizeof(label);
430  Label* future_words = reinterpret_cast<Label*>(data + offset);
431  offset += num_futures * sizeof(label);
432  offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1);
433  Weight* backoff = reinterpret_cast<Weight*>(data + offset);
434  offset += (num_states + 1) * sizeof(weight);
435  Weight* final_probs = reinterpret_cast<Weight*>(data + offset);
436  offset += num_final * sizeof(weight);
437  Weight* future_probs = reinterpret_cast<Weight*>(data + offset);
438  int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0,
439        final_bit = 0;
440
441  // pseudo-root bits
442  BitmapIndex::Set(context_bits, context_bit++);
443  ++context_bit;
444  context_words[context_arc] = label;
445  backoff[context_arc] = Weight::Zero();
446  context_arc++;
447
448  ++future_bit;
449  if (order_out) {
450    order_out->clear();
451    order_out->resize(num_states);
452  }
453
454  queue<StateId> context_q;
455  context_q.push(context_fst.Start());
456  StateId state_number = 0;
457  while (!context_q.empty()) {
458    const StateId &state = context_q.front();
459    if (order_out) {
460      (*order_out)[state] = state_number;
461    }
462
463    const Weight &final = context_fst.Final(state);
464    if (final != Weight::Zero()) {
465      BitmapIndex::Set(final_bits, state_number);
466      final_probs[final_bit] = final;
467      ++final_bit;
468    }
469
470    for (ArcIterator<VectorFst<A> > aiter(context_fst, state);
471         !aiter.Done(); aiter.Next()) {
472      const Arc &arc = aiter.Value();
473      context_words[context_arc] = arc.ilabel;
474      backoff[context_arc] = arc.weight;
475      ++context_arc;
476      BitmapIndex::Set(context_bits, context_bit++);
477      context_q.push(arc.nextstate);
478    }
479    ++context_bit;
480
481    for (ArcIterator<Fst<A> > aiter(fst, state); !aiter.Done(); aiter.Next()) {
482      const Arc &arc = aiter.Value();
483      if (arc.ilabel != 0) {
484        future_words[future_arc] = arc.ilabel;
485        future_probs[future_arc] = arc.weight;
486        ++future_arc;
487        BitmapIndex::Set(future_bits, future_bit++);
488      }
489    }
490    ++future_bit;
491    ++state_number;
492    context_q.pop();
493  }
494
495  if ((state_number !=  num_states) ||
496      (context_bit != num_states * 2 + 1) ||
497      (context_arc != num_states) ||
498      (future_arc != num_futures) ||
499      (future_bit != num_futures + num_states + 1) ||
500      (final_bit != num_final)) {
501    FSTERROR() << "Structure problems detected during construction";
502    SetProperties(kError, kError);
503    return;
504  }
505
506  Init(data, false, data_region);
507}
508
509template<typename A>
510inline void NGramFstImpl<A>::Init(const char* data, bool owned,
511                                  MappedFile *data_region) {
512  if (owned_) {
513    delete [] data_;
514  }
515  delete data_region_;
516  data_region_ = data_region;
517  owned_ = owned;
518  data_ = data;
519  size_t offset = 0;
520  num_states_ = *(reinterpret_cast<const uint64*>(data_ + offset));
521  offset += sizeof(num_states_);
522  num_futures_ = *(reinterpret_cast<const uint64*>(data_ + offset));
523  offset += sizeof(num_futures_);
524  num_final_ = *(reinterpret_cast<const uint64*>(data_ + offset));
525  offset += sizeof(num_final_);
526  uint64 bits;
527  size_t context_bits = num_states_ * 2 + 1;
528  size_t future_bits = num_futures_ + num_states_ + 1;
529  context_ = reinterpret_cast<const uint64*>(data_ + offset);
530  offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits);
531  future_ = reinterpret_cast<const uint64*>(data_ + offset);
532  offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits);
533  final_ = reinterpret_cast<const uint64*>(data_ + offset);
534  offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits);
535  context_words_ = reinterpret_cast<const Label*>(data_ + offset);
536  offset += (num_states_ + 1) * sizeof(*context_words_);
537  future_words_ = reinterpret_cast<const Label*>(data_ + offset);
538  offset += num_futures_ * sizeof(*future_words_);
539  offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1);
540  backoff_ = reinterpret_cast<const Weight*>(data_ + offset);
541  offset += (num_states_ + 1) * sizeof(*backoff_);
542  final_probs_ = reinterpret_cast<const Weight*>(data_ + offset);
543  offset += num_final_ * sizeof(*final_probs_);
544  future_probs_ = reinterpret_cast<const Weight*>(data_ + offset);
545
546  context_index_.BuildIndex(context_, context_bits);
547  future_index_.BuildIndex(future_, future_bits);
548  final_index_.BuildIndex(final_, num_states_);
549
550  const size_t node_rank = context_index_.Rank1(0);
551  root_first_child_ = context_index_.Select0(node_rank) + 1;
552  if (context_index_.Get(root_first_child_) == false) {
553    FSTERROR() << "Missing unigrams";
554    SetProperties(kError, kError);
555    return;
556  }
557  const size_t last_child = context_index_.Select0(node_rank + 1) - 1;
558  root_num_children_ = last_child - root_first_child_ + 1;
559  root_children_ = context_words_ + context_index_.Rank1(root_first_child_);
560}
561
562template<typename A>
563inline typename A::StateId NGramFstImpl<A>::Transition(
564        const vector<Label> &context, Label future) const {
565  const Label *children = root_children_;
566  const Label *loc = lower_bound(children, children + root_num_children_,
567                                 future);
568  if (loc == children + root_num_children_ || *loc != future) {
569    return context_index_.Rank1(0);
570  }
571  size_t node = root_first_child_ + loc - children;
572  size_t node_rank = context_index_.Rank1(node);
573  size_t first_child = context_index_.Select0(node_rank) + 1;
574  if (context_index_.Get(first_child) == false) {
575    return context_index_.Rank1(node);
576  }
577  size_t last_child = context_index_.Select0(node_rank + 1) - 1;
578  for (int word = context.size() - 1; word >= 0; --word) {
579    children = context_words_ + context_index_.Rank1(first_child);
580    loc = lower_bound(children, children + last_child - first_child + 1,
581                      context[word]);
582    if (loc == children + last_child - first_child + 1 ||
583        *loc != context[word]) {
584      break;
585    }
586    node = first_child + loc - children;
587    node_rank = context_index_.Rank1(node);
588    first_child = context_index_.Select0(node_rank) + 1;
589    if (context_index_.Get(first_child) == false) break;
590    last_child = context_index_.Select0(node_rank + 1) - 1;
591  }
592  return context_index_.Rank1(node);
593}
594
595template<typename A>
596inline void NGramFstImpl<A>::GetStates(
597    const vector<Label> &context,
598    vector<typename A::StateId>* states) const {
599  states->clear();
600  states->push_back(0);
601  typename vector<Label>::const_reverse_iterator cit = context.rbegin();
602  const Label *children = root_children_;
603  const Label *loc = lower_bound(children, children + root_num_children_, *cit);
604  if (loc == children + root_num_children_ || *loc != *cit) return;
605  size_t node = root_first_child_ + loc - children;
606  states->push_back(context_index_.Rank1(node));
607  if (context.size() == 1) return;
608  size_t node_rank = context_index_.Rank1(node);
609  size_t first_child = context_index_.Select0(node_rank) + 1;
610  ++cit;
611  if (context_index_.Get(first_child) != false) {
612    size_t last_child = context_index_.Select0(node_rank + 1) - 1;
613    while (cit != context.rend()) {
614      children = context_words_ + context_index_.Rank1(first_child);
615      loc = lower_bound(children, children + last_child - first_child + 1,
616                        *cit);
617      if (loc == children + last_child - first_child + 1 || *loc != *cit) {
618        break;
619      }
620      ++cit;
621      node = first_child + loc - children;
622      states->push_back(context_index_.Rank1(node));
623      node_rank = context_index_.Rank1(node);
624      first_child = context_index_.Select0(node_rank) + 1;
625      if (context_index_.Get(first_child) == false) break;
626      last_child = context_index_.Select0(node_rank + 1) - 1;
627    }
628  }
629}
630
631/*****************************************************************************/
632template<class A>
633class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > {
634  friend class ArcIterator<NGramFst<A> >;
635  friend class NGramFstMatcher<A>;
636
637 public:
638  typedef A Arc;
639  typedef typename A::StateId StateId;
640  typedef typename A::Label Label;
641  typedef typename A::Weight Weight;
642  typedef NGramFstImpl<A> Impl;
643
644  explicit NGramFst(const Fst<A> &dst)
645      : ImplToExpandedFst<Impl>(new Impl(dst, NULL)) {}
646
647  NGramFst(const Fst<A> &fst, vector<StateId>* order_out)
648      : ImplToExpandedFst<Impl>(new Impl(fst, order_out)) {}
649
650  // Because the NGramFstImpl is a const stateless data structure, there
651  // is never a need to do anything beside copy the reference.
652  NGramFst(const NGramFst<A> &fst, bool safe = false)
653      : ImplToExpandedFst<Impl>(fst, false) {}
654
655  NGramFst() : ImplToExpandedFst<Impl>(new Impl()) {}
656
657  // Non-standard constructor to initialize NGramFst directly from data.
658  NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) {
659    GetImpl()->Init(data, owned, NULL);
660  }
661
662  // Get method that gets the data associated with Init().
663  const char* GetData(size_t* data_size) const {
664    return GetImpl()->GetData(data_size);
665  }
666
667  const vector<Label> GetContext(StateId s) const {
668    return GetImpl()->GetContext(s, &inst_);
669  }
670
671  // Consumes as much as possible of context from right to left, returns the
672  // the states corresponding to the increasingly conditioned input sequence.
673  void GetStates(const vector<Label>& context, vector<StateId> *state) const {
674    return GetImpl()->GetStates(context, state);
675  }
676
677  virtual size_t NumArcs(StateId s) const {
678    return GetImpl()->NumArcs(s, &inst_);
679  }
680
681  virtual NGramFst<A>* Copy(bool safe = false) const {
682    return new NGramFst(*this, safe);
683  }
684
685  static NGramFst<A>* Read(istream &strm, const FstReadOptions &opts) {
686    Impl* impl = Impl::Read(strm, opts);
687    return impl ? new NGramFst<A>(impl) : 0;
688  }
689
690  static NGramFst<A>* Read(const string &filename) {
691    if (!filename.empty()) {
692      ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
693      if (!strm) {
694        LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename;
695        return 0;
696      }
697      return Read(strm, FstReadOptions(filename));
698    } else {
699      return Read(cin, FstReadOptions("standard input"));
700    }
701  }
702
703  virtual bool Write(ostream &strm, const FstWriteOptions &opts) const {
704    return GetImpl()->Write(strm, opts);
705  }
706
707  virtual bool Write(const string &filename) const {
708    return Fst<A>::WriteFile(filename);
709  }
710
711  virtual inline void InitStateIterator(StateIteratorData<A>* data) const {
712    GetImpl()->InitStateIterator(data);
713  }
714
715  virtual inline void InitArcIterator(
716      StateId s, ArcIteratorData<A>* data) const;
717
718  virtual MatcherBase<A>* InitMatcher(MatchType match_type) const {
719    return new NGramFstMatcher<A>(*this, match_type);
720  }
721
722  size_t StorageSize() const {
723    return GetImpl()->StorageSize();
724  }
725
726 private:
727  explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {}
728
729  Impl* GetImpl() const {
730    return
731        ImplToExpandedFst<Impl, ExpandedFst<A> >::GetImpl();
732  }
733
734  void SetImpl(Impl* impl, bool own_impl = true) {
735    ImplToExpandedFst<Impl, Fst<A> >::SetImpl(impl, own_impl);
736  }
737
738  mutable NGramFstInst<A> inst_;
739};
740
741template <class A> inline void
742NGramFst<A>::InitArcIterator(StateId s, ArcIteratorData<A>* data) const {
743  GetImpl()->SetInstFuture(s, &inst_);
744  GetImpl()->SetInstNode(&inst_);
745  data->base = new ArcIterator<NGramFst<A> >(*this, s);
746}
747
748/*****************************************************************************/
749template <class A>
750class NGramFstMatcher : public MatcherBase<A> {
751 public:
752  typedef A Arc;
753  typedef typename A::Label Label;
754  typedef typename A::StateId StateId;
755  typedef typename A::Weight Weight;
756
757  NGramFstMatcher(const NGramFst<A> &fst, MatchType match_type)
758      : fst_(fst), inst_(fst.inst_), match_type_(match_type),
759        current_loop_(false),
760        loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
761    if (match_type_ == MATCH_OUTPUT) {
762      swap(loop_.ilabel, loop_.olabel);
763    }
764  }
765
766  NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false)
767      : fst_(matcher.fst_), inst_(matcher.inst_),
768        match_type_(matcher.match_type_), current_loop_(false),
769        loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) {
770    if (match_type_ == MATCH_OUTPUT) {
771      swap(loop_.ilabel, loop_.olabel);
772    }
773  }
774
775  virtual NGramFstMatcher<A>* Copy(bool safe = false) const {
776    return new NGramFstMatcher<A>(*this, safe);
777  }
778
779  virtual MatchType Type(bool test) const {
780    return match_type_;
781  }
782
783  virtual const Fst<A> &GetFst() const {
784    return fst_;
785  }
786
787  virtual uint64 Properties(uint64 props) const {
788    return props;
789  }
790
791 private:
792  virtual void SetState_(StateId s) {
793    fst_.GetImpl()->SetInstFuture(s, &inst_);
794    current_loop_ = false;
795  }
796
797  virtual bool Find_(Label label) {
798    const Label nolabel = kNoLabel;
799    done_ = true;
800    if (label == 0 || label == nolabel) {
801      if (label == 0) {
802        current_loop_ = true;
803        loop_.nextstate = inst_.state_;
804      }
805      // The unigram state has no epsilon arc.
806      if (inst_.state_ != 0) {
807        arc_.ilabel = arc_.olabel = 0;
808        fst_.GetImpl()->SetInstNode(&inst_);
809        arc_.nextstate = fst_.GetImpl()->context_index_.Rank1(
810            fst_.GetImpl()->context_index_.Select1(
811                fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1));
812        arc_.weight = fst_.GetImpl()->backoff_[inst_.state_];
813        done_ = false;
814      }
815    } else {
816      const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_;
817      const Label *end = start + inst_.num_futures_;
818      const Label* search = lower_bound(start, end, label);
819      if (search != end && *search == label) {
820        size_t state = search - start;
821        arc_.ilabel = arc_.olabel = label;
822        arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state];
823        fst_.GetImpl()->SetInstContext(&inst_);
824        arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label);
825        done_ = false;
826      }
827    }
828    return !Done_();
829  }
830
831  virtual bool Done_() const {
832    return !current_loop_ && done_;
833  }
834
835  virtual const Arc& Value_() const {
836    return (current_loop_) ? loop_ : arc_;
837  }
838
839  virtual void Next_() {
840    if (current_loop_) {
841      current_loop_ = false;
842    } else {
843      done_ = true;
844    }
845  }
846
847  const NGramFst<A>& fst_;
848  NGramFstInst<A> inst_;
849  MatchType match_type_;             // Supplied by caller
850  bool done_;
851  Arc arc_;
852  bool current_loop_;                // Current arc is the implicit loop
853  Arc loop_;
854};
855
856/*****************************************************************************/
857template<class A>
858class ArcIterator<NGramFst<A> > : public ArcIteratorBase<A> {
859 public:
860  typedef A Arc;
861  typedef typename A::Label Label;
862  typedef typename A::StateId StateId;
863  typedef typename A::Weight Weight;
864
865  ArcIterator(const NGramFst<A> &fst, StateId state)
866      : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) {
867    inst_ = fst.inst_;
868    impl_->SetInstFuture(state, &inst_);
869    impl_->SetInstNode(&inst_);
870  }
871
872  bool Done() const {
873    return i_ >= ((inst_.node_ == 0) ? inst_.num_futures_ :
874                  inst_.num_futures_ + 1);
875  }
876
877  const Arc &Value() const {
878    bool eps = (inst_.node_ != 0 && i_ == 0);
879    StateId state = (inst_.node_ == 0) ? i_ : i_ - 1;
880    if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) {
881      arc_.ilabel =
882          arc_.olabel = eps ? 0 : impl_->future_words_[inst_.offset_ + state];
883      lazy_ &= ~(kArcILabelValue | kArcOLabelValue);
884    }
885    if (flags_ & lazy_ & kArcNextStateValue) {
886      if (eps) {
887        arc_.nextstate = impl_->context_index_.Rank1(
888            impl_->context_index_.Select1(
889                impl_->context_index_.Rank0(inst_.node_) - 1));
890      } else {
891        if (lazy_ & kArcNextStateValue) {
892          impl_->SetInstContext(&inst_);  // first time only.
893        }
894        arc_.nextstate =
895            impl_->Transition(inst_.context_,
896                              impl_->future_words_[inst_.offset_ + state]);
897      }
898      lazy_ &= ~kArcNextStateValue;
899    }
900    if (flags_ & lazy_ & kArcWeightValue) {
901      arc_.weight = eps ?  impl_->backoff_[inst_.state_] :
902          impl_->future_probs_[inst_.offset_ + state];
903      lazy_ &= ~kArcWeightValue;
904    }
905    return arc_;
906  }
907
908  void Next() {
909    ++i_;
910    lazy_ = ~0;
911  }
912
913  size_t Position() const { return i_; }
914
915  void Reset() {
916    i_ = 0;
917    lazy_ = ~0;
918  }
919
920  void Seek(size_t a) {
921    if (i_ != a) {
922      i_ = a;
923      lazy_ = ~0;
924    }
925  }
926
927  uint32 Flags() const {
928    return flags_;
929  }
930
931  void SetFlags(uint32 f, uint32 m) {
932    flags_ &= ~m;
933    flags_ |= (f & kArcValueFlags);
934  }
935
936 private:
937  virtual bool Done_() const { return Done(); }
938  virtual const Arc& Value_() const { return Value(); }
939  virtual void Next_() { Next(); }
940  virtual size_t Position_() const { return Position(); }
941  virtual void Reset_() { Reset(); }
942  virtual void Seek_(size_t a) { Seek(a); }
943  uint32 Flags_() const { return Flags(); }
944  void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); }
945
946  mutable Arc arc_;
947  mutable uint32 lazy_;
948  const NGramFstImpl<A> *impl_;
949  mutable NGramFstInst<A> inst_;
950
951  size_t i_;
952  uint32 flags_;
953
954  DISALLOW_COPY_AND_ASSIGN(ArcIterator);
955};
956
957/*****************************************************************************/
958// Specialization for NGramFst; see generic version in fst.h
959// for sample usage (but use the ProdLmFst type!). This version
960// should inline.
961template <class A>
962class StateIterator<NGramFst<A> > : public StateIteratorBase<A> {
963  public:
964  typedef typename A::StateId StateId;
965
966  explicit StateIterator(const NGramFst<A> &fst)
967    : s_(0), num_states_(fst.NumStates()) { }
968
969  bool Done() const { return s_ >= num_states_; }
970  StateId Value() const { return s_; }
971  void Next() { ++s_; }
972  void Reset() { s_ = 0; }
973
974 private:
975  virtual bool Done_() const { return Done(); }
976  virtual StateId Value_() const { return Value(); }
977  virtual void Next_() { Next(); }
978  virtual void Reset_() { Reset(); }
979
980  StateId s_, num_states_;
981
982  DISALLOW_COPY_AND_ASSIGN(StateIterator);
983};
984}  // namespace fst
985#endif  // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_
986