dic_node_utils.cpp revision a71ed8caa27c4a0174f25750171282980bc26880
1/* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include <cstring> 18#include <vector> 19 20#include "suggest/core/dicnode/dic_node.h" 21#include "suggest/core/dicnode/dic_node_utils.h" 22#include "suggest/core/dicnode/dic_node_vector.h" 23#include "suggest/core/dictionary/binary_dictionary_info.h" 24#include "suggest/core/dictionary/binary_format.h" 25#include "suggest/core/dictionary/multi_bigram_map.h" 26#include "suggest/core/dictionary/probability_utils.h" 27#include "suggest/core/layout/proximity_info.h" 28#include "suggest/core/layout/proximity_info_state.h" 29#include "utils/char_utils.h" 30 31namespace latinime { 32 33/////////////////////////////// 34// Node initialization utils // 35/////////////////////////////// 36 37/* static */ void DicNodeUtils::initAsRoot(const BinaryDictionaryInfo *const binaryDictionaryInfo, 38 const int prevWordNodePos, DicNode *const newRootNode) { 39 int curPos = binaryDictionaryInfo->getRootPosition(); 40 const int pos = curPos; 41 const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer( 42 binaryDictionaryInfo->getDictRoot(), &curPos); 43 const int childrenPos = curPos; 44 newRootNode->initAsRoot(pos, childrenPos, childrenCount, prevWordNodePos); 45} 46 47/*static */ void DicNodeUtils::initAsRootWithPreviousWord( 48 const BinaryDictionaryInfo *const binaryDictionaryInfo, 49 DicNode *const prevWordLastNode, DicNode *const newRootNode) { 50 int curPos = binaryDictionaryInfo->getRootPosition(); 51 const int pos = curPos; 52 const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer( 53 binaryDictionaryInfo->getDictRoot(), &curPos); 54 const int childrenPos = curPos; 55 newRootNode->initAsRootWithPreviousWord(prevWordLastNode, pos, childrenPos, childrenCount); 56} 57 58/* static */ void DicNodeUtils::initByCopy(DicNode *srcNode, DicNode *destNode) { 59 destNode->initByCopy(srcNode); 60} 61 62/////////////////////////////////// 63// Traverse node expansion utils // 64/////////////////////////////////// 65 66/* static */ void DicNodeUtils::createAndGetPassingChildNode(DicNode *dicNode, 67 const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, 68 DicNodeVector *childDicNodes) { 69 // Passing multiple chars node. No need to traverse child 70 const int codePoint = dicNode->getNodeTypedCodePoint(); 71 const int baseLowerCaseCodePoint = CharUtils::toBaseLowerCase(codePoint); 72 const bool isMatch = isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, codePoint); 73 if (isMatch || CharUtils::isIntentionalOmissionCodePoint(baseLowerCaseCodePoint)) { 74 childDicNodes->pushPassingChild(dicNode); 75 } 76} 77 78/* static */ int DicNodeUtils::createAndGetLeavingChildNode(DicNode *dicNode, int pos, 79 const BinaryDictionaryInfo *const binaryDictionaryInfo, const int terminalDepth, 80 const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, 81 const std::vector<int> *const codePointsFilter, const ProximityInfo *const pInfo, 82 DicNodeVector *childDicNodes) { 83 int nextPos = pos; 84 const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer( 85 binaryDictionaryInfo->getDictRoot(), &pos); 86 const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags)); 87 const bool isTerminal = (0 != (BinaryFormat::FLAG_IS_TERMINAL & flags)); 88 const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags); 89 90 int codePoint = BinaryFormat::getCodePointAndForwardPointer( 91 binaryDictionaryInfo->getDictRoot(), &pos); 92 ASSERT(NOT_A_CODE_POINT != codePoint); 93 const int nodeCodePoint = codePoint; 94 // TODO: optimize this 95 int additionalWordBuf[MAX_WORD_LENGTH]; 96 uint16_t additionalSubwordLength = 0; 97 additionalWordBuf[additionalSubwordLength++] = codePoint; 98 99 do { 100 const int nextCodePoint = hasMultipleChars 101 ? BinaryFormat::getCodePointAndForwardPointer( 102 binaryDictionaryInfo->getDictRoot(), &pos) : NOT_A_CODE_POINT; 103 const bool isLastChar = (NOT_A_CODE_POINT == nextCodePoint); 104 if (!isLastChar) { 105 additionalWordBuf[additionalSubwordLength++] = nextCodePoint; 106 } 107 codePoint = nextCodePoint; 108 } while (NOT_A_CODE_POINT != codePoint); 109 110 const int probability = isTerminal ? BinaryFormat::readProbabilityWithoutMovingPointer( 111 binaryDictionaryInfo->getDictRoot(), pos) : -1; 112 pos = BinaryFormat::skipProbability(flags, pos); 113 int childrenPos = hasChildren ? BinaryFormat::readChildrenPosition( 114 binaryDictionaryInfo->getDictRoot(), flags, pos) : 0; 115 const int attributesPos = BinaryFormat::skipChildrenPosition(flags, pos); 116 const int siblingPos = BinaryFormat::skipChildrenPosAndAttributes( 117 binaryDictionaryInfo->getDictRoot(), flags, pos); 118 119 if (isDicNodeFilteredOut(nodeCodePoint, pInfo, codePointsFilter)) { 120 return siblingPos; 121 } 122 if (!isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, nodeCodePoint)) { 123 return siblingPos; 124 } 125 const int childrenCount = hasChildren ? BinaryFormat::getGroupCountAndForwardPointer( 126 binaryDictionaryInfo->getDictRoot(), &childrenPos) : 0; 127 childDicNodes->pushLeavingChild(dicNode, nextPos, flags, childrenPos, attributesPos, siblingPos, 128 nodeCodePoint, childrenCount, probability, -1 /* bigramProbability */, isTerminal, 129 hasMultipleChars, hasChildren, additionalSubwordLength, additionalWordBuf); 130 return siblingPos; 131} 132 133/* static */ bool DicNodeUtils::isDicNodeFilteredOut(const int nodeCodePoint, 134 const ProximityInfo *const pInfo, const std::vector<int> *const codePointsFilter) { 135 const int filterSize = codePointsFilter ? codePointsFilter->size() : 0; 136 if (filterSize <= 0) { 137 return false; 138 } 139 if (pInfo && (pInfo->getKeyIndexOf(nodeCodePoint) == NOT_AN_INDEX 140 || CharUtils::isIntentionalOmissionCodePoint(nodeCodePoint))) { 141 // If normalized nodeCodePoint is not on the keyboard or skippable, this child is never 142 // filtered. 143 return false; 144 } 145 const int lowerCodePoint = CharUtils::toLowerCase(nodeCodePoint); 146 const int baseLowerCodePoint = CharUtils::toBaseCodePoint(lowerCodePoint); 147 // TODO: Avoid linear search 148 for (int i = 0; i < filterSize; ++i) { 149 // Checking if a normalized code point is in filter characters when pInfo is not 150 // null. When pInfo is null, nodeCodePoint is used to check filtering without 151 // normalizing. 152 if ((pInfo && ((*codePointsFilter)[i] == lowerCodePoint 153 || (*codePointsFilter)[i] == baseLowerCodePoint)) 154 || (!pInfo && (*codePointsFilter)[i] == nodeCodePoint)) { 155 return false; 156 } 157 } 158 return true; 159} 160 161/* static */ void DicNodeUtils::createAndGetAllLeavingChildNodes(DicNode *dicNode, 162 const BinaryDictionaryInfo *const binaryDictionaryInfo, 163 const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, 164 const std::vector<int> *const codePointsFilter, const ProximityInfo *const pInfo, 165 DicNodeVector *childDicNodes) { 166 const int terminalDepth = dicNode->getLeavingDepth(); 167 const int childCount = dicNode->getChildrenCount(); 168 int nextPos = dicNode->getChildrenPos(); 169 for (int i = 0; i < childCount; i++) { 170 const int filterSize = codePointsFilter ? codePointsFilter->size() : 0; 171 nextPos = createAndGetLeavingChildNode(dicNode, nextPos, binaryDictionaryInfo, 172 terminalDepth, pInfoState, pointIndex, exactOnly, codePointsFilter, pInfo, 173 childDicNodes); 174 if (!pInfo && filterSize > 0 && childDicNodes->exceeds(filterSize)) { 175 // All code points have been found. 176 break; 177 } 178 } 179} 180 181/* static */ void DicNodeUtils::getAllChildDicNodes(DicNode *dicNode, 182 const BinaryDictionaryInfo *const binaryDictionaryInfo, DicNodeVector *childDicNodes) { 183 getProximityChildDicNodes(dicNode, binaryDictionaryInfo, 0, 0, false, childDicNodes); 184} 185 186/* static */ void DicNodeUtils::getProximityChildDicNodes(DicNode *dicNode, 187 const BinaryDictionaryInfo *const binaryDictionaryInfo, 188 const ProximityInfoState *pInfoState, const int pointIndex, bool exactOnly, 189 DicNodeVector *childDicNodes) { 190 if (dicNode->isTotalInputSizeExceedingLimit()) { 191 return; 192 } 193 if (!dicNode->isLeavingNode()) { 194 DicNodeUtils::createAndGetPassingChildNode(dicNode, pInfoState, pointIndex, exactOnly, 195 childDicNodes); 196 } else { 197 DicNodeUtils::createAndGetAllLeavingChildNodes( 198 dicNode, binaryDictionaryInfo, pInfoState, pointIndex, exactOnly, 199 0 /* codePointsFilter */, 0 /* pInfo */, childDicNodes); 200 } 201} 202 203/////////////////// 204// Scoring utils // 205/////////////////// 206/** 207 * Computes the combined bigram / unigram cost for the given dicNode. 208 */ 209/* static */ float DicNodeUtils::getBigramNodeImprobability( 210 const BinaryDictionaryInfo *const binaryDictionaryInfo, 211 const DicNode *const node, MultiBigramMap *multiBigramMap) { 212 if (node->isImpossibleBigramWord()) { 213 return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); 214 } 215 const int probability = getBigramNodeProbability(binaryDictionaryInfo, node, multiBigramMap); 216 // TODO: This equation to calculate the improbability looks unreasonable. Investigate this. 217 const float cost = static_cast<float>(MAX_PROBABILITY - probability) 218 / static_cast<float>(MAX_PROBABILITY); 219 return cost; 220} 221 222/* static */ int DicNodeUtils::getBigramNodeProbability( 223 const BinaryDictionaryInfo *const binaryDictionaryInfo, 224 const DicNode *const node, MultiBigramMap *multiBigramMap) { 225 const int unigramProbability = node->getProbability(); 226 const int wordPos = node->getPos(); 227 const int prevWordPos = node->getPrevWordPos(); 228 if (NOT_VALID_WORD == wordPos || NOT_VALID_WORD == prevWordPos) { 229 // Note: Normally wordPos comes from the dictionary and should never equal NOT_VALID_WORD. 230 return ProbabilityUtils::backoff(unigramProbability); 231 } 232 if (multiBigramMap) { 233 return multiBigramMap->getBigramProbability( 234 binaryDictionaryInfo, prevWordPos, wordPos, unigramProbability); 235 } 236 return ProbabilityUtils::backoff(unigramProbability); 237} 238 239/////////////////////////////////////// 240// Bigram / Unigram dictionary utils // 241/////////////////////////////////////// 242 243/* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, 244 const int pointIndex, const bool exactOnly, const int nodeCodePoint) { 245 if (!pInfoState) { 246 return true; 247 } 248 if (exactOnly) { 249 return pInfoState->getPrimaryCodePointAt(pointIndex) == nodeCodePoint; 250 } 251 const ProximityType matchedId = pInfoState->getProximityType(pointIndex, nodeCodePoint, 252 true /* checkProximityChars */); 253 return isProximityChar(matchedId); 254} 255 256//////////////// 257// Char utils // 258//////////////// 259 260// TODO: Move to char_utils? 261/* static */ int DicNodeUtils::appendTwoWords(const int *const src0, const int16_t length0, 262 const int *const src1, const int16_t length1, int *dest) { 263 int actualLength0 = 0; 264 for (int i = 0; i < length0; ++i) { 265 if (src0[i] == 0) { 266 break; 267 } 268 actualLength0 = i + 1; 269 } 270 actualLength0 = min(actualLength0, MAX_WORD_LENGTH); 271 memcpy(dest, src0, actualLength0 * sizeof(dest[0])); 272 if (!src1 || length1 == 0) { 273 return actualLength0; 274 } 275 int actualLength1 = 0; 276 for (int i = 0; i < length1; ++i) { 277 if (src1[i] == 0) { 278 break; 279 } 280 actualLength1 = i + 1; 281 } 282 actualLength1 = min(actualLength1, MAX_WORD_LENGTH - actualLength0 - 1); 283 memcpy(&dest[actualLength0], src1, actualLength1 * sizeof(dest[0])); 284 return actualLength0 + actualLength1; 285} 286} // namespace latinime 287