dic_node_utils.cpp revision 6379a4de29fee7019b32b93bc424eda720e02dcf
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <cstring>
18#include <vector>
19
20#include "suggest/core/dicnode/dic_node.h"
21#include "suggest/core/dicnode/dic_node_utils.h"
22#include "suggest/core/dicnode/dic_node_vector.h"
23#include "suggest/core/dictionary/binary_dictionary_info.h"
24#include "suggest/core/dictionary/binary_format.h"
25#include "suggest/core/dictionary/multi_bigram_map.h"
26#include "suggest/core/dictionary/probability_utils.h"
27#include "suggest/core/layout/proximity_info.h"
28#include "suggest/core/layout/proximity_info_state.h"
29#include "utils/char_utils.h"
30
31namespace latinime {
32
33///////////////////////////////
34// Node initialization utils //
35///////////////////////////////
36
37/* static */ void DicNodeUtils::initAsRoot(const BinaryDictionaryInfo *const binaryDictionaryInfo,
38        const int prevWordNodePos, DicNode *const newRootNode) {
39    const int rootPos = binaryDictionaryInfo->getRootPosition();
40    const int childrenPos = rootPos;
41    newRootNode->initAsRoot(rootPos, childrenPos, prevWordNodePos);
42}
43
44/*static */ void DicNodeUtils::initAsRootWithPreviousWord(
45        const BinaryDictionaryInfo *const binaryDictionaryInfo,
46        DicNode *const prevWordLastNode, DicNode *const newRootNode) {
47    const int rootPos = binaryDictionaryInfo->getRootPosition();
48    const int childrenPos = rootPos;
49    newRootNode->initAsRootWithPreviousWord(prevWordLastNode, rootPos, childrenPos);
50}
51
52/* static */ void DicNodeUtils::initByCopy(DicNode *srcNode, DicNode *destNode) {
53    destNode->initByCopy(srcNode);
54}
55
56///////////////////////////////////
57// Traverse node expansion utils //
58///////////////////////////////////
59
60/* static */ void DicNodeUtils::createAndGetPassingChildNode(DicNode *dicNode,
61        const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly,
62        DicNodeVector *childDicNodes) {
63    // Passing multiple chars node. No need to traverse child
64    const int codePoint = dicNode->getNodeTypedCodePoint();
65    const int baseLowerCaseCodePoint = CharUtils::toBaseLowerCase(codePoint);
66    const bool isMatch = isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, codePoint);
67    if (isMatch || CharUtils::isIntentionalOmissionCodePoint(baseLowerCaseCodePoint)) {
68        childDicNodes->pushPassingChild(dicNode);
69    }
70}
71
72/* static */ int DicNodeUtils::createAndGetLeavingChildNode(DicNode *dicNode, int pos,
73        const BinaryDictionaryInfo *const binaryDictionaryInfo,
74        const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly,
75        const std::vector<int> *const codePointsFilter, const ProximityInfo *const pInfo,
76        DicNodeVector *childDicNodes) {
77    int nextPos = pos;
78    const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(
79            binaryDictionaryInfo->getDictRoot(), &pos);
80    const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags));
81    const bool isTerminal = (0 != (BinaryFormat::FLAG_IS_TERMINAL & flags));
82    const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags);
83
84    int codePoint = BinaryFormat::getCodePointAndForwardPointer(
85            binaryDictionaryInfo->getDictRoot(), &pos);
86    ASSERT(NOT_A_CODE_POINT != codePoint);
87    // TODO: optimize this
88    int mergedNodeCodePoints[MAX_WORD_LENGTH];
89    uint16_t mergedNodeCodePointCount = 0;
90    mergedNodeCodePoints[mergedNodeCodePointCount++] = codePoint;
91
92    do {
93        const int nextCodePoint = hasMultipleChars
94                ? BinaryFormat::getCodePointAndForwardPointer(
95                        binaryDictionaryInfo->getDictRoot(), &pos) : NOT_A_CODE_POINT;
96        const bool isLastChar = (NOT_A_CODE_POINT == nextCodePoint);
97        if (!isLastChar) {
98            mergedNodeCodePoints[mergedNodeCodePointCount++] = nextCodePoint;
99        }
100        codePoint = nextCodePoint;
101    } while (NOT_A_CODE_POINT != codePoint);
102
103    const int probability = isTerminal ? BinaryFormat::readProbabilityWithoutMovingPointer(
104            binaryDictionaryInfo->getDictRoot(), pos) : -1;
105    pos = BinaryFormat::skipProbability(flags, pos);
106    int childrenPos = hasChildren ? BinaryFormat::readChildrenPosition(
107            binaryDictionaryInfo->getDictRoot(), flags, pos) : 0;
108    const int attributesPos = BinaryFormat::skipChildrenPosition(flags, pos);
109    const int siblingPos = BinaryFormat::skipChildrenPosAndAttributes(
110            binaryDictionaryInfo->getDictRoot(), flags, pos);
111
112    if (isDicNodeFilteredOut(mergedNodeCodePoints[0], pInfo, codePointsFilter)) {
113        return siblingPos;
114    }
115    if (!isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, mergedNodeCodePoints[0])) {
116        return siblingPos;
117    }
118    childDicNodes->pushLeavingChild(dicNode, nextPos, flags, childrenPos, attributesPos,
119            probability, isTerminal, hasChildren, mergedNodeCodePointCount, mergedNodeCodePoints);
120    return siblingPos;
121}
122
123/* static */ bool DicNodeUtils::isDicNodeFilteredOut(const int nodeCodePoint,
124        const ProximityInfo *const pInfo, const std::vector<int> *const codePointsFilter) {
125    const int filterSize = codePointsFilter ? codePointsFilter->size() : 0;
126    if (filterSize <= 0) {
127        return false;
128    }
129    if (pInfo && (pInfo->getKeyIndexOf(nodeCodePoint) == NOT_AN_INDEX
130            || CharUtils::isIntentionalOmissionCodePoint(nodeCodePoint))) {
131        // If normalized nodeCodePoint is not on the keyboard or skippable, this child is never
132        // filtered.
133        return false;
134    }
135    const int lowerCodePoint = CharUtils::toLowerCase(nodeCodePoint);
136    const int baseLowerCodePoint = CharUtils::toBaseCodePoint(lowerCodePoint);
137    // TODO: Avoid linear search
138    for (int i = 0; i < filterSize; ++i) {
139        // Checking if a normalized code point is in filter characters when pInfo is not
140        // null. When pInfo is null, nodeCodePoint is used to check filtering without
141        // normalizing.
142        if ((pInfo && ((*codePointsFilter)[i] == lowerCodePoint
143                || (*codePointsFilter)[i] == baseLowerCodePoint))
144                        || (!pInfo && (*codePointsFilter)[i] == nodeCodePoint)) {
145            return false;
146        }
147    }
148    return true;
149}
150
151/* static */ void DicNodeUtils::createAndGetAllLeavingChildNodes(DicNode *dicNode,
152        const BinaryDictionaryInfo *const binaryDictionaryInfo,
153        const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly,
154        const std::vector<int> *const codePointsFilter, const ProximityInfo *const pInfo,
155        DicNodeVector *childDicNodes) {
156    if (!dicNode->hasChildren()) {
157        return;
158    }
159    int nextPos = dicNode->getChildrenPos();
160    const int childCount = BinaryFormat::getGroupCountAndForwardPointer(
161            binaryDictionaryInfo->getDictRoot(), &nextPos);
162    for (int i = 0; i < childCount; i++) {
163        const int filterSize = codePointsFilter ? codePointsFilter->size() : 0;
164        nextPos = createAndGetLeavingChildNode(dicNode, nextPos, binaryDictionaryInfo,
165                pInfoState, pointIndex, exactOnly, codePointsFilter, pInfo,
166                childDicNodes);
167        if (!pInfo && filterSize > 0 && childDicNodes->exceeds(filterSize)) {
168            // All code points have been found.
169            break;
170        }
171    }
172}
173
174/* static */ void DicNodeUtils::getAllChildDicNodes(DicNode *dicNode,
175        const BinaryDictionaryInfo *const binaryDictionaryInfo, DicNodeVector *childDicNodes) {
176    getProximityChildDicNodes(dicNode, binaryDictionaryInfo, 0, 0, false, childDicNodes);
177}
178
179/* static */ void DicNodeUtils::getProximityChildDicNodes(DicNode *dicNode,
180        const BinaryDictionaryInfo *const binaryDictionaryInfo,
181        const ProximityInfoState *pInfoState, const int pointIndex, bool exactOnly,
182        DicNodeVector *childDicNodes) {
183    if (dicNode->isTotalInputSizeExceedingLimit()) {
184        return;
185    }
186    if (!dicNode->isLeavingNode()) {
187        DicNodeUtils::createAndGetPassingChildNode(dicNode, pInfoState, pointIndex, exactOnly,
188                childDicNodes);
189    } else {
190        DicNodeUtils::createAndGetAllLeavingChildNodes(
191                dicNode, binaryDictionaryInfo, pInfoState, pointIndex, exactOnly,
192                0 /* codePointsFilter */, 0 /* pInfo */, childDicNodes);
193    }
194}
195
196///////////////////
197// Scoring utils //
198///////////////////
199/**
200 * Computes the combined bigram / unigram cost for the given dicNode.
201 */
202/* static */ float DicNodeUtils::getBigramNodeImprobability(
203        const BinaryDictionaryInfo *const binaryDictionaryInfo,
204        const DicNode *const node, MultiBigramMap *multiBigramMap) {
205    if (node->isImpossibleBigramWord()) {
206        return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
207    }
208    const int probability = getBigramNodeProbability(binaryDictionaryInfo, node, multiBigramMap);
209    // TODO: This equation to calculate the improbability looks unreasonable.  Investigate this.
210    const float cost = static_cast<float>(MAX_PROBABILITY - probability)
211            / static_cast<float>(MAX_PROBABILITY);
212    return cost;
213}
214
215/* static */ int DicNodeUtils::getBigramNodeProbability(
216        const BinaryDictionaryInfo *const binaryDictionaryInfo,
217        const DicNode *const node, MultiBigramMap *multiBigramMap) {
218    const int unigramProbability = node->getProbability();
219    const int wordPos = node->getPos();
220    const int prevWordPos = node->getPrevWordPos();
221    if (NOT_VALID_WORD == wordPos || NOT_VALID_WORD == prevWordPos) {
222        // Note: Normally wordPos comes from the dictionary and should never equal NOT_VALID_WORD.
223        return ProbabilityUtils::backoff(unigramProbability);
224    }
225    if (multiBigramMap) {
226        return multiBigramMap->getBigramProbability(
227                binaryDictionaryInfo, prevWordPos, wordPos, unigramProbability);
228    }
229    return ProbabilityUtils::backoff(unigramProbability);
230}
231
232///////////////////////////////////////
233// Bigram / Unigram dictionary utils //
234///////////////////////////////////////
235
236/* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState,
237        const int pointIndex, const bool exactOnly, const int nodeCodePoint) {
238    if (!pInfoState) {
239        return true;
240    }
241    if (exactOnly) {
242        return pInfoState->getPrimaryCodePointAt(pointIndex) == nodeCodePoint;
243    }
244    const ProximityType matchedId = pInfoState->getProximityType(pointIndex, nodeCodePoint,
245            true /* checkProximityChars */);
246    return isProximityChar(matchedId);
247}
248
249////////////////
250// Char utils //
251////////////////
252
253// TODO: Move to char_utils?
254/* static */ int DicNodeUtils::appendTwoWords(const int *const src0, const int16_t length0,
255        const int *const src1, const int16_t length1, int *dest) {
256    int actualLength0 = 0;
257    for (int i = 0; i < length0; ++i) {
258        if (src0[i] == 0) {
259            break;
260        }
261        actualLength0 = i + 1;
262    }
263    actualLength0 = min(actualLength0, MAX_WORD_LENGTH);
264    memcpy(dest, src0, actualLength0 * sizeof(dest[0]));
265    if (!src1 || length1 == 0) {
266        return actualLength0;
267    }
268    int actualLength1 = 0;
269    for (int i = 0; i < length1; ++i) {
270        if (src1[i] == 0) {
271            break;
272        }
273        actualLength1 = i + 1;
274    }
275    actualLength1 = min(actualLength1, MAX_WORD_LENGTH - actualLength0 - 1);
276    memcpy(&dest[actualLength0], src1, actualLength1 * sizeof(dest[0]));
277    return actualLength0 + actualLength1;
278}
279} // namespace latinime
280