1#ifndef MARISA_ALPHA_TRIE_H_
2#define MARISA_ALPHA_TRIE_H_
3
4#include "base.h"
5
6#ifdef __cplusplus
7
8#include <memory>
9#include <vector>
10
11#include "progress.h"
12#include "key.h"
13#include "query.h"
14#include "container.h"
15#include "intvector.h"
16#include "bitvector.h"
17#include "tail.h"
18
19namespace marisa_alpha {
20
21class Trie {
22 public:
23  Trie();
24
25  void build(const char * const *keys, std::size_t num_keys,
26      const std::size_t *key_lengths = NULL,
27      const double *key_weights = NULL,
28      UInt32 *key_ids = NULL, int flags = 0);
29
30  void build(const std::vector<std::string> &keys,
31      std::vector<UInt32> *key_ids = NULL, int flags = 0);
32  void build(const std::vector<std::pair<std::string, double> > &keys,
33      std::vector<UInt32> *key_ids = NULL, int flags = 0);
34
35  void mmap(Mapper *mapper, const char *filename,
36      long offset = 0, int whence = SEEK_SET);
37  void map(const void *ptr, std::size_t size);
38  void map(Mapper &mapper);
39
40  void load(const char *filename,
41      long offset = 0, int whence = SEEK_SET);
42  void fread(std::FILE *file);
43  void read(int fd);
44  void read(std::istream &stream);
45  void read(Reader &reader);
46
47  void save(const char *filename, bool trunc_flag = true,
48      long offset = 0, int whence = SEEK_SET) const;
49  void fwrite(std::FILE *file) const;
50  void write(int fd) const;
51  void write(std::ostream &stream) const;
52  void write(Writer &writer) const;
53
54  std::string operator[](UInt32 key_id) const;
55
56  UInt32 operator[](const char *str) const;
57  UInt32 operator[](const std::string &str) const;
58
59  std::string restore(UInt32 key_id) const;
60  void restore(UInt32 key_id, std::string *key) const;
61  std::size_t restore(UInt32 key_id, char *key_buf,
62      std::size_t key_buf_size) const;
63
64  UInt32 lookup(const char *str) const;
65  UInt32 lookup(const char *ptr, std::size_t length) const;
66  UInt32 lookup(const std::string &str) const;
67
68  std::size_t find(const char *str,
69      UInt32 *key_ids, std::size_t *key_lengths,
70      std::size_t max_num_results) const;
71  std::size_t find(const char *ptr, std::size_t length,
72      UInt32 *key_ids, std::size_t *key_lengths,
73      std::size_t max_num_results) const;
74  std::size_t find(const std::string &str,
75      UInt32 *key_ids, std::size_t *key_lengths,
76      std::size_t max_num_results) const;
77
78  std::size_t find(const char *str,
79      std::vector<UInt32> *key_ids = NULL,
80      std::vector<std::size_t> *key_lengths = NULL,
81      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
82  std::size_t find(const char *ptr, std::size_t length,
83      std::vector<UInt32> *key_ids = NULL,
84      std::vector<std::size_t> *key_lengths = NULL,
85      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
86  std::size_t find(const std::string &str,
87      std::vector<UInt32> *key_ids = NULL,
88      std::vector<std::size_t> *key_lengths = NULL,
89      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
90
91  UInt32 find_first(const char *str,
92      std::size_t *key_length = NULL) const;
93  UInt32 find_first(const char *ptr, std::size_t length,
94      std::size_t *key_length = NULL) const;
95  UInt32 find_first(const std::string &str,
96      std::size_t *key_length = NULL) const;
97
98  UInt32 find_last(const char *str,
99      std::size_t *key_length = NULL) const;
100  UInt32 find_last(const char *ptr, std::size_t length,
101      std::size_t *key_length = NULL) const;
102  UInt32 find_last(const std::string &str,
103      std::size_t *key_length = NULL) const;
104
105  // bool callback(UInt32 key_id, std::size_t key_length);
106  template <typename T>
107  std::size_t find_callback(const char *str, T callback) const;
108  template <typename T>
109  std::size_t find_callback(const char *ptr, std::size_t length,
110      T callback) const;
111  template <typename T>
112  std::size_t find_callback(const std::string &str, T callback) const;
113
114  std::size_t predict(const char *str,
115      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
116  std::size_t predict(const char *ptr, std::size_t length,
117      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
118  std::size_t predict(const std::string &str,
119      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
120
121  std::size_t predict(const char *str,
122      std::vector<UInt32> *key_ids = NULL,
123      std::vector<std::string> *keys = NULL,
124      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
125  std::size_t predict(const char *ptr, std::size_t length,
126      std::vector<UInt32> *key_ids = NULL,
127      std::vector<std::string> *keys = NULL,
128      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
129  std::size_t predict(const std::string &str,
130      std::vector<UInt32> *key_ids = NULL,
131      std::vector<std::string> *keys = NULL,
132      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
133
134  std::size_t predict_breadth_first(const char *str,
135      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
136  std::size_t predict_breadth_first(const char *ptr, std::size_t length,
137      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
138  std::size_t predict_breadth_first(const std::string &str,
139      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
140
141  std::size_t predict_breadth_first(const char *str,
142      std::vector<UInt32> *key_ids = NULL,
143      std::vector<std::string> *keys = NULL,
144      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
145  std::size_t predict_breadth_first(const char *ptr, std::size_t length,
146      std::vector<UInt32> *key_ids = NULL,
147      std::vector<std::string> *keys = NULL,
148      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
149  std::size_t predict_breadth_first(const std::string &str,
150      std::vector<UInt32> *key_ids = NULL,
151      std::vector<std::string> *keys = NULL,
152      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
153
154  std::size_t predict_depth_first(const char *str,
155      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
156  std::size_t predict_depth_first(const char *ptr, std::size_t length,
157      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
158  std::size_t predict_depth_first(const std::string &str,
159      UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
160
161  std::size_t predict_depth_first(const char *str,
162      std::vector<UInt32> *key_ids = NULL,
163      std::vector<std::string> *keys = NULL,
164      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
165  std::size_t predict_depth_first(const char *ptr, std::size_t length,
166      std::vector<UInt32> *key_ids = NULL,
167      std::vector<std::string> *keys = NULL,
168      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
169  std::size_t predict_depth_first(const std::string &str,
170      std::vector<UInt32> *key_ids = NULL,
171      std::vector<std::string> *keys = NULL,
172      std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
173
174  // bool callback(UInt32 key_id, const std::string &key);
175  template <typename T>
176  std::size_t predict_callback(const char *str, T callback) const;
177  template <typename T>
178  std::size_t predict_callback(const char *ptr, std::size_t length,
179      T callback) const;
180  template <typename T>
181  std::size_t predict_callback(const std::string &str, T callback) const;
182
183  bool empty() const;
184  std::size_t num_tries() const;
185  std::size_t num_keys() const;
186  std::size_t num_nodes() const;
187  std::size_t total_size() const;
188
189  void clear();
190  void swap(Trie *rhs);
191
192  static UInt32 notfound();
193  static std::size_t mismatch();
194
195 private:
196  BitVector louds_;
197  Vector<UInt8> labels_;
198  BitVector terminal_flags_;
199  BitVector link_flags_;
200  IntVector links_;
201  std::auto_ptr<Trie> trie_;
202  Tail tail_;
203  UInt32 num_first_branches_;
204  UInt32 num_keys_;
205
206  void build_trie(Vector<Key<String> > &keys,
207      std::vector<UInt32> *key_ids, int flags);
208  void build_trie(Vector<Key<String> > &keys,
209      UInt32 *key_ids, int flags);
210
211  template <typename T>
212  void build_trie(Vector<Key<T> > &keys,
213      Vector<UInt32> *terminals, Progress &progress);
214
215  template <typename T>
216  void build_cur(Vector<Key<T> > &keys,
217      Vector<UInt32> *terminals, Progress &progress);
218
219  void build_next(Vector<Key<String> > &keys,
220      Vector<UInt32> *terminals, Progress &progress);
221  void build_next(Vector<Key<RString> > &rkeys,
222      Vector<UInt32> *terminals, Progress &progress);
223
224  template <typename T>
225  UInt32 sort_keys(Vector<Key<T> > &keys) const;
226
227  template <typename T>
228  void build_terminals(const Vector<Key<T> > &keys,
229      Vector<UInt32> *terminals) const;
230
231  void restore_(UInt32 key_id, std::string *key) const;
232  void trie_restore(UInt32 node, std::string *key) const;
233  void tail_restore(UInt32 node, std::string *key) const;
234
235  std::size_t restore_(UInt32 key_id, char *key_buf,
236      std::size_t key_buf_size) const;
237  void trie_restore(UInt32 node, char *key_buf,
238      std::size_t key_buf_size, std::size_t &key_pos) const;
239  void tail_restore(UInt32 node, char *key_buf,
240      std::size_t key_buf_size, std::size_t &key_pos) const;
241
242  template <typename T>
243  UInt32 lookup_(T query) const;
244  template <typename T>
245  bool find_child(UInt32 &node, T query, std::size_t &pos) const;
246  template <typename T>
247  std::size_t trie_match(UInt32 node, T query, std::size_t pos) const;
248  template <typename T>
249  std::size_t tail_match(UInt32 node, UInt32 link_id,
250      T query, std::size_t pos) const;
251
252  template <typename T, typename U, typename V>
253  std::size_t find_(T query, U key_ids, V key_lengths,
254      std::size_t max_num_results) const;
255  template <typename T>
256  UInt32 find_first_(T query, std::size_t *key_length) const;
257  template <typename T>
258  UInt32 find_last_(T query, std::size_t *key_length) const;
259  template <typename T, typename U>
260  std::size_t find_callback_(T query, U callback) const;
261
262  template <typename T, typename U, typename V>
263  std::size_t predict_breadth_first_(T query, U key_ids, V keys,
264      std::size_t max_num_results) const;
265  template <typename T, typename U, typename V>
266  std::size_t predict_depth_first_(T query, U key_ids, V keys,
267      std::size_t max_num_results) const;
268  template <typename T, typename U>
269  std::size_t predict_callback_(T query, U callback) const;
270
271  template <typename T>
272  bool predict_child(UInt32 &node, T query, std::size_t &pos,
273      std::string *key) const;
274  template <typename T>
275  std::size_t trie_prefix_match(UInt32 node, T query,
276      std::size_t pos, std::string *key) const;
277  template <typename T>
278  std::size_t tail_prefix_match(UInt32 node, UInt32 link_id,
279      T query, std::size_t pos, std::string *key) const;
280
281  UInt32 key_id_to_node(UInt32 key_id) const;
282  UInt32 node_to_key_id(UInt32 node) const;
283  UInt32 louds_pos_to_node(UInt32 louds_pos, UInt32 parent_node) const;
284
285  UInt32 get_child(UInt32 node) const;
286  UInt32 get_parent(UInt32 node) const;
287
288  bool has_link(UInt32 node) const;
289  UInt32 get_link_id(UInt32 node) const;
290  UInt32 get_link(UInt32 node) const;
291  UInt32 get_link(UInt32 node, UInt32 link_id) const;
292
293  bool has_link() const;
294  bool has_trie() const;
295  bool has_tail() const;
296
297  // Disallows copy and assignment.
298  Trie(const Trie &);
299  Trie &operator=(const Trie &);
300};
301
302}  // namespace marisa_alpha
303
304#include "trie-inline.h"
305
306#else  // __cplusplus
307
308#include <stdio.h>
309
310#endif  // __cplusplus
311
312#ifdef __cplusplus
313extern "C" {
314#endif  // __cplusplus
315
316typedef struct marisa_alpha_trie_ marisa_alpha_trie;
317
318marisa_alpha_status marisa_alpha_init(marisa_alpha_trie **h);
319marisa_alpha_status marisa_alpha_end(marisa_alpha_trie *h);
320
321marisa_alpha_status marisa_alpha_build(marisa_alpha_trie *h,
322    const char * const *keys, size_t num_keys, const size_t *key_lengths,
323    const double *key_weights, marisa_alpha_uint32 *key_ids, int flags);
324
325marisa_alpha_status marisa_alpha_mmap(marisa_alpha_trie *h,
326    const char *filename, long offset, int whence);
327marisa_alpha_status marisa_alpha_map(marisa_alpha_trie *h, const void *ptr,
328    size_t size);
329
330marisa_alpha_status marisa_alpha_load(marisa_alpha_trie *h,
331    const char *filename, long offset, int whence);
332marisa_alpha_status marisa_alpha_fread(marisa_alpha_trie *h, FILE *file);
333marisa_alpha_status marisa_alpha_read(marisa_alpha_trie *h, int fd);
334
335marisa_alpha_status marisa_alpha_save(const marisa_alpha_trie *h,
336    const char *filename, int trunc_flag, long offset, int whence);
337marisa_alpha_status marisa_alpha_fwrite(const marisa_alpha_trie *h,
338    FILE *file);
339marisa_alpha_status marisa_alpha_write(const marisa_alpha_trie *h, int fd);
340
341marisa_alpha_status marisa_alpha_restore(const marisa_alpha_trie *h,
342    marisa_alpha_uint32 key_id, char *key_buf, size_t key_buf_size,
343    size_t *key_length);
344
345marisa_alpha_status marisa_alpha_lookup(const marisa_alpha_trie *h,
346    const char *ptr, size_t length, marisa_alpha_uint32 *key_id);
347
348marisa_alpha_status marisa_alpha_find(const marisa_alpha_trie *h,
349    const char *ptr, size_t length,
350    marisa_alpha_uint32 *key_ids, size_t *key_lengths,
351    size_t max_num_results, size_t *num_results);
352marisa_alpha_status marisa_alpha_find_first(const marisa_alpha_trie *h,
353    const char *ptr, size_t length,
354    marisa_alpha_uint32 *key_id, size_t *key_length);
355marisa_alpha_status marisa_alpha_find_last(const marisa_alpha_trie *h,
356    const char *ptr, size_t length,
357    marisa_alpha_uint32 *key_id, size_t *key_length);
358marisa_alpha_status marisa_alpha_find_callback(const marisa_alpha_trie *h,
359    const char *ptr, size_t length,
360    int (*callback)(void *, marisa_alpha_uint32, size_t),
361    void *first_arg_to_callback);
362
363marisa_alpha_status marisa_alpha_predict(const marisa_alpha_trie *h,
364    const char *ptr, size_t length, marisa_alpha_uint32 *key_ids,
365    size_t max_num_results, size_t *num_results);
366marisa_alpha_status marisa_alpha_predict_breadth_first(
367    const marisa_alpha_trie *h, const char *ptr, size_t length,
368    marisa_alpha_uint32 *key_ids, size_t max_num_results, size_t *num_results);
369marisa_alpha_status marisa_alpha_predict_depth_first(
370    const marisa_alpha_trie *h, const char *ptr, size_t length,
371    marisa_alpha_uint32 *key_ids, size_t max_num_results, size_t *num_results);
372marisa_alpha_status marisa_alpha_predict_callback(const marisa_alpha_trie *h,
373    const char *ptr, size_t length,
374    int (*callback)(void *, marisa_alpha_uint32, const char *, size_t),
375    void *first_arg_to_callback);
376
377size_t marisa_alpha_get_num_tries(const marisa_alpha_trie *h);
378size_t marisa_alpha_get_num_keys(const marisa_alpha_trie *h);
379size_t marisa_alpha_get_num_nodes(const marisa_alpha_trie *h);
380size_t marisa_alpha_get_total_size(const marisa_alpha_trie *h);
381
382marisa_alpha_status marisa_alpha_clear(marisa_alpha_trie *h);
383
384#ifdef __cplusplus
385}  // extern "C"
386#endif  // __cplusplus
387
388#endif  // MARISA_ALPHA_TRIE_H_
389