1#ifndef MARISA_ALPHA_TRIE_INLINE_H_
2#define MARISA_ALPHA_TRIE_INLINE_H_
3
4#include <stdexcept>
5
6#include "cell.h"
7
8namespace marisa_alpha {
9
10inline std::string Trie::operator[](UInt32 key_id) const {
11  std::string key;
12  restore(key_id, &key);
13  return key;
14}
15
16inline UInt32 Trie::operator[](const char *str) const {
17  return lookup(str);
18}
19
20inline UInt32 Trie::operator[](const std::string &str) const {
21  return lookup(str);
22}
23
24inline UInt32 Trie::lookup(const std::string &str) const {
25  return lookup(str.c_str(), str.length());
26}
27
28inline std::size_t Trie::find(const std::string &str,
29    UInt32 *key_ids, std::size_t *key_lengths,
30    std::size_t max_num_results) const {
31  return find(str.c_str(), str.length(),
32      key_ids, key_lengths, max_num_results);
33}
34
35inline std::size_t Trie::find(const std::string &str,
36    std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
37    std::size_t max_num_results) const {
38  return find(str.c_str(), str.length(),
39      key_ids, key_lengths, max_num_results);
40}
41
42inline UInt32 Trie::find_first(const std::string &str,
43    std::size_t *key_length) const {
44  return find_first(str.c_str(), str.length(), key_length);
45}
46
47inline UInt32 Trie::find_last(const std::string &str,
48    std::size_t *key_length) const {
49  return find_last(str.c_str(), str.length(), key_length);
50}
51
52template <typename T>
53inline std::size_t Trie::find_callback(const char *str,
54    T callback) const {
55  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
56  MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
57  return find_callback_<CQuery>(CQuery(str), callback);
58}
59
60template <typename T>
61inline std::size_t Trie::find_callback(const char *ptr, std::size_t length,
62    T callback) const {
63  MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
64  MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
65      MARISA_ALPHA_PARAM_ERROR);
66  return find_callback_<const Query &>(Query(ptr, length), callback);
67}
68
69template <typename T>
70inline std::size_t Trie::find_callback(const std::string &str,
71    T callback) const {
72  return find_callback(str.c_str(), str.length(), callback);
73}
74
75inline std::size_t Trie::predict(const std::string &str,
76    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
77  return predict(str.c_str(), str.length(), key_ids, keys, max_num_results);
78}
79
80inline std::size_t Trie::predict(const std::string &str,
81    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
82    std::size_t max_num_results) const {
83  return predict(str.c_str(), str.length(), key_ids, keys, max_num_results);
84}
85
86inline std::size_t Trie::predict_breadth_first(const std::string &str,
87    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
88  return predict_breadth_first(str.c_str(), str.length(),
89      key_ids, keys, max_num_results);
90}
91
92inline std::size_t Trie::predict_breadth_first(const std::string &str,
93    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
94    std::size_t max_num_results) const {
95  return predict_breadth_first(str.c_str(), str.length(),
96      key_ids, keys, max_num_results);
97}
98
99inline std::size_t Trie::predict_depth_first(const std::string &str,
100    UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
101  return predict_depth_first(str.c_str(), str.length(),
102      key_ids, keys, max_num_results);
103}
104
105inline std::size_t Trie::predict_depth_first(const std::string &str,
106    std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
107    std::size_t max_num_results) const {
108  return predict_depth_first(str.c_str(), str.length(),
109      key_ids, keys, max_num_results);
110}
111
112template <typename T>
113inline std::size_t Trie::predict_callback(
114    const char *str, T callback) const {
115  return predict_callback_<CQuery>(CQuery(str), callback);
116}
117
118template <typename T>
119inline std::size_t Trie::predict_callback(
120    const char *ptr, std::size_t length,
121    T callback) const {
122  return predict_callback_<const Query &>(Query(ptr, length), callback);
123}
124
125template <typename T>
126inline std::size_t Trie::predict_callback(
127    const std::string &str, T callback) const {
128  return predict_callback(str.c_str(), str.length(), callback);
129}
130
131inline bool Trie::empty() const {
132  return louds_.empty();
133}
134
135inline std::size_t Trie::num_keys() const {
136  return num_keys_;
137}
138
139inline UInt32 Trie::notfound() {
140  return MARISA_ALPHA_NOT_FOUND;
141}
142
143inline std::size_t Trie::mismatch() {
144  return MARISA_ALPHA_MISMATCH;
145}
146
147template <typename T>
148inline bool Trie::find_child(UInt32 &node, T query,
149    std::size_t &pos) const {
150  UInt32 louds_pos = get_child(node);
151  if (!louds_[louds_pos]) {
152    return false;
153  }
154  node = louds_pos_to_node(louds_pos, node);
155  UInt32 link_id = MARISA_ALPHA_UINT32_MAX;
156  do {
157    if (has_link(node)) {
158      if (link_id == MARISA_ALPHA_UINT32_MAX) {
159        link_id = get_link_id(node);
160      } else {
161        ++link_id;
162      }
163      std::size_t next_pos = has_trie() ?
164          trie_->trie_match<T>(get_link(node, link_id), query, pos) :
165          tail_match<T>(node, link_id, query, pos);
166      if (next_pos == mismatch()) {
167        return false;
168      } else if (next_pos != pos) {
169        pos = next_pos;
170        return true;
171      }
172    } else if (labels_[node] == query[pos]) {
173      ++pos;
174      return true;
175    }
176    ++node;
177    ++louds_pos;
178  } while (louds_[louds_pos]);
179  return false;
180}
181
182template <typename T, typename U>
183std::size_t Trie::find_callback_(T query, U callback) const try {
184  std::size_t count = 0;
185  UInt32 node = 0;
186  std::size_t pos = 0;
187  do {
188    if (terminal_flags_[node]) {
189      ++count;
190      if (!callback(node_to_key_id(node), pos)) {
191        return count;
192      }
193    }
194  } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
195  return count;
196} catch (const std::bad_alloc &) {
197  MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
198} catch (const std::length_error &) {
199  MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
200}
201
202template <typename T>
203inline bool Trie::predict_child(UInt32 &node, T query, std::size_t &pos,
204    std::string *key) const {
205  UInt32 louds_pos = get_child(node);
206  if (!louds_[louds_pos]) {
207    return false;
208  }
209  node = louds_pos_to_node(louds_pos, node);
210  UInt32 link_id = MARISA_ALPHA_UINT32_MAX;
211  do {
212    if (has_link(node)) {
213      if (link_id == MARISA_ALPHA_UINT32_MAX) {
214        link_id = get_link_id(node);
215      } else {
216        ++link_id;
217      }
218      std::size_t next_pos = has_trie() ?
219          trie_->trie_prefix_match<T>(
220              get_link(node, link_id), query, pos, key) :
221          tail_prefix_match<T>(node, link_id, query, pos, key);
222      if (next_pos == mismatch()) {
223        return false;
224      } else if (next_pos != pos) {
225        pos = next_pos;
226        return true;
227      }
228    } else if (labels_[node] == query[pos]) {
229      ++pos;
230      return true;
231    }
232    ++node;
233    ++louds_pos;
234  } while (louds_[louds_pos]);
235  return false;
236}
237
238template <typename T, typename U>
239std::size_t Trie::predict_callback_(T query, U callback) const try {
240  std::string key;
241  UInt32 node = 0;
242  std::size_t pos = 0;
243  while (!query.ends_at(pos)) {
244    if (!predict_child<T>(node, query, pos, &key)) {
245      return 0;
246    }
247  }
248  query.insert(&key);
249  std::size_t count = 0;
250  if (terminal_flags_[node]) {
251    ++count;
252    if (!callback(node_to_key_id(node), key)) {
253      return count;
254    }
255  }
256  Cell cell;
257  cell.set_louds_pos(get_child(node));
258  if (!louds_[cell.louds_pos()]) {
259    return count;
260  }
261  cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
262  cell.set_key_id(node_to_key_id(cell.node()));
263  cell.set_length(key.length());
264  Vector<Cell> stack;
265  stack.push_back(cell);
266  std::size_t stack_pos = 1;
267  while (stack_pos != 0) {
268    Cell &cur = stack[stack_pos - 1];
269    if (!louds_[cur.louds_pos()]) {
270      cur.set_louds_pos(cur.louds_pos() + 1);
271      --stack_pos;
272      continue;
273    }
274    cur.set_louds_pos(cur.louds_pos() + 1);
275    key.resize(cur.length());
276    if (has_link(cur.node())) {
277      if (has_trie()) {
278        trie_->trie_restore(get_link(cur.node()), &key);
279      } else {
280        tail_restore(cur.node(), &key);
281      }
282    } else {
283      key += labels_[cur.node()];
284    }
285    if (terminal_flags_[cur.node()]) {
286      ++count;
287      if (!callback(cur.key_id(), key)) {
288        return count;
289      }
290      cur.set_key_id(cur.key_id() + 1);
291    }
292    if (stack_pos == stack.size()) {
293      cell.set_louds_pos(get_child(cur.node()));
294      cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
295      cell.set_key_id(node_to_key_id(cell.node()));
296      stack.push_back(cell);
297    }
298    stack[stack_pos].set_length(key.length());
299    stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
300    ++stack_pos;
301  }
302  return count;
303} catch (const std::bad_alloc &) {
304  MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
305} catch (const std::length_error &) {
306  MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
307}
308
309inline UInt32 Trie::key_id_to_node(UInt32 key_id) const {
310  return terminal_flags_.select1(key_id);
311}
312
313inline UInt32 Trie::node_to_key_id(UInt32 node) const {
314  return terminal_flags_.rank1(node);
315}
316
317inline UInt32 Trie::louds_pos_to_node(UInt32 louds_pos,
318    UInt32 parent_node) const {
319  return louds_pos - parent_node - 1;
320}
321
322inline UInt32 Trie::get_child(UInt32 node) const {
323  return louds_.select0(node) + 1;
324}
325
326inline UInt32 Trie::get_parent(UInt32 node) const {
327  return (node > num_first_branches_) ? (louds_.select1(node) - node - 1) : 0;
328}
329
330inline bool Trie::has_link(UInt32 node) const {
331  return (link_flags_.empty()) ? false : link_flags_[node];
332}
333
334inline UInt32 Trie::get_link_id(UInt32 node) const {
335  return link_flags_.rank1(node);
336}
337
338inline UInt32 Trie::get_link(UInt32 node) const {
339  return get_link(node, get_link_id(node));
340}
341
342inline UInt32 Trie::get_link(UInt32 node, UInt32 link_id) const {
343  return (links_[link_id] * 256) + labels_[node];
344}
345
346inline bool Trie::has_link() const {
347  return !link_flags_.empty();
348}
349
350inline bool Trie::has_trie() const {
351  return trie_.get() != NULL;
352}
353
354inline bool Trie::has_tail() const {
355  return !tail_.empty();
356}
357
358}  // namespace marisa_alpha
359
360#endif  // MARISA_ALPHA_TRIE_INLINE_H_
361