1#include <algorithm>
2#include <stdexcept>
3
4#include "trie.h"
5
6namespace marisa_alpha {
7namespace {
8
9template <typename T, typename U>
10class PredictCallback {
11 public:
12  PredictCallback(T key_ids, U keys, std::size_t max_num_results)
13      : key_ids_(key_ids), keys_(keys),
14        max_num_results_(max_num_results), num_results_(0) {}
15  PredictCallback(const PredictCallback &callback)
16      : key_ids_(callback.key_ids_), keys_(callback.keys_),
17        max_num_results_(callback.max_num_results_),
18        num_results_(callback.num_results_) {}
19
20  bool operator()(marisa_alpha::UInt32 key_id, const std::string &key) {
21    if (key_ids_.is_valid()) {
22      key_ids_.insert(num_results_, key_id);
23    }
24    if (keys_.is_valid()) {
25      keys_.insert(num_results_, key);
26    }
27    return ++num_results_ < max_num_results_;
28  }
29
30 private:
31  T key_ids_;
32  U keys_;
33  const std::size_t max_num_results_;
34  std::size_t num_results_;
35
36  // Disallows assignment.
37  PredictCallback &operator=(const PredictCallback &);
38};
39
40}  // namespace
41
42std::string Trie::restore(UInt32 key_id) const {
43  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
44  MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR);
45  std::string key;
46  restore_(key_id, &key);
47  return key;
48}
49
50void Trie::restore(UInt32 key_id, std::string *key) const {
51  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
52  MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR);
53  MARISA_ALPHA_THROW_IF(key == NULL, MARISA_ALPHA_PARAM_ERROR);
54  restore_(key_id, key);
55}
56
57std::size_t Trie::restore(UInt32 key_id, char *key_buf,
58    std::size_t key_buf_size) const {
59  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
60  MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR);
61  MARISA_ALPHA_THROW_IF((key_buf == NULL) && (key_buf_size != 0),
62      MARISA_ALPHA_PARAM_ERROR);
63  return restore_(key_id, key_buf, key_buf_size);
64}
65
66UInt32 Trie::lookup(const char *str) const {
67  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
68  MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
69  return lookup_<CQuery>(CQuery(str));
70}
71
72UInt32 Trie::lookup(const char *ptr, std::size_t length) const {
73  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
74  MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
75      MARISA_ALPHA_PARAM_ERROR);
76  return lookup_<const Query &>(Query(ptr, length));
77}
78
79std::size_t Trie::find(const char *str,
80    UInt32 *key_ids, std::size_t *key_lengths,
81    std::size_t max_num_results) const {
82  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
83  MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
84  return find_<CQuery>(CQuery(str),
85      MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
86}
87
88std::size_t Trie::find(const char *ptr, std::size_t length,
89    UInt32 *key_ids, std::size_t *key_lengths,
90    std::size_t max_num_results) const {
91  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
92  MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
93      MARISA_ALPHA_PARAM_ERROR);
94  return find_<const Query &>(Query(ptr, length),
95      MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
96}
97
98std::size_t Trie::find(const char *str,
99    std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
100    std::size_t max_num_results) const {
101  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
102  MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
103  return find_<CQuery>(CQuery(str),
104      MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
105}
106
107std::size_t Trie::find(const char *ptr, std::size_t length,
108    std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
109    std::size_t max_num_results) const {
110  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
111  MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
112      MARISA_ALPHA_PARAM_ERROR);
113  return find_<const Query &>(Query(ptr, length),
114      MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
115}
116
117UInt32 Trie::find_first(const char *str,
118    std::size_t *key_length) const {
119  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
120  MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
121  return find_first_<CQuery>(CQuery(str), key_length);
122}
123
124UInt32 Trie::find_first(const char *ptr, std::size_t length,
125    std::size_t *key_length) const {
126  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
127  MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
128      MARISA_ALPHA_PARAM_ERROR);
129  return find_first_<const Query &>(Query(ptr, length), key_length);
130}
131
132UInt32 Trie::find_last(const char *str,
133    std::size_t *key_length) const {
134  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
135  MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
136  return find_last_<CQuery>(CQuery(str), key_length);
137}
138
139UInt32 Trie::find_last(const char *ptr, std::size_t length,
140    std::size_t *key_length) const {
141  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
142  MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
143      MARISA_ALPHA_PARAM_ERROR);
144  return find_last_<const Query &>(Query(ptr, length), key_length);
145}
146
147std::size_t Trie::predict(const char *str,
148    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
149  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
150  MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
151  return (keys == NULL) ?
152      predict_breadth_first(str, key_ids, keys, max_num_results) :
153      predict_depth_first(str, key_ids, keys, max_num_results);
154}
155
156std::size_t Trie::predict(const char *ptr, std::size_t length,
157    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
158  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
159  MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
160      MARISA_ALPHA_PARAM_ERROR);
161  return (keys == NULL) ?
162      predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
163      predict_depth_first(ptr, length, key_ids, keys, max_num_results);
164}
165
166std::size_t Trie::predict(const char *str,
167    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
168    std::size_t max_num_results) const {
169  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
170  MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
171  return (keys == NULL) ?
172      predict_breadth_first(str, key_ids, keys, max_num_results) :
173      predict_depth_first(str, key_ids, keys, max_num_results);
174}
175
176std::size_t Trie::predict(const char *ptr, std::size_t length,
177    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
178    std::size_t max_num_results) const {
179  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
180  MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
181      MARISA_ALPHA_PARAM_ERROR);
182  return (keys == NULL) ?
183      predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
184      predict_depth_first(ptr, length, key_ids, keys, max_num_results);
185}
186
187std::size_t Trie::predict_breadth_first(const char *str,
188    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
189  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
190  MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
191  return predict_breadth_first_<CQuery>(CQuery(str),
192      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
193}
194
195std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
196    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
197  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
198  MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
199      MARISA_ALPHA_PARAM_ERROR);
200  return predict_breadth_first_<const Query &>(Query(ptr, length),
201      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
202}
203
204std::size_t Trie::predict_breadth_first(const char *str,
205    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
206    std::size_t max_num_results) const {
207  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
208  MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
209  return predict_breadth_first_<CQuery>(CQuery(str),
210      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
211}
212
213std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
214    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
215    std::size_t max_num_results) const {
216  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
217  MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
218      MARISA_ALPHA_PARAM_ERROR);
219  return predict_breadth_first_<const Query &>(Query(ptr, length),
220      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
221}
222
223std::size_t Trie::predict_depth_first(const char *str,
224    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
225  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
226  MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
227  return predict_depth_first_<CQuery>(CQuery(str),
228      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
229}
230
231std::size_t Trie::predict_depth_first(const char *ptr, std::size_t length,
232    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
233  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
234  MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
235      MARISA_ALPHA_PARAM_ERROR);
236  return predict_depth_first_<const Query &>(Query(ptr, length),
237      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
238}
239
240std::size_t Trie::predict_depth_first(
241    const char *str, std::vector<UInt32> *key_ids,
242    std::vector<std::string> *keys, std::size_t max_num_results) const {
243  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
244  MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
245  return predict_depth_first_<CQuery>(CQuery(str),
246      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
247}
248
249std::size_t Trie::predict_depth_first(
250    const char *ptr, std::size_t length, std::vector<UInt32> *key_ids,
251    std::vector<std::string> *keys, std::size_t max_num_results) const {
252  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
253  MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
254      MARISA_ALPHA_PARAM_ERROR);
255  return predict_depth_first_<const Query &>(Query(ptr, length),
256      MakeContainer(key_ids), MakeContainer(keys), max_num_results);
257}
258
259void Trie::restore_(UInt32 key_id, std::string *key) const {
260  const std::size_t start_pos = key->length();
261  try {
262    UInt32 node = key_id_to_node(key_id);
263    while (node != 0) {
264      if (has_link(node)) {
265        const std::size_t prev_pos = key->length();
266        if (has_trie()) {
267          trie_->trie_restore(get_link(node), key);
268        } else {
269          tail_restore(node, key);
270        }
271        std::reverse(key->begin() + prev_pos, key->end());
272      } else {
273        *key += labels_[node];
274      }
275      node = get_parent(node);
276    }
277    std::reverse(key->begin() + start_pos, key->end());
278  } catch (const std::bad_alloc &) {
279    key->resize(start_pos);
280    MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
281  } catch (const std::length_error &) {
282    key->resize(start_pos);
283    MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
284  }
285}
286
287void Trie::trie_restore(UInt32 node, std::string *key) const {
288  do {
289    if (has_link(node)) {
290      if (has_trie()) {
291        trie_->trie_restore(get_link(node), key);
292      } else {
293        tail_restore(node, key);
294      }
295    } else {
296      *key += labels_[node];
297    }
298    node = get_parent(node);
299  } while (node != 0);
300}
301
302void Trie::tail_restore(UInt32 node, std::string *key) const {
303  const UInt32 link_id = link_flags_.rank1(node);
304  const UInt32 offset = (links_[link_id] * 256) + labels_[node];
305  if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
306    const UInt32 length = (links_[link_id + 1] * 256)
307        + labels_[link_flags_.select1(link_id + 1)] - offset;
308    key->append(reinterpret_cast<const char *>(tail_[offset]), length);
309  } else {
310    key->append(reinterpret_cast<const char *>(tail_[offset]));
311  }
312}
313
314std::size_t Trie::restore_(UInt32 key_id, char *key_buf,
315    std::size_t key_buf_size) const {
316  std::size_t pos = 0;
317  UInt32 node = key_id_to_node(key_id);
318  while (node != 0) {
319    if (has_link(node)) {
320      const std::size_t prev_pos = pos;
321      if (has_trie()) {
322        trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
323      } else {
324        tail_restore(node, key_buf, key_buf_size, pos);
325      }
326      if (pos < key_buf_size) {
327        std::reverse(key_buf + prev_pos, key_buf + pos);
328      }
329    } else {
330      if (pos < key_buf_size) {
331        key_buf[pos] = labels_[node];
332      }
333      ++pos;
334    }
335    node = get_parent(node);
336  }
337  if (pos < key_buf_size) {
338    key_buf[pos] = '\0';
339    std::reverse(key_buf, key_buf + pos);
340  }
341  return pos;
342}
343
344void Trie::trie_restore(UInt32 node, char *key_buf,
345    std::size_t key_buf_size, std::size_t &pos) const {
346  do {
347    if (has_link(node)) {
348      if (has_trie()) {
349        trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
350      } else {
351        tail_restore(node, key_buf, key_buf_size, pos);
352      }
353    } else {
354      if (pos < key_buf_size) {
355        key_buf[pos] = labels_[node];
356      }
357      ++pos;
358    }
359    node = get_parent(node);
360  } while (node != 0);
361}
362
363void Trie::tail_restore(UInt32 node, char *key_buf,
364    std::size_t key_buf_size, std::size_t &pos) const {
365  const UInt32 link_id = link_flags_.rank1(node);
366  const UInt32 offset = (links_[link_id] * 256) + labels_[node];
367  if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
368    const UInt8 *ptr = tail_[offset];
369    const UInt32 length = (links_[link_id + 1] * 256)
370        + labels_[link_flags_.select1(link_id + 1)] - offset;
371    for (UInt32 i = 0; i < length; ++i) {
372      if (pos < key_buf_size) {
373        key_buf[pos] = ptr[i];
374      }
375      ++pos;
376    }
377  } else {
378    for (const UInt8 *str = tail_[offset]; *str != '\0'; ++str) {
379      if (pos < key_buf_size) {
380        key_buf[pos] = *str;
381      }
382      ++pos;
383    }
384  }
385}
386
387template <typename T>
388UInt32 Trie::lookup_(T query) const {
389  UInt32 node = 0;
390  std::size_t pos = 0;
391  while (!query.ends_at(pos)) {
392    if (!find_child<T>(node, query, pos)) {
393      return notfound();
394    }
395  }
396  return terminal_flags_[node] ? node_to_key_id(node) : notfound();
397}
398
399template <typename T>
400std::size_t Trie::trie_match(UInt32 node, T query,
401    std::size_t pos) const {
402  if (has_link(node)) {
403    std::size_t next_pos;
404    if (has_trie()) {
405      next_pos = trie_->trie_match<T>(get_link(node), query, pos);
406    } else {
407      next_pos = tail_match<T>(node, get_link_id(node), query, pos);
408    }
409    if ((next_pos == mismatch()) || (next_pos == pos)) {
410      return next_pos;
411    }
412    pos = next_pos;
413  } else if (labels_[node] != query[pos]) {
414    return pos;
415  } else {
416    ++pos;
417  }
418  node = get_parent(node);
419  while (node != 0) {
420    if (query.ends_at(pos)) {
421      return mismatch();
422    }
423    if (has_link(node)) {
424      std::size_t next_pos;
425      if (has_trie()) {
426        next_pos = trie_->trie_match<T>(get_link(node), query, pos);
427      } else {
428        next_pos = tail_match<T>(node, get_link_id(node), query, pos);
429      }
430      if ((next_pos == mismatch()) || (next_pos == pos)) {
431        return mismatch();
432      }
433      pos = next_pos;
434    } else if (labels_[node] != query[pos]) {
435      return mismatch();
436    } else {
437      ++pos;
438    }
439    node = get_parent(node);
440  }
441  return pos;
442}
443
444template std::size_t Trie::trie_match<CQuery>(UInt32 node,
445    CQuery query, std::size_t pos) const;
446template std::size_t Trie::trie_match<const Query &>(UInt32 node,
447    const Query &query, std::size_t pos) const;
448
449template <typename T>
450std::size_t Trie::tail_match(UInt32 node, UInt32 link_id,
451    T query, std::size_t pos) const {
452  const UInt32 offset = (links_[link_id] * 256) + labels_[node];
453  const UInt8 *ptr = tail_[offset];
454  if (*ptr != query[pos]) {
455    return pos;
456  } else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
457    const UInt32 length = (links_[link_id + 1] * 256)
458        + labels_[link_flags_.select1(link_id + 1)] - offset;
459    for (UInt32 i = 1; i < length; ++i) {
460      if (query.ends_at(pos + i) || (ptr[i] != query[pos + i])) {
461        return mismatch();
462      }
463    }
464    return pos + length;
465  } else {
466    for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
467      if (query.ends_at(pos) || (*ptr != query[pos])) {
468        return mismatch();
469      }
470    }
471    return pos;
472  }
473}
474
475template std::size_t Trie::tail_match<CQuery>(UInt32 node,
476    UInt32 link_id, CQuery query, std::size_t pos) const;
477template std::size_t Trie::tail_match<const Query &>(UInt32 node,
478    UInt32 link_id, const Query &query, std::size_t pos) const;
479
480template <typename T, typename U, typename V>
481std::size_t Trie::find_(T query, U key_ids, V key_lengths,
482    std::size_t max_num_results) const try {
483  if (max_num_results == 0) {
484    return 0;
485  }
486  std::size_t count = 0;
487  UInt32 node = 0;
488  std::size_t pos = 0;
489  do {
490    if (terminal_flags_[node]) {
491      if (key_ids.is_valid()) {
492        key_ids.insert(count, node_to_key_id(node));
493      }
494      if (key_lengths.is_valid()) {
495        key_lengths.insert(count, pos);
496      }
497      if (++count >= max_num_results) {
498        return count;
499      }
500    }
501  } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
502  return count;
503} catch (const std::bad_alloc &) {
504  MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
505} catch (const std::length_error &) {
506  MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
507}
508
509template <typename T>
510UInt32 Trie::find_first_(T query, std::size_t *key_length) const {
511  UInt32 node = 0;
512  std::size_t pos = 0;
513  do {
514    if (terminal_flags_[node]) {
515      if (key_length != NULL) {
516        *key_length = pos;
517      }
518      return node_to_key_id(node);
519    }
520  } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
521  return notfound();
522}
523
524template <typename T>
525UInt32 Trie::find_last_(T query, std::size_t *key_length) const {
526  UInt32 node = 0;
527  UInt32 node_found = notfound();
528  std::size_t pos = 0;
529  std::size_t pos_found = mismatch();
530  do {
531    if (terminal_flags_[node]) {
532      node_found = node;
533      pos_found = pos;
534    }
535  } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
536  if (node_found != notfound()) {
537    if (key_length != NULL) {
538      *key_length = pos_found;
539    }
540    return node_to_key_id(node_found);
541  }
542  return notfound();
543}
544
545template <typename T, typename U, typename V>
546std::size_t Trie::predict_breadth_first_(T query, U key_ids, V keys,
547    std::size_t max_num_results) const try {
548  if (max_num_results == 0) {
549    return 0;
550  }
551  UInt32 node = 0;
552  std::size_t pos = 0;
553  while (!query.ends_at(pos)) {
554    if (!predict_child<T>(node, query, pos, NULL)) {
555      return 0;
556    }
557  }
558  std::string key;
559  std::size_t count = 0;
560  if (terminal_flags_[node]) {
561    const UInt32 key_id = node_to_key_id(node);
562    if (key_ids.is_valid()) {
563      key_ids.insert(count, key_id);
564    }
565    if (keys.is_valid()) {
566      restore(key_id, &key);
567      keys.insert(count, key);
568    }
569    if (++count >= max_num_results) {
570      return count;
571    }
572  }
573  const UInt32 louds_pos = get_child(node);
574  if (!louds_[louds_pos]) {
575    return count;
576  }
577  UInt32 node_begin = louds_pos_to_node(louds_pos, node);
578  UInt32 node_end = louds_pos_to_node(get_child(node + 1), node + 1);
579  while (node_begin < node_end) {
580    const UInt32 key_id_begin = node_to_key_id(node_begin);
581    const UInt32 key_id_end = node_to_key_id(node_end);
582    if (key_ids.is_valid()) {
583      UInt32 temp_count = count;
584      for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
585        key_ids.insert(temp_count, key_id);
586        if (++temp_count >= max_num_results) {
587          break;
588        }
589      }
590    }
591    if (keys.is_valid()) {
592      UInt32 temp_count = count;
593      for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
594        key.clear();
595        restore(key_id, &key);
596        keys.insert(temp_count, key);
597        if (++temp_count >= max_num_results) {
598          break;
599        }
600      }
601    }
602    count += key_id_end - key_id_begin;
603    if (count >= max_num_results) {
604      return max_num_results;
605    }
606    node_begin = louds_pos_to_node(get_child(node_begin), node_begin);
607    node_end = louds_pos_to_node(get_child(node_end), node_end);
608  }
609  return count;
610} catch (const std::bad_alloc &) {
611  MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
612} catch (const std::length_error &) {
613  MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
614}
615
616template <typename T, typename U, typename V>
617std::size_t Trie::predict_depth_first_(T query, U key_ids, V keys,
618    std::size_t max_num_results) const try {
619  if (max_num_results == 0) {
620    return 0;
621  } else if (keys.is_valid()) {
622    PredictCallback<U, V> callback(key_ids, keys, max_num_results);
623    return predict_callback_(query, callback);
624  }
625
626  UInt32 node = 0;
627  std::size_t pos = 0;
628  while (!query.ends_at(pos)) {
629    if (!predict_child<T>(node, query, pos, NULL)) {
630      return 0;
631    }
632  }
633  std::size_t count = 0;
634  if (terminal_flags_[node]) {
635    if (key_ids.is_valid()) {
636      key_ids.insert(count, node_to_key_id(node));
637    }
638    if (++count >= max_num_results) {
639      return count;
640    }
641  }
642  Cell cell;
643  cell.set_louds_pos(get_child(node));
644  if (!louds_[cell.louds_pos()]) {
645    return count;
646  }
647  cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
648  cell.set_key_id(node_to_key_id(cell.node()));
649  Vector<Cell> stack;
650  stack.push_back(cell);
651  std::size_t stack_pos = 1;
652  while (stack_pos != 0) {
653    Cell &cur = stack[stack_pos - 1];
654    if (!louds_[cur.louds_pos()]) {
655      cur.set_louds_pos(cur.louds_pos() + 1);
656      --stack_pos;
657      continue;
658    }
659    cur.set_louds_pos(cur.louds_pos() + 1);
660    if (terminal_flags_[cur.node()]) {
661      if (key_ids.is_valid()) {
662        key_ids.insert(count, cur.key_id());
663      }
664      if (++count >= max_num_results) {
665        return count;
666      }
667      cur.set_key_id(cur.key_id() + 1);
668    }
669    if (stack_pos == stack.size()) {
670      cell.set_louds_pos(get_child(cur.node()));
671      cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
672      cell.set_key_id(node_to_key_id(cell.node()));
673      stack.push_back(cell);
674    }
675    stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
676    ++stack_pos;
677  }
678  return count;
679} catch (const std::bad_alloc &) {
680  MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
681} catch (const std::length_error &) {
682  MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
683}
684
685template <typename T>
686std::size_t Trie::trie_prefix_match(UInt32 node, T query,
687    std::size_t pos, std::string *key) const {
688  if (has_link(node)) {
689    std::size_t next_pos;
690    if (has_trie()) {
691      next_pos = trie_->trie_prefix_match<T>(get_link(node), query, pos, key);
692    } else {
693      next_pos = tail_prefix_match<T>(
694          node, get_link_id(node), query, pos, key);
695    }
696    if ((next_pos == mismatch()) || (next_pos == pos)) {
697      return next_pos;
698    }
699    pos = next_pos;
700  } else if (labels_[node] != query[pos]) {
701    return pos;
702  } else {
703    ++pos;
704  }
705  node = get_parent(node);
706  while (node != 0) {
707    if (query.ends_at(pos)) {
708      if (key != NULL) {
709        trie_restore(node, key);
710      }
711      return pos;
712    }
713    if (has_link(node)) {
714      std::size_t next_pos;
715      if (has_trie()) {
716        next_pos = trie_->trie_prefix_match<T>(
717            get_link(node), query, pos, key);
718      } else {
719        next_pos = tail_prefix_match<T>(
720            node, get_link_id(node), query, pos, key);
721      }
722      if ((next_pos == mismatch()) || (next_pos == pos)) {
723        return next_pos;
724      }
725      pos = next_pos;
726    } else if (labels_[node] != query[pos]) {
727      return mismatch();
728    } else {
729      ++pos;
730    }
731    node = get_parent(node);
732  }
733  return pos;
734}
735
736template std::size_t Trie::trie_prefix_match<CQuery>(UInt32 node,
737    CQuery query, std::size_t pos, std::string *key) const;
738template std::size_t Trie::trie_prefix_match<const Query &>(UInt32 node,
739    const Query &query, std::size_t pos, std::string *key) const;
740
741template <typename T>
742std::size_t Trie::tail_prefix_match(UInt32 node, UInt32 link_id,
743    T query, std::size_t pos, std::string *key) const {
744  const UInt32 offset = (links_[link_id] * 256) + labels_[node];
745  const UInt8 *ptr = tail_[offset];
746  if (*ptr != query[pos]) {
747    return pos;
748  } else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
749    const UInt32 length = (links_[link_id + 1] * 256)
750        + labels_[link_flags_.select1(link_id + 1)] - offset;
751    for (UInt32 i = 1; i < length; ++i) {
752      if (query.ends_at(pos + i)) {
753        if (key != NULL) {
754          key->append(reinterpret_cast<const char *>(ptr + i), length - i);
755        }
756        return pos + i;
757      } else if (ptr[i] != query[pos + i]) {
758        return mismatch();
759      }
760    }
761    return pos + length;
762  } else {
763    for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
764      if (query.ends_at(pos)) {
765        if (key != NULL) {
766          key->append(reinterpret_cast<const char *>(ptr));
767        }
768        return pos;
769      } else if (*ptr != query[pos]) {
770        return mismatch();
771      }
772    }
773    return pos;
774  }
775}
776
777template std::size_t Trie::tail_prefix_match<CQuery>(
778    UInt32 node, UInt32 link_id,
779    CQuery query, std::size_t pos, std::string *key) const;
780template std::size_t Trie::tail_prefix_match<const Query &>(
781    UInt32 node, UInt32 link_id,
782    const Query &query, std::size_t pos, std::string *key) const;
783
784}  // namespace marisa_alpha
785