1#ifndef MARISA_TRIE_INLINE_H_ 2#define MARISA_TRIE_INLINE_H_ 3 4#include <stdexcept> 5 6#include "cell.h" 7 8namespace marisa { 9 10inline std::string Trie::operator[](UInt32 key_id) const { 11 std::string key; 12 restore(key_id, &key); 13 return key; 14} 15 16inline UInt32 Trie::operator[](const char *str) const { 17 return lookup(str); 18} 19 20inline UInt32 Trie::operator[](const std::string &str) const { 21 return lookup(str); 22} 23 24inline UInt32 Trie::lookup(const std::string &str) const { 25 return lookup(str.c_str(), str.length()); 26} 27 28inline std::size_t Trie::find(const std::string &str, 29 UInt32 *key_ids, std::size_t *key_lengths, 30 std::size_t max_num_results) const { 31 return find(str.c_str(), str.length(), 32 key_ids, key_lengths, max_num_results); 33} 34 35inline std::size_t Trie::find(const std::string &str, 36 std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths, 37 std::size_t max_num_results) const { 38 return find(str.c_str(), str.length(), 39 key_ids, key_lengths, max_num_results); 40} 41 42inline UInt32 Trie::find_first(const std::string &str, 43 std::size_t *key_length) const { 44 return find_first(str.c_str(), str.length(), key_length); 45} 46 47inline UInt32 Trie::find_last(const std::string &str, 48 std::size_t *key_length) const { 49 return find_last(str.c_str(), str.length(), key_length); 50} 51 52template <typename T> 53inline std::size_t Trie::find_callback(const char *str, 54 T callback) const { 55 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 56 MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR); 57 return find_callback_<CQuery>(CQuery(str), callback); 58} 59 60template <typename T> 61inline std::size_t Trie::find_callback(const char *ptr, std::size_t length, 62 T callback) const { 63 MARISA_THROW_IF(empty(), MARISA_STATE_ERROR); 64 MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR); 65 return find_callback_<const Query &>(Query(ptr, length), callback); 66} 67 68template <typename T> 69inline std::size_t Trie::find_callback(const std::string &str, 70 T callback) const { 71 return find_callback(str.c_str(), str.length(), callback); 72} 73 74inline std::size_t Trie::predict(const std::string &str, 75 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 76 return predict(str.c_str(), str.length(), key_ids, keys, max_num_results); 77} 78 79inline std::size_t Trie::predict(const std::string &str, 80 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 81 std::size_t max_num_results) const { 82 return predict(str.c_str(), str.length(), key_ids, keys, max_num_results); 83} 84 85inline std::size_t Trie::predict_breadth_first(const std::string &str, 86 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 87 return predict_breadth_first(str.c_str(), str.length(), 88 key_ids, keys, max_num_results); 89} 90 91inline std::size_t Trie::predict_breadth_first(const std::string &str, 92 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 93 std::size_t max_num_results) const { 94 return predict_breadth_first(str.c_str(), str.length(), 95 key_ids, keys, max_num_results); 96} 97 98inline std::size_t Trie::predict_depth_first(const std::string &str, 99 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 100 return predict_depth_first(str.c_str(), str.length(), 101 key_ids, keys, max_num_results); 102} 103 104inline std::size_t Trie::predict_depth_first(const std::string &str, 105 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 106 std::size_t max_num_results) const { 107 return predict_depth_first(str.c_str(), str.length(), 108 key_ids, keys, max_num_results); 109} 110 111template <typename T> 112inline std::size_t Trie::predict_callback( 113 const char *str, T callback) const { 114 return predict_callback_<CQuery>(CQuery(str), callback); 115} 116 117template <typename T> 118inline std::size_t Trie::predict_callback( 119 const char *ptr, std::size_t length, 120 T callback) const { 121 return predict_callback_<const Query &>(Query(ptr, length), callback); 122} 123 124template <typename T> 125inline std::size_t Trie::predict_callback( 126 const std::string &str, T callback) const { 127 return predict_callback(str.c_str(), str.length(), callback); 128} 129 130inline bool Trie::empty() const { 131 return louds_.empty(); 132} 133 134inline std::size_t Trie::num_keys() const { 135 return num_keys_; 136} 137 138inline UInt32 Trie::notfound() { 139 return MARISA_NOT_FOUND; 140} 141 142inline std::size_t Trie::mismatch() { 143 return MARISA_MISMATCH; 144} 145 146template <typename T> 147inline bool Trie::find_child(UInt32 &node, T query, 148 std::size_t &pos) const { 149 UInt32 louds_pos = get_child(node); 150 if (!louds_[louds_pos]) { 151 return false; 152 } 153 node = louds_pos_to_node(louds_pos, node); 154 UInt32 link_id = MARISA_UINT32_MAX; 155 do { 156 if (has_link(node)) { 157 if (link_id == MARISA_UINT32_MAX) { 158 link_id = get_link_id(node); 159 } else { 160 ++link_id; 161 } 162 std::size_t next_pos = has_trie() ? 163 trie_->trie_match<T>(get_link(node, link_id), query, pos) : 164 tail_match<T>(node, link_id, query, pos); 165 if (next_pos == mismatch()) { 166 return false; 167 } else if (next_pos != pos) { 168 pos = next_pos; 169 return true; 170 } 171 } else if (labels_[node] == query[pos]) { 172 ++pos; 173 return true; 174 } 175 ++node; 176 ++louds_pos; 177 } while (louds_[louds_pos]); 178 return false; 179} 180 181template <typename T, typename U> 182std::size_t Trie::find_callback_(T query, U callback) const { 183 std::size_t count = 0; 184 UInt32 node = 0; 185 std::size_t pos = 0; 186 do { 187 if (terminal_flags_[node]) { 188 ++count; 189 if (!callback(node_to_key_id(node), pos)) { 190 return count; 191 } 192 } 193 } while (!query.ends_at(pos) && find_child<T>(node, query, pos)); 194 return count; 195} 196 197template <typename T> 198inline bool Trie::predict_child(UInt32 &node, T query, std::size_t &pos, 199 std::string *key) const { 200 UInt32 louds_pos = get_child(node); 201 if (!louds_[louds_pos]) { 202 return false; 203 } 204 node = louds_pos_to_node(louds_pos, node); 205 UInt32 link_id = MARISA_UINT32_MAX; 206 do { 207 if (has_link(node)) { 208 if (link_id == MARISA_UINT32_MAX) { 209 link_id = get_link_id(node); 210 } else { 211 ++link_id; 212 } 213 std::size_t next_pos = has_trie() ? 214 trie_->trie_prefix_match<T>( 215 get_link(node, link_id), query, pos, key) : 216 tail_prefix_match<T>(node, link_id, query, pos, key); 217 if (next_pos == mismatch()) { 218 return false; 219 } else if (next_pos != pos) { 220 pos = next_pos; 221 return true; 222 } 223 } else if (labels_[node] == query[pos]) { 224 ++pos; 225 return true; 226 } 227 ++node; 228 ++louds_pos; 229 } while (louds_[louds_pos]); 230 return false; 231} 232 233template <typename T, typename U> 234std::size_t Trie::predict_callback_(T query, U callback) const { 235 std::string key; 236 UInt32 node = 0; 237 std::size_t pos = 0; 238 while (!query.ends_at(pos)) { 239 if (!predict_child<T>(node, query, pos, &key)) { 240 return 0; 241 } 242 } 243 query.insert(&key); 244 std::size_t count = 0; 245 if (terminal_flags_[node]) { 246 ++count; 247 if (!callback(node_to_key_id(node), key)) { 248 return count; 249 } 250 } 251 Cell cell; 252 cell.set_louds_pos(get_child(node)); 253 if (!louds_[cell.louds_pos()]) { 254 return count; 255 } 256 cell.set_node(louds_pos_to_node(cell.louds_pos(), node)); 257 cell.set_key_id(node_to_key_id(cell.node())); 258 cell.set_length(key.length()); 259 Vector<Cell> stack; 260 stack.push_back(cell); 261 std::size_t stack_pos = 1; 262 while (stack_pos != 0) { 263 Cell &cur = stack[stack_pos - 1]; 264 if (!louds_[cur.louds_pos()]) { 265 cur.set_louds_pos(cur.louds_pos() + 1); 266 --stack_pos; 267 continue; 268 } 269 cur.set_louds_pos(cur.louds_pos() + 1); 270 key.resize(cur.length()); 271 if (has_link(cur.node())) { 272 if (has_trie()) { 273 trie_->trie_restore(get_link(cur.node()), &key); 274 } else { 275 tail_restore(cur.node(), &key); 276 } 277 } else { 278 key += labels_[cur.node()]; 279 } 280 if (terminal_flags_[cur.node()]) { 281 ++count; 282 if (!callback(cur.key_id(), key)) { 283 return count; 284 } 285 cur.set_key_id(cur.key_id() + 1); 286 } 287 if (stack_pos == stack.size()) { 288 cell.set_louds_pos(get_child(cur.node())); 289 cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node())); 290 cell.set_key_id(node_to_key_id(cell.node())); 291 stack.push_back(cell); 292 } 293 stack[stack_pos].set_length(key.length()); 294 stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1); 295 ++stack_pos; 296 } 297 return count; 298} 299 300inline UInt32 Trie::key_id_to_node(UInt32 key_id) const { 301 return terminal_flags_.select1(key_id); 302} 303 304inline UInt32 Trie::node_to_key_id(UInt32 node) const { 305 return terminal_flags_.rank1(node); 306} 307 308inline UInt32 Trie::louds_pos_to_node(UInt32 louds_pos, 309 UInt32 parent_node) const { 310 return louds_pos - parent_node - 1; 311} 312 313inline UInt32 Trie::get_child(UInt32 node) const { 314 return louds_.select0(node) + 1; 315} 316 317inline UInt32 Trie::get_parent(UInt32 node) const { 318 return (node > num_first_branches_) ? (louds_.select1(node) - node - 1) : 0; 319} 320 321inline bool Trie::has_link(UInt32 node) const { 322 return (link_flags_.empty()) ? false : link_flags_[node]; 323} 324 325inline UInt32 Trie::get_link_id(UInt32 node) const { 326 return link_flags_.rank1(node); 327} 328 329inline UInt32 Trie::get_link(UInt32 node) const { 330 return get_link(node, get_link_id(node)); 331} 332 333inline UInt32 Trie::get_link(UInt32 node, UInt32 link_id) const { 334 return (links_[link_id] * 256) + labels_[node]; 335} 336 337inline bool Trie::has_link() const { 338 return !link_flags_.empty(); 339} 340 341inline bool Trie::has_trie() const { 342 return trie_.get() != NULL; 343} 344 345inline bool Trie::has_tail() const { 346 return !tail_.empty(); 347} 348 349} // namespace marisa 350 351#endif // MARISA_TRIE_INLINE_H_ 352