dic_node.h revision 9f8c9a0161924f515c5ff9617db2317cdc1d01e2
1/* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#ifndef LATINIME_DIC_NODE_H 18#define LATINIME_DIC_NODE_H 19 20#include "defines.h" 21#include "suggest/core/dicnode/dic_node_profiler.h" 22#include "suggest/core/dicnode/dic_node_utils.h" 23#include "suggest/core/dicnode/internal/dic_node_state.h" 24#include "suggest/core/dicnode/internal/dic_node_properties.h" 25#include "suggest/core/dictionary/digraph_utils.h" 26#include "suggest/core/dictionary/error_type_utils.h" 27#include "suggest/core/layout/proximity_info_state.h" 28#include "utils/char_utils.h" 29 30#if DEBUG_DICT 31#define LOGI_SHOW_ADD_COST_PROP \ 32 do { \ 33 char charBuf[50]; \ 34 INTS_TO_CHARS(getOutputWordBuf(), getNodeCodePointCount(), charBuf, NELEMS(charBuf)); \ 35 AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \ 36 __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \ 37 getInputIndex(0), getNormalizedCompoundDistance(), charBuf); \ 38 } while (0) 39#define DUMP_WORD_AND_SCORE(header) \ 40 do { \ 41 char charBuf[50]; \ 42 INTS_TO_CHARS(getOutputWordBuf(), \ 43 getNodeCodePointCount() \ 44 + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength(), \ 45 charBuf, NELEMS(charBuf)); \ 46 AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %d, %5f,", header, \ 47 getSpatialDistanceForScoring(), \ 48 mDicNodeState.mDicNodeStateScoring.getLanguageDistance(), \ 49 getNormalizedCompoundDistance(), getRawLength(), charBuf, \ 50 getInputIndex(0), getNormalizedCompoundDistanceAfterFirstWord()); \ 51 } while (0) 52#else 53#define LOGI_SHOW_ADD_COST_PROP 54#define DUMP_WORD_AND_SCORE(header) 55#endif 56 57namespace latinime { 58 59// This struct is purely a bucket to return values. No instances of this struct should be kept. 60struct DicNode_InputStateG { 61 DicNode_InputStateG() 62 : mNeedsToUpdateInputStateG(false), mPointerId(0), mInputIndex(0), 63 mPrevCodePoint(0), mTerminalDiffCost(0.0f), mRawLength(0.0f), 64 mDoubleLetterLevel(NOT_A_DOUBLE_LETTER) {} 65 66 bool mNeedsToUpdateInputStateG; 67 int mPointerId; 68 int16_t mInputIndex; 69 int mPrevCodePoint; 70 float mTerminalDiffCost; 71 float mRawLength; 72 DoubleLetterLevel mDoubleLetterLevel; 73}; 74 75class DicNode { 76 // Caveat: We define Weighting as a friend class of DicNode to let Weighting change 77 // the distance of DicNode. 78 // Caution!!! In general, we avoid using the "friend" access modifier. 79 // This is an exception to explicitly hide DicNode::addCost() from all classes but Weighting. 80 friend class Weighting; 81 82 public: 83#if DEBUG_DICT 84 DicNodeProfiler mProfiler; 85#endif 86 87 AK_FORCE_INLINE DicNode() 88 : 89#if DEBUG_DICT 90 mProfiler(), 91#endif 92 mDicNodeProperties(), mDicNodeState(), mIsCachedForNextSuggestion(false) {} 93 94 DicNode(const DicNode &dicNode); 95 DicNode &operator=(const DicNode &dicNode); 96 ~DicNode() {} 97 98 // Init for copy 99 void initByCopy(const DicNode *const dicNode) { 100 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 101 mDicNodeProperties.initByCopy(&dicNode->mDicNodeProperties); 102 mDicNodeState.initByCopy(&dicNode->mDicNodeState); 103 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 104 } 105 106 // Init for root with prevWordsPtNodePos which is used for n-gram 107 void initAsRoot(const int rootPtNodeArrayPos, const int *const prevWordsPtNodePos) { 108 mIsCachedForNextSuggestion = false; 109 mDicNodeProperties.init(rootPtNodeArrayPos, prevWordsPtNodePos); 110 mDicNodeState.init(); 111 PROF_NODE_RESET(mProfiler); 112 } 113 114 // Init for root with previous word 115 void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) { 116 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 117 int newPrevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; 118 newPrevWordsPtNodePos[0] = dicNode->mDicNodeProperties.getPtNodePos(); 119 for (size_t i = 1; i < NELEMS(newPrevWordsPtNodePos); ++i) { 120 newPrevWordsPtNodePos[i] = dicNode->getNthPrevWordTerminalPtNodePos(i); 121 } 122 mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordsPtNodePos); 123 mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState, 124 dicNode->mDicNodeProperties.getDepth()); 125 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 126 } 127 128 void initAsPassingChild(DicNode *parentDicNode) { 129 mIsCachedForNextSuggestion = parentDicNode->mIsCachedForNextSuggestion; 130 const int codePoint = 131 parentDicNode->mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt( 132 parentDicNode->getNodeCodePointCount()); 133 mDicNodeProperties.init(&parentDicNode->mDicNodeProperties, codePoint); 134 mDicNodeState.initByCopy(&parentDicNode->mDicNodeState); 135 PROF_NODE_COPY(&parentDicNode->mProfiler, mProfiler); 136 } 137 138 void initAsChild(const DicNode *const dicNode, const int ptNodePos, 139 const int childrenPtNodeArrayPos, const int probability, const bool isTerminal, 140 const bool hasChildren, const bool isBlacklistedOrNotAWord, 141 const uint16_t mergedNodeCodePointCount, const int *const mergedNodeCodePoints) { 142 uint16_t newDepth = static_cast<uint16_t>(dicNode->getNodeCodePointCount() + 1); 143 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 144 const uint16_t newLeavingDepth = static_cast<uint16_t>( 145 dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount); 146 mDicNodeProperties.init(ptNodePos, childrenPtNodeArrayPos, mergedNodeCodePoints[0], 147 probability, isTerminal, hasChildren, isBlacklistedOrNotAWord, newDepth, 148 newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordsTerminalPtNodePos()); 149 mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount, 150 mergedNodeCodePoints); 151 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 152 } 153 154 bool isRoot() const { 155 return getNodeCodePointCount() == 0; 156 } 157 158 bool hasChildren() const { 159 return mDicNodeProperties.hasChildren(); 160 } 161 162 bool isLeavingNode() const { 163 ASSERT(getNodeCodePointCount() <= mDicNodeProperties.getLeavingDepth()); 164 return getNodeCodePointCount() == mDicNodeProperties.getLeavingDepth(); 165 } 166 167 AK_FORCE_INLINE bool isFirstLetter() const { 168 return getNodeCodePointCount() == 1; 169 } 170 171 bool isCached() const { 172 return mIsCachedForNextSuggestion; 173 } 174 175 void setCached() { 176 mIsCachedForNextSuggestion = true; 177 } 178 179 // Check if the current word and the previous word can be considered as a valid multiple word 180 // suggestion. 181 bool isValidMultipleWordSuggestion() const { 182 if (isBlacklistedOrNotAWord()) { 183 return false; 184 } 185 // Treat suggestion as invalid if the current and the previous word are single character 186 // words. 187 const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength() 188 - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1; 189 const int currentWordLen = getNodeCodePointCount(); 190 return (prevWordLen != 1 || currentWordLen != 1); 191 } 192 193 bool isFirstCharUppercase() const { 194 const int c = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(0); 195 return CharUtils::isAsciiUpper(c); 196 } 197 198 bool isCompletion(const int inputSize) const { 199 return mDicNodeState.mDicNodeStateInput.getInputIndex(0) >= inputSize; 200 } 201 202 bool canDoLookAheadCorrection(const int inputSize) const { 203 return mDicNodeState.mDicNodeStateInput.getInputIndex(0) < inputSize - 1; 204 } 205 206 // Used to get n-gram probability in DicNodeUtils. 207 int getPtNodePos() const { 208 return mDicNodeProperties.getPtNodePos(); 209 } 210 211 // Used to get n-gram probability in DicNodeUtils. n is 1-indexed. 212 int getNthPrevWordTerminalPtNodePos(const int n) const { 213 if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { 214 return NOT_A_DICT_POS; 215 } 216 return mDicNodeProperties.getPrevWordsTerminalPtNodePos()[n - 1]; 217 } 218 219 // Used in DicNodeUtils 220 int getChildrenPtNodeArrayPos() const { 221 return mDicNodeProperties.getChildrenPtNodeArrayPos(); 222 } 223 224 int getProbability() const { 225 return mDicNodeProperties.getProbability(); 226 } 227 228 AK_FORCE_INLINE bool isTerminalDicNode() const { 229 const bool isTerminalPtNode = mDicNodeProperties.isTerminal(); 230 const int currentDicNodeDepth = getNodeCodePointCount(); 231 const int terminalDicNodeDepth = mDicNodeProperties.getLeavingDepth(); 232 return isTerminalPtNode && currentDicNodeDepth > 0 233 && currentDicNodeDepth == terminalDicNodeDepth; 234 } 235 236 bool shouldBeFilteredBySafetyNetForBigram() const { 237 const uint16_t currentDepth = getNodeCodePointCount(); 238 const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength() 239 - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1; 240 return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1)); 241 } 242 243 bool hasMatchedOrProximityCodePoints() const { 244 // This DicNode does not have matched or proximity code points when all code points have 245 // been handled as edit corrections or completion so far. 246 const int editCorrectionCount = mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount(); 247 const int completionCount = mDicNodeState.mDicNodeStateScoring.getCompletionCount(); 248 return (editCorrectionCount + completionCount) < getNodeCodePointCount(); 249 } 250 251 bool isTotalInputSizeExceedingLimit() const { 252 // TODO: 3 can be 2? Needs to be investigated. 253 // TODO: Have a const variable for 3 (or 2) 254 return getTotalNodeCodePointCount() > MAX_WORD_LENGTH - 3; 255 } 256 257 void outputResult(int *dest) const { 258 memmove(dest, getOutputWordBuf(), getTotalNodeCodePointCount() * sizeof(dest[0])); 259 DUMP_WORD_AND_SCORE("OUTPUT"); 260 } 261 262 // "Total" in this context (and other methods in this class) means the whole suggestion. When 263 // this represents a multi-word suggestion, the referenced PtNode (in mDicNodeState) is only 264 // the one that corresponds to the last word of the suggestion, and all the previous words 265 // are concatenated together in mDicNodeStateOutput. 266 int getTotalNodeSpaceCount() const { 267 if (!hasMultipleWords()) { 268 return 0; 269 } 270 return CharUtils::getSpaceCount(mDicNodeState.mDicNodeStateOutput.getCodePointBuf(), 271 mDicNodeState.mDicNodeStateOutput.getPrevWordsLength()); 272 } 273 274 int getSecondWordFirstInputIndex(const ProximityInfoState *const pInfoState) const { 275 const int inputIndex = mDicNodeState.mDicNodeStateOutput.getSecondWordFirstInputIndex(); 276 if (inputIndex == NOT_AN_INDEX) { 277 return NOT_AN_INDEX; 278 } else { 279 return pInfoState->getInputIndexOfSampledPoint(inputIndex); 280 } 281 } 282 283 bool hasMultipleWords() const { 284 return mDicNodeState.mDicNodeStateOutput.getPrevWordCount() > 0; 285 } 286 287 int getProximityCorrectionCount() const { 288 return mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount(); 289 } 290 291 int getEditCorrectionCount() const { 292 return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount(); 293 } 294 295 // Used to prune nodes 296 float getNormalizedCompoundDistance() const { 297 return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistance(); 298 } 299 300 // Used to prune nodes 301 float getNormalizedSpatialDistance() const { 302 return mDicNodeState.mDicNodeStateScoring.getSpatialDistance() 303 / static_cast<float>(getInputIndex(0) + 1); 304 } 305 306 // Used to prune nodes 307 float getCompoundDistance() const { 308 return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(); 309 } 310 311 // Used to prune nodes 312 float getCompoundDistance(const float languageWeight) const { 313 return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight); 314 } 315 316 AK_FORCE_INLINE const int *getOutputWordBuf() const { 317 return mDicNodeState.mDicNodeStateOutput.getCodePointBuf(); 318 } 319 320 int getPrevCodePointG(int pointerId) const { 321 return mDicNodeState.mDicNodeStateInput.getPrevCodePoint(pointerId); 322 } 323 324 // Whether the current codepoint can be an intentional omission, in which case the traversal 325 // algorithm will always check for a possible omission here. 326 bool canBeIntentionalOmission() const { 327 return CharUtils::isIntentionalOmissionCodePoint(getNodeCodePoint()); 328 } 329 330 // Whether the omission is so frequent that it should incur zero cost. 331 bool isZeroCostOmission() const { 332 // TODO: do not hardcode and read from header 333 return (getNodeCodePoint() == KEYCODE_SINGLE_QUOTE); 334 } 335 336 // TODO: remove 337 float getTerminalDiffCostG(int path) const { 338 return mDicNodeState.mDicNodeStateInput.getTerminalDiffCost(path); 339 } 340 341 ////////////////////// 342 // Temporary getter // 343 // TODO: Remove // 344 ////////////////////// 345 // TODO: Remove once touch path is merged into ProximityInfoState 346 // Note: Returned codepoint may be a digraph codepoint if the node is in a composite glyph. 347 int getNodeCodePoint() const { 348 const int codePoint = mDicNodeProperties.getDicNodeCodePoint(); 349 const DigraphUtils::DigraphCodePointIndex digraphIndex = 350 mDicNodeState.mDicNodeStateScoring.getDigraphIndex(); 351 if (digraphIndex == DigraphUtils::NOT_A_DIGRAPH_INDEX) { 352 return codePoint; 353 } 354 return DigraphUtils::getDigraphCodePointForIndex(codePoint, digraphIndex); 355 } 356 357 //////////////////////////////// 358 // Utils for cost calculation // 359 //////////////////////////////// 360 AK_FORCE_INLINE bool isSameNodeCodePoint(const DicNode *const dicNode) const { 361 return mDicNodeProperties.getDicNodeCodePoint() 362 == dicNode->mDicNodeProperties.getDicNodeCodePoint(); 363 } 364 365 // TODO: remove 366 // TODO: rename getNextInputIndex 367 int16_t getInputIndex(int pointerId) const { 368 return mDicNodeState.mDicNodeStateInput.getInputIndex(pointerId); 369 } 370 371 //////////////////////////////////// 372 // Getter of features for scoring // 373 //////////////////////////////////// 374 float getSpatialDistanceForScoring() const { 375 return mDicNodeState.mDicNodeStateScoring.getSpatialDistance(); 376 } 377 378 // For space-aware gestures, we store the normalized distance at the char index 379 // that ends the first word of the suggestion. We call this the distance after 380 // first word. 381 float getNormalizedCompoundDistanceAfterFirstWord() const { 382 return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistanceAfterFirstWord(); 383 } 384 385 float getRawLength() const { 386 return mDicNodeState.mDicNodeStateScoring.getRawLength(); 387 } 388 389 DoubleLetterLevel getDoubleLetterLevel() const { 390 return mDicNodeState.mDicNodeStateScoring.getDoubleLetterLevel(); 391 } 392 393 void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) { 394 mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel); 395 } 396 397 bool isInDigraph() const { 398 return mDicNodeState.mDicNodeStateScoring.getDigraphIndex() 399 != DigraphUtils::NOT_A_DIGRAPH_INDEX; 400 } 401 402 void advanceDigraphIndex() { 403 mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex(); 404 } 405 406 ErrorTypeUtils::ErrorType getContainedErrorTypes() const { 407 return mDicNodeState.mDicNodeStateScoring.getContainedErrorTypes(); 408 } 409 410 bool isBlacklistedOrNotAWord() const { 411 return mDicNodeProperties.isBlacklistedOrNotAWord(); 412 } 413 414 inline uint16_t getNodeCodePointCount() const { 415 return mDicNodeProperties.getDepth(); 416 } 417 418 // Returns code point count including spaces 419 inline uint16_t getTotalNodeCodePointCount() const { 420 return getNodeCodePointCount() + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength(); 421 } 422 423 AK_FORCE_INLINE void dump(const char *tag) const { 424#if DEBUG_DICT 425 DUMP_WORD_AND_SCORE(tag); 426#if DEBUG_DUMP_ERROR 427 mProfiler.dump(); 428#endif 429#endif 430 } 431 432 AK_FORCE_INLINE bool compare(const DicNode *right) const { 433 // Promote exact matches to prevent them from being pruned. 434 const bool leftExactMatch = ErrorTypeUtils::isExactMatch(getContainedErrorTypes()); 435 const bool rightExactMatch = ErrorTypeUtils::isExactMatch(right->getContainedErrorTypes()); 436 if (leftExactMatch != rightExactMatch) { 437 return leftExactMatch; 438 } 439 const float diff = 440 right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance(); 441 static const float MIN_DIFF = 0.000001f; 442 if (diff > MIN_DIFF) { 443 return true; 444 } else if (diff < -MIN_DIFF) { 445 return false; 446 } 447 const int depth = getNodeCodePointCount(); 448 const int depthDiff = right->getNodeCodePointCount() - depth; 449 if (depthDiff != 0) { 450 return depthDiff > 0; 451 } 452 for (int i = 0; i < depth; ++i) { 453 const int codePoint = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i); 454 const int rightCodePoint = 455 right->mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i); 456 if (codePoint != rightCodePoint) { 457 return rightCodePoint > codePoint; 458 } 459 } 460 // Compare pointer values here for stable comparison 461 return this > right; 462 } 463 464 private: 465 DicNodeProperties mDicNodeProperties; 466 DicNodeState mDicNodeState; 467 // TODO: Remove 468 bool mIsCachedForNextSuggestion; 469 470 AK_FORCE_INLINE int getTotalInputIndex() const { 471 int index = 0; 472 for (int i = 0; i < MAX_POINTER_COUNT_G; i++) { 473 index += mDicNodeState.mDicNodeStateInput.getInputIndex(i); 474 } 475 return index; 476 } 477 478 // Caveat: Must not be called outside Weighting 479 // This restriction is guaranteed by "friend" 480 AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost, 481 const bool doNormalization, const int inputSize, 482 const ErrorTypeUtils::ErrorType errorType) { 483 if (DEBUG_GEO_FULL) { 484 LOGI_SHOW_ADD_COST_PROP; 485 } 486 mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization, 487 inputSize, getTotalInputIndex(), errorType); 488 } 489 490 // Saves the current normalized compound distance for space-aware gestures. 491 // See getNormalizedCompoundDistanceAfterFirstWord for details. 492 AK_FORCE_INLINE void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() { 493 mDicNodeState.mDicNodeStateScoring.saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet(); 494 } 495 496 // Caveat: Must not be called outside Weighting 497 // This restriction is guaranteed by "friend" 498 AK_FORCE_INLINE void forwardInputIndex(const int pointerId, const int count, 499 const bool overwritesPrevCodePointByNodeCodePoint) { 500 if (count == 0) { 501 return; 502 } 503 mDicNodeState.mDicNodeStateInput.forwardInputIndex(pointerId, count); 504 if (overwritesPrevCodePointByNodeCodePoint) { 505 mDicNodeState.mDicNodeStateInput.setPrevCodePoint(0, getNodeCodePoint()); 506 } 507 } 508 509 AK_FORCE_INLINE void updateInputIndexG(const DicNode_InputStateG *const inputStateG) { 510 if (mDicNodeState.mDicNodeStateOutput.getPrevWordCount() == 1 && isFirstLetter()) { 511 mDicNodeState.mDicNodeStateOutput.setSecondWordFirstInputIndex( 512 inputStateG->mInputIndex); 513 } 514 mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId, 515 inputStateG->mInputIndex, inputStateG->mPrevCodePoint, 516 inputStateG->mTerminalDiffCost, inputStateG->mRawLength); 517 mDicNodeState.mDicNodeStateScoring.addRawLength(inputStateG->mRawLength); 518 mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(inputStateG->mDoubleLetterLevel); 519 } 520}; 521} // namespace latinime 522#endif // LATINIME_DIC_NODE_H 523