1/*
2 * Copyright (C) 2009 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <assert.h>
18#include <stdlib.h>
19#include <string.h>
20#include "../include/dictlist.h"
21#include "../include/mystdlib.h"
22#include "../include/ngram.h"
23#include "../include/searchutility.h"
24
25namespace ime_pinyin {
26
27DictList::DictList() {
28  initialized_ = false;
29  scis_num_ = 0;
30  scis_hz_ = NULL;
31  scis_splid_ = NULL;
32  buf_ = NULL;
33  spl_trie_ = SpellingTrie::get_cpinstance();
34
35  assert(kMaxLemmaSize == 8);
36  cmp_func_[0] = cmp_hanzis_1;
37  cmp_func_[1] = cmp_hanzis_2;
38  cmp_func_[2] = cmp_hanzis_3;
39  cmp_func_[3] = cmp_hanzis_4;
40  cmp_func_[4] = cmp_hanzis_5;
41  cmp_func_[5] = cmp_hanzis_6;
42  cmp_func_[6] = cmp_hanzis_7;
43  cmp_func_[7] = cmp_hanzis_8;
44}
45
46DictList::~DictList() {
47  free_resource();
48}
49
50bool DictList::alloc_resource(size_t buf_size, size_t scis_num) {
51  // Allocate memory
52  buf_ = static_cast<char16*>(malloc(buf_size * sizeof(char16)));
53  if (NULL == buf_)
54    return false;
55
56  scis_num_ = scis_num;
57
58  scis_hz_ = static_cast<char16*>(malloc(scis_num_ * sizeof(char16)));
59  if (NULL == scis_hz_)
60    return false;
61
62  scis_splid_ = static_cast<SpellingId*>
63      (malloc(scis_num_ * sizeof(SpellingId)));
64
65  if (NULL == scis_splid_)
66    return false;
67
68  return true;
69}
70
71void DictList::free_resource() {
72  if (NULL != buf_)
73    free(buf_);
74  buf_ = NULL;
75
76  if (NULL != scis_hz_)
77    free(scis_hz_);
78  scis_hz_ = NULL;
79
80  if (NULL != scis_splid_)
81    free(scis_splid_);
82  scis_splid_ = NULL;
83}
84
85#ifdef ___BUILD_MODEL___
86bool DictList::init_list(const SingleCharItem *scis, size_t scis_num,
87                         const LemmaEntry *lemma_arr, size_t lemma_num) {
88  if (NULL == scis || 0 == scis_num || NULL == lemma_arr || 0 == lemma_num)
89    return false;
90
91  initialized_ = false;
92
93  if (NULL != buf_)
94    free(buf_);
95
96  // calculate the size
97  size_t buf_size = calculate_size(lemma_arr, lemma_num);
98  if (0 == buf_size)
99    return false;
100
101  if (!alloc_resource(buf_size, scis_num))
102    return false;
103
104  fill_scis(scis, scis_num);
105
106  // Copy the related content from the array to inner buffer
107  fill_list(lemma_arr, lemma_num);
108
109  initialized_ = true;
110  return true;
111}
112
113size_t DictList::calculate_size(const LemmaEntry* lemma_arr, size_t lemma_num) {
114  size_t last_hz_len = 0;
115  size_t list_size = 0;
116  size_t id_num = 0;
117
118  for (size_t i = 0; i < lemma_num; i++) {
119    if (0 == i) {
120      last_hz_len = lemma_arr[i].hz_str_len;
121
122      assert(last_hz_len > 0);
123      assert(lemma_arr[0].idx_by_hz == 1);
124
125      id_num++;
126      start_pos_[0] = 0;
127      start_id_[0] = id_num;
128
129      last_hz_len = 1;
130      list_size += last_hz_len;
131    } else {
132      size_t current_hz_len = lemma_arr[i].hz_str_len;
133
134      assert(current_hz_len >= last_hz_len);
135
136      if (current_hz_len == last_hz_len) {
137          list_size += current_hz_len;
138          id_num++;
139      } else {
140        for (size_t len = last_hz_len; len < current_hz_len - 1; len++) {
141          start_pos_[len] = start_pos_[len - 1];
142          start_id_[len] = start_id_[len - 1];
143        }
144
145        start_pos_[current_hz_len - 1] = list_size;
146
147        id_num++;
148        start_id_[current_hz_len - 1] = id_num;
149
150        last_hz_len = current_hz_len;
151        list_size += current_hz_len;
152      }
153    }
154  }
155
156  for (size_t i = last_hz_len; i <= kMaxLemmaSize; i++) {
157    if (0 == i) {
158      start_pos_[0] = 0;
159      start_id_[0] = 1;
160    } else {
161      start_pos_[i] = list_size;
162      start_id_[i] = id_num;
163    }
164  }
165
166  return start_pos_[kMaxLemmaSize];
167}
168
169void DictList::fill_scis(const SingleCharItem *scis, size_t scis_num) {
170  assert(scis_num_ == scis_num);
171
172  for (size_t pos = 0; pos < scis_num_; pos++) {
173    scis_hz_[pos] = scis[pos].hz;
174    scis_splid_[pos] = scis[pos].splid;
175  }
176}
177
178void DictList::fill_list(const LemmaEntry* lemma_arr, size_t lemma_num) {
179  size_t current_pos = 0;
180
181  utf16_strncpy(buf_, lemma_arr[0].hanzi_str,
182                lemma_arr[0].hz_str_len);
183
184  current_pos = lemma_arr[0].hz_str_len;
185
186  size_t id_num = 1;
187
188  for (size_t i = 1; i < lemma_num; i++) {
189    utf16_strncpy(buf_ + current_pos, lemma_arr[i].hanzi_str,
190                  lemma_arr[i].hz_str_len);
191
192    id_num++;
193    current_pos += lemma_arr[i].hz_str_len;
194  }
195
196  assert(current_pos == start_pos_[kMaxLemmaSize]);
197  assert(id_num == start_id_[kMaxLemmaSize]);
198}
199
200char16* DictList::find_pos2_startedbyhz(char16 hz_char) {
201  char16 *found_2w = static_cast<char16*>
202                     (mybsearch(&hz_char, buf_ + start_pos_[1],
203                                (start_pos_[2] - start_pos_[1]) / 2,
204                                sizeof(char16) * 2, cmp_hanzis_1));
205  if (NULL == found_2w)
206    return NULL;
207
208  while (found_2w > buf_ + start_pos_[1] && *found_2w == *(found_2w - 1))
209    found_2w -= 2;
210
211  return found_2w;
212}
213#endif  // ___BUILD_MODEL___
214
215char16* DictList::find_pos_startedbyhzs(const char16 last_hzs[],
216    size_t word_len, int (*cmp_func)(const void *, const void *)) {
217  char16 *found_w = static_cast<char16*>
218                    (mybsearch(last_hzs, buf_ + start_pos_[word_len - 1],
219                               (start_pos_[word_len] - start_pos_[word_len - 1])
220                               / word_len,
221                               sizeof(char16) * word_len, cmp_func));
222
223  if (NULL == found_w)
224    return NULL;
225
226  while (found_w > buf_ + start_pos_[word_len -1] &&
227         cmp_func(found_w, found_w - word_len) == 0)
228    found_w -= word_len;
229
230  return found_w;
231}
232
233size_t DictList::predict(const char16 last_hzs[], uint16 hzs_len,
234                         NPredictItem *npre_items, size_t npre_max,
235                         size_t b4_used) {
236  assert(hzs_len <= kMaxPredictSize && hzs_len > 0);
237
238  // 1. Prepare work
239  int (*cmp_func)(const void *, const void *) = cmp_func_[hzs_len - 1];
240
241  NGram& ngram = NGram::get_instance();
242
243  size_t item_num = 0;
244
245  // 2. Do prediction
246  for (uint16 pre_len = 1; pre_len <= kMaxPredictSize + 1 - hzs_len;
247       pre_len++) {
248    uint16 word_len = hzs_len + pre_len;
249    char16 *w_buf = find_pos_startedbyhzs(last_hzs, word_len, cmp_func);
250    if (NULL == w_buf)
251      continue;
252    while (w_buf < buf_ + start_pos_[word_len] &&
253           cmp_func(w_buf, last_hzs) == 0 &&
254           item_num < npre_max) {
255      memset(npre_items + item_num, 0, sizeof(NPredictItem));
256      utf16_strncpy(npre_items[item_num].pre_hzs, w_buf + hzs_len, pre_len);
257      npre_items[item_num].psb =
258        ngram.get_uni_psb((size_t)(w_buf - buf_ - start_pos_[word_len - 1])
259                          / word_len + start_id_[word_len - 1]);
260      npre_items[item_num].his_len = hzs_len;
261      item_num++;
262      w_buf += word_len;
263    }
264  }
265
266  size_t new_num = 0;
267  for (size_t i = 0; i < item_num; i++) {
268    // Try to find it in the existing items
269    size_t e_pos;
270    for (e_pos = 1; e_pos <= b4_used; e_pos++) {
271      if (utf16_strncmp((*(npre_items - e_pos)).pre_hzs, npre_items[i].pre_hzs,
272                        kMaxPredictSize) == 0)
273        break;
274    }
275    if (e_pos <= b4_used)
276      continue;
277
278    // If not found, append it to the buffer
279    npre_items[new_num] = npre_items[i];
280    new_num++;
281  }
282
283  return new_num;
284}
285
286uint16 DictList::get_lemma_str(LemmaIdType id_lemma, char16 *str_buf,
287                               uint16 str_max) {
288  if (!initialized_ || id_lemma >= start_id_[kMaxLemmaSize] || NULL == str_buf
289      || str_max <= 1)
290    return 0;
291
292  // Find the range
293  for (uint16 i = 0; i < kMaxLemmaSize; i++) {
294    if (i + 1 > str_max - 1)
295      return 0;
296    if (start_id_[i] <= id_lemma && start_id_[i + 1] > id_lemma) {
297      size_t id_span = id_lemma - start_id_[i];
298
299      uint16 *buf = buf_ + start_pos_[i] + id_span * (i + 1);
300      for (uint16 len = 0; len <= i; len++) {
301        str_buf[len] = buf[len];
302      }
303      str_buf[i+1] = (char16)'\0';
304      return i + 1;
305    }
306  }
307  return 0;
308}
309
310uint16 DictList::get_splids_for_hanzi(char16 hanzi, uint16 half_splid,
311                                      uint16 *splids, uint16 max_splids) {
312  char16 *hz_found = static_cast<char16*>
313      (mybsearch(&hanzi, scis_hz_, scis_num_, sizeof(char16), cmp_hanzis_1));
314  assert(NULL != hz_found && hanzi == *hz_found);
315
316  // Move to the first one.
317  while (hz_found > scis_hz_ && hanzi == *(hz_found - 1))
318    hz_found--;
319
320  // First try to found if strict comparison result is not zero.
321  char16 *hz_f = hz_found;
322  bool strict = false;
323  while (hz_f < scis_hz_ + scis_num_ && hanzi == *hz_f) {
324    uint16 pos = hz_f - scis_hz_;
325    if (0 == half_splid || scis_splid_[pos].half_splid == half_splid) {
326      strict = true;
327    }
328    hz_f++;
329  }
330
331  uint16 found_num = 0;
332  while (hz_found < scis_hz_ + scis_num_ && hanzi == *hz_found) {
333    uint16 pos = hz_found - scis_hz_;
334    if (0 == half_splid ||
335        (strict && scis_splid_[pos].half_splid == half_splid) ||
336        (!strict && spl_trie_->half_full_compatible(half_splid,
337        scis_splid_[pos].full_splid))) {
338      assert(found_num + 1 < max_splids);
339      splids[found_num] = scis_splid_[pos].full_splid;
340      found_num++;
341    }
342    hz_found++;
343  }
344
345  return found_num;
346}
347
348LemmaIdType DictList::get_lemma_id(const char16 *str, uint16 str_len) {
349  if (NULL == str || str_len > kMaxLemmaSize)
350    return 0;
351
352  char16 *found = find_pos_startedbyhzs(str, str_len, cmp_func_[str_len - 1]);
353  if (NULL == found)
354    return 0;
355
356  assert(found > buf_);
357  assert(static_cast<size_t>(found - buf_) >= start_pos_[str_len - 1]);
358  return static_cast<LemmaIdType>
359      (start_id_[str_len - 1] +
360       (found - buf_ - start_pos_[str_len - 1]) / str_len);
361}
362
363void DictList::convert_to_hanzis(char16 *str, uint16 str_len) {
364  assert(NULL != str);
365
366  for (uint16 str_pos = 0; str_pos < str_len; str_pos++) {
367    str[str_pos] = scis_hz_[str[str_pos]];
368  }
369}
370
371void DictList::convert_to_scis_ids(char16 *str, uint16 str_len) {
372  assert(NULL != str);
373
374  for (uint16 str_pos = 0; str_pos < str_len; str_pos++) {
375    str[str_pos] = 0x100;
376  }
377}
378
379bool DictList::save_list(FILE *fp) {
380  if (!initialized_ || NULL == fp)
381    return false;
382
383  if (NULL == buf_ || 0 == start_pos_[kMaxLemmaSize] ||
384      NULL == scis_hz_ || NULL == scis_splid_ || 0 == scis_num_)
385    return false;
386
387  if (fwrite(&scis_num_, sizeof(size_t), 1, fp) != 1)
388    return false;
389
390  if (fwrite(start_pos_, sizeof(size_t), kMaxLemmaSize + 1, fp) !=
391      kMaxLemmaSize + 1)
392    return false;
393
394  if (fwrite(start_id_, sizeof(size_t), kMaxLemmaSize + 1, fp) !=
395      kMaxLemmaSize + 1)
396    return false;
397
398  if (fwrite(scis_hz_, sizeof(char16), scis_num_, fp) != scis_num_)
399    return false;
400
401  if (fwrite(scis_splid_, sizeof(SpellingId), scis_num_, fp) != scis_num_)
402    return false;
403
404  if (fwrite(buf_, sizeof(char16), start_pos_[kMaxLemmaSize], fp) !=
405      start_pos_[kMaxLemmaSize])
406    return false;
407
408  return true;
409}
410
411bool DictList::load_list(FILE *fp) {
412  if (NULL == fp)
413    return false;
414
415  initialized_ = false;
416
417  if (fread(&scis_num_, sizeof(size_t), 1, fp) != 1)
418    return false;
419
420  if (fread(start_pos_, sizeof(size_t), kMaxLemmaSize + 1, fp) !=
421      kMaxLemmaSize + 1)
422    return false;
423
424  if (fread(start_id_, sizeof(size_t), kMaxLemmaSize + 1, fp) !=
425      kMaxLemmaSize + 1)
426    return false;
427
428  free_resource();
429
430  if (!alloc_resource(start_pos_[kMaxLemmaSize], scis_num_))
431    return false;
432
433  if (fread(scis_hz_, sizeof(char16), scis_num_, fp) != scis_num_)
434    return false;
435
436  if (fread(scis_splid_, sizeof(SpellingId), scis_num_, fp) != scis_num_)
437    return false;
438
439  if (fread(buf_, sizeof(char16), start_pos_[kMaxLemmaSize], fp) !=
440      start_pos_[kMaxLemmaSize])
441    return false;
442
443  initialized_ = true;
444  return true;
445}
446}  // namespace ime_pinyin
447