1#include <algorithm> 2#include <stdexcept> 3 4#include "trie.h" 5 6namespace marisa_alpha { 7namespace { 8 9template <typename T, typename U> 10class PredictCallback { 11 public: 12 PredictCallback(T key_ids, U keys, std::size_t max_num_results) 13 : key_ids_(key_ids), keys_(keys), 14 max_num_results_(max_num_results), num_results_(0) {} 15 PredictCallback(const PredictCallback &callback) 16 : key_ids_(callback.key_ids_), keys_(callback.keys_), 17 max_num_results_(callback.max_num_results_), 18 num_results_(callback.num_results_) {} 19 20 bool operator()(marisa_alpha::UInt32 key_id, const std::string &key) { 21 if (key_ids_.is_valid()) { 22 key_ids_.insert(num_results_, key_id); 23 } 24 if (keys_.is_valid()) { 25 keys_.insert(num_results_, key); 26 } 27 return ++num_results_ < max_num_results_; 28 } 29 30 private: 31 T key_ids_; 32 U keys_; 33 const std::size_t max_num_results_; 34 std::size_t num_results_; 35 36 // Disallows assignment. 37 PredictCallback &operator=(const PredictCallback &); 38}; 39 40} // namespace 41 42std::string Trie::restore(UInt32 key_id) const { 43 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 44 MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR); 45 std::string key; 46 restore_(key_id, &key); 47 return key; 48} 49 50void Trie::restore(UInt32 key_id, std::string *key) const { 51 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 52 MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR); 53 MARISA_ALPHA_THROW_IF(key == NULL, MARISA_ALPHA_PARAM_ERROR); 54 restore_(key_id, key); 55} 56 57std::size_t Trie::restore(UInt32 key_id, char *key_buf, 58 std::size_t key_buf_size) const { 59 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 60 MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR); 61 MARISA_ALPHA_THROW_IF((key_buf == NULL) && (key_buf_size != 0), 62 MARISA_ALPHA_PARAM_ERROR); 63 return restore_(key_id, key_buf, key_buf_size); 64} 65 66UInt32 Trie::lookup(const char *str) const { 67 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 68 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR); 69 return lookup_<CQuery>(CQuery(str)); 70} 71 72UInt32 Trie::lookup(const char *ptr, std::size_t length) const { 73 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 74 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0), 75 MARISA_ALPHA_PARAM_ERROR); 76 return lookup_<const Query &>(Query(ptr, length)); 77} 78 79std::size_t Trie::find(const char *str, 80 UInt32 *key_ids, std::size_t *key_lengths, 81 std::size_t max_num_results) const { 82 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 83 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR); 84 return find_<CQuery>(CQuery(str), 85 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results); 86} 87 88std::size_t Trie::find(const char *ptr, std::size_t length, 89 UInt32 *key_ids, std::size_t *key_lengths, 90 std::size_t max_num_results) const { 91 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 92 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0), 93 MARISA_ALPHA_PARAM_ERROR); 94 return find_<const Query &>(Query(ptr, length), 95 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results); 96} 97 98std::size_t Trie::find(const char *str, 99 std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths, 100 std::size_t max_num_results) const { 101 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 102 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR); 103 return find_<CQuery>(CQuery(str), 104 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results); 105} 106 107std::size_t Trie::find(const char *ptr, std::size_t length, 108 std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths, 109 std::size_t max_num_results) const { 110 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 111 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0), 112 MARISA_ALPHA_PARAM_ERROR); 113 return find_<const Query &>(Query(ptr, length), 114 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results); 115} 116 117UInt32 Trie::find_first(const char *str, 118 std::size_t *key_length) const { 119 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 120 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR); 121 return find_first_<CQuery>(CQuery(str), key_length); 122} 123 124UInt32 Trie::find_first(const char *ptr, std::size_t length, 125 std::size_t *key_length) const { 126 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 127 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0), 128 MARISA_ALPHA_PARAM_ERROR); 129 return find_first_<const Query &>(Query(ptr, length), key_length); 130} 131 132UInt32 Trie::find_last(const char *str, 133 std::size_t *key_length) const { 134 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 135 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR); 136 return find_last_<CQuery>(CQuery(str), key_length); 137} 138 139UInt32 Trie::find_last(const char *ptr, std::size_t length, 140 std::size_t *key_length) const { 141 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 142 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0), 143 MARISA_ALPHA_PARAM_ERROR); 144 return find_last_<const Query &>(Query(ptr, length), key_length); 145} 146 147std::size_t Trie::predict(const char *str, 148 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 149 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 150 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR); 151 return (keys == NULL) ? 152 predict_breadth_first(str, key_ids, keys, max_num_results) : 153 predict_depth_first(str, key_ids, keys, max_num_results); 154} 155 156std::size_t Trie::predict(const char *ptr, std::size_t length, 157 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 158 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 159 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0), 160 MARISA_ALPHA_PARAM_ERROR); 161 return (keys == NULL) ? 162 predict_breadth_first(ptr, length, key_ids, keys, max_num_results) : 163 predict_depth_first(ptr, length, key_ids, keys, max_num_results); 164} 165 166std::size_t Trie::predict(const char *str, 167 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 168 std::size_t max_num_results) const { 169 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 170 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR); 171 return (keys == NULL) ? 172 predict_breadth_first(str, key_ids, keys, max_num_results) : 173 predict_depth_first(str, key_ids, keys, max_num_results); 174} 175 176std::size_t Trie::predict(const char *ptr, std::size_t length, 177 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 178 std::size_t max_num_results) const { 179 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 180 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0), 181 MARISA_ALPHA_PARAM_ERROR); 182 return (keys == NULL) ? 183 predict_breadth_first(ptr, length, key_ids, keys, max_num_results) : 184 predict_depth_first(ptr, length, key_ids, keys, max_num_results); 185} 186 187std::size_t Trie::predict_breadth_first(const char *str, 188 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 189 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 190 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR); 191 return predict_breadth_first_<CQuery>(CQuery(str), 192 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 193} 194 195std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length, 196 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 197 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 198 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0), 199 MARISA_ALPHA_PARAM_ERROR); 200 return predict_breadth_first_<const Query &>(Query(ptr, length), 201 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 202} 203 204std::size_t Trie::predict_breadth_first(const char *str, 205 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 206 std::size_t max_num_results) const { 207 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 208 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR); 209 return predict_breadth_first_<CQuery>(CQuery(str), 210 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 211} 212 213std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length, 214 std::vector<UInt32> *key_ids, std::vector<std::string> *keys, 215 std::size_t max_num_results) const { 216 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 217 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0), 218 MARISA_ALPHA_PARAM_ERROR); 219 return predict_breadth_first_<const Query &>(Query(ptr, length), 220 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 221} 222 223std::size_t Trie::predict_depth_first(const char *str, 224 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 225 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 226 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR); 227 return predict_depth_first_<CQuery>(CQuery(str), 228 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 229} 230 231std::size_t Trie::predict_depth_first(const char *ptr, std::size_t length, 232 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const { 233 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 234 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0), 235 MARISA_ALPHA_PARAM_ERROR); 236 return predict_depth_first_<const Query &>(Query(ptr, length), 237 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 238} 239 240std::size_t Trie::predict_depth_first( 241 const char *str, std::vector<UInt32> *key_ids, 242 std::vector<std::string> *keys, std::size_t max_num_results) const { 243 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 244 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR); 245 return predict_depth_first_<CQuery>(CQuery(str), 246 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 247} 248 249std::size_t Trie::predict_depth_first( 250 const char *ptr, std::size_t length, std::vector<UInt32> *key_ids, 251 std::vector<std::string> *keys, std::size_t max_num_results) const { 252 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR); 253 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0), 254 MARISA_ALPHA_PARAM_ERROR); 255 return predict_depth_first_<const Query &>(Query(ptr, length), 256 MakeContainer(key_ids), MakeContainer(keys), max_num_results); 257} 258 259void Trie::restore_(UInt32 key_id, std::string *key) const { 260 const std::size_t start_pos = key->length(); 261 try { 262 UInt32 node = key_id_to_node(key_id); 263 while (node != 0) { 264 if (has_link(node)) { 265 const std::size_t prev_pos = key->length(); 266 if (has_trie()) { 267 trie_->trie_restore(get_link(node), key); 268 } else { 269 tail_restore(node, key); 270 } 271 std::reverse(key->begin() + prev_pos, key->end()); 272 } else { 273 *key += labels_[node]; 274 } 275 node = get_parent(node); 276 } 277 std::reverse(key->begin() + start_pos, key->end()); 278 } catch (const std::bad_alloc &) { 279 key->resize(start_pos); 280 MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR); 281 } catch (const std::length_error &) { 282 key->resize(start_pos); 283 MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR); 284 } 285} 286 287void Trie::trie_restore(UInt32 node, std::string *key) const { 288 do { 289 if (has_link(node)) { 290 if (has_trie()) { 291 trie_->trie_restore(get_link(node), key); 292 } else { 293 tail_restore(node, key); 294 } 295 } else { 296 *key += labels_[node]; 297 } 298 node = get_parent(node); 299 } while (node != 0); 300} 301 302void Trie::tail_restore(UInt32 node, std::string *key) const { 303 const UInt32 link_id = link_flags_.rank1(node); 304 const UInt32 offset = (links_[link_id] * 256) + labels_[node]; 305 if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) { 306 const UInt32 length = (links_[link_id + 1] * 256) 307 + labels_[link_flags_.select1(link_id + 1)] - offset; 308 key->append(reinterpret_cast<const char *>(tail_[offset]), length); 309 } else { 310 key->append(reinterpret_cast<const char *>(tail_[offset])); 311 } 312} 313 314std::size_t Trie::restore_(UInt32 key_id, char *key_buf, 315 std::size_t key_buf_size) const { 316 std::size_t pos = 0; 317 UInt32 node = key_id_to_node(key_id); 318 while (node != 0) { 319 if (has_link(node)) { 320 const std::size_t prev_pos = pos; 321 if (has_trie()) { 322 trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos); 323 } else { 324 tail_restore(node, key_buf, key_buf_size, pos); 325 } 326 if (pos < key_buf_size) { 327 std::reverse(key_buf + prev_pos, key_buf + pos); 328 } 329 } else { 330 if (pos < key_buf_size) { 331 key_buf[pos] = labels_[node]; 332 } 333 ++pos; 334 } 335 node = get_parent(node); 336 } 337 if (pos < key_buf_size) { 338 key_buf[pos] = '\0'; 339 std::reverse(key_buf, key_buf + pos); 340 } 341 return pos; 342} 343 344void Trie::trie_restore(UInt32 node, char *key_buf, 345 std::size_t key_buf_size, std::size_t &pos) const { 346 do { 347 if (has_link(node)) { 348 if (has_trie()) { 349 trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos); 350 } else { 351 tail_restore(node, key_buf, key_buf_size, pos); 352 } 353 } else { 354 if (pos < key_buf_size) { 355 key_buf[pos] = labels_[node]; 356 } 357 ++pos; 358 } 359 node = get_parent(node); 360 } while (node != 0); 361} 362 363void Trie::tail_restore(UInt32 node, char *key_buf, 364 std::size_t key_buf_size, std::size_t &pos) const { 365 const UInt32 link_id = link_flags_.rank1(node); 366 const UInt32 offset = (links_[link_id] * 256) + labels_[node]; 367 if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) { 368 const UInt8 *ptr = tail_[offset]; 369 const UInt32 length = (links_[link_id + 1] * 256) 370 + labels_[link_flags_.select1(link_id + 1)] - offset; 371 for (UInt32 i = 0; i < length; ++i) { 372 if (pos < key_buf_size) { 373 key_buf[pos] = ptr[i]; 374 } 375 ++pos; 376 } 377 } else { 378 for (const UInt8 *str = tail_[offset]; *str != '\0'; ++str) { 379 if (pos < key_buf_size) { 380 key_buf[pos] = *str; 381 } 382 ++pos; 383 } 384 } 385} 386 387template <typename T> 388UInt32 Trie::lookup_(T query) const { 389 UInt32 node = 0; 390 std::size_t pos = 0; 391 while (!query.ends_at(pos)) { 392 if (!find_child<T>(node, query, pos)) { 393 return notfound(); 394 } 395 } 396 return terminal_flags_[node] ? node_to_key_id(node) : notfound(); 397} 398 399template <typename T> 400std::size_t Trie::trie_match(UInt32 node, T query, 401 std::size_t pos) const { 402 if (has_link(node)) { 403 std::size_t next_pos; 404 if (has_trie()) { 405 next_pos = trie_->trie_match<T>(get_link(node), query, pos); 406 } else { 407 next_pos = tail_match<T>(node, get_link_id(node), query, pos); 408 } 409 if ((next_pos == mismatch()) || (next_pos == pos)) { 410 return next_pos; 411 } 412 pos = next_pos; 413 } else if (labels_[node] != query[pos]) { 414 return pos; 415 } else { 416 ++pos; 417 } 418 node = get_parent(node); 419 while (node != 0) { 420 if (query.ends_at(pos)) { 421 return mismatch(); 422 } 423 if (has_link(node)) { 424 std::size_t next_pos; 425 if (has_trie()) { 426 next_pos = trie_->trie_match<T>(get_link(node), query, pos); 427 } else { 428 next_pos = tail_match<T>(node, get_link_id(node), query, pos); 429 } 430 if ((next_pos == mismatch()) || (next_pos == pos)) { 431 return mismatch(); 432 } 433 pos = next_pos; 434 } else if (labels_[node] != query[pos]) { 435 return mismatch(); 436 } else { 437 ++pos; 438 } 439 node = get_parent(node); 440 } 441 return pos; 442} 443 444template std::size_t Trie::trie_match<CQuery>(UInt32 node, 445 CQuery query, std::size_t pos) const; 446template std::size_t Trie::trie_match<const Query &>(UInt32 node, 447 const Query &query, std::size_t pos) const; 448 449template <typename T> 450std::size_t Trie::tail_match(UInt32 node, UInt32 link_id, 451 T query, std::size_t pos) const { 452 const UInt32 offset = (links_[link_id] * 256) + labels_[node]; 453 const UInt8 *ptr = tail_[offset]; 454 if (*ptr != query[pos]) { 455 return pos; 456 } else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) { 457 const UInt32 length = (links_[link_id + 1] * 256) 458 + labels_[link_flags_.select1(link_id + 1)] - offset; 459 for (UInt32 i = 1; i < length; ++i) { 460 if (query.ends_at(pos + i) || (ptr[i] != query[pos + i])) { 461 return mismatch(); 462 } 463 } 464 return pos + length; 465 } else { 466 for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) { 467 if (query.ends_at(pos) || (*ptr != query[pos])) { 468 return mismatch(); 469 } 470 } 471 return pos; 472 } 473} 474 475template std::size_t Trie::tail_match<CQuery>(UInt32 node, 476 UInt32 link_id, CQuery query, std::size_t pos) const; 477template std::size_t Trie::tail_match<const Query &>(UInt32 node, 478 UInt32 link_id, const Query &query, std::size_t pos) const; 479 480template <typename T, typename U, typename V> 481std::size_t Trie::find_(T query, U key_ids, V key_lengths, 482 std::size_t max_num_results) const try { 483 if (max_num_results == 0) { 484 return 0; 485 } 486 std::size_t count = 0; 487 UInt32 node = 0; 488 std::size_t pos = 0; 489 do { 490 if (terminal_flags_[node]) { 491 if (key_ids.is_valid()) { 492 key_ids.insert(count, node_to_key_id(node)); 493 } 494 if (key_lengths.is_valid()) { 495 key_lengths.insert(count, pos); 496 } 497 if (++count >= max_num_results) { 498 return count; 499 } 500 } 501 } while (!query.ends_at(pos) && find_child<T>(node, query, pos)); 502 return count; 503} catch (const std::bad_alloc &) { 504 MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR); 505} catch (const std::length_error &) { 506 MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR); 507} 508 509template <typename T> 510UInt32 Trie::find_first_(T query, std::size_t *key_length) const { 511 UInt32 node = 0; 512 std::size_t pos = 0; 513 do { 514 if (terminal_flags_[node]) { 515 if (key_length != NULL) { 516 *key_length = pos; 517 } 518 return node_to_key_id(node); 519 } 520 } while (!query.ends_at(pos) && find_child<T>(node, query, pos)); 521 return notfound(); 522} 523 524template <typename T> 525UInt32 Trie::find_last_(T query, std::size_t *key_length) const { 526 UInt32 node = 0; 527 UInt32 node_found = notfound(); 528 std::size_t pos = 0; 529 std::size_t pos_found = mismatch(); 530 do { 531 if (terminal_flags_[node]) { 532 node_found = node; 533 pos_found = pos; 534 } 535 } while (!query.ends_at(pos) && find_child<T>(node, query, pos)); 536 if (node_found != notfound()) { 537 if (key_length != NULL) { 538 *key_length = pos_found; 539 } 540 return node_to_key_id(node_found); 541 } 542 return notfound(); 543} 544 545template <typename T, typename U, typename V> 546std::size_t Trie::predict_breadth_first_(T query, U key_ids, V keys, 547 std::size_t max_num_results) const try { 548 if (max_num_results == 0) { 549 return 0; 550 } 551 UInt32 node = 0; 552 std::size_t pos = 0; 553 while (!query.ends_at(pos)) { 554 if (!predict_child<T>(node, query, pos, NULL)) { 555 return 0; 556 } 557 } 558 std::string key; 559 std::size_t count = 0; 560 if (terminal_flags_[node]) { 561 const UInt32 key_id = node_to_key_id(node); 562 if (key_ids.is_valid()) { 563 key_ids.insert(count, key_id); 564 } 565 if (keys.is_valid()) { 566 restore(key_id, &key); 567 keys.insert(count, key); 568 } 569 if (++count >= max_num_results) { 570 return count; 571 } 572 } 573 const UInt32 louds_pos = get_child(node); 574 if (!louds_[louds_pos]) { 575 return count; 576 } 577 UInt32 node_begin = louds_pos_to_node(louds_pos, node); 578 UInt32 node_end = louds_pos_to_node(get_child(node + 1), node + 1); 579 while (node_begin < node_end) { 580 const UInt32 key_id_begin = node_to_key_id(node_begin); 581 const UInt32 key_id_end = node_to_key_id(node_end); 582 if (key_ids.is_valid()) { 583 UInt32 temp_count = count; 584 for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) { 585 key_ids.insert(temp_count, key_id); 586 if (++temp_count >= max_num_results) { 587 break; 588 } 589 } 590 } 591 if (keys.is_valid()) { 592 UInt32 temp_count = count; 593 for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) { 594 key.clear(); 595 restore(key_id, &key); 596 keys.insert(temp_count, key); 597 if (++temp_count >= max_num_results) { 598 break; 599 } 600 } 601 } 602 count += key_id_end - key_id_begin; 603 if (count >= max_num_results) { 604 return max_num_results; 605 } 606 node_begin = louds_pos_to_node(get_child(node_begin), node_begin); 607 node_end = louds_pos_to_node(get_child(node_end), node_end); 608 } 609 return count; 610} catch (const std::bad_alloc &) { 611 MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR); 612} catch (const std::length_error &) { 613 MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR); 614} 615 616template <typename T, typename U, typename V> 617std::size_t Trie::predict_depth_first_(T query, U key_ids, V keys, 618 std::size_t max_num_results) const try { 619 if (max_num_results == 0) { 620 return 0; 621 } else if (keys.is_valid()) { 622 PredictCallback<U, V> callback(key_ids, keys, max_num_results); 623 return predict_callback_(query, callback); 624 } 625 626 UInt32 node = 0; 627 std::size_t pos = 0; 628 while (!query.ends_at(pos)) { 629 if (!predict_child<T>(node, query, pos, NULL)) { 630 return 0; 631 } 632 } 633 std::size_t count = 0; 634 if (terminal_flags_[node]) { 635 if (key_ids.is_valid()) { 636 key_ids.insert(count, node_to_key_id(node)); 637 } 638 if (++count >= max_num_results) { 639 return count; 640 } 641 } 642 Cell cell; 643 cell.set_louds_pos(get_child(node)); 644 if (!louds_[cell.louds_pos()]) { 645 return count; 646 } 647 cell.set_node(louds_pos_to_node(cell.louds_pos(), node)); 648 cell.set_key_id(node_to_key_id(cell.node())); 649 Vector<Cell> stack; 650 stack.push_back(cell); 651 std::size_t stack_pos = 1; 652 while (stack_pos != 0) { 653 Cell &cur = stack[stack_pos - 1]; 654 if (!louds_[cur.louds_pos()]) { 655 cur.set_louds_pos(cur.louds_pos() + 1); 656 --stack_pos; 657 continue; 658 } 659 cur.set_louds_pos(cur.louds_pos() + 1); 660 if (terminal_flags_[cur.node()]) { 661 if (key_ids.is_valid()) { 662 key_ids.insert(count, cur.key_id()); 663 } 664 if (++count >= max_num_results) { 665 return count; 666 } 667 cur.set_key_id(cur.key_id() + 1); 668 } 669 if (stack_pos == stack.size()) { 670 cell.set_louds_pos(get_child(cur.node())); 671 cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node())); 672 cell.set_key_id(node_to_key_id(cell.node())); 673 stack.push_back(cell); 674 } 675 stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1); 676 ++stack_pos; 677 } 678 return count; 679} catch (const std::bad_alloc &) { 680 MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR); 681} catch (const std::length_error &) { 682 MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR); 683} 684 685template <typename T> 686std::size_t Trie::trie_prefix_match(UInt32 node, T query, 687 std::size_t pos, std::string *key) const { 688 if (has_link(node)) { 689 std::size_t next_pos; 690 if (has_trie()) { 691 next_pos = trie_->trie_prefix_match<T>(get_link(node), query, pos, key); 692 } else { 693 next_pos = tail_prefix_match<T>( 694 node, get_link_id(node), query, pos, key); 695 } 696 if ((next_pos == mismatch()) || (next_pos == pos)) { 697 return next_pos; 698 } 699 pos = next_pos; 700 } else if (labels_[node] != query[pos]) { 701 return pos; 702 } else { 703 ++pos; 704 } 705 node = get_parent(node); 706 while (node != 0) { 707 if (query.ends_at(pos)) { 708 if (key != NULL) { 709 trie_restore(node, key); 710 } 711 return pos; 712 } 713 if (has_link(node)) { 714 std::size_t next_pos; 715 if (has_trie()) { 716 next_pos = trie_->trie_prefix_match<T>( 717 get_link(node), query, pos, key); 718 } else { 719 next_pos = tail_prefix_match<T>( 720 node, get_link_id(node), query, pos, key); 721 } 722 if ((next_pos == mismatch()) || (next_pos == pos)) { 723 return next_pos; 724 } 725 pos = next_pos; 726 } else if (labels_[node] != query[pos]) { 727 return mismatch(); 728 } else { 729 ++pos; 730 } 731 node = get_parent(node); 732 } 733 return pos; 734} 735 736template std::size_t Trie::trie_prefix_match<CQuery>(UInt32 node, 737 CQuery query, std::size_t pos, std::string *key) const; 738template std::size_t Trie::trie_prefix_match<const Query &>(UInt32 node, 739 const Query &query, std::size_t pos, std::string *key) const; 740 741template <typename T> 742std::size_t Trie::tail_prefix_match(UInt32 node, UInt32 link_id, 743 T query, std::size_t pos, std::string *key) const { 744 const UInt32 offset = (links_[link_id] * 256) + labels_[node]; 745 const UInt8 *ptr = tail_[offset]; 746 if (*ptr != query[pos]) { 747 return pos; 748 } else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) { 749 const UInt32 length = (links_[link_id + 1] * 256) 750 + labels_[link_flags_.select1(link_id + 1)] - offset; 751 for (UInt32 i = 1; i < length; ++i) { 752 if (query.ends_at(pos + i)) { 753 if (key != NULL) { 754 key->append(reinterpret_cast<const char *>(ptr + i), length - i); 755 } 756 return pos + i; 757 } else if (ptr[i] != query[pos + i]) { 758 return mismatch(); 759 } 760 } 761 return pos + length; 762 } else { 763 for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) { 764 if (query.ends_at(pos)) { 765 if (key != NULL) { 766 key->append(reinterpret_cast<const char *>(ptr)); 767 } 768 return pos; 769 } else if (*ptr != query[pos]) { 770 return mismatch(); 771 } 772 } 773 return pos; 774 } 775} 776 777template std::size_t Trie::tail_prefix_match<CQuery>( 778 UInt32 node, UInt32 link_id, 779 CQuery query, std::size_t pos, std::string *key) const; 780template std::size_t Trie::tail_prefix_match<const Query &>( 781 UInt32 node, UInt32 link_id, 782 const Query &query, std::size_t pos, std::string *key) const; 783 784} // namespace marisa_alpha 785