1#include <algorithm>
2#include <stdexcept>
3
4#include "trie.h"
5
6namespace marisa {
7namespace {
8
9template <typename T, typename U>
10class PredictCallback {
11 public:
12  PredictCallback(T key_ids, U keys, std::size_t max_num_results)
13      : key_ids_(key_ids), keys_(keys),
14        max_num_results_(max_num_results), num_results_(0) {}
15  PredictCallback(const PredictCallback &callback)
16      : key_ids_(callback.key_ids_), keys_(callback.keys_),
17        max_num_results_(callback.max_num_results_),
18        num_results_(callback.num_results_) {}
19
20  bool operator()(marisa::UInt32 key_id, const std::string &key) {
21    if (key_ids_.is_valid()) {
22      key_ids_.insert(num_results_, key_id);
23    }
24    if (keys_.is_valid()) {
25      keys_.insert(num_results_, key);
26    }
27    return ++num_results_ < max_num_results_;
28  }
29
30 private:
31  T key_ids_;
32  U keys_;
33  const std::size_t max_num_results_;
34  std::size_t num_results_;
35
36  // Disallows assignment.
37  PredictCallback &operator=(const PredictCallback &);
38};
39
40}  // namespace
41
42std::string Trie::restore(UInt32 key_id) const {
43  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
44  MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
45  std::string key;
46  restore_(key_id, &key);
47  return key;
48}
49
50void Trie::restore(UInt32 key_id, std::string *key) const {
51  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
52  MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
53  MARISA_THROW_IF(key == NULL, MARISA_PARAM_ERROR);
54  restore_(key_id, key);
55}
56
57std::size_t Trie::restore(UInt32 key_id, char *key_buf,
58    std::size_t key_buf_size) const {
59  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
60  MARISA_THROW_IF(key_id >= num_keys_, MARISA_PARAM_ERROR);
61  MARISA_THROW_IF((key_buf == NULL) && (key_buf_size != 0),
62      MARISA_PARAM_ERROR);
63  return restore_(key_id, key_buf, key_buf_size);
64}
65
66UInt32 Trie::lookup(const char *str) const {
67  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
68  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
69  return lookup_<CQuery>(CQuery(str));
70}
71
72UInt32 Trie::lookup(const char *ptr, std::size_t length) const {
73  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
74  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
75  return lookup_<const Query &>(Query(ptr, length));
76}
77
78std::size_t Trie::find(const char *str,
79    UInt32 *key_ids, std::size_t *key_lengths,
80    std::size_t max_num_results) const {
81  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
82  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
83  return find_<CQuery>(CQuery(str),
84      MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
85}
86
87std::size_t Trie::find(const char *ptr, std::size_t length,
88    UInt32 *key_ids, std::size_t *key_lengths,
89    std::size_t max_num_results) const {
90  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
91  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
92  return find_<const Query &>(Query(ptr, length),
93      MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
94}
95
96std::size_t Trie::find(const char *str,
97    std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
98    std::size_t max_num_results) const {
99  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
100  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
101  return find_<CQuery>(CQuery(str),
102      MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
103}
104
105std::size_t Trie::find(const char *ptr, std::size_t length,
106    std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
107    std::size_t max_num_results) const {
108  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
109  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
110  return find_<const Query &>(Query(ptr, length),
111      MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
112}
113
114UInt32 Trie::find_first(const char *str,
115    std::size_t *key_length) const {
116  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
117  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
118  return find_first_<CQuery>(CQuery(str), key_length);
119}
120
121UInt32 Trie::find_first(const char *ptr, std::size_t length,
122    std::size_t *key_length) const {
123  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
124  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
125  return find_first_<const Query &>(Query(ptr, length), key_length);
126}
127
128UInt32 Trie::find_last(const char *str,
129    std::size_t *key_length) const {
130  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
131  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
132  return find_last_<CQuery>(CQuery(str), key_length);
133}
134
135UInt32 Trie::find_last(const char *ptr, std::size_t length,
136    std::size_t *key_length) const {
137  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
138  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
139  return find_last_<const Query &>(Query(ptr, length), key_length);
140}
141
142std::size_t Trie::predict(const char *str,
143    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
144  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
145  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
146  return (keys == NULL) ?
147      predict_breadth_first(str, key_ids, keys, max_num_results) :
148      predict_depth_first(str, key_ids, keys, max_num_results);
149}
150
151std::size_t Trie::predict(const char *ptr, std::size_t length,
152    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
153  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
154  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
155  return (keys == NULL) ?
156      predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
157      predict_depth_first(ptr, length, key_ids, keys, max_num_results);
158}
159
160std::size_t Trie::predict(const char *str,
161    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
162    std::size_t max_num_results) const {
163  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
164  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
165  return (keys == NULL) ?
166      predict_breadth_first(str, key_ids, keys, max_num_results) :
167      predict_depth_first(str, key_ids, keys, max_num_results);
168}
169
170std::size_t Trie::predict(const char *ptr, std::size_t length,
171    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
172    std::size_t max_num_results) const {
173  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
174  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
175  return (keys == NULL) ?
176      predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
177      predict_depth_first(ptr, length, key_ids, keys, max_num_results);
178}
179
180std::size_t Trie::predict_breadth_first(const char *str,
181    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
182  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
183  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
184  return predict_breadth_first_<CQuery>(CQuery(str),
185      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
186}
187
188std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
189    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
190  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
191  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
192  return predict_breadth_first_<const Query &>(Query(ptr, length),
193      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
194}
195
196std::size_t Trie::predict_breadth_first(const char *str,
197    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
198    std::size_t max_num_results) const {
199  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
200  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
201  return predict_breadth_first_<CQuery>(CQuery(str),
202      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
203}
204
205std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
206    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
207    std::size_t max_num_results) const {
208  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
209  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
210  return predict_breadth_first_<const Query &>(Query(ptr, length),
211      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
212}
213
214std::size_t Trie::predict_depth_first(const char *str,
215    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
216  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
217  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
218  return predict_depth_first_<CQuery>(CQuery(str),
219      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
220}
221
222std::size_t Trie::predict_depth_first(const char *ptr, std::size_t length,
223    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
224  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
225  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
226  return predict_depth_first_<const Query &>(Query(ptr, length),
227      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
228}
229
230std::size_t Trie::predict_depth_first(
231    const char *str, std::vector<UInt32> *key_ids,
232    std::vector<std::string> *keys, std::size_t max_num_results) const {
233  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
234  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
235  return predict_depth_first_<CQuery>(CQuery(str),
236      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
237}
238
239std::size_t Trie::predict_depth_first(
240    const char *ptr, std::size_t length, std::vector<UInt32> *key_ids,
241    std::vector<std::string> *keys, std::size_t max_num_results) const {
242  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
243  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
244  return predict_depth_first_<const Query &>(Query(ptr, length),
245      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
246}
247
248void Trie::restore_(UInt32 key_id, std::string *key) const {
249  const std::size_t start_pos = key->length();
250  UInt32 node = key_id_to_node(key_id);
251  while (node != 0) {
252    if (has_link(node)) {
253      const std::size_t prev_pos = key->length();
254      if (has_trie()) {
255        trie_->trie_restore(get_link(node), key);
256      } else {
257        tail_restore(node, key);
258      }
259      std::reverse(key->begin() + prev_pos, key->end());
260    } else {
261      *key += labels_[node];
262    }
263    node = get_parent(node);
264  }
265  std::reverse(key->begin() + start_pos, key->end());
266}
267
268void Trie::trie_restore(UInt32 node, std::string *key) const {
269  do {
270    if (has_link(node)) {
271      if (has_trie()) {
272        trie_->trie_restore(get_link(node), key);
273      } else {
274        tail_restore(node, key);
275      }
276    } else {
277      *key += labels_[node];
278    }
279    node = get_parent(node);
280  } while (node != 0);
281}
282
283void Trie::tail_restore(UInt32 node, std::string *key) const {
284  const UInt32 link_id = link_flags_.rank1(node);
285  const UInt32 offset = (links_[link_id] * 256) + labels_[node];
286  if (tail_.mode() == MARISA_BINARY_TAIL) {
287    const UInt32 length = (links_[link_id + 1] * 256)
288        + labels_[link_flags_.select1(link_id + 1)] - offset;
289    key->append(reinterpret_cast<const char *>(tail_[offset]), length);
290  } else {
291    key->append(reinterpret_cast<const char *>(tail_[offset]));
292  }
293}
294
295std::size_t Trie::restore_(UInt32 key_id, char *key_buf,
296    std::size_t key_buf_size) const {
297  std::size_t pos = 0;
298  UInt32 node = key_id_to_node(key_id);
299  while (node != 0) {
300    if (has_link(node)) {
301      const std::size_t prev_pos = pos;
302      if (has_trie()) {
303        trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
304      } else {
305        tail_restore(node, key_buf, key_buf_size, pos);
306      }
307      if (pos < key_buf_size) {
308        std::reverse(key_buf + prev_pos, key_buf + pos);
309      }
310    } else {
311      if (pos < key_buf_size) {
312        key_buf[pos] = labels_[node];
313      }
314      ++pos;
315    }
316    node = get_parent(node);
317  }
318  if (pos < key_buf_size) {
319    key_buf[pos] = '\0';
320    std::reverse(key_buf, key_buf + pos);
321  }
322  return pos;
323}
324
325void Trie::trie_restore(UInt32 node, char *key_buf,
326    std::size_t key_buf_size, std::size_t &pos) const {
327  do {
328    if (has_link(node)) {
329      if (has_trie()) {
330        trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
331      } else {
332        tail_restore(node, key_buf, key_buf_size, pos);
333      }
334    } else {
335      if (pos < key_buf_size) {
336        key_buf[pos] = labels_[node];
337      }
338      ++pos;
339    }
340    node = get_parent(node);
341  } while (node != 0);
342}
343
344void Trie::tail_restore(UInt32 node, char *key_buf,
345    std::size_t key_buf_size, std::size_t &pos) const {
346  const UInt32 link_id = link_flags_.rank1(node);
347  const UInt32 offset = (links_[link_id] * 256) + labels_[node];
348  if (tail_.mode() == MARISA_BINARY_TAIL) {
349    const UInt8 *ptr = tail_[offset];
350    const UInt32 length = (links_[link_id + 1] * 256)
351        + labels_[link_flags_.select1(link_id + 1)] - offset;
352    for (UInt32 i = 0; i < length; ++i) {
353      if (pos < key_buf_size) {
354        key_buf[pos] = ptr[i];
355      }
356      ++pos;
357    }
358  } else {
359    for (const UInt8 *str = tail_[offset]; *str != '\0'; ++str) {
360      if (pos < key_buf_size) {
361        key_buf[pos] = *str;
362      }
363      ++pos;
364    }
365  }
366}
367
368template <typename T>
369UInt32 Trie::lookup_(T query) const {
370  UInt32 node = 0;
371  std::size_t pos = 0;
372  while (!query.ends_at(pos)) {
373    if (!find_child<T>(node, query, pos)) {
374      return notfound();
375    }
376  }
377  return terminal_flags_[node] ? node_to_key_id(node) : notfound();
378}
379
380template <typename T>
381std::size_t Trie::trie_match(UInt32 node, T query,
382    std::size_t pos) const {
383  if (has_link(node)) {
384    std::size_t next_pos;
385    if (has_trie()) {
386      next_pos = trie_->trie_match<T>(get_link(node), query, pos);
387    } else {
388      next_pos = tail_match<T>(node, get_link_id(node), query, pos);
389    }
390    if ((next_pos == mismatch()) || (next_pos == pos)) {
391      return next_pos;
392    }
393    pos = next_pos;
394  } else if (labels_[node] != query[pos]) {
395    return pos;
396  } else {
397    ++pos;
398  }
399  node = get_parent(node);
400  while (node != 0) {
401    if (query.ends_at(pos)) {
402      return mismatch();
403    }
404    if (has_link(node)) {
405      std::size_t next_pos;
406      if (has_trie()) {
407        next_pos = trie_->trie_match<T>(get_link(node), query, pos);
408      } else {
409        next_pos = tail_match<T>(node, get_link_id(node), query, pos);
410      }
411      if ((next_pos == mismatch()) || (next_pos == pos)) {
412        return mismatch();
413      }
414      pos = next_pos;
415    } else if (labels_[node] != query[pos]) {
416      return mismatch();
417    } else {
418      ++pos;
419    }
420    node = get_parent(node);
421  }
422  return pos;
423}
424
425template std::size_t Trie::trie_match<CQuery>(UInt32 node,
426    CQuery query, std::size_t pos) const;
427template std::size_t Trie::trie_match<const Query &>(UInt32 node,
428    const Query &query, std::size_t pos) const;
429
430template <typename T>
431std::size_t Trie::tail_match(UInt32 node, UInt32 link_id,
432    T query, std::size_t pos) const {
433  const UInt32 offset = (links_[link_id] * 256) + labels_[node];
434  const UInt8 *ptr = tail_[offset];
435  if (*ptr != query[pos]) {
436    return pos;
437  } else if (tail_.mode() == MARISA_BINARY_TAIL) {
438    const UInt32 length = (links_[link_id + 1] * 256)
439        + labels_[link_flags_.select1(link_id + 1)] - offset;
440    for (UInt32 i = 1; i < length; ++i) {
441      if (query.ends_at(pos + i) || (ptr[i] != query[pos + i])) {
442        return mismatch();
443      }
444    }
445    return pos + length;
446  } else {
447    for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
448      if (query.ends_at(pos) || (*ptr != query[pos])) {
449        return mismatch();
450      }
451    }
452    return pos;
453  }
454}
455
456template std::size_t Trie::tail_match<CQuery>(UInt32 node,
457    UInt32 link_id, CQuery query, std::size_t pos) const;
458template std::size_t Trie::tail_match<const Query &>(UInt32 node,
459    UInt32 link_id, const Query &query, std::size_t pos) const;
460
461template <typename T, typename U, typename V>
462std::size_t Trie::find_(T query, U key_ids, V key_lengths,
463    std::size_t max_num_results) const {
464  if (max_num_results == 0) {
465    return 0;
466  }
467  std::size_t count = 0;
468  UInt32 node = 0;
469  std::size_t pos = 0;
470  do {
471    if (terminal_flags_[node]) {
472      if (key_ids.is_valid()) {
473        key_ids.insert(count, node_to_key_id(node));
474      }
475      if (key_lengths.is_valid()) {
476        key_lengths.insert(count, pos);
477      }
478      if (++count >= max_num_results) {
479        return count;
480      }
481    }
482  } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
483  return count;
484}
485
486template <typename T>
487UInt32 Trie::find_first_(T query, std::size_t *key_length) const {
488  UInt32 node = 0;
489  std::size_t pos = 0;
490  do {
491    if (terminal_flags_[node]) {
492      if (key_length != NULL) {
493        *key_length = pos;
494      }
495      return node_to_key_id(node);
496    }
497  } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
498  return notfound();
499}
500
501template <typename T>
502UInt32 Trie::find_last_(T query, std::size_t *key_length) const {
503  UInt32 node = 0;
504  UInt32 node_found = notfound();
505  std::size_t pos = 0;
506  std::size_t pos_found = mismatch();
507  do {
508    if (terminal_flags_[node]) {
509      node_found = node;
510      pos_found = pos;
511    }
512  } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
513  if (node_found != notfound()) {
514    if (key_length != NULL) {
515      *key_length = pos_found;
516    }
517    return node_to_key_id(node_found);
518  }
519  return notfound();
520}
521
522template <typename T, typename U, typename V>
523std::size_t Trie::predict_breadth_first_(T query, U key_ids, V keys,
524    std::size_t max_num_results) const {
525  if (max_num_results == 0) {
526    return 0;
527  }
528  UInt32 node = 0;
529  std::size_t pos = 0;
530  while (!query.ends_at(pos)) {
531    if (!predict_child<T>(node, query, pos, NULL)) {
532      return 0;
533    }
534  }
535  std::string key;
536  std::size_t count = 0;
537  if (terminal_flags_[node]) {
538    const UInt32 key_id = node_to_key_id(node);
539    if (key_ids.is_valid()) {
540      key_ids.insert(count, key_id);
541    }
542    if (keys.is_valid()) {
543      restore(key_id, &key);
544      keys.insert(count, key);
545    }
546    if (++count >= max_num_results) {
547      return count;
548    }
549  }
550  const UInt32 louds_pos = get_child(node);
551  if (!louds_[louds_pos]) {
552    return count;
553  }
554  UInt32 node_begin = louds_pos_to_node(louds_pos, node);
555  UInt32 node_end = louds_pos_to_node(get_child(node + 1), node + 1);
556  while (node_begin < node_end) {
557    const UInt32 key_id_begin = node_to_key_id(node_begin);
558    const UInt32 key_id_end = node_to_key_id(node_end);
559    if (key_ids.is_valid()) {
560      UInt32 temp_count = count;
561      for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
562        key_ids.insert(temp_count, key_id);
563        if (++temp_count >= max_num_results) {
564          break;
565        }
566      }
567    }
568    if (keys.is_valid()) {
569      UInt32 temp_count = count;
570      for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
571        key.clear();
572        restore(key_id, &key);
573        keys.insert(temp_count, key);
574        if (++temp_count >= max_num_results) {
575          break;
576        }
577      }
578    }
579    count += key_id_end - key_id_begin;
580    if (count >= max_num_results) {
581      return max_num_results;
582    }
583    node_begin = louds_pos_to_node(get_child(node_begin), node_begin);
584    node_end = louds_pos_to_node(get_child(node_end), node_end);
585  }
586  return count;
587}
588
589template <typename T, typename U, typename V>
590std::size_t Trie::predict_depth_first_(T query, U key_ids, V keys,
591    std::size_t max_num_results) const {
592  if (max_num_results == 0) {
593    return 0;
594  } else if (keys.is_valid()) {
595    PredictCallback<U, V> callback(key_ids, keys, max_num_results);
596    return predict_callback_(query, callback);
597  }
598
599  UInt32 node = 0;
600  std::size_t pos = 0;
601  while (!query.ends_at(pos)) {
602    if (!predict_child<T>(node, query, pos, NULL)) {
603      return 0;
604    }
605  }
606  std::size_t count = 0;
607  if (terminal_flags_[node]) {
608    if (key_ids.is_valid()) {
609      key_ids.insert(count, node_to_key_id(node));
610    }
611    if (++count >= max_num_results) {
612      return count;
613    }
614  }
615  Cell cell;
616  cell.set_louds_pos(get_child(node));
617  if (!louds_[cell.louds_pos()]) {
618    return count;
619  }
620  cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
621  cell.set_key_id(node_to_key_id(cell.node()));
622  Vector<Cell> stack;
623  stack.push_back(cell);
624  std::size_t stack_pos = 1;
625  while (stack_pos != 0) {
626    Cell &cur = stack[stack_pos - 1];
627    if (!louds_[cur.louds_pos()]) {
628      cur.set_louds_pos(cur.louds_pos() + 1);
629      --stack_pos;
630      continue;
631    }
632    cur.set_louds_pos(cur.louds_pos() + 1);
633    if (terminal_flags_[cur.node()]) {
634      if (key_ids.is_valid()) {
635        key_ids.insert(count, cur.key_id());
636      }
637      if (++count >= max_num_results) {
638        return count;
639      }
640      cur.set_key_id(cur.key_id() + 1);
641    }
642    if (stack_pos == stack.size()) {
643      cell.set_louds_pos(get_child(cur.node()));
644      cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
645      cell.set_key_id(node_to_key_id(cell.node()));
646      stack.push_back(cell);
647    }
648    stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
649    ++stack_pos;
650  }
651  return count;
652}
653
654template <typename T>
655std::size_t Trie::trie_prefix_match(UInt32 node, T query,
656    std::size_t pos, std::string *key) const {
657  if (has_link(node)) {
658    std::size_t next_pos;
659    if (has_trie()) {
660      next_pos = trie_->trie_prefix_match<T>(get_link(node), query, pos, key);
661    } else {
662      next_pos = tail_prefix_match<T>(
663          node, get_link_id(node), query, pos, key);
664    }
665    if ((next_pos == mismatch()) || (next_pos == pos)) {
666      return next_pos;
667    }
668    pos = next_pos;
669  } else if (labels_[node] != query[pos]) {
670    return pos;
671  } else {
672    ++pos;
673  }
674  node = get_parent(node);
675  while (node != 0) {
676    if (query.ends_at(pos)) {
677      if (key != NULL) {
678        trie_restore(node, key);
679      }
680      return pos;
681    }
682    if (has_link(node)) {
683      std::size_t next_pos;
684      if (has_trie()) {
685        next_pos = trie_->trie_prefix_match<T>(
686            get_link(node), query, pos, key);
687      } else {
688        next_pos = tail_prefix_match<T>(
689            node, get_link_id(node), query, pos, key);
690      }
691      if ((next_pos == mismatch()) || (next_pos == pos)) {
692        return next_pos;
693      }
694      pos = next_pos;
695    } else if (labels_[node] != query[pos]) {
696      return mismatch();
697    } else {
698      ++pos;
699    }
700    node = get_parent(node);
701  }
702  return pos;
703}
704
705template std::size_t Trie::trie_prefix_match<CQuery>(UInt32 node,
706    CQuery query, std::size_t pos, std::string *key) const;
707template std::size_t Trie::trie_prefix_match<const Query &>(UInt32 node,
708    const Query &query, std::size_t pos, std::string *key) const;
709
710template <typename T>
711std::size_t Trie::tail_prefix_match(UInt32 node, UInt32 link_id,
712    T query, std::size_t pos, std::string *key) const {
713  const UInt32 offset = (links_[link_id] * 256) + labels_[node];
714  const UInt8 *ptr = tail_[offset];
715  if (*ptr != query[pos]) {
716    return pos;
717  } else if (tail_.mode() == MARISA_BINARY_TAIL) {
718    const UInt32 length = (links_[link_id + 1] * 256)
719        + labels_[link_flags_.select1(link_id + 1)] - offset;
720    for (UInt32 i = 1; i < length; ++i) {
721      if (query.ends_at(pos + i)) {
722        if (key != NULL) {
723          key->append(reinterpret_cast<const char *>(ptr + i), length - i);
724        }
725        return pos + i;
726      } else if (ptr[i] != query[pos + i]) {
727        return mismatch();
728      }
729    }
730    return pos + length;
731  } else {
732    for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
733      if (query.ends_at(pos)) {
734        if (key != NULL) {
735          key->append(reinterpret_cast<const char *>(ptr));
736        }
737        return pos;
738      } else if (*ptr != query[pos]) {
739        return mismatch();
740      }
741    }
742    return pos;
743  }
744}
745
746template std::size_t Trie::tail_prefix_match<CQuery>(
747    UInt32 node, UInt32 link_id,
748    CQuery query, std::size_t pos, std::string *key) const;
749template std::size_t Trie::tail_prefix_match<const Query &>(
750    UInt32 node, UInt32 link_id,
751    const Query &query, std::size_t pos, std::string *key) const;
752
753}  // namespace marisa
754