dic_node_utils.cpp revision 9559dd2e30de288a9ff7069bfc59f8500b949a88
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <cstring>
18#include <vector>
19
20#include "binary_format.h"
21#include "dic_node.h"
22#include "dic_node_utils.h"
23#include "dic_node_vector.h"
24#include "multi_bigram_map.h"
25#include "proximity_info.h"
26#include "proximity_info_state.h"
27
28namespace latinime {
29
30///////////////////////////////
31// Node initialization utils //
32///////////////////////////////
33
34/* static */ void DicNodeUtils::initAsRoot(const int rootPos, const uint8_t *const dicRoot,
35        const int prevWordNodePos, DicNode *newRootNode) {
36    int curPos = rootPos;
37    const int pos = curPos;
38    const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &curPos);
39    const int childrenPos = curPos;
40    newRootNode->initAsRoot(pos, childrenPos, childrenCount, prevWordNodePos);
41}
42
43/*static */ void DicNodeUtils::initAsRootWithPreviousWord(const int rootPos,
44        const uint8_t *const dicRoot, DicNode *prevWordLastNode, DicNode *newRootNode) {
45    int curPos = rootPos;
46    const int pos = curPos;
47    const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &curPos);
48    const int childrenPos = curPos;
49    newRootNode->initAsRootWithPreviousWord(prevWordLastNode, pos, childrenPos, childrenCount);
50}
51
52/* static */ void DicNodeUtils::initByCopy(DicNode *srcNode, DicNode *destNode) {
53    destNode->initByCopy(srcNode);
54}
55
56///////////////////////////////////
57// Traverse node expansion utils //
58///////////////////////////////////
59
60/* static */ void DicNodeUtils::createAndGetPassingChildNode(DicNode *dicNode,
61        const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly,
62        DicNodeVector *childDicNodes) {
63    // Passing multiple chars node. No need to traverse child
64    const int codePoint = dicNode->getNodeTypedCodePoint();
65    const int baseLowerCaseCodePoint = toBaseLowerCase(codePoint);
66    const bool isMatch = isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, codePoint);
67    if (isMatch || isIntentionalOmissionCodePoint(baseLowerCaseCodePoint)) {
68        childDicNodes->pushPassingChild(dicNode);
69    }
70}
71
72/* static */ int DicNodeUtils::createAndGetLeavingChildNode(DicNode *dicNode, int pos,
73        const uint8_t *const dicRoot, const int terminalDepth, const ProximityInfoState *pInfoState,
74        const int pointIndex, const bool exactOnly, const std::vector<int> *const codePointsFilter,
75        const ProximityInfo *const pInfo, DicNodeVector *childDicNodes) {
76    int nextPos = pos;
77    const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos);
78    const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags));
79    const bool isTerminal = (0 != (BinaryFormat::FLAG_IS_TERMINAL & flags));
80    const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags);
81
82    int codePoint = BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos);
83    ASSERT(NOT_A_CODE_POINT != codePoint);
84    const int nodeCodePoint = codePoint;
85    // TODO: optimize this
86    int additionalWordBuf[MAX_WORD_LENGTH];
87    uint16_t additionalSubwordLength = 0;
88    additionalWordBuf[additionalSubwordLength++] = codePoint;
89
90    do {
91        const int nextCodePoint = hasMultipleChars
92                ? BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos) : NOT_A_CODE_POINT;
93        const bool isLastChar = (NOT_A_CODE_POINT == nextCodePoint);
94        if (!isLastChar) {
95            additionalWordBuf[additionalSubwordLength++] = nextCodePoint;
96        }
97        codePoint = nextCodePoint;
98    } while (NOT_A_CODE_POINT != codePoint);
99
100    const int probability =
101            isTerminal ? BinaryFormat::readProbabilityWithoutMovingPointer(dicRoot, pos) : -1;
102    pos = BinaryFormat::skipProbability(flags, pos);
103    int childrenPos = hasChildren ? BinaryFormat::readChildrenPosition(dicRoot, flags, pos) : 0;
104    const int attributesPos = BinaryFormat::skipChildrenPosition(flags, pos);
105    const int siblingPos = BinaryFormat::skipChildrenPosAndAttributes(dicRoot, flags, pos);
106
107    if (isDicNodeFilteredOut(nodeCodePoint, pInfo, codePointsFilter)) {
108        return siblingPos;
109    }
110    if (!isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, nodeCodePoint)) {
111        return siblingPos;
112    }
113    const int childrenCount = hasChildren
114            ? BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &childrenPos) : 0;
115    childDicNodes->pushLeavingChild(dicNode, nextPos, flags, childrenPos, attributesPos, siblingPos,
116            nodeCodePoint, childrenCount, probability, -1 /* bigramProbability */, isTerminal,
117            hasMultipleChars, hasChildren, additionalSubwordLength, additionalWordBuf);
118    return siblingPos;
119}
120
121/* static */ bool DicNodeUtils::isDicNodeFilteredOut(const int nodeCodePoint,
122        const ProximityInfo *const pInfo, const std::vector<int> *const codePointsFilter) {
123    const int filterSize = codePointsFilter ? codePointsFilter->size() : 0;
124    if (filterSize <= 0) {
125        return false;
126    }
127    if (pInfo && (pInfo->getKeyIndexOf(nodeCodePoint) == NOT_AN_INDEX
128            || isIntentionalOmissionCodePoint(nodeCodePoint))) {
129        // If normalized nodeCodePoint is not on the keyboard or skippable, this child is never
130        // filtered.
131        return false;
132    }
133    const int lowerCodePoint = toLowerCase(nodeCodePoint);
134    const int baseLowerCodePoint = toBaseCodePoint(lowerCodePoint);
135    // TODO: Avoid linear search
136    for (int i = 0; i < filterSize; ++i) {
137        // Checking if a normalized code point is in filter characters when pInfo is not
138        // null. When pInfo is null, nodeCodePoint is used to check filtering without
139        // normalizing.
140        if ((pInfo && ((*codePointsFilter)[i] == lowerCodePoint
141                || (*codePointsFilter)[i] == baseLowerCodePoint))
142                        || (!pInfo && (*codePointsFilter)[i] == nodeCodePoint)) {
143            return false;
144        }
145    }
146    return true;
147}
148
149/* static */ void DicNodeUtils::createAndGetAllLeavingChildNodes(DicNode *dicNode,
150        const uint8_t *const dicRoot, const ProximityInfoState *pInfoState, const int pointIndex,
151        const bool exactOnly, const std::vector<int> *const codePointsFilter,
152        const ProximityInfo *const pInfo, DicNodeVector *childDicNodes) {
153    const int terminalDepth = dicNode->getLeavingDepth();
154    const int childCount = dicNode->getChildrenCount();
155    int nextPos = dicNode->getChildrenPos();
156    for (int i = 0; i < childCount; i++) {
157        const int filterSize = codePointsFilter ? codePointsFilter->size() : 0;
158        nextPos = createAndGetLeavingChildNode(dicNode, nextPos, dicRoot, terminalDepth, pInfoState,
159                pointIndex, exactOnly, codePointsFilter, pInfo, childDicNodes);
160        if (!pInfo && filterSize > 0 && childDicNodes->exceeds(filterSize)) {
161            // All code points have been found.
162            break;
163        }
164    }
165}
166
167/* static */ void DicNodeUtils::getAllChildDicNodes(DicNode *dicNode, const uint8_t *const dicRoot,
168        DicNodeVector *childDicNodes) {
169    getProximityChildDicNodes(dicNode, dicRoot, 0, 0, false, childDicNodes);
170}
171
172/* static */ void DicNodeUtils::getProximityChildDicNodes(DicNode *dicNode,
173        const uint8_t *const dicRoot, const ProximityInfoState *pInfoState, const int pointIndex,
174        bool exactOnly, DicNodeVector *childDicNodes) {
175    if (dicNode->isTotalInputSizeExceedingLimit()) {
176        return;
177    }
178    if (!dicNode->isLeavingNode()) {
179        DicNodeUtils::createAndGetPassingChildNode(dicNode, pInfoState, pointIndex, exactOnly,
180                childDicNodes);
181    } else {
182        DicNodeUtils::createAndGetAllLeavingChildNodes(dicNode, dicRoot, pInfoState, pointIndex,
183                exactOnly, 0 /* codePointsFilter */, 0 /* pInfo */,
184                childDicNodes);
185    }
186}
187
188///////////////////
189// Scoring utils //
190///////////////////
191/**
192 * Computes the combined bigram / unigram cost for the given dicNode.
193 */
194/* static */ float DicNodeUtils::getBigramNodeImprobability(const uint8_t *const dicRoot,
195        const DicNode *const node, MultiBigramMap *multiBigramMap) {
196    if (node->isImpossibleBigramWord()) {
197        return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
198    }
199    const int probability = getBigramNodeProbability(dicRoot, node, multiBigramMap);
200    // TODO: This equation to calculate the improbability looks unreasonable.  Investigate this.
201    const float cost = static_cast<float>(MAX_PROBABILITY - probability)
202            / static_cast<float>(MAX_PROBABILITY);
203    return cost;
204}
205
206/* static */ int DicNodeUtils::getBigramNodeProbability(const uint8_t *const dicRoot,
207        const DicNode *const node, MultiBigramMap *multiBigramMap) {
208    const int unigramProbability = node->getProbability();
209    const int wordPos = node->getPos();
210    const int prevWordPos = node->getPrevWordPos();
211    if (NOT_VALID_WORD == wordPos || NOT_VALID_WORD == prevWordPos) {
212        // Note: Normally wordPos comes from the dictionary and should never equal NOT_VALID_WORD.
213        return backoff(unigramProbability);
214    }
215    if (multiBigramMap) {
216        return multiBigramMap->getBigramProbability(
217                dicRoot, prevWordPos, wordPos, unigramProbability);
218    }
219    return BinaryFormat::getBigramProbability(dicRoot, prevWordPos, wordPos, unigramProbability);
220}
221
222///////////////////////////////////////
223// Bigram / Unigram dictionary utils //
224///////////////////////////////////////
225
226/* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState,
227        const int pointIndex, const bool exactOnly, const int nodeCodePoint) {
228    if (!pInfoState) {
229        return true;
230    }
231    if (exactOnly) {
232        return pInfoState->getPrimaryCodePointAt(pointIndex) == nodeCodePoint;
233    }
234    const ProximityType matchedId = pInfoState->getProximityType(pointIndex, nodeCodePoint,
235            true /* checkProximityChars */);
236    return isProximityChar(matchedId);
237}
238
239////////////////
240// Char utils //
241////////////////
242
243// TODO: Move to char_utils?
244/* static */ int DicNodeUtils::appendTwoWords(const int *const src0, const int16_t length0,
245        const int *const src1, const int16_t length1, int *dest) {
246    int actualLength0 = 0;
247    for (int i = 0; i < length0; ++i) {
248        if (src0[i] == 0) {
249            break;
250        }
251        actualLength0 = i + 1;
252    }
253    actualLength0 = min(actualLength0, MAX_WORD_LENGTH);
254    memcpy(dest, src0, actualLength0 * sizeof(dest[0]));
255    if (!src1 || length1 == 0) {
256        return actualLength0;
257    }
258    int actualLength1 = 0;
259    for (int i = 0; i < length1; ++i) {
260        if (src1[i] == 0) {
261            break;
262        }
263        actualLength1 = i + 1;
264    }
265    actualLength1 = min(actualLength1, MAX_WORD_LENGTH - actualLength0 - 1);
266    memcpy(&dest[actualLength0], src1, actualLength1 * sizeof(dest[0]));
267    return actualLength0 + actualLength1;
268}
269} // namespace latinime
270