1#ifndef MARISA_TRIE_INLINE_H_
2#define MARISA_TRIE_INLINE_H_
3
4#include <stdexcept>
5
6#include "cell.h"
7
8namespace marisa {
9
10inline std::string Trie::operator[](UInt32 key_id) const {
11  std::string key;
12  restore(key_id, &key);
13  return key;
14}
15
16inline UInt32 Trie::operator[](const char *str) const {
17  return lookup(str);
18}
19
20inline UInt32 Trie::operator[](const std::string &str) const {
21  return lookup(str);
22}
23
24inline UInt32 Trie::lookup(const std::string &str) const {
25  return lookup(str.c_str(), str.length());
26}
27
28inline std::size_t Trie::find(const std::string &str,
29    UInt32 *key_ids, std::size_t *key_lengths,
30    std::size_t max_num_results) const {
31  return find(str.c_str(), str.length(),
32      key_ids, key_lengths, max_num_results);
33}
34
35inline std::size_t Trie::find(const std::string &str,
36    std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
37    std::size_t max_num_results) const {
38  return find(str.c_str(), str.length(),
39      key_ids, key_lengths, max_num_results);
40}
41
42inline UInt32 Trie::find_first(const std::string &str,
43    std::size_t *key_length) const {
44  return find_first(str.c_str(), str.length(), key_length);
45}
46
47inline UInt32 Trie::find_last(const std::string &str,
48    std::size_t *key_length) const {
49  return find_last(str.c_str(), str.length(), key_length);
50}
51
52template <typename T>
53inline std::size_t Trie::find_callback(const char *str,
54    T callback) const {
55  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
56  MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
57  return find_callback_<CQuery>(CQuery(str), callback);
58}
59
60template <typename T>
61inline std::size_t Trie::find_callback(const char *ptr, std::size_t length,
62    T callback) const {
63  MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
64  MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
65  return find_callback_<const Query &>(Query(ptr, length), callback);
66}
67
68template <typename T>
69inline std::size_t Trie::find_callback(const std::string &str,
70    T callback) const {
71  return find_callback(str.c_str(), str.length(), callback);
72}
73
74inline std::size_t Trie::predict(const std::string &str,
75    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
76  return predict(str.c_str(), str.length(), key_ids, keys, max_num_results);
77}
78
79inline std::size_t Trie::predict(const std::string &str,
80    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
81    std::size_t max_num_results) const {
82  return predict(str.c_str(), str.length(), key_ids, keys, max_num_results);
83}
84
85inline std::size_t Trie::predict_breadth_first(const std::string &str,
86    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
87  return predict_breadth_first(str.c_str(), str.length(),
88      key_ids, keys, max_num_results);
89}
90
91inline std::size_t Trie::predict_breadth_first(const std::string &str,
92    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
93    std::size_t max_num_results) const {
94  return predict_breadth_first(str.c_str(), str.length(),
95      key_ids, keys, max_num_results);
96}
97
98inline std::size_t Trie::predict_depth_first(const std::string &str,
99    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
100  return predict_depth_first(str.c_str(), str.length(),
101      key_ids, keys, max_num_results);
102}
103
104inline std::size_t Trie::predict_depth_first(const std::string &str,
105    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
106    std::size_t max_num_results) const {
107  return predict_depth_first(str.c_str(), str.length(),
108      key_ids, keys, max_num_results);
109}
110
111template <typename T>
112inline std::size_t Trie::predict_callback(
113    const char *str, T callback) const {
114  return predict_callback_<CQuery>(CQuery(str), callback);
115}
116
117template <typename T>
118inline std::size_t Trie::predict_callback(
119    const char *ptr, std::size_t length,
120    T callback) const {
121  return predict_callback_<const Query &>(Query(ptr, length), callback);
122}
123
124template <typename T>
125inline std::size_t Trie::predict_callback(
126    const std::string &str, T callback) const {
127  return predict_callback(str.c_str(), str.length(), callback);
128}
129
130inline bool Trie::empty() const {
131  return louds_.empty();
132}
133
134inline std::size_t Trie::num_keys() const {
135  return num_keys_;
136}
137
138inline UInt32 Trie::notfound() {
139  return MARISA_NOT_FOUND;
140}
141
142inline std::size_t Trie::mismatch() {
143  return MARISA_MISMATCH;
144}
145
146template <typename T>
147inline bool Trie::find_child(UInt32 &node, T query,
148    std::size_t &pos) const {
149  UInt32 louds_pos = get_child(node);
150  if (!louds_[louds_pos]) {
151    return false;
152  }
153  node = louds_pos_to_node(louds_pos, node);
154  UInt32 link_id = MARISA_UINT32_MAX;
155  do {
156    if (has_link(node)) {
157      if (link_id == MARISA_UINT32_MAX) {
158        link_id = get_link_id(node);
159      } else {
160        ++link_id;
161      }
162      std::size_t next_pos = has_trie() ?
163          trie_->trie_match<T>(get_link(node, link_id), query, pos) :
164          tail_match<T>(node, link_id, query, pos);
165      if (next_pos == mismatch()) {
166        return false;
167      } else if (next_pos != pos) {
168        pos = next_pos;
169        return true;
170      }
171    } else if (labels_[node] == query[pos]) {
172      ++pos;
173      return true;
174    }
175    ++node;
176    ++louds_pos;
177  } while (louds_[louds_pos]);
178  return false;
179}
180
181template <typename T, typename U>
182std::size_t Trie::find_callback_(T query, U callback) const {
183  std::size_t count = 0;
184  UInt32 node = 0;
185  std::size_t pos = 0;
186  do {
187    if (terminal_flags_[node]) {
188      ++count;
189      if (!callback(node_to_key_id(node), pos)) {
190        return count;
191      }
192    }
193  } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
194  return count;
195}
196
197template <typename T>
198inline bool Trie::predict_child(UInt32 &node, T query, std::size_t &pos,
199    std::string *key) const {
200  UInt32 louds_pos = get_child(node);
201  if (!louds_[louds_pos]) {
202    return false;
203  }
204  node = louds_pos_to_node(louds_pos, node);
205  UInt32 link_id = MARISA_UINT32_MAX;
206  do {
207    if (has_link(node)) {
208      if (link_id == MARISA_UINT32_MAX) {
209        link_id = get_link_id(node);
210      } else {
211        ++link_id;
212      }
213      std::size_t next_pos = has_trie() ?
214          trie_->trie_prefix_match<T>(
215              get_link(node, link_id), query, pos, key) :
216          tail_prefix_match<T>(node, link_id, query, pos, key);
217      if (next_pos == mismatch()) {
218        return false;
219      } else if (next_pos != pos) {
220        pos = next_pos;
221        return true;
222      }
223    } else if (labels_[node] == query[pos]) {
224      ++pos;
225      return true;
226    }
227    ++node;
228    ++louds_pos;
229  } while (louds_[louds_pos]);
230  return false;
231}
232
233template <typename T, typename U>
234std::size_t Trie::predict_callback_(T query, U callback) const {
235  std::string key;
236  UInt32 node = 0;
237  std::size_t pos = 0;
238  while (!query.ends_at(pos)) {
239    if (!predict_child<T>(node, query, pos, &key)) {
240      return 0;
241    }
242  }
243  query.insert(&key);
244  std::size_t count = 0;
245  if (terminal_flags_[node]) {
246    ++count;
247    if (!callback(node_to_key_id(node), key)) {
248      return count;
249    }
250  }
251  Cell cell;
252  cell.set_louds_pos(get_child(node));
253  if (!louds_[cell.louds_pos()]) {
254    return count;
255  }
256  cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
257  cell.set_key_id(node_to_key_id(cell.node()));
258  cell.set_length(key.length());
259  Vector<Cell> stack;
260  stack.push_back(cell);
261  std::size_t stack_pos = 1;
262  while (stack_pos != 0) {
263    Cell &cur = stack[stack_pos - 1];
264    if (!louds_[cur.louds_pos()]) {
265      cur.set_louds_pos(cur.louds_pos() + 1);
266      --stack_pos;
267      continue;
268    }
269    cur.set_louds_pos(cur.louds_pos() + 1);
270    key.resize(cur.length());
271    if (has_link(cur.node())) {
272      if (has_trie()) {
273        trie_->trie_restore(get_link(cur.node()), &key);
274      } else {
275        tail_restore(cur.node(), &key);
276      }
277    } else {
278      key += labels_[cur.node()];
279    }
280    if (terminal_flags_[cur.node()]) {
281      ++count;
282      if (!callback(cur.key_id(), key)) {
283        return count;
284      }
285      cur.set_key_id(cur.key_id() + 1);
286    }
287    if (stack_pos == stack.size()) {
288      cell.set_louds_pos(get_child(cur.node()));
289      cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
290      cell.set_key_id(node_to_key_id(cell.node()));
291      stack.push_back(cell);
292    }
293    stack[stack_pos].set_length(key.length());
294    stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
295    ++stack_pos;
296  }
297  return count;
298}
299
300inline UInt32 Trie::key_id_to_node(UInt32 key_id) const {
301  return terminal_flags_.select1(key_id);
302}
303
304inline UInt32 Trie::node_to_key_id(UInt32 node) const {
305  return terminal_flags_.rank1(node);
306}
307
308inline UInt32 Trie::louds_pos_to_node(UInt32 louds_pos,
309    UInt32 parent_node) const {
310  return louds_pos - parent_node - 1;
311}
312
313inline UInt32 Trie::get_child(UInt32 node) const {
314  return louds_.select0(node) + 1;
315}
316
317inline UInt32 Trie::get_parent(UInt32 node) const {
318  return (node > num_first_branches_) ? (louds_.select1(node) - node - 1) : 0;
319}
320
321inline bool Trie::has_link(UInt32 node) const {
322  return (link_flags_.empty()) ? false : link_flags_[node];
323}
324
325inline UInt32 Trie::get_link_id(UInt32 node) const {
326  return link_flags_.rank1(node);
327}
328
329inline UInt32 Trie::get_link(UInt32 node) const {
330  return get_link(node, get_link_id(node));
331}
332
333inline UInt32 Trie::get_link(UInt32 node, UInt32 link_id) const {
334  return (links_[link_id] * 256) + labels_[node];
335}
336
337inline bool Trie::has_link() const {
338  return !link_flags_.empty();
339}
340
341inline bool Trie::has_trie() const {
342  return trie_.get() != NULL;
343}
344
345inline bool Trie::has_tail() const {
346  return !tail_.empty();
347}
348
349}  // namespace marisa
350
351#endif  // MARISA_TRIE_INLINE_H_
352