1#include "trie.h"
2
3extern "C" {
4
5namespace {
6
7class FindCallback {
8 public:
9  typedef int (*Func)(void *, marisa_uint32, size_t);
10
11  FindCallback(Func func, void *first_arg)
12      : func_(func), first_arg_(first_arg) {}
13  FindCallback(const FindCallback &callback)
14      : func_(callback.func_), first_arg_(callback.first_arg_) {}
15
16  bool operator()(marisa::UInt32 key_id, std::size_t key_length) const {
17    return func_(first_arg_, key_id, key_length) != 0;
18  }
19
20 private:
21  Func func_;
22  void *first_arg_;
23
24  // Disallows assignment.
25  FindCallback &operator=(const FindCallback &);
26};
27
28class PredictCallback {
29 public:
30  typedef int (*Func)(void *, marisa_uint32, const char *, size_t);
31
32  PredictCallback(Func func, void *first_arg)
33      : func_(func), first_arg_(first_arg) {}
34  PredictCallback(const PredictCallback &callback)
35      : func_(callback.func_), first_arg_(callback.first_arg_) {}
36
37  bool operator()(marisa::UInt32 key_id, const std::string &key) const {
38    return func_(first_arg_, key_id, key.c_str(), key.length()) != 0;
39  }
40
41 private:
42  Func func_;
43  void *first_arg_;
44
45  // Disallows assignment.
46  PredictCallback &operator=(const PredictCallback &);
47};
48
49}  // namespace
50
51struct marisa_trie_ {
52 public:
53  marisa_trie_() : trie(), mapper() {}
54
55  marisa::Trie trie;
56  marisa::Mapper mapper;
57
58 private:
59  // Disallows copy and assignment.
60  marisa_trie_(const marisa_trie_ &);
61  marisa_trie_ &operator=(const marisa_trie_ &);
62};
63
64marisa_status marisa_init(marisa_trie **h) {
65  if ((h == NULL) || (*h != NULL)) {
66    return MARISA_HANDLE_ERROR;
67  }
68  *h = new (std::nothrow) marisa_trie_();
69  return (*h != NULL) ? MARISA_OK : MARISA_MEMORY_ERROR;
70}
71
72marisa_status marisa_end(marisa_trie *h) {
73  if (h == NULL) {
74    return MARISA_HANDLE_ERROR;
75  }
76  delete h;
77  return MARISA_OK;
78}
79
80marisa_status marisa_build(marisa_trie *h, const char * const *keys,
81    size_t num_keys, const size_t *key_lengths, const double *key_weights,
82    marisa_uint32 *key_ids, int flags) {
83  if (h == NULL) {
84    return MARISA_HANDLE_ERROR;
85  }
86  h->trie.build(keys, num_keys, key_lengths, key_weights, key_ids, flags);
87  h->mapper.clear();
88  return MARISA_OK;
89}
90
91marisa_status marisa_mmap(marisa_trie *h, const char *filename,
92    long offset, int whence) {
93  if (h == NULL) {
94    return MARISA_HANDLE_ERROR;
95  }
96  h->trie.mmap(&h->mapper, filename, offset, whence);
97  return MARISA_OK;
98}
99
100marisa_status marisa_map(marisa_trie *h, const void *ptr, size_t size) {
101  if (h == NULL) {
102    return MARISA_HANDLE_ERROR;
103  }
104  h->trie.map(ptr, size);
105  h->mapper.clear();
106  return MARISA_OK;
107}
108
109marisa_status marisa_load(marisa_trie *h, const char *filename,
110    long offset, int whence) {
111  if (h == NULL) {
112    return MARISA_HANDLE_ERROR;
113  }
114  h->trie.load(filename, offset, whence);
115  h->mapper.clear();
116  return MARISA_OK;
117}
118
119marisa_status marisa_fread(marisa_trie *h, FILE *file) {
120  if (h == NULL) {
121    return MARISA_HANDLE_ERROR;
122  }
123  h->trie.fread(file);
124  h->mapper.clear();
125  return MARISA_OK;
126}
127
128marisa_status marisa_read(marisa_trie *h, int fd) {
129  if (h == NULL) {
130    return MARISA_HANDLE_ERROR;
131  }
132  h->trie.read(fd);
133  h->mapper.clear();
134  return MARISA_OK;
135}
136
137marisa_status marisa_save(const marisa_trie *h, const char *filename,
138    int trunc_flag, long offset, int whence) {
139  if (h == NULL) {
140    return MARISA_HANDLE_ERROR;
141  }
142  h->trie.save(filename, trunc_flag != 0, offset, whence);
143  return MARISA_OK;
144}
145
146marisa_status marisa_fwrite(const marisa_trie *h, FILE *file) {
147  if (h == NULL) {
148    return MARISA_HANDLE_ERROR;
149  }
150  h->trie.fwrite(file);
151  return MARISA_OK;
152}
153
154marisa_status marisa_write(const marisa_trie *h, int fd) {
155  if (h == NULL) {
156    return MARISA_HANDLE_ERROR;
157  }
158  h->trie.write(fd);
159  return MARISA_OK;
160}
161
162marisa_status marisa_restore(const marisa_trie *h, marisa_uint32 key_id,
163    char *key_buf, size_t key_buf_size, size_t *key_length) {
164  if (h == NULL) {
165    return MARISA_HANDLE_ERROR;
166  } else if (key_length == NULL) {
167    return MARISA_PARAM_ERROR;
168  }
169  *key_length = h->trie.restore(key_id, key_buf, key_buf_size);
170  return MARISA_OK;
171}
172
173marisa_status marisa_lookup(const marisa_trie *h,
174    const char *ptr, size_t length, marisa_uint32 *key_id) {
175  if (h == NULL) {
176    return MARISA_HANDLE_ERROR;
177  } else if (key_id == NULL) {
178    return MARISA_PARAM_ERROR;
179  }
180  if (length == MARISA_ZERO_TERMINATED) {
181    *key_id = h->trie.lookup(ptr);
182  } else {
183    *key_id = h->trie.lookup(ptr, length);
184  }
185  return MARISA_OK;
186}
187
188marisa_status marisa_find(const marisa_trie *h,
189    const char *ptr, size_t length,
190    marisa_uint32 *key_ids, size_t *key_lengths,
191    size_t max_num_results, size_t *num_results) {
192  if (h == NULL) {
193    return MARISA_HANDLE_ERROR;
194  } else if (num_results == NULL) {
195    return MARISA_PARAM_ERROR;
196  }
197  if (length == MARISA_ZERO_TERMINATED) {
198    *num_results = h->trie.find(ptr, key_ids, key_lengths, max_num_results);
199  } else {
200    *num_results = h->trie.find(ptr, length,
201        key_ids, key_lengths, max_num_results);
202  }
203  return MARISA_OK;
204}
205
206marisa_status marisa_find_first(const marisa_trie *h,
207    const char *ptr, size_t length,
208    marisa_uint32 *key_id, size_t *key_length) {
209  if (h == NULL) {
210    return MARISA_HANDLE_ERROR;
211  } else if (key_id == NULL) {
212    return MARISA_PARAM_ERROR;
213  }
214  if (length == MARISA_ZERO_TERMINATED) {
215    *key_id = h->trie.find_first(ptr, key_length);
216  } else {
217    *key_id = h->trie.find_first(ptr, length, key_length);
218  }
219  return MARISA_OK;
220}
221
222marisa_status marisa_find_last(const marisa_trie *h,
223    const char *ptr, size_t length,
224    marisa_uint32 *key_id, size_t *key_length) {
225  if (h == NULL) {
226    return MARISA_HANDLE_ERROR;
227  } else if (key_id == NULL) {
228    return MARISA_PARAM_ERROR;
229  }
230  if (length == MARISA_ZERO_TERMINATED) {
231    *key_id = h->trie.find_last(ptr, key_length);
232  } else {
233    *key_id = h->trie.find_last(ptr, length, key_length);
234  }
235  return MARISA_OK;
236}
237
238marisa_status marisa_find_callback(const marisa_trie *h,
239    const char *ptr, size_t length,
240    int (*callback)(void *, marisa_uint32, size_t),
241    void *first_arg_to_callback) {
242  if (h == NULL) {
243    return MARISA_HANDLE_ERROR;
244  } else if (callback == NULL) {
245    return MARISA_PARAM_ERROR;
246  }
247  if (length == MARISA_ZERO_TERMINATED) {
248    h->trie.find_callback(ptr,
249        ::FindCallback(callback, first_arg_to_callback));
250  } else {
251    h->trie.find_callback(ptr, length,
252        ::FindCallback(callback, first_arg_to_callback));
253  }
254  return MARISA_OK;
255}
256
257marisa_status marisa_predict(const marisa_trie *h,
258    const char *ptr, size_t length, marisa_uint32 *key_ids,
259    size_t max_num_results, size_t *num_results) {
260  return marisa_predict_breadth_first(h, ptr, length,
261      key_ids, max_num_results, num_results);
262}
263
264marisa_status marisa_predict_breadth_first(const marisa_trie *h,
265    const char *ptr, size_t length, marisa_uint32 *key_ids,
266    size_t max_num_results, size_t *num_results) {
267  if (h == NULL) {
268    return MARISA_HANDLE_ERROR;
269  } else if (num_results == NULL) {
270    return MARISA_PARAM_ERROR;
271  }
272  if (length == MARISA_ZERO_TERMINATED) {
273    *num_results = h->trie.predict_breadth_first(
274        ptr, key_ids, NULL, max_num_results);
275  } else {
276    *num_results = h->trie.predict_breadth_first(
277        ptr, length, key_ids, NULL, max_num_results);
278  }
279  return MARISA_OK;
280}
281
282marisa_status marisa_predict_depth_first(const marisa_trie *h,
283    const char *ptr, size_t length, marisa_uint32 *key_ids,
284    size_t max_num_results, size_t *num_results) {
285  if (h == NULL) {
286    return MARISA_HANDLE_ERROR;
287  } else if (num_results == NULL) {
288    return MARISA_PARAM_ERROR;
289  }
290  if (length == MARISA_ZERO_TERMINATED) {
291    *num_results = h->trie.predict_depth_first(
292        ptr, key_ids, NULL, max_num_results);
293  } else {
294    *num_results = h->trie.predict_depth_first(
295        ptr, length, key_ids, NULL, max_num_results);
296  }
297  return MARISA_OK;
298}
299
300marisa_status marisa_predict_callback(const marisa_trie *h,
301    const char *ptr, size_t length,
302    int (*callback)(void *, marisa_uint32, const char *, size_t),
303    void *first_arg_to_callback) {
304  if (h == NULL) {
305    return MARISA_HANDLE_ERROR;
306  } else if (callback == NULL) {
307    return MARISA_PARAM_ERROR;
308  }
309  if (length == MARISA_ZERO_TERMINATED) {
310    h->trie.predict_callback(ptr,
311        ::PredictCallback(callback, first_arg_to_callback));
312  } else {
313    h->trie.predict_callback(ptr, length,
314        ::PredictCallback(callback, first_arg_to_callback));
315  }
316  return MARISA_OK;
317}
318
319size_t marisa_get_num_tries(const marisa_trie *h) {
320  return (h != NULL) ? h->trie.num_tries() : 0;
321}
322
323size_t marisa_get_num_keys(const marisa_trie *h) {
324  return (h != NULL) ? h->trie.num_keys() : 0;
325}
326
327size_t marisa_get_num_nodes(const marisa_trie *h) {
328  return (h != NULL) ? h->trie.num_nodes() : 0;
329}
330
331size_t marisa_get_total_size(const marisa_trie *h) {
332  return (h != NULL) ? h->trie.total_size() : 0;
333}
334
335marisa_status marisa_clear(marisa_trie *h) {
336  if (h == NULL) {
337    return MARISA_HANDLE_ERROR;
338  }
339  h->trie.clear();
340  h->mapper.clear();
341  return MARISA_OK;
342}
343
344}  // extern "C"
345