dicttrie.cpp revision 4248fb0083e2e0f2f9b379ea5ce898036b900218
1/*
2 * Copyright (C) 2009 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <assert.h>
18#include <stdio.h>
19#include <string.h>
20#include "../include/dicttrie.h"
21#include "../include/dictbuilder.h"
22#include "../include/lpicache.h"
23#include "../include/mystdlib.h"
24#include "../include/ngram.h"
25
26namespace ime_pinyin {
27
28DictTrie::DictTrie() {
29  spl_trie_ = SpellingTrie::get_cpinstance();
30
31  root_ = NULL;
32  splid_le0_index_ = NULL;
33  lma_node_num_le0_ = 0;
34  nodes_ge1_ = NULL;
35  lma_node_num_ge1_ = 0;
36  lma_idx_buf_ = NULL;
37  lma_idx_buf_len_ = 0;
38  total_lma_num_ = 0;
39  top_lmas_num_ = 0;
40  dict_list_ = NULL;
41
42  parsing_marks_ = NULL;
43  mile_stones_ = NULL;
44  reset_milestones(0, kFirstValidMileStoneHandle);
45}
46
47DictTrie::~DictTrie() {
48  free_resource(true);
49}
50
51void DictTrie::free_resource(bool free_dict_list) {
52  if (NULL != root_)
53    free(root_);
54  root_ = NULL;
55
56  if (NULL != splid_le0_index_)
57    free(splid_le0_index_);
58  splid_le0_index_ = NULL;
59
60  if (NULL != nodes_ge1_)
61    free(nodes_ge1_);
62  nodes_ge1_ = NULL;
63
64  if (NULL != nodes_ge1_)
65    free(nodes_ge1_);
66  nodes_ge1_ = NULL;
67
68  if (free_dict_list) {
69    if (NULL != dict_list_) {
70      delete dict_list_;
71    }
72    dict_list_ = NULL;
73  }
74
75  if (parsing_marks_)
76    delete [] parsing_marks_;
77  parsing_marks_ = NULL;
78
79  if (mile_stones_)
80    delete [] mile_stones_;
81  mile_stones_ = NULL;
82
83  reset_milestones(0, kFirstValidMileStoneHandle);
84}
85
86inline size_t DictTrie::get_son_offset(const LmaNodeGE1 *node) {
87  return ((size_t)node->son_1st_off_l + ((size_t)node->son_1st_off_h << 16));
88}
89
90inline size_t DictTrie::get_homo_idx_buf_offset(const LmaNodeGE1 *node) {
91  return ((size_t)node->homo_idx_buf_off_l +
92          ((size_t)node->homo_idx_buf_off_h << 16));
93}
94
95inline LemmaIdType DictTrie::get_lemma_id(size_t id_offset) {
96  LemmaIdType id = 0;
97  for (uint16 pos = kLemmaIdSize - 1; pos > 0; pos--)
98    id = (id << 8) + lma_idx_buf_[id_offset * kLemmaIdSize + pos];
99  id = (id << 8) + lma_idx_buf_[id_offset * kLemmaIdSize];
100  return id;
101}
102
103#ifdef ___BUILD_MODEL___
104bool DictTrie::build_dict(const char* fn_raw, const char* fn_validhzs) {
105  DictBuilder* dict_builder = new DictBuilder();
106
107  free_resource(true);
108
109  return dict_builder->build_dict(fn_raw, fn_validhzs, this);
110}
111
112bool DictTrie::save_dict(FILE *fp) {
113  if (NULL == fp)
114    return false;
115
116  if (fwrite(&lma_node_num_le0_, sizeof(size_t), 1, fp) != 1)
117    return false;
118
119  if (fwrite(&lma_node_num_ge1_, sizeof(size_t), 1, fp) != 1)
120    return false;
121
122  if (fwrite(&lma_idx_buf_len_, sizeof(size_t), 1, fp) != 1)
123    return false;
124
125  if (fwrite(&top_lmas_num_, sizeof(size_t), 1, fp) != 1)
126    return false;
127
128  if (fwrite(root_, sizeof(LmaNodeLE0), lma_node_num_le0_, fp)
129      != lma_node_num_le0_)
130    return false;
131
132  if (fwrite(nodes_ge1_, sizeof(LmaNodeGE1), lma_node_num_ge1_, fp)
133      != lma_node_num_ge1_)
134    return false;
135
136  if (fwrite(lma_idx_buf_, sizeof(unsigned char), lma_idx_buf_len_, fp) !=
137      lma_idx_buf_len_)
138    return false;
139
140  return true;
141}
142
143bool DictTrie::save_dict(const char *filename) {
144  if (NULL == filename)
145    return false;
146
147  if (NULL == root_ || NULL == dict_list_)
148    return false;
149
150  SpellingTrie &spl_trie = SpellingTrie::get_instance();
151  NGram &ngram = NGram::get_instance();
152
153  FILE *fp = fopen(filename, "wb");
154  if (NULL == fp)
155    return false;
156
157  if (!spl_trie.save_spl_trie(fp) || !dict_list_->save_list(fp) ||
158      !save_dict(fp) || !ngram.save_ngram(fp)) {
159    fclose(fp);
160    return false;
161  }
162
163  fclose(fp);
164  return true;
165}
166#endif  // ___BUILD_MODEL___
167
168bool DictTrie::load_dict(FILE *fp) {
169  if (NULL == fp)
170    return false;
171
172  if (fread(&lma_node_num_le0_, sizeof(size_t), 1, fp) != 1)
173    return false;
174
175  if (fread(&lma_node_num_ge1_, sizeof(size_t), 1, fp) != 1)
176    return false;
177
178  if (fread(&lma_idx_buf_len_, sizeof(size_t), 1, fp) != 1)
179    return false;
180
181  if (fread(&top_lmas_num_, sizeof(size_t), 1, fp) != 1 ||
182      top_lmas_num_ >= lma_idx_buf_len_)
183    return false;
184
185  free_resource(false);
186
187  root_ = static_cast<LmaNodeLE0*>
188          (malloc(lma_node_num_le0_ * sizeof(LmaNodeLE0)));
189  nodes_ge1_ = static_cast<LmaNodeGE1*>
190               (malloc(lma_node_num_ge1_ * sizeof(LmaNodeGE1)));
191  lma_idx_buf_ = (unsigned char*)malloc(lma_idx_buf_len_);
192  total_lma_num_ = lma_idx_buf_len_ / kLemmaIdSize;
193
194  size_t buf_size = SpellingTrie::get_instance().get_spelling_num() + 1;
195  assert(lma_node_num_le0_ <= buf_size);
196  splid_le0_index_ = static_cast<uint16*>(malloc(buf_size * sizeof(uint16)));
197
198  // Init the space for parsing.
199  parsing_marks_ = new ParsingMark[kMaxParsingMark];
200  mile_stones_ = new MileStone[kMaxMileStone];
201  reset_milestones(0, kFirstValidMileStoneHandle);
202
203  if (NULL == root_ || NULL == nodes_ge1_ || NULL == lma_idx_buf_ ||
204      NULL == splid_le0_index_ || NULL == parsing_marks_ ||
205      NULL == mile_stones_) {
206    free_resource(false);
207    return false;
208  }
209
210  if (fread(root_, sizeof(LmaNodeLE0), lma_node_num_le0_, fp)
211      != lma_node_num_le0_)
212    return false;
213
214  if (fread(nodes_ge1_, sizeof(LmaNodeGE1), lma_node_num_ge1_, fp)
215      != lma_node_num_ge1_)
216    return false;
217
218  if (fread(lma_idx_buf_, sizeof(unsigned char), lma_idx_buf_len_, fp) !=
219      lma_idx_buf_len_)
220    return false;
221
222  // The quick index for the first level sons
223  uint16 last_splid = kFullSplIdStart;
224  size_t last_pos = 0;
225  for (size_t i = 1; i < lma_node_num_le0_; i++) {
226    for (uint16 splid = last_splid; splid < root_[i].spl_idx; splid++)
227      splid_le0_index_[splid - kFullSplIdStart] = last_pos;
228
229    splid_le0_index_[root_[i].spl_idx - kFullSplIdStart] =
230        static_cast<uint16>(i);
231    last_splid = root_[i].spl_idx;
232    last_pos = i;
233  }
234
235  for (uint16 splid = last_splid + 1;
236       splid < buf_size + kFullSplIdStart; splid++) {
237    assert(static_cast<size_t>(splid - kFullSplIdStart) < buf_size);
238    splid_le0_index_[splid - kFullSplIdStart] = last_pos + 1;
239  }
240
241  return true;
242}
243
244bool DictTrie::load_dict(const char *filename, LemmaIdType start_id,
245                         LemmaIdType end_id) {
246  if (NULL == filename || end_id <= start_id)
247    return false;
248
249  FILE *fp = fopen(filename, "rb");
250  if (NULL == fp)
251    return false;
252
253  free_resource(true);
254
255  dict_list_ = new DictList();
256  if (NULL == dict_list_) {
257    fclose(fp);
258    return false;
259  }
260
261  SpellingTrie &spl_trie = SpellingTrie::get_instance();
262  NGram &ngram = NGram::get_instance();
263
264  if (!spl_trie.load_spl_trie(fp) || !dict_list_->load_list(fp) ||
265      !load_dict(fp) || !ngram.load_ngram(fp) ||
266      total_lma_num_ > end_id - start_id + 1) {
267    free_resource(true);
268    fclose(fp);
269    return false;
270  }
271
272  fclose(fp);
273  return true;
274}
275
276bool DictTrie::load_dict_fd(int sys_fd, long start_offset,
277                            long length, LemmaIdType start_id,
278                            LemmaIdType end_id) {
279  if (start_offset < 0 || length <= 0 || end_id <= start_id)
280    return false;
281
282  FILE *fp = fdopen(sys_fd, "rb");
283  if (NULL == fp)
284    return false;
285
286  if (-1 == fseek(fp, start_offset, SEEK_SET)) {
287    fclose(fp);
288    return false;
289  }
290
291  free_resource(true);
292
293  dict_list_ = new DictList();
294  if (NULL == dict_list_) {
295    fclose(fp);
296    return false;
297  }
298
299  SpellingTrie &spl_trie = SpellingTrie::get_instance();
300  NGram &ngram = NGram::get_instance();
301
302  if (!spl_trie.load_spl_trie(fp) || !dict_list_->load_list(fp) ||
303      !load_dict(fp) || !ngram.load_ngram(fp) ||
304      ftell(fp) < start_offset + length ||
305      total_lma_num_ > end_id - start_id + 1) {
306    free_resource(true);
307    fclose(fp);
308    return false;
309  }
310
311  fclose(fp);
312  return true;
313}
314
315size_t DictTrie::fill_lpi_buffer(LmaPsbItem lpi_items[], size_t lpi_max,
316                                 LmaNodeLE0 *node) {
317  size_t lpi_num = 0;
318  NGram& ngram = NGram::get_instance();
319  for (size_t homo = 0; homo < (size_t)node->num_of_homo; homo++) {
320    lpi_items[lpi_num].id = get_lemma_id(node->homo_idx_buf_off +
321                                         homo);
322    lpi_items[lpi_num].lma_len = 1;
323    lpi_items[lpi_num].psb =
324        static_cast<LmaScoreType>(ngram.get_uni_psb(lpi_items[lpi_num].id));
325    lpi_num++;
326    if (lpi_num >= lpi_max)
327      break;
328  }
329
330  return lpi_num;
331}
332
333size_t DictTrie::fill_lpi_buffer(LmaPsbItem lpi_items[], size_t lpi_max,
334                                 size_t homo_buf_off, LmaNodeGE1 *node,
335                                 uint16 lma_len) {
336  size_t lpi_num = 0;
337  NGram& ngram = NGram::get_instance();
338  for (size_t homo = 0; homo < (size_t)node->num_of_homo; homo++) {
339    lpi_items[lpi_num].id = get_lemma_id(homo_buf_off + homo);
340    lpi_items[lpi_num].lma_len = lma_len;
341    lpi_items[lpi_num].psb =
342        static_cast<LmaScoreType>(ngram.get_uni_psb(lpi_items[lpi_num].id));
343    lpi_num++;
344    if (lpi_num >= lpi_max)
345      break;
346  }
347
348  return lpi_num;
349}
350
351void DictTrie::reset_milestones(uint16 from_step, MileStoneHandle from_handle) {
352  if (0 == from_step) {
353    parsing_marks_pos_ = 0;
354    mile_stones_pos_ = kFirstValidMileStoneHandle;
355  } else {
356    if (from_handle > 0 && from_handle < mile_stones_pos_) {
357      mile_stones_pos_ = from_handle;
358
359      MileStone *mile_stone = mile_stones_ + from_handle;
360      parsing_marks_pos_ = mile_stone->mark_start;
361    }
362  }
363}
364
365MileStoneHandle DictTrie::extend_dict(MileStoneHandle from_handle,
366                                      const DictExtPara *dep,
367                                      LmaPsbItem *lpi_items, size_t lpi_max,
368                                      size_t *lpi_num) {
369  if (NULL == dep)
370    return 0;
371
372  // from LmaNodeLE0 (root) to LmaNodeLE0
373  if (0 == from_handle) {
374    assert(0 == dep->splids_extended);
375    return extend_dict0(from_handle, dep, lpi_items, lpi_max, lpi_num);
376  }
377
378  // from LmaNodeLE0 to LmaNodeGE1
379  if (1 == dep->splids_extended)
380    return extend_dict1(from_handle, dep, lpi_items, lpi_max, lpi_num);
381
382  // From LmaNodeGE1 to LmaNodeGE1
383  return extend_dict2(from_handle, dep, lpi_items, lpi_max, lpi_num);
384}
385
386MileStoneHandle DictTrie::extend_dict0(MileStoneHandle from_handle,
387                                       const DictExtPara *dep,
388                                       LmaPsbItem *lpi_items,
389                                       size_t lpi_max, size_t *lpi_num) {
390  assert(NULL != dep && 0 == from_handle);
391  *lpi_num = 0;
392  MileStoneHandle ret_handle = 0;
393
394  uint16 splid = dep->splids[dep->splids_extended];
395  uint16 id_start = dep->id_start;
396  uint16 id_num = dep->id_num;
397
398  LpiCache& lpi_cache = LpiCache::get_instance();
399  bool cached = lpi_cache.is_cached(splid);
400
401  // 2. Begin exgtending
402  // 2.1 Get the LmaPsbItem list
403  LmaNodeLE0 *node = root_;
404  size_t son_start = splid_le0_index_[id_start - kFullSplIdStart];
405  size_t son_end = splid_le0_index_[id_start + id_num - kFullSplIdStart];
406  for (size_t son_pos = son_start; son_pos < son_end; son_pos++) {
407    assert(1 == node->son_1st_off);
408    LmaNodeLE0 *son = root_ + son_pos;
409    assert(son->spl_idx >= id_start && son->spl_idx < id_start + id_num);
410
411    if (!cached && *lpi_num < lpi_max) {
412      bool need_lpi = true;
413      if (spl_trie_->is_half_id_yunmu(splid) && son_pos != son_start)
414        need_lpi = false;
415
416      if (need_lpi)
417        *lpi_num += fill_lpi_buffer(lpi_items + (*lpi_num),
418                                    lpi_max - *lpi_num, son);
419    }
420
421    // If necessary, fill in a new mile stone.
422    if (son->spl_idx == id_start) {
423      if (mile_stones_pos_ < kMaxMileStone &&
424          parsing_marks_pos_ < kMaxParsingMark) {
425        parsing_marks_[parsing_marks_pos_].node_offset = son_pos;
426        parsing_marks_[parsing_marks_pos_].node_num = id_num;
427        mile_stones_[mile_stones_pos_].mark_start = parsing_marks_pos_;
428        mile_stones_[mile_stones_pos_].mark_num = 1;
429        ret_handle = mile_stones_pos_;
430        parsing_marks_pos_++;
431        mile_stones_pos_++;
432      }
433    }
434
435    if (son->spl_idx >= id_start + id_num -1)
436      break;
437  }
438
439  //  printf("----- parsing marks: %d, mile stone: %d \n", parsing_marks_pos_,
440  //      mile_stones_pos_);
441  return ret_handle;
442}
443
444MileStoneHandle DictTrie::extend_dict1(MileStoneHandle from_handle,
445                                       const DictExtPara *dep,
446                                       LmaPsbItem *lpi_items,
447                                       size_t lpi_max, size_t *lpi_num) {
448  assert(NULL != dep && from_handle > 0 && from_handle < mile_stones_pos_);
449
450  MileStoneHandle ret_handle = 0;
451
452  // 1. If this is a half Id, get its corresponding full starting Id and
453  // number of full Id.
454  size_t ret_val = 0;
455
456  uint16 id_start = dep->id_start;
457  uint16 id_num = dep->id_num;
458
459  // 2. Begin extending.
460  MileStone *mile_stone = mile_stones_ + from_handle;
461
462  for (uint16 h_pos = 0; h_pos < mile_stone->mark_num; h_pos++) {
463    ParsingMark p_mark = parsing_marks_[mile_stone->mark_start + h_pos];
464    uint16 ext_num = p_mark.node_num;
465    for (uint16 ext_pos = 0; ext_pos < ext_num; ext_pos++) {
466      LmaNodeLE0 *node = root_ + p_mark.node_offset + ext_pos;
467      size_t found_start = 0;
468      size_t found_num = 0;
469      for (size_t son_pos = 0; son_pos < (size_t)node->num_of_son; son_pos++) {
470        assert(node->son_1st_off <= lma_node_num_ge1_);
471        LmaNodeGE1 *son = nodes_ge1_ + node->son_1st_off + son_pos;
472        if (son->spl_idx >= id_start
473            && son->spl_idx < id_start + id_num) {
474          if (*lpi_num < lpi_max) {
475            size_t homo_buf_off = get_homo_idx_buf_offset(son);
476            *lpi_num += fill_lpi_buffer(lpi_items + (*lpi_num),
477                                        lpi_max - *lpi_num, homo_buf_off, son,
478                                        2);
479          }
480
481          // If necessary, fill in the new DTMI
482          if (0 == found_num) {
483            found_start = son_pos;
484          }
485          found_num++;
486        }
487        if (son->spl_idx >= id_start + id_num - 1 || son_pos ==
488            (size_t)node->num_of_son - 1) {
489          if (found_num > 0) {
490            if (mile_stones_pos_ < kMaxMileStone &&
491                parsing_marks_pos_ < kMaxParsingMark) {
492              parsing_marks_[parsing_marks_pos_].node_offset =
493                node->son_1st_off + found_start;
494              parsing_marks_[parsing_marks_pos_].node_num = found_num;
495              if (0 == ret_val)
496                mile_stones_[mile_stones_pos_].mark_start =
497                  parsing_marks_pos_;
498              parsing_marks_pos_++;
499            }
500
501            ret_val++;
502          }
503          break;
504        }  // for son_pos
505      }  // for ext_pos
506    }  // for h_pos
507  }
508
509  if (ret_val > 0) {
510    mile_stones_[mile_stones_pos_].mark_num = ret_val;
511    ret_handle = mile_stones_pos_;
512    mile_stones_pos_++;
513    ret_val = 1;
514  }
515
516  //  printf("----- parsing marks: %d, mile stone: %d \n", parsing_marks_pos_,
517  //         mile_stones_pos_);
518  return ret_handle;
519}
520
521MileStoneHandle DictTrie::extend_dict2(MileStoneHandle from_handle,
522                                       const DictExtPara *dep,
523                                       LmaPsbItem *lpi_items,
524                                       size_t lpi_max, size_t *lpi_num) {
525  assert(NULL != dep && from_handle > 0 && from_handle < mile_stones_pos_);
526
527  MileStoneHandle ret_handle = 0;
528
529  // 1. If this is a half Id, get its corresponding full starting Id and
530  // number of full Id.
531  size_t ret_val = 0;
532
533  uint16 id_start = dep->id_start;
534  uint16 id_num = dep->id_num;
535
536  // 2. Begin extending.
537  MileStone *mile_stone = mile_stones_ + from_handle;
538
539  for (uint16 h_pos = 0; h_pos < mile_stone->mark_num; h_pos++) {
540    ParsingMark p_mark = parsing_marks_[mile_stone->mark_start + h_pos];
541    uint16 ext_num = p_mark.node_num;
542    for (uint16 ext_pos = 0; ext_pos < ext_num; ext_pos++) {
543      LmaNodeGE1 *node = nodes_ge1_ + p_mark.node_offset + ext_pos;
544      size_t found_start = 0;
545      size_t found_num = 0;
546
547      for (size_t son_pos = 0; son_pos < (size_t)node->num_of_son; son_pos++) {
548        assert(node->son_1st_off_l > 0 || node->son_1st_off_h > 0);
549        LmaNodeGE1 *son = nodes_ge1_ + get_son_offset(node) + son_pos;
550        if (son->spl_idx >= id_start
551            && son->spl_idx < id_start + id_num) {
552          if (*lpi_num < lpi_max) {
553            size_t homo_buf_off = get_homo_idx_buf_offset(son);
554            *lpi_num += fill_lpi_buffer(lpi_items + (*lpi_num),
555                                        lpi_max - *lpi_num, homo_buf_off, son,
556                                        dep->splids_extended + 1);
557          }
558
559          // If necessary, fill in the new DTMI
560          if (0 == found_num) {
561            found_start = son_pos;
562          }
563          found_num++;
564        }
565        if (son->spl_idx >= id_start + id_num - 1 || son_pos ==
566            (size_t)node->num_of_son - 1) {
567          if (found_num > 0) {
568            if (mile_stones_pos_ < kMaxMileStone &&
569                parsing_marks_pos_ < kMaxParsingMark) {
570              parsing_marks_[parsing_marks_pos_].node_offset =
571                get_son_offset(node) + found_start;
572              parsing_marks_[parsing_marks_pos_].node_num = found_num;
573              if (0 == ret_val)
574                mile_stones_[mile_stones_pos_].mark_start =
575                  parsing_marks_pos_;
576              parsing_marks_pos_++;
577            }
578
579            ret_val++;
580          }
581          break;
582        }
583      }  // for son_pos
584    }  // for ext_pos
585  }  // for h_pos
586
587  if (ret_val > 0) {
588    mile_stones_[mile_stones_pos_].mark_num = ret_val;
589    ret_handle = mile_stones_pos_;
590    mile_stones_pos_++;
591  }
592
593  // printf("----- parsing marks: %d, mile stone: %d \n", parsing_marks_pos_,
594  //        mile_stones_pos_);
595  return ret_handle;
596}
597
598bool DictTrie::try_extend(const uint16 *splids, uint16 splid_num,
599                          LemmaIdType id_lemma) {
600  if (0 == splid_num || NULL == splids)
601    return false;
602
603  void *node = root_ + splid_le0_index_[splids[0] - kFullSplIdStart];
604
605  for (uint16 pos = 1; pos < splid_num; pos++) {
606    if (1 == pos) {
607      LmaNodeLE0 *node_le0 = reinterpret_cast<LmaNodeLE0*>(node);
608      LmaNodeGE1 *node_son;
609      uint16 son_pos;
610      for (son_pos = 0; son_pos < static_cast<uint16>(node_le0->num_of_son);
611           son_pos++) {
612        assert(node_le0->son_1st_off <= lma_node_num_ge1_);
613        node_son = nodes_ge1_ + node_le0->son_1st_off
614            + son_pos;
615        if (node_son->spl_idx == splids[pos])
616          break;
617      }
618      if (son_pos < node_le0->num_of_son)
619        node = reinterpret_cast<void*>(node_son);
620      else
621        return false;
622    } else {
623      LmaNodeGE1 *node_ge1 = reinterpret_cast<LmaNodeGE1*>(node);
624      LmaNodeGE1 *node_son;
625      uint16 son_pos;
626      for (son_pos = 0; son_pos < static_cast<uint16>(node_ge1->num_of_son);
627           son_pos++) {
628        assert(node_ge1->son_1st_off_l > 0 || node_ge1->son_1st_off_h > 0);
629        node_son = nodes_ge1_ + get_son_offset(node_ge1) + son_pos;
630        if (node_son->spl_idx == splids[pos])
631          break;
632      }
633      if (son_pos < node_ge1->num_of_son)
634        node = reinterpret_cast<void*>(node_son);
635      else
636        return false;
637    }
638  }
639
640  if (1 == splid_num) {
641    LmaNodeLE0* node_le0 = reinterpret_cast<LmaNodeLE0*>(node);
642    size_t num_of_homo = (size_t)node_le0->num_of_homo;
643    for (size_t homo_pos = 0; homo_pos < num_of_homo; homo_pos++) {
644      LemmaIdType id_this = get_lemma_id(node_le0->homo_idx_buf_off + homo_pos);
645      char16 str[2];
646      get_lemma_str(id_this, str, 2);
647      if (id_this == id_lemma)
648        return true;
649    }
650  } else {
651    LmaNodeGE1* node_ge1 = reinterpret_cast<LmaNodeGE1*>(node);
652    size_t num_of_homo = (size_t)node_ge1->num_of_homo;
653    for (size_t homo_pos = 0; homo_pos < num_of_homo; homo_pos++) {
654      size_t node_homo_off = get_homo_idx_buf_offset(node_ge1);
655      if (get_lemma_id(node_homo_off + homo_pos) == id_lemma)
656        return true;
657    }
658  }
659
660  return false;
661}
662
663size_t DictTrie::get_lpis(const uint16* splid_str, uint16 splid_str_len,
664                          LmaPsbItem* lma_buf, size_t max_lma_buf) {
665  if (splid_str_len > kMaxLemmaSize)
666    return 0;
667
668#define MAX_EXTENDBUF_LEN 200
669
670  size_t* node_buf1[MAX_EXTENDBUF_LEN];  // use size_t for data alignment
671  size_t* node_buf2[MAX_EXTENDBUF_LEN];
672  LmaNodeLE0** node_fr_le0 =
673    reinterpret_cast<LmaNodeLE0**>(node_buf1);      // Nodes from.
674  LmaNodeLE0** node_to_le0 =
675    reinterpret_cast<LmaNodeLE0**>(node_buf2);      // Nodes to.
676  LmaNodeGE1** node_fr_ge1 = NULL;
677  LmaNodeGE1** node_to_ge1 = NULL;
678  size_t node_fr_num = 1;
679  size_t node_to_num = 0;
680  node_fr_le0[0] = root_;
681  if (NULL == node_fr_le0[0])
682    return 0;
683
684  size_t spl_pos = 0;
685
686  while (spl_pos < splid_str_len) {
687    uint16 id_num = 1;
688    uint16 id_start = splid_str[spl_pos];
689    // If it is a half id
690    if (spl_trie_->is_half_id(splid_str[spl_pos])) {
691      id_num = spl_trie_->half_to_full(splid_str[spl_pos], &id_start);
692      assert(id_num > 0);
693    }
694
695    // Extend the nodes
696    if (0 == spl_pos) {  // From LmaNodeLE0 (root) to LmaNodeLE0 nodes
697      for (size_t node_fr_pos = 0; node_fr_pos < node_fr_num; node_fr_pos++) {
698        LmaNodeLE0 *node = node_fr_le0[node_fr_pos];
699        assert(node == root_ && 1 == node_fr_num);
700        size_t son_start = splid_le0_index_[id_start - kFullSplIdStart];
701        size_t son_end =
702            splid_le0_index_[id_start + id_num - kFullSplIdStart];
703        for (size_t son_pos = son_start; son_pos < son_end; son_pos++) {
704          assert(1 == node->son_1st_off);
705          LmaNodeLE0 *node_son = root_ + son_pos;
706          assert(node_son->spl_idx >= id_start
707                 && node_son->spl_idx < id_start + id_num);
708          if (node_to_num < MAX_EXTENDBUF_LEN) {
709            node_to_le0[node_to_num] = node_son;
710            node_to_num++;
711          }
712          // id_start + id_num - 1 is the last one, which has just been
713          // recorded.
714          if (node_son->spl_idx >= id_start + id_num - 1)
715            break;
716        }
717      }
718
719      spl_pos++;
720      if (spl_pos >= splid_str_len || node_to_num == 0)
721        break;
722      // Prepare the nodes for next extending
723      // next time, from LmaNodeLE0 to LmaNodeGE1
724      LmaNodeLE0** node_tmp = node_fr_le0;
725      node_fr_le0 = node_to_le0;
726      node_to_le0 = NULL;
727      node_to_ge1 = reinterpret_cast<LmaNodeGE1**>(node_tmp);
728    } else if (1 == spl_pos) {  // From LmaNodeLE0 to LmaNodeGE1 nodes
729      for (size_t node_fr_pos = 0; node_fr_pos < node_fr_num; node_fr_pos++) {
730        LmaNodeLE0 *node = node_fr_le0[node_fr_pos];
731        for (size_t son_pos = 0; son_pos < (size_t)node->num_of_son;
732             son_pos++) {
733          assert(node->son_1st_off <= lma_node_num_ge1_);
734          LmaNodeGE1 *node_son = nodes_ge1_ + node->son_1st_off
735                                  + son_pos;
736          if (node_son->spl_idx >= id_start
737              && node_son->spl_idx < id_start + id_num) {
738            if (node_to_num < MAX_EXTENDBUF_LEN) {
739              node_to_ge1[node_to_num] = node_son;
740              node_to_num++;
741            }
742          }
743          // id_start + id_num - 1 is the last one, which has just been
744          // recorded.
745          if (node_son->spl_idx >= id_start + id_num - 1)
746            break;
747        }
748      }
749
750      spl_pos++;
751      if (spl_pos >= splid_str_len || node_to_num == 0)
752        break;
753      // Prepare the nodes for next extending
754      // next time, from LmaNodeGE1 to LmaNodeGE1
755      node_fr_ge1 = node_to_ge1;
756      node_to_ge1 = reinterpret_cast<LmaNodeGE1**>(node_fr_le0);
757      node_fr_le0 = NULL;
758      node_to_le0 = NULL;
759    } else {  // From LmaNodeGE1 to LmaNodeGE1 nodes
760      for (size_t node_fr_pos = 0; node_fr_pos < node_fr_num; node_fr_pos++) {
761        LmaNodeGE1 *node = node_fr_ge1[node_fr_pos];
762        for (size_t son_pos = 0; son_pos < (size_t)node->num_of_son;
763             son_pos++) {
764          assert(node->son_1st_off_l > 0 || node->son_1st_off_h > 0);
765          LmaNodeGE1 *node_son = nodes_ge1_
766                                  + get_son_offset(node) + son_pos;
767          if (node_son->spl_idx >= id_start
768              && node_son->spl_idx < id_start + id_num) {
769            if (node_to_num < MAX_EXTENDBUF_LEN) {
770              node_to_ge1[node_to_num] = node_son;
771              node_to_num++;
772            }
773          }
774          // id_start + id_num - 1 is the last one, which has just been
775          // recorded.
776          if (node_son->spl_idx >= id_start + id_num - 1)
777            break;
778        }
779      }
780
781      spl_pos++;
782      if (spl_pos >= splid_str_len || node_to_num == 0)
783        break;
784      // Prepare the nodes for next extending
785      // next time, from LmaNodeGE1 to LmaNodeGE1
786      LmaNodeGE1 **node_tmp = node_fr_ge1;
787      node_fr_ge1 = node_to_ge1;
788      node_to_ge1 = node_tmp;
789    }
790
791    // The number of node for next extending
792    node_fr_num = node_to_num;
793    node_to_num = 0;
794  }  // while
795
796  if (0 == node_to_num)
797    return 0;
798
799  NGram &ngram = NGram::get_instance();
800  size_t lma_num = 0;
801
802  // If the length is 1, and the splid is a one-char Yunmu like 'a', 'o', 'e',
803  // only those candidates for the full matched one-char id will be returned.
804  if (1 == splid_str_len && spl_trie_->is_half_id_yunmu(splid_str[0]))
805    node_to_num = node_to_num > 0 ? 1 : 0;
806
807  for (size_t node_pos = 0; node_pos < node_to_num; node_pos++) {
808    size_t num_of_homo = 0;
809    if (spl_pos <= 1) {  // Get from LmaNodeLE0 nodes
810      LmaNodeLE0* node_le0 = node_to_le0[node_pos];
811      num_of_homo = (size_t)node_le0->num_of_homo;
812      for (size_t homo_pos = 0; homo_pos < num_of_homo; homo_pos++) {
813        size_t ch_pos = lma_num + homo_pos;
814        lma_buf[ch_pos].id =
815            get_lemma_id(node_le0->homo_idx_buf_off + homo_pos);
816        lma_buf[ch_pos].lma_len = 1;
817        lma_buf[ch_pos].psb =
818            static_cast<LmaScoreType>(ngram.get_uni_psb(lma_buf[ch_pos].id));
819
820        if (lma_num + homo_pos >= max_lma_buf - 1)
821          break;
822      }
823    } else {  // Get from LmaNodeGE1 nodes
824      LmaNodeGE1* node_ge1 = node_to_ge1[node_pos];
825      num_of_homo = (size_t)node_ge1->num_of_homo;
826      for (size_t homo_pos = 0; homo_pos < num_of_homo; homo_pos++) {
827        size_t ch_pos = lma_num + homo_pos;
828        size_t node_homo_off = get_homo_idx_buf_offset(node_ge1);
829        lma_buf[ch_pos].id = get_lemma_id(node_homo_off + homo_pos);
830        lma_buf[ch_pos].lma_len = splid_str_len;
831        lma_buf[ch_pos].psb =
832            static_cast<LmaScoreType>(ngram.get_uni_psb(lma_buf[ch_pos].id));
833
834        if (lma_num + homo_pos >= max_lma_buf - 1)
835          break;
836      }
837    }
838
839    lma_num += num_of_homo;
840    if (lma_num >= max_lma_buf) {
841      lma_num = max_lma_buf;
842      break;
843    }
844  }
845  return lma_num;
846}
847
848uint16 DictTrie::get_lemma_str(LemmaIdType id_lemma, char16 *str_buf,
849                               uint16 str_max) {
850  return dict_list_->get_lemma_str(id_lemma, str_buf, str_max);
851}
852
853uint16 DictTrie::get_lemma_splids(LemmaIdType id_lemma, uint16 *splids,
854                                  uint16 splids_max, bool arg_valid) {
855  char16 lma_str[kMaxLemmaSize + 1];
856  uint16 lma_len = get_lemma_str(id_lemma, lma_str, kMaxLemmaSize + 1);
857  assert((!arg_valid && splids_max >= lma_len) || lma_len == splids_max);
858
859  uint16 spl_mtrx[kMaxLemmaSize * 5];
860  uint16 spl_start[kMaxLemmaSize + 1];
861  spl_start[0] = 0;
862  uint16 try_num = 1;
863
864  for (uint16 pos = 0; pos < lma_len; pos++) {
865    uint16 cand_splids_this = 0;
866    if (arg_valid && spl_trie_->is_full_id(splids[pos])) {
867      spl_mtrx[spl_start[pos]] = splids[pos];
868      cand_splids_this = 1;
869    } else {
870      cand_splids_this = dict_list_->get_splids_for_hanzi(lma_str[pos],
871          arg_valid ? splids[pos] : 0, spl_mtrx + spl_start[pos],
872          kMaxLemmaSize * 5 - spl_start[pos]);
873      assert(cand_splids_this > 0);
874    }
875    spl_start[pos + 1] = spl_start[pos] + cand_splids_this;
876    try_num *= cand_splids_this;
877  }
878
879  for (uint16 try_pos = 0; try_pos < try_num; try_pos++) {
880    uint16 mod = 1;
881    for (uint16 pos = 0; pos < lma_len; pos++) {
882      uint16 radix = spl_start[pos + 1] - spl_start[pos];
883      splids[pos] = spl_mtrx[ spl_start[pos] + try_pos / mod % radix];
884      mod *= radix;
885    }
886
887    if (try_extend(splids, lma_len, id_lemma))
888      return lma_len;
889  }
890
891  return 0;
892}
893
894void DictTrie::set_total_lemma_count_of_others(size_t count) {
895  NGram& ngram = NGram::get_instance();
896  ngram.set_total_freq_none_sys(count);
897}
898
899void DictTrie::convert_to_hanzis(char16 *str, uint16 str_len) {
900  return dict_list_->convert_to_hanzis(str, str_len);
901}
902
903void DictTrie::convert_to_scis_ids(char16 *str, uint16 str_len) {
904  return dict_list_->convert_to_scis_ids(str, str_len);
905}
906
907LemmaIdType DictTrie::get_lemma_id(const char16 lemma_str[], uint16 lemma_len) {
908  if (NULL == lemma_str || lemma_len > kMaxLemmaSize)
909    return 0;
910
911  return dict_list_->get_lemma_id(lemma_str, lemma_len);
912}
913
914size_t DictTrie::predict_top_lmas(size_t his_len, NPredictItem *npre_items,
915                                  size_t npre_max, size_t b4_used) {
916  NGram &ngram = NGram::get_instance();
917
918  size_t item_num = 0;
919  size_t top_lmas_id_offset = lma_idx_buf_len_ / kLemmaIdSize - top_lmas_num_;
920  size_t top_lmas_pos = 0;
921  while (item_num < npre_max && top_lmas_pos < top_lmas_num_) {
922    memset(npre_items + item_num, 0, sizeof(NPredictItem));
923    LemmaIdType top_lma_id = get_lemma_id(top_lmas_id_offset + top_lmas_pos);
924    top_lmas_pos += 1;
925    if (dict_list_->get_lemma_str(top_lma_id,
926                                  npre_items[item_num].pre_hzs,
927                                  kMaxLemmaSize - 1) == 0) {
928      continue;
929    }
930    npre_items[item_num].psb = ngram.get_uni_psb(top_lma_id);
931    npre_items[item_num].his_len = his_len;
932    item_num++;
933  }
934  return item_num;
935}
936
937size_t DictTrie::predict(const char16 *last_hzs, uint16 hzs_len,
938                         NPredictItem *npre_items, size_t npre_max,
939                         size_t b4_used) {
940  return dict_list_->predict(last_hzs, hzs_len, npre_items, npre_max, b4_used);
941}
942}  // namespace ime_pinyin
943