dic_node_utils.cpp revision a71ed8caa27c4a0174f25750171282980bc26880
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <cstring>
18#include <vector>
19
20#include "suggest/core/dicnode/dic_node.h"
21#include "suggest/core/dicnode/dic_node_utils.h"
22#include "suggest/core/dicnode/dic_node_vector.h"
23#include "suggest/core/dictionary/binary_dictionary_info.h"
24#include "suggest/core/dictionary/binary_format.h"
25#include "suggest/core/dictionary/multi_bigram_map.h"
26#include "suggest/core/dictionary/probability_utils.h"
27#include "suggest/core/layout/proximity_info.h"
28#include "suggest/core/layout/proximity_info_state.h"
29#include "utils/char_utils.h"
30
31namespace latinime {
32
33///////////////////////////////
34// Node initialization utils //
35///////////////////////////////
36
37/* static */ void DicNodeUtils::initAsRoot(const BinaryDictionaryInfo *const binaryDictionaryInfo,
38        const int prevWordNodePos, DicNode *const newRootNode) {
39    int curPos = binaryDictionaryInfo->getRootPosition();
40    const int pos = curPos;
41    const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer(
42            binaryDictionaryInfo->getDictRoot(), &curPos);
43    const int childrenPos = curPos;
44    newRootNode->initAsRoot(pos, childrenPos, childrenCount, prevWordNodePos);
45}
46
47/*static */ void DicNodeUtils::initAsRootWithPreviousWord(
48        const BinaryDictionaryInfo *const binaryDictionaryInfo,
49        DicNode *const prevWordLastNode, DicNode *const newRootNode) {
50    int curPos = binaryDictionaryInfo->getRootPosition();
51    const int pos = curPos;
52    const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer(
53            binaryDictionaryInfo->getDictRoot(), &curPos);
54    const int childrenPos = curPos;
55    newRootNode->initAsRootWithPreviousWord(prevWordLastNode, pos, childrenPos, childrenCount);
56}
57
58/* static */ void DicNodeUtils::initByCopy(DicNode *srcNode, DicNode *destNode) {
59    destNode->initByCopy(srcNode);
60}
61
62///////////////////////////////////
63// Traverse node expansion utils //
64///////////////////////////////////
65
66/* static */ void DicNodeUtils::createAndGetPassingChildNode(DicNode *dicNode,
67        const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly,
68        DicNodeVector *childDicNodes) {
69    // Passing multiple chars node. No need to traverse child
70    const int codePoint = dicNode->getNodeTypedCodePoint();
71    const int baseLowerCaseCodePoint = CharUtils::toBaseLowerCase(codePoint);
72    const bool isMatch = isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, codePoint);
73    if (isMatch || CharUtils::isIntentionalOmissionCodePoint(baseLowerCaseCodePoint)) {
74        childDicNodes->pushPassingChild(dicNode);
75    }
76}
77
78/* static */ int DicNodeUtils::createAndGetLeavingChildNode(DicNode *dicNode, int pos,
79        const BinaryDictionaryInfo *const binaryDictionaryInfo, const int terminalDepth,
80        const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly,
81        const std::vector<int> *const codePointsFilter, const ProximityInfo *const pInfo,
82        DicNodeVector *childDicNodes) {
83    int nextPos = pos;
84    const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(
85            binaryDictionaryInfo->getDictRoot(), &pos);
86    const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags));
87    const bool isTerminal = (0 != (BinaryFormat::FLAG_IS_TERMINAL & flags));
88    const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags);
89
90    int codePoint = BinaryFormat::getCodePointAndForwardPointer(
91            binaryDictionaryInfo->getDictRoot(), &pos);
92    ASSERT(NOT_A_CODE_POINT != codePoint);
93    const int nodeCodePoint = codePoint;
94    // TODO: optimize this
95    int additionalWordBuf[MAX_WORD_LENGTH];
96    uint16_t additionalSubwordLength = 0;
97    additionalWordBuf[additionalSubwordLength++] = codePoint;
98
99    do {
100        const int nextCodePoint = hasMultipleChars
101                ? BinaryFormat::getCodePointAndForwardPointer(
102                        binaryDictionaryInfo->getDictRoot(), &pos) : NOT_A_CODE_POINT;
103        const bool isLastChar = (NOT_A_CODE_POINT == nextCodePoint);
104        if (!isLastChar) {
105            additionalWordBuf[additionalSubwordLength++] = nextCodePoint;
106        }
107        codePoint = nextCodePoint;
108    } while (NOT_A_CODE_POINT != codePoint);
109
110    const int probability = isTerminal ? BinaryFormat::readProbabilityWithoutMovingPointer(
111            binaryDictionaryInfo->getDictRoot(), pos) : -1;
112    pos = BinaryFormat::skipProbability(flags, pos);
113    int childrenPos = hasChildren ? BinaryFormat::readChildrenPosition(
114            binaryDictionaryInfo->getDictRoot(), flags, pos) : 0;
115    const int attributesPos = BinaryFormat::skipChildrenPosition(flags, pos);
116    const int siblingPos = BinaryFormat::skipChildrenPosAndAttributes(
117            binaryDictionaryInfo->getDictRoot(), flags, pos);
118
119    if (isDicNodeFilteredOut(nodeCodePoint, pInfo, codePointsFilter)) {
120        return siblingPos;
121    }
122    if (!isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, nodeCodePoint)) {
123        return siblingPos;
124    }
125    const int childrenCount = hasChildren ? BinaryFormat::getGroupCountAndForwardPointer(
126            binaryDictionaryInfo->getDictRoot(), &childrenPos) : 0;
127    childDicNodes->pushLeavingChild(dicNode, nextPos, flags, childrenPos, attributesPos, siblingPos,
128            nodeCodePoint, childrenCount, probability, -1 /* bigramProbability */, isTerminal,
129            hasMultipleChars, hasChildren, additionalSubwordLength, additionalWordBuf);
130    return siblingPos;
131}
132
133/* static */ bool DicNodeUtils::isDicNodeFilteredOut(const int nodeCodePoint,
134        const ProximityInfo *const pInfo, const std::vector<int> *const codePointsFilter) {
135    const int filterSize = codePointsFilter ? codePointsFilter->size() : 0;
136    if (filterSize <= 0) {
137        return false;
138    }
139    if (pInfo && (pInfo->getKeyIndexOf(nodeCodePoint) == NOT_AN_INDEX
140            || CharUtils::isIntentionalOmissionCodePoint(nodeCodePoint))) {
141        // If normalized nodeCodePoint is not on the keyboard or skippable, this child is never
142        // filtered.
143        return false;
144    }
145    const int lowerCodePoint = CharUtils::toLowerCase(nodeCodePoint);
146    const int baseLowerCodePoint = CharUtils::toBaseCodePoint(lowerCodePoint);
147    // TODO: Avoid linear search
148    for (int i = 0; i < filterSize; ++i) {
149        // Checking if a normalized code point is in filter characters when pInfo is not
150        // null. When pInfo is null, nodeCodePoint is used to check filtering without
151        // normalizing.
152        if ((pInfo && ((*codePointsFilter)[i] == lowerCodePoint
153                || (*codePointsFilter)[i] == baseLowerCodePoint))
154                        || (!pInfo && (*codePointsFilter)[i] == nodeCodePoint)) {
155            return false;
156        }
157    }
158    return true;
159}
160
161/* static */ void DicNodeUtils::createAndGetAllLeavingChildNodes(DicNode *dicNode,
162        const BinaryDictionaryInfo *const binaryDictionaryInfo,
163        const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly,
164        const std::vector<int> *const codePointsFilter, const ProximityInfo *const pInfo,
165        DicNodeVector *childDicNodes) {
166    const int terminalDepth = dicNode->getLeavingDepth();
167    const int childCount = dicNode->getChildrenCount();
168    int nextPos = dicNode->getChildrenPos();
169    for (int i = 0; i < childCount; i++) {
170        const int filterSize = codePointsFilter ? codePointsFilter->size() : 0;
171        nextPos = createAndGetLeavingChildNode(dicNode, nextPos, binaryDictionaryInfo,
172                terminalDepth, pInfoState, pointIndex, exactOnly, codePointsFilter, pInfo,
173                childDicNodes);
174        if (!pInfo && filterSize > 0 && childDicNodes->exceeds(filterSize)) {
175            // All code points have been found.
176            break;
177        }
178    }
179}
180
181/* static */ void DicNodeUtils::getAllChildDicNodes(DicNode *dicNode,
182        const BinaryDictionaryInfo *const binaryDictionaryInfo, DicNodeVector *childDicNodes) {
183    getProximityChildDicNodes(dicNode, binaryDictionaryInfo, 0, 0, false, childDicNodes);
184}
185
186/* static */ void DicNodeUtils::getProximityChildDicNodes(DicNode *dicNode,
187        const BinaryDictionaryInfo *const binaryDictionaryInfo,
188        const ProximityInfoState *pInfoState, const int pointIndex, bool exactOnly,
189        DicNodeVector *childDicNodes) {
190    if (dicNode->isTotalInputSizeExceedingLimit()) {
191        return;
192    }
193    if (!dicNode->isLeavingNode()) {
194        DicNodeUtils::createAndGetPassingChildNode(dicNode, pInfoState, pointIndex, exactOnly,
195                childDicNodes);
196    } else {
197        DicNodeUtils::createAndGetAllLeavingChildNodes(
198                dicNode, binaryDictionaryInfo, pInfoState, pointIndex, exactOnly,
199                0 /* codePointsFilter */, 0 /* pInfo */, childDicNodes);
200    }
201}
202
203///////////////////
204// Scoring utils //
205///////////////////
206/**
207 * Computes the combined bigram / unigram cost for the given dicNode.
208 */
209/* static */ float DicNodeUtils::getBigramNodeImprobability(
210        const BinaryDictionaryInfo *const binaryDictionaryInfo,
211        const DicNode *const node, MultiBigramMap *multiBigramMap) {
212    if (node->isImpossibleBigramWord()) {
213        return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
214    }
215    const int probability = getBigramNodeProbability(binaryDictionaryInfo, node, multiBigramMap);
216    // TODO: This equation to calculate the improbability looks unreasonable.  Investigate this.
217    const float cost = static_cast<float>(MAX_PROBABILITY - probability)
218            / static_cast<float>(MAX_PROBABILITY);
219    return cost;
220}
221
222/* static */ int DicNodeUtils::getBigramNodeProbability(
223        const BinaryDictionaryInfo *const binaryDictionaryInfo,
224        const DicNode *const node, MultiBigramMap *multiBigramMap) {
225    const int unigramProbability = node->getProbability();
226    const int wordPos = node->getPos();
227    const int prevWordPos = node->getPrevWordPos();
228    if (NOT_VALID_WORD == wordPos || NOT_VALID_WORD == prevWordPos) {
229        // Note: Normally wordPos comes from the dictionary and should never equal NOT_VALID_WORD.
230        return ProbabilityUtils::backoff(unigramProbability);
231    }
232    if (multiBigramMap) {
233        return multiBigramMap->getBigramProbability(
234                binaryDictionaryInfo, prevWordPos, wordPos, unigramProbability);
235    }
236    return ProbabilityUtils::backoff(unigramProbability);
237}
238
239///////////////////////////////////////
240// Bigram / Unigram dictionary utils //
241///////////////////////////////////////
242
243/* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState,
244        const int pointIndex, const bool exactOnly, const int nodeCodePoint) {
245    if (!pInfoState) {
246        return true;
247    }
248    if (exactOnly) {
249        return pInfoState->getPrimaryCodePointAt(pointIndex) == nodeCodePoint;
250    }
251    const ProximityType matchedId = pInfoState->getProximityType(pointIndex, nodeCodePoint,
252            true /* checkProximityChars */);
253    return isProximityChar(matchedId);
254}
255
256////////////////
257// Char utils //
258////////////////
259
260// TODO: Move to char_utils?
261/* static */ int DicNodeUtils::appendTwoWords(const int *const src0, const int16_t length0,
262        const int *const src1, const int16_t length1, int *dest) {
263    int actualLength0 = 0;
264    for (int i = 0; i < length0; ++i) {
265        if (src0[i] == 0) {
266            break;
267        }
268        actualLength0 = i + 1;
269    }
270    actualLength0 = min(actualLength0, MAX_WORD_LENGTH);
271    memcpy(dest, src0, actualLength0 * sizeof(dest[0]));
272    if (!src1 || length1 == 0) {
273        return actualLength0;
274    }
275    int actualLength1 = 0;
276    for (int i = 0; i < length1; ++i) {
277        if (src1[i] == 0) {
278            break;
279        }
280        actualLength1 = i + 1;
281    }
282    actualLength1 = min(actualLength1, MAX_WORD_LENGTH - actualLength0 - 1);
283    memcpy(&dest[actualLength0], src1, actualLength1 * sizeof(dest[0]));
284    return actualLength0 + actualLength1;
285}
286} // namespace latinime
287