1#include "trie.h"
2
3extern "C" {
4
5namespace {
6
7class FindCallback {
8 public:
9  typedef int (*Func)(void *, marisa_alpha_uint32, size_t);
10
11  FindCallback(Func func, void *first_arg)
12      : func_(func), first_arg_(first_arg) {}
13  FindCallback(const FindCallback &callback)
14      : func_(callback.func_), first_arg_(callback.first_arg_) {}
15
16  bool operator()(marisa_alpha::UInt32 key_id, std::size_t key_length) const {
17    return func_(first_arg_, key_id, key_length) != 0;
18  }
19
20 private:
21  Func func_;
22  void *first_arg_;
23
24  // Disallows assignment.
25  FindCallback &operator=(const FindCallback &);
26};
27
28class PredictCallback {
29 public:
30  typedef int (*Func)(void *, marisa_alpha_uint32, const char *, size_t);
31
32  PredictCallback(Func func, void *first_arg)
33      : func_(func), first_arg_(first_arg) {}
34  PredictCallback(const PredictCallback &callback)
35      : func_(callback.func_), first_arg_(callback.first_arg_) {}
36
37  bool operator()(marisa_alpha::UInt32 key_id, const std::string &key) const {
38    return func_(first_arg_, key_id, key.c_str(), key.length()) != 0;
39  }
40
41 private:
42  Func func_;
43  void *first_arg_;
44
45  // Disallows assignment.
46  PredictCallback &operator=(const PredictCallback &);
47};
48
49}  // namespace
50
51struct marisa_alpha_trie_ {
52 public:
53  marisa_alpha_trie_() : trie(), mapper() {}
54
55  marisa_alpha::Trie trie;
56  marisa_alpha::Mapper mapper;
57
58 private:
59  // Disallows copy and assignment.
60  marisa_alpha_trie_(const marisa_alpha_trie_ &);
61  marisa_alpha_trie_ &operator=(const marisa_alpha_trie_ &);
62};
63
64marisa_alpha_status marisa_alpha_init(marisa_alpha_trie **h) {
65  if ((h == NULL) || (*h != NULL)) {
66    return MARISA_ALPHA_HANDLE_ERROR;
67  }
68  *h = new (std::nothrow) marisa_alpha_trie_();
69  return (*h != NULL) ? MARISA_ALPHA_OK : MARISA_ALPHA_MEMORY_ERROR;
70}
71
72marisa_alpha_status marisa_alpha_end(marisa_alpha_trie *h) {
73  if (h == NULL) {
74    return MARISA_ALPHA_HANDLE_ERROR;
75  }
76  delete h;
77  return MARISA_ALPHA_OK;
78}
79
80marisa_alpha_status marisa_alpha_build(marisa_alpha_trie *h,
81    const char * const *keys, size_t num_keys, const size_t *key_lengths,
82    const double *key_weights, marisa_alpha_uint32 *key_ids, int flags) try {
83  if (h == NULL) {
84    return MARISA_ALPHA_HANDLE_ERROR;
85  }
86  h->trie.build(keys, num_keys, key_lengths, key_weights, key_ids, flags);
87  h->mapper.clear();
88  return MARISA_ALPHA_OK;
89} catch (const marisa_alpha::Exception &ex) {
90  return ex.status();
91}
92
93marisa_alpha_status marisa_alpha_mmap(marisa_alpha_trie *h,
94    const char *filename, long offset, int whence) try {
95  if (h == NULL) {
96    return MARISA_ALPHA_HANDLE_ERROR;
97  }
98  h->trie.mmap(&h->mapper, filename, offset, whence);
99  return MARISA_ALPHA_OK;
100} catch (const marisa_alpha::Exception &ex) {
101  return ex.status();
102}
103
104marisa_alpha_status marisa_alpha_map(marisa_alpha_trie *h, const void *ptr,
105    size_t size) try {
106  if (h == NULL) {
107    return MARISA_ALPHA_HANDLE_ERROR;
108  }
109  h->trie.map(ptr, size);
110  h->mapper.clear();
111  return MARISA_ALPHA_OK;
112} catch (const marisa_alpha::Exception &ex) {
113  return ex.status();
114}
115
116marisa_alpha_status marisa_alpha_load(marisa_alpha_trie *h,
117    const char *filename, long offset, int whence) try {
118  if (h == NULL) {
119    return MARISA_ALPHA_HANDLE_ERROR;
120  }
121  h->trie.load(filename, offset, whence);
122  h->mapper.clear();
123  return MARISA_ALPHA_OK;
124} catch (const marisa_alpha::Exception &ex) {
125  return ex.status();
126}
127
128marisa_alpha_status marisa_alpha_fread(marisa_alpha_trie *h, FILE *file) try {
129  if (h == NULL) {
130    return MARISA_ALPHA_HANDLE_ERROR;
131  }
132  h->trie.fread(file);
133  h->mapper.clear();
134  return MARISA_ALPHA_OK;
135} catch (const marisa_alpha::Exception &ex) {
136  return ex.status();
137}
138
139marisa_alpha_status marisa_alpha_read(marisa_alpha_trie *h, int fd) try {
140  if (h == NULL) {
141    return MARISA_ALPHA_HANDLE_ERROR;
142  }
143  h->trie.read(fd);
144  h->mapper.clear();
145  return MARISA_ALPHA_OK;
146} catch (const marisa_alpha::Exception &ex) {
147  return ex.status();
148}
149
150marisa_alpha_status marisa_alpha_save(const marisa_alpha_trie *h,
151    const char *filename, int trunc_flag, long offset, int whence) try {
152  if (h == NULL) {
153    return MARISA_ALPHA_HANDLE_ERROR;
154  }
155  h->trie.save(filename, trunc_flag != 0, offset, whence);
156  return MARISA_ALPHA_OK;
157} catch (const marisa_alpha::Exception &ex) {
158  return ex.status();
159}
160
161marisa_alpha_status marisa_alpha_fwrite(const marisa_alpha_trie *h,
162    FILE *file) try {
163  if (h == NULL) {
164    return MARISA_ALPHA_HANDLE_ERROR;
165  }
166  h->trie.fwrite(file);
167  return MARISA_ALPHA_OK;
168} catch (const marisa_alpha::Exception &ex) {
169  return ex.status();
170}
171
172marisa_alpha_status marisa_alpha_write(const marisa_alpha_trie *h, int fd) try {
173  if (h == NULL) {
174    return MARISA_ALPHA_HANDLE_ERROR;
175  }
176  h->trie.write(fd);
177  return MARISA_ALPHA_OK;
178} catch (const marisa_alpha::Exception &ex) {
179  return ex.status();
180}
181
182marisa_alpha_status marisa_alpha_restore(const marisa_alpha_trie *h,
183    marisa_alpha_uint32 key_id, char *key_buf, size_t key_buf_size,
184    size_t *key_length) try {
185  if (h == NULL) {
186    return MARISA_ALPHA_HANDLE_ERROR;
187  } else if (key_length == NULL) {
188    return MARISA_ALPHA_PARAM_ERROR;
189  }
190  *key_length = h->trie.restore(key_id, key_buf, key_buf_size);
191  return MARISA_ALPHA_OK;
192} catch (const marisa_alpha::Exception &ex) {
193  return ex.status();
194}
195
196marisa_alpha_status marisa_alpha_lookup(const marisa_alpha_trie *h,
197    const char *ptr, size_t length, marisa_alpha_uint32 *key_id) try {
198  if (h == NULL) {
199    return MARISA_ALPHA_HANDLE_ERROR;
200  } else if (key_id == NULL) {
201    return MARISA_ALPHA_PARAM_ERROR;
202  }
203  if (length == MARISA_ALPHA_ZERO_TERMINATED) {
204    *key_id = h->trie.lookup(ptr);
205  } else {
206    *key_id = h->trie.lookup(ptr, length);
207  }
208  return MARISA_ALPHA_OK;
209} catch (const marisa_alpha::Exception &ex) {
210  return ex.status();
211}
212
213marisa_alpha_status marisa_alpha_find(const marisa_alpha_trie *h,
214    const char *ptr, size_t length,
215    marisa_alpha_uint32 *key_ids, size_t *key_lengths,
216    size_t max_num_results, size_t *num_results) try {
217  if (h == NULL) {
218    return MARISA_ALPHA_HANDLE_ERROR;
219  } else if (num_results == NULL) {
220    return MARISA_ALPHA_PARAM_ERROR;
221  }
222  if (length == MARISA_ALPHA_ZERO_TERMINATED) {
223    *num_results = h->trie.find(ptr, key_ids, key_lengths, max_num_results);
224  } else {
225    *num_results = h->trie.find(ptr, length,
226        key_ids, key_lengths, max_num_results);
227  }
228  return MARISA_ALPHA_OK;
229} catch (const marisa_alpha::Exception &ex) {
230  return ex.status();
231}
232
233marisa_alpha_status marisa_alpha_find_first(const marisa_alpha_trie *h,
234    const char *ptr, size_t length,
235    marisa_alpha_uint32 *key_id, size_t *key_length) {
236  if (h == NULL) {
237    return MARISA_ALPHA_HANDLE_ERROR;
238  } else if (key_id == NULL) {
239    return MARISA_ALPHA_PARAM_ERROR;
240  }
241  if (length == MARISA_ALPHA_ZERO_TERMINATED) {
242    *key_id = h->trie.find_first(ptr, key_length);
243  } else {
244    *key_id = h->trie.find_first(ptr, length, key_length);
245  }
246  return MARISA_ALPHA_OK;
247}
248
249marisa_alpha_status marisa_alpha_find_last(const marisa_alpha_trie *h,
250    const char *ptr, size_t length,
251    marisa_alpha_uint32 *key_id, size_t *key_length) {
252  if (h == NULL) {
253    return MARISA_ALPHA_HANDLE_ERROR;
254  } else if (key_id == NULL) {
255    return MARISA_ALPHA_PARAM_ERROR;
256  }
257  if (length == MARISA_ALPHA_ZERO_TERMINATED) {
258    *key_id = h->trie.find_last(ptr, key_length);
259  } else {
260    *key_id = h->trie.find_last(ptr, length, key_length);
261  }
262  return MARISA_ALPHA_OK;
263}
264
265marisa_alpha_status marisa_alpha_find_callback(const marisa_alpha_trie *h,
266    const char *ptr, size_t length,
267    int (*callback)(void *, marisa_alpha_uint32, size_t),
268    void *first_arg_to_callback) try {
269  if (h == NULL) {
270    return MARISA_ALPHA_HANDLE_ERROR;
271  } else if (callback == NULL) {
272    return MARISA_ALPHA_PARAM_ERROR;
273  }
274  if (length == MARISA_ALPHA_ZERO_TERMINATED) {
275    h->trie.find_callback(ptr,
276        ::FindCallback(callback, first_arg_to_callback));
277  } else {
278    h->trie.find_callback(ptr, length,
279        ::FindCallback(callback, first_arg_to_callback));
280  }
281  return MARISA_ALPHA_OK;
282} catch (const marisa_alpha::Exception &ex) {
283  return ex.status();
284}
285
286marisa_alpha_status marisa_alpha_predict(const marisa_alpha_trie *h,
287    const char *ptr, size_t length, marisa_alpha_uint32 *key_ids,
288    size_t max_num_results, size_t *num_results) {
289  return marisa_alpha_predict_breadth_first(h, ptr, length,
290      key_ids, max_num_results, num_results);
291}
292
293marisa_alpha_status marisa_alpha_predict_breadth_first(
294    const marisa_alpha_trie *h, const char *ptr, size_t length,
295    marisa_alpha_uint32 *key_ids, size_t max_num_results,
296    size_t *num_results) try {
297  if (h == NULL) {
298    return MARISA_ALPHA_HANDLE_ERROR;
299  } else if (num_results == NULL) {
300    return MARISA_ALPHA_PARAM_ERROR;
301  }
302  if (length == MARISA_ALPHA_ZERO_TERMINATED) {
303    *num_results = h->trie.predict_breadth_first(
304        ptr, key_ids, NULL, max_num_results);
305  } else {
306    *num_results = h->trie.predict_breadth_first(
307        ptr, length, key_ids, NULL, max_num_results);
308  }
309  return MARISA_ALPHA_OK;
310} catch (const marisa_alpha::Exception &ex) {
311  return ex.status();
312}
313
314marisa_alpha_status marisa_alpha_predict_depth_first(
315    const marisa_alpha_trie *h, const char *ptr, size_t length,
316    marisa_alpha_uint32 *key_ids, size_t max_num_results,
317    size_t *num_results) try {
318  if (h == NULL) {
319    return MARISA_ALPHA_HANDLE_ERROR;
320  } else if (num_results == NULL) {
321    return MARISA_ALPHA_PARAM_ERROR;
322  }
323  if (length == MARISA_ALPHA_ZERO_TERMINATED) {
324    *num_results = h->trie.predict_depth_first(
325        ptr, key_ids, NULL, max_num_results);
326  } else {
327    *num_results = h->trie.predict_depth_first(
328        ptr, length, key_ids, NULL, max_num_results);
329  }
330  return MARISA_ALPHA_OK;
331} catch (const marisa_alpha::Exception &ex) {
332  return ex.status();
333}
334
335marisa_alpha_status marisa_alpha_predict_callback(const marisa_alpha_trie *h,
336    const char *ptr, size_t length,
337    int (*callback)(void *, marisa_alpha_uint32, const char *, size_t),
338    void *first_arg_to_callback) try {
339  if (h == NULL) {
340    return MARISA_ALPHA_HANDLE_ERROR;
341  } else if (callback == NULL) {
342    return MARISA_ALPHA_PARAM_ERROR;
343  }
344  if (length == MARISA_ALPHA_ZERO_TERMINATED) {
345    h->trie.predict_callback(ptr,
346        ::PredictCallback(callback, first_arg_to_callback));
347  } else {
348    h->trie.predict_callback(ptr, length,
349        ::PredictCallback(callback, first_arg_to_callback));
350  }
351  return MARISA_ALPHA_OK;
352} catch (const marisa_alpha::Exception &ex) {
353  return ex.status();
354}
355
356size_t marisa_alpha_get_num_tries(const marisa_alpha_trie *h) {
357  return (h != NULL) ? h->trie.num_tries() : 0;
358}
359
360size_t marisa_alpha_get_num_keys(const marisa_alpha_trie *h) {
361  return (h != NULL) ? h->trie.num_keys() : 0;
362}
363
364size_t marisa_alpha_get_num_nodes(const marisa_alpha_trie *h) {
365  return (h != NULL) ? h->trie.num_nodes() : 0;
366}
367
368size_t marisa_alpha_get_total_size(const marisa_alpha_trie *h) {
369  return (h != NULL) ? h->trie.total_size() : 0;
370}
371
372marisa_alpha_status marisa_alpha_clear(marisa_alpha_trie *h) {
373  if (h == NULL) {
374    return MARISA_ALPHA_HANDLE_ERROR;
375  }
376  h->trie.clear();
377  h->mapper.clear();
378  return MARISA_ALPHA_OK;
379}
380
381}  // extern "C"
382