dic_node_utils.cpp revision 6379a4de29fee7019b32b93bc424eda720e02dcf
1/* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include <cstring> 18#include <vector> 19 20#include "suggest/core/dicnode/dic_node.h" 21#include "suggest/core/dicnode/dic_node_utils.h" 22#include "suggest/core/dicnode/dic_node_vector.h" 23#include "suggest/core/dictionary/binary_dictionary_info.h" 24#include "suggest/core/dictionary/binary_format.h" 25#include "suggest/core/dictionary/multi_bigram_map.h" 26#include "suggest/core/dictionary/probability_utils.h" 27#include "suggest/core/layout/proximity_info.h" 28#include "suggest/core/layout/proximity_info_state.h" 29#include "utils/char_utils.h" 30 31namespace latinime { 32 33/////////////////////////////// 34// Node initialization utils // 35/////////////////////////////// 36 37/* static */ void DicNodeUtils::initAsRoot(const BinaryDictionaryInfo *const binaryDictionaryInfo, 38 const int prevWordNodePos, DicNode *const newRootNode) { 39 const int rootPos = binaryDictionaryInfo->getRootPosition(); 40 const int childrenPos = rootPos; 41 newRootNode->initAsRoot(rootPos, childrenPos, prevWordNodePos); 42} 43 44/*static */ void DicNodeUtils::initAsRootWithPreviousWord( 45 const BinaryDictionaryInfo *const binaryDictionaryInfo, 46 DicNode *const prevWordLastNode, DicNode *const newRootNode) { 47 const int rootPos = binaryDictionaryInfo->getRootPosition(); 48 const int childrenPos = rootPos; 49 newRootNode->initAsRootWithPreviousWord(prevWordLastNode, rootPos, childrenPos); 50} 51 52/* static */ void DicNodeUtils::initByCopy(DicNode *srcNode, DicNode *destNode) { 53 destNode->initByCopy(srcNode); 54} 55 56/////////////////////////////////// 57// Traverse node expansion utils // 58/////////////////////////////////// 59 60/* static */ void DicNodeUtils::createAndGetPassingChildNode(DicNode *dicNode, 61 const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, 62 DicNodeVector *childDicNodes) { 63 // Passing multiple chars node. No need to traverse child 64 const int codePoint = dicNode->getNodeTypedCodePoint(); 65 const int baseLowerCaseCodePoint = CharUtils::toBaseLowerCase(codePoint); 66 const bool isMatch = isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, codePoint); 67 if (isMatch || CharUtils::isIntentionalOmissionCodePoint(baseLowerCaseCodePoint)) { 68 childDicNodes->pushPassingChild(dicNode); 69 } 70} 71 72/* static */ int DicNodeUtils::createAndGetLeavingChildNode(DicNode *dicNode, int pos, 73 const BinaryDictionaryInfo *const binaryDictionaryInfo, 74 const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, 75 const std::vector<int> *const codePointsFilter, const ProximityInfo *const pInfo, 76 DicNodeVector *childDicNodes) { 77 int nextPos = pos; 78 const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer( 79 binaryDictionaryInfo->getDictRoot(), &pos); 80 const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags)); 81 const bool isTerminal = (0 != (BinaryFormat::FLAG_IS_TERMINAL & flags)); 82 const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags); 83 84 int codePoint = BinaryFormat::getCodePointAndForwardPointer( 85 binaryDictionaryInfo->getDictRoot(), &pos); 86 ASSERT(NOT_A_CODE_POINT != codePoint); 87 // TODO: optimize this 88 int mergedNodeCodePoints[MAX_WORD_LENGTH]; 89 uint16_t mergedNodeCodePointCount = 0; 90 mergedNodeCodePoints[mergedNodeCodePointCount++] = codePoint; 91 92 do { 93 const int nextCodePoint = hasMultipleChars 94 ? BinaryFormat::getCodePointAndForwardPointer( 95 binaryDictionaryInfo->getDictRoot(), &pos) : NOT_A_CODE_POINT; 96 const bool isLastChar = (NOT_A_CODE_POINT == nextCodePoint); 97 if (!isLastChar) { 98 mergedNodeCodePoints[mergedNodeCodePointCount++] = nextCodePoint; 99 } 100 codePoint = nextCodePoint; 101 } while (NOT_A_CODE_POINT != codePoint); 102 103 const int probability = isTerminal ? BinaryFormat::readProbabilityWithoutMovingPointer( 104 binaryDictionaryInfo->getDictRoot(), pos) : -1; 105 pos = BinaryFormat::skipProbability(flags, pos); 106 int childrenPos = hasChildren ? BinaryFormat::readChildrenPosition( 107 binaryDictionaryInfo->getDictRoot(), flags, pos) : 0; 108 const int attributesPos = BinaryFormat::skipChildrenPosition(flags, pos); 109 const int siblingPos = BinaryFormat::skipChildrenPosAndAttributes( 110 binaryDictionaryInfo->getDictRoot(), flags, pos); 111 112 if (isDicNodeFilteredOut(mergedNodeCodePoints[0], pInfo, codePointsFilter)) { 113 return siblingPos; 114 } 115 if (!isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, mergedNodeCodePoints[0])) { 116 return siblingPos; 117 } 118 childDicNodes->pushLeavingChild(dicNode, nextPos, flags, childrenPos, attributesPos, 119 probability, isTerminal, hasChildren, mergedNodeCodePointCount, mergedNodeCodePoints); 120 return siblingPos; 121} 122 123/* static */ bool DicNodeUtils::isDicNodeFilteredOut(const int nodeCodePoint, 124 const ProximityInfo *const pInfo, const std::vector<int> *const codePointsFilter) { 125 const int filterSize = codePointsFilter ? codePointsFilter->size() : 0; 126 if (filterSize <= 0) { 127 return false; 128 } 129 if (pInfo && (pInfo->getKeyIndexOf(nodeCodePoint) == NOT_AN_INDEX 130 || CharUtils::isIntentionalOmissionCodePoint(nodeCodePoint))) { 131 // If normalized nodeCodePoint is not on the keyboard or skippable, this child is never 132 // filtered. 133 return false; 134 } 135 const int lowerCodePoint = CharUtils::toLowerCase(nodeCodePoint); 136 const int baseLowerCodePoint = CharUtils::toBaseCodePoint(lowerCodePoint); 137 // TODO: Avoid linear search 138 for (int i = 0; i < filterSize; ++i) { 139 // Checking if a normalized code point is in filter characters when pInfo is not 140 // null. When pInfo is null, nodeCodePoint is used to check filtering without 141 // normalizing. 142 if ((pInfo && ((*codePointsFilter)[i] == lowerCodePoint 143 || (*codePointsFilter)[i] == baseLowerCodePoint)) 144 || (!pInfo && (*codePointsFilter)[i] == nodeCodePoint)) { 145 return false; 146 } 147 } 148 return true; 149} 150 151/* static */ void DicNodeUtils::createAndGetAllLeavingChildNodes(DicNode *dicNode, 152 const BinaryDictionaryInfo *const binaryDictionaryInfo, 153 const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, 154 const std::vector<int> *const codePointsFilter, const ProximityInfo *const pInfo, 155 DicNodeVector *childDicNodes) { 156 if (!dicNode->hasChildren()) { 157 return; 158 } 159 int nextPos = dicNode->getChildrenPos(); 160 const int childCount = BinaryFormat::getGroupCountAndForwardPointer( 161 binaryDictionaryInfo->getDictRoot(), &nextPos); 162 for (int i = 0; i < childCount; i++) { 163 const int filterSize = codePointsFilter ? codePointsFilter->size() : 0; 164 nextPos = createAndGetLeavingChildNode(dicNode, nextPos, binaryDictionaryInfo, 165 pInfoState, pointIndex, exactOnly, codePointsFilter, pInfo, 166 childDicNodes); 167 if (!pInfo && filterSize > 0 && childDicNodes->exceeds(filterSize)) { 168 // All code points have been found. 169 break; 170 } 171 } 172} 173 174/* static */ void DicNodeUtils::getAllChildDicNodes(DicNode *dicNode, 175 const BinaryDictionaryInfo *const binaryDictionaryInfo, DicNodeVector *childDicNodes) { 176 getProximityChildDicNodes(dicNode, binaryDictionaryInfo, 0, 0, false, childDicNodes); 177} 178 179/* static */ void DicNodeUtils::getProximityChildDicNodes(DicNode *dicNode, 180 const BinaryDictionaryInfo *const binaryDictionaryInfo, 181 const ProximityInfoState *pInfoState, const int pointIndex, bool exactOnly, 182 DicNodeVector *childDicNodes) { 183 if (dicNode->isTotalInputSizeExceedingLimit()) { 184 return; 185 } 186 if (!dicNode->isLeavingNode()) { 187 DicNodeUtils::createAndGetPassingChildNode(dicNode, pInfoState, pointIndex, exactOnly, 188 childDicNodes); 189 } else { 190 DicNodeUtils::createAndGetAllLeavingChildNodes( 191 dicNode, binaryDictionaryInfo, pInfoState, pointIndex, exactOnly, 192 0 /* codePointsFilter */, 0 /* pInfo */, childDicNodes); 193 } 194} 195 196/////////////////// 197// Scoring utils // 198/////////////////// 199/** 200 * Computes the combined bigram / unigram cost for the given dicNode. 201 */ 202/* static */ float DicNodeUtils::getBigramNodeImprobability( 203 const BinaryDictionaryInfo *const binaryDictionaryInfo, 204 const DicNode *const node, MultiBigramMap *multiBigramMap) { 205 if (node->isImpossibleBigramWord()) { 206 return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); 207 } 208 const int probability = getBigramNodeProbability(binaryDictionaryInfo, node, multiBigramMap); 209 // TODO: This equation to calculate the improbability looks unreasonable. Investigate this. 210 const float cost = static_cast<float>(MAX_PROBABILITY - probability) 211 / static_cast<float>(MAX_PROBABILITY); 212 return cost; 213} 214 215/* static */ int DicNodeUtils::getBigramNodeProbability( 216 const BinaryDictionaryInfo *const binaryDictionaryInfo, 217 const DicNode *const node, MultiBigramMap *multiBigramMap) { 218 const int unigramProbability = node->getProbability(); 219 const int wordPos = node->getPos(); 220 const int prevWordPos = node->getPrevWordPos(); 221 if (NOT_VALID_WORD == wordPos || NOT_VALID_WORD == prevWordPos) { 222 // Note: Normally wordPos comes from the dictionary and should never equal NOT_VALID_WORD. 223 return ProbabilityUtils::backoff(unigramProbability); 224 } 225 if (multiBigramMap) { 226 return multiBigramMap->getBigramProbability( 227 binaryDictionaryInfo, prevWordPos, wordPos, unigramProbability); 228 } 229 return ProbabilityUtils::backoff(unigramProbability); 230} 231 232/////////////////////////////////////// 233// Bigram / Unigram dictionary utils // 234/////////////////////////////////////// 235 236/* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, 237 const int pointIndex, const bool exactOnly, const int nodeCodePoint) { 238 if (!pInfoState) { 239 return true; 240 } 241 if (exactOnly) { 242 return pInfoState->getPrimaryCodePointAt(pointIndex) == nodeCodePoint; 243 } 244 const ProximityType matchedId = pInfoState->getProximityType(pointIndex, nodeCodePoint, 245 true /* checkProximityChars */); 246 return isProximityChar(matchedId); 247} 248 249//////////////// 250// Char utils // 251//////////////// 252 253// TODO: Move to char_utils? 254/* static */ int DicNodeUtils::appendTwoWords(const int *const src0, const int16_t length0, 255 const int *const src1, const int16_t length1, int *dest) { 256 int actualLength0 = 0; 257 for (int i = 0; i < length0; ++i) { 258 if (src0[i] == 0) { 259 break; 260 } 261 actualLength0 = i + 1; 262 } 263 actualLength0 = min(actualLength0, MAX_WORD_LENGTH); 264 memcpy(dest, src0, actualLength0 * sizeof(dest[0])); 265 if (!src1 || length1 == 0) { 266 return actualLength0; 267 } 268 int actualLength1 = 0; 269 for (int i = 0; i < length1; ++i) { 270 if (src1[i] == 0) { 271 break; 272 } 273 actualLength1 = i + 1; 274 } 275 actualLength1 = min(actualLength1, MAX_WORD_LENGTH - actualLength0 - 1); 276 memcpy(&dest[actualLength0], src1, actualLength1 * sizeof(dest[0])); 277 return actualLength0 + actualLength1; 278} 279} // namespace latinime 280