dic_node.h revision 67ff21f3217c9f2ff81beac6f29f5c35a83da228
1/* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#ifndef LATINIME_DIC_NODE_H 18#define LATINIME_DIC_NODE_H 19 20#include "defines.h" 21#include "suggest/core/dicnode/dic_node_profiler.h" 22#include "suggest/core/dicnode/dic_node_utils.h" 23#include "suggest/core/dicnode/internal/dic_node_state.h" 24#include "suggest/core/dicnode/internal/dic_node_properties.h" 25#include "suggest/core/dictionary/digraph_utils.h" 26#include "suggest/core/dictionary/error_type_utils.h" 27#include "suggest/core/layout/proximity_info_state.h" 28#include "utils/char_utils.h" 29 30#if DEBUG_DICT 31#define LOGI_SHOW_ADD_COST_PROP \ 32 do { \ 33 char charBuf[50]; \ 34 INTS_TO_CHARS(getOutputWordBuf(), getNodeCodePointCount(), charBuf, NELEMS(charBuf)); \ 35 AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \ 36 __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \ 37 getInputIndex(0), getNormalizedCompoundDistance(), charBuf); \ 38 } while (0) 39#define DUMP_WORD_AND_SCORE(header) \ 40 do { \ 41 char charBuf[50]; \ 42 INTS_TO_CHARS(getOutputWordBuf(), \ 43 getNodeCodePointCount() \ 44 + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength(), \ 45 charBuf, NELEMS(charBuf)); \ 46 AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %d, %5f,", header, \ 47 getSpatialDistanceForScoring(), getLanguageDistanceForScoring(), \ 48 getNormalizedCompoundDistance(), getRawLength(), charBuf, \ 49 getInputIndex(0), getNormalizedCompoundDistanceAfterFirstWord()); \ 50 } while (0) 51#else 52#define LOGI_SHOW_ADD_COST_PROP 53#define DUMP_WORD_AND_SCORE(header) 54#endif 55 56namespace latinime { 57 58// This struct is purely a bucket to return values. No instances of this struct should be kept. 59struct DicNode_InputStateG { 60 DicNode_InputStateG() 61 : mNeedsToUpdateInputStateG(false), mPointerId(0), mInputIndex(0), 62 mPrevCodePoint(0), mTerminalDiffCost(0.0f), mRawLength(0.0f), 63 mDoubleLetterLevel(NOT_A_DOUBLE_LETTER) {} 64 65 bool mNeedsToUpdateInputStateG; 66 int mPointerId; 67 int16_t mInputIndex; 68 int mPrevCodePoint; 69 float mTerminalDiffCost; 70 float mRawLength; 71 DoubleLetterLevel mDoubleLetterLevel; 72}; 73 74class DicNode { 75 // Caveat: We define Weighting as a friend class of DicNode to let Weighting change 76 // the distance of DicNode. 77 // Caution!!! In general, we avoid using the "friend" access modifier. 78 // This is an exception to explicitly hide DicNode::addCost() from all classes but Weighting. 79 friend class Weighting; 80 81 public: 82#if DEBUG_DICT 83 DicNodeProfiler mProfiler; 84#endif 85 86 AK_FORCE_INLINE DicNode() 87 : 88#if DEBUG_DICT 89 mProfiler(), 90#endif 91 mDicNodeProperties(), mDicNodeState(), mIsCachedForNextSuggestion(false) {} 92 93 DicNode(const DicNode &dicNode); 94 DicNode &operator=(const DicNode &dicNode); 95 ~DicNode() {} 96 97 // Init for copy 98 void initByCopy(const DicNode *const dicNode) { 99 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 100 mDicNodeProperties.initByCopy(&dicNode->mDicNodeProperties); 101 mDicNodeState.initByCopy(&dicNode->mDicNodeState); 102 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 103 } 104 105 // Init for root with prevWordPtNodePos which is used for bigram 106 void initAsRoot(const int rootPtNodeArrayPos, const int prevWordPtNodePos) { 107 mIsCachedForNextSuggestion = false; 108 mDicNodeProperties.init(rootPtNodeArrayPos, prevWordPtNodePos); 109 mDicNodeState.init(); 110 PROF_NODE_RESET(mProfiler); 111 } 112 113 // Init for root with previous word 114 void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) { 115 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 116 mDicNodeProperties.init(rootPtNodeArrayPos, dicNode->mDicNodeProperties.getPtNodePos()); 117 mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState, 118 dicNode->mDicNodeProperties.getDepth()); 119 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 120 } 121 122 void initAsPassingChild(DicNode *parentDicNode) { 123 mIsCachedForNextSuggestion = parentDicNode->mIsCachedForNextSuggestion; 124 const int parentCodePoint = parentDicNode->getNodeTypedCodePoint(); 125 mDicNodeProperties.init(&parentDicNode->mDicNodeProperties, parentCodePoint); 126 mDicNodeState.initByCopy(&parentDicNode->mDicNodeState); 127 PROF_NODE_COPY(&parentDicNode->mProfiler, mProfiler); 128 } 129 130 void initAsChild(const DicNode *const dicNode, const int ptNodePos, 131 const int childrenPtNodeArrayPos, const int probability, const bool isTerminal, 132 const bool hasChildren, const bool isBlacklistedOrNotAWord, 133 const uint16_t mergedNodeCodePointCount, const int *const mergedNodeCodePoints) { 134 uint16_t newDepth = static_cast<uint16_t>(dicNode->getNodeCodePointCount() + 1); 135 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 136 const uint16_t newLeavingDepth = static_cast<uint16_t>( 137 dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount); 138 mDicNodeProperties.init(ptNodePos, childrenPtNodeArrayPos, mergedNodeCodePoints[0], 139 probability, isTerminal, hasChildren, isBlacklistedOrNotAWord, newDepth, 140 newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordTerminalPtNodePos()); 141 mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount, 142 mergedNodeCodePoints); 143 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 144 } 145 146 bool isRoot() const { 147 return getNodeCodePointCount() == 0; 148 } 149 150 bool hasChildren() const { 151 return mDicNodeProperties.hasChildren(); 152 } 153 154 bool isLeavingNode() const { 155 ASSERT(getNodeCodePointCount() <= mDicNodeProperties.getLeavingDepth()); 156 return getNodeCodePointCount() == mDicNodeProperties.getLeavingDepth(); 157 } 158 159 AK_FORCE_INLINE bool isFirstLetter() const { 160 return getNodeCodePointCount() == 1; 161 } 162 163 bool isCached() const { 164 return mIsCachedForNextSuggestion; 165 } 166 167 void setCached() { 168 mIsCachedForNextSuggestion = true; 169 } 170 171 // Used to expand the node in DicNodeUtils 172 int getNodeTypedCodePoint() const { 173 return mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(getNodeCodePointCount()); 174 } 175 176 // Check if the current word and the previous word can be considered as a valid multiple word 177 // suggestion. 178 bool isValidMultipleWordSuggestion() const { 179 if (isBlacklistedOrNotAWord()) { 180 return false; 181 } 182 // Treat suggestion as invalid if the current and the previous word are single character 183 // words. 184 const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength() 185 - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1; 186 const int currentWordLen = getNodeCodePointCount(); 187 return (prevWordLen != 1 || currentWordLen != 1); 188 } 189 190 bool isFirstCharUppercase() const { 191 const int c = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(0); 192 return CharUtils::isAsciiUpper(c); 193 } 194 195 bool isCompletion(const int inputSize) const { 196 return mDicNodeState.mDicNodeStateInput.getInputIndex(0) >= inputSize; 197 } 198 199 bool canDoLookAheadCorrection(const int inputSize) const { 200 return mDicNodeState.mDicNodeStateInput.getInputIndex(0) < inputSize - 1; 201 } 202 203 // Used to get bigram probability in DicNodeUtils 204 int getPtNodePos() const { 205 return mDicNodeProperties.getPtNodePos(); 206 } 207 208 // Used to get bigram probability in DicNodeUtils 209 int getPrevWordTerminalPtNodePos() const { 210 return mDicNodeProperties.getPrevWordTerminalPtNodePos(); 211 } 212 213 // Used in DicNodeUtils 214 int getChildrenPtNodeArrayPos() const { 215 return mDicNodeProperties.getChildrenPtNodeArrayPos(); 216 } 217 218 int getProbability() const { 219 return mDicNodeProperties.getProbability(); 220 } 221 222 AK_FORCE_INLINE bool isTerminalDicNode() const { 223 const bool isTerminalPtNode = mDicNodeProperties.isTerminal(); 224 const int currentDicNodeDepth = getNodeCodePointCount(); 225 const int terminalDicNodeDepth = mDicNodeProperties.getLeavingDepth(); 226 return isTerminalPtNode && currentDicNodeDepth > 0 227 && currentDicNodeDepth == terminalDicNodeDepth; 228 } 229 230 bool shouldBeFilteredBySafetyNetForBigram() const { 231 const uint16_t currentDepth = getNodeCodePointCount(); 232 const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength() 233 - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1; 234 return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1)); 235 } 236 237 bool hasMatchedOrProximityCodePoints() const { 238 // This DicNode does not have matched or proximity code points when all code points have 239 // been handled as edit corrections or completion so far. 240 const int editCorrectionCount = mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount(); 241 const int completionCount = mDicNodeState.mDicNodeStateScoring.getCompletionCount(); 242 return (editCorrectionCount + completionCount) < getNodeCodePointCount(); 243 } 244 245 bool isTotalInputSizeExceedingLimit() const { 246 const int prevWordsLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength(); 247 const int currentWordDepth = getNodeCodePointCount(); 248 // TODO: 3 can be 2? Needs to be investigated. 249 // TODO: Have a const variable for 3 (or 2) 250 return prevWordsLen + currentWordDepth > MAX_WORD_LENGTH - 3; 251 } 252 253 void outputResult(int *dest) const { 254 const uint16_t prevWordLength = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength(); 255 const uint16_t currentDepth = getNodeCodePointCount(); 256 memmove(dest, getOutputWordBuf(), (prevWordLength + currentDepth) * sizeof(dest[0])); 257 DUMP_WORD_AND_SCORE("OUTPUT"); 258 } 259 260 // "Total" in this context (and other methods in this class) means the whole suggestion. When 261 // this represents a multi-word suggestion, the referenced PtNode (in mDicNodeState) is only 262 // the one that corresponds to the last word of the suggestion, and all the previous words 263 // are concatenated together in mDicNodeStateOutput. 264 int getTotalNodeSpaceCount() const { 265 if (!hasMultipleWords()) { 266 return 0; 267 } 268 return CharUtils::getSpaceCount(mDicNodeState.mDicNodeStateOutput.getCodePointBuf(), 269 mDicNodeState.mDicNodeStateOutput.getPrevWordsLength()); 270 } 271 272 int getSecondWordFirstInputIndex(const ProximityInfoState *const pInfoState) const { 273 const int inputIndex = mDicNodeState.mDicNodeStateOutput.getSecondWordFirstInputIndex(); 274 if (inputIndex == NOT_AN_INDEX) { 275 return NOT_AN_INDEX; 276 } else { 277 return pInfoState->getInputIndexOfSampledPoint(inputIndex); 278 } 279 } 280 281 bool hasMultipleWords() const { 282 return mDicNodeState.mDicNodeStateOutput.getPrevWordCount() > 0; 283 } 284 285 int getProximityCorrectionCount() const { 286 return mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount(); 287 } 288 289 int getEditCorrectionCount() const { 290 return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount(); 291 } 292 293 // Used to prune nodes 294 float getNormalizedCompoundDistance() const { 295 return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistance(); 296 } 297 298 // Used to prune nodes 299 float getNormalizedSpatialDistance() const { 300 return mDicNodeState.mDicNodeStateScoring.getSpatialDistance() 301 / static_cast<float>(getInputIndex(0) + 1); 302 } 303 304 // Used to prune nodes 305 float getCompoundDistance() const { 306 return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(); 307 } 308 309 // Used to prune nodes 310 float getCompoundDistance(const float languageWeight) const { 311 return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight); 312 } 313 314 // Used to commit input partially 315 int getPrevWordPtNodePos() const { 316 return mDicNodeProperties.getPrevWordTerminalPtNodePos(); 317 } 318 319 AK_FORCE_INLINE const int *getOutputWordBuf() const { 320 return mDicNodeState.mDicNodeStateOutput.getCodePointBuf(); 321 } 322 323 int getPrevCodePointG(int pointerId) const { 324 return mDicNodeState.mDicNodeStateInput.getPrevCodePoint(pointerId); 325 } 326 327 // Whether the current codepoint can be an intentional omission, in which case the traversal 328 // algorithm will always check for a possible omission here. 329 bool canBeIntentionalOmission() const { 330 return CharUtils::isIntentionalOmissionCodePoint(getNodeCodePoint()); 331 } 332 333 // Whether the omission is so frequent that it should incur zero cost. 334 bool isZeroCostOmission() const { 335 // TODO: do not hardcode and read from header 336 return (getNodeCodePoint() == KEYCODE_SINGLE_QUOTE); 337 } 338 339 // TODO: remove 340 float getTerminalDiffCostG(int path) const { 341 return mDicNodeState.mDicNodeStateInput.getTerminalDiffCost(path); 342 } 343 344 ////////////////////// 345 // Temporary getter // 346 // TODO: Remove // 347 ////////////////////// 348 // TODO: Remove once touch path is merged into ProximityInfoState 349 // Note: Returned codepoint may be a digraph codepoint if the node is in a composite glyph. 350 int getNodeCodePoint() const { 351 const int codePoint = mDicNodeProperties.getDicNodeCodePoint(); 352 const DigraphUtils::DigraphCodePointIndex digraphIndex = 353 mDicNodeState.mDicNodeStateScoring.getDigraphIndex(); 354 if (digraphIndex == DigraphUtils::NOT_A_DIGRAPH_INDEX) { 355 return codePoint; 356 } 357 return DigraphUtils::getDigraphCodePointForIndex(codePoint, digraphIndex); 358 } 359 360 //////////////////////////////// 361 // Utils for cost calculation // 362 //////////////////////////////// 363 AK_FORCE_INLINE bool isSameNodeCodePoint(const DicNode *const dicNode) const { 364 return mDicNodeProperties.getDicNodeCodePoint() 365 == dicNode->mDicNodeProperties.getDicNodeCodePoint(); 366 } 367 368 // TODO: remove 369 // TODO: rename getNextInputIndex 370 int16_t getInputIndex(int pointerId) const { 371 return mDicNodeState.mDicNodeStateInput.getInputIndex(pointerId); 372 } 373 374 //////////////////////////////////// 375 // Getter of features for scoring // 376 //////////////////////////////////// 377 float getSpatialDistanceForScoring() const { 378 return mDicNodeState.mDicNodeStateScoring.getSpatialDistance(); 379 } 380 381 float getLanguageDistanceForScoring() const { 382 return mDicNodeState.mDicNodeStateScoring.getLanguageDistance(); 383 } 384 385 // For space-aware gestures, we store the normalized distance at the char index 386 // that ends the first word of the suggestion. We call this the distance after 387 // first word. 388 float getNormalizedCompoundDistanceAfterFirstWord() const { 389 return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistanceAfterFirstWord(); 390 } 391 392 float getLanguageDistanceRatePerWordForScoring() const { 393 const float langDist = getLanguageDistanceForScoring(); 394 const float totalWordCount = 395 static_cast<float>(mDicNodeState.mDicNodeStateOutput.getPrevWordCount() + 1); 396 return langDist / totalWordCount; 397 } 398 399 float getRawLength() const { 400 return mDicNodeState.mDicNodeStateScoring.getRawLength(); 401 } 402 403 bool isLessThanOneErrorForScoring() const { 404 return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount() 405 + mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount() <= 1; 406 } 407 408 DoubleLetterLevel getDoubleLetterLevel() const { 409 return mDicNodeState.mDicNodeStateScoring.getDoubleLetterLevel(); 410 } 411 412 void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) { 413 mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel); 414 } 415 416 bool isInDigraph() const { 417 return mDicNodeState.mDicNodeStateScoring.getDigraphIndex() 418 != DigraphUtils::NOT_A_DIGRAPH_INDEX; 419 } 420 421 void advanceDigraphIndex() { 422 mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex(); 423 } 424 425 ErrorTypeUtils::ErrorType getContainedErrorTypes() const { 426 return mDicNodeState.mDicNodeStateScoring.getContainedErrorTypes(); 427 } 428 429 bool isBlacklistedOrNotAWord() const { 430 return mDicNodeProperties.isBlacklistedOrNotAWord(); 431 } 432 433 inline uint16_t getNodeCodePointCount() const { 434 return mDicNodeProperties.getDepth(); 435 } 436 437 // Returns code point count including spaces 438 inline uint16_t getTotalNodeCodePointCount() const { 439 return getNodeCodePointCount() + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength(); 440 } 441 442 AK_FORCE_INLINE void dump(const char *tag) const { 443#if DEBUG_DICT 444 DUMP_WORD_AND_SCORE(tag); 445#if DEBUG_DUMP_ERROR 446 mProfiler.dump(); 447#endif 448#endif 449 } 450 451 AK_FORCE_INLINE bool compare(const DicNode *right) const { 452 // Promote exact matches to prevent them from being pruned. 453 const bool leftExactMatch = ErrorTypeUtils::isExactMatch(getContainedErrorTypes()); 454 const bool rightExactMatch = ErrorTypeUtils::isExactMatch(right->getContainedErrorTypes()); 455 if (leftExactMatch != rightExactMatch) { 456 return leftExactMatch; 457 } 458 const float diff = 459 right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance(); 460 static const float MIN_DIFF = 0.000001f; 461 if (diff > MIN_DIFF) { 462 return true; 463 } else if (diff < -MIN_DIFF) { 464 return false; 465 } 466 const int depth = getNodeCodePointCount(); 467 const int depthDiff = right->getNodeCodePointCount() - depth; 468 if (depthDiff != 0) { 469 return depthDiff > 0; 470 } 471 for (int i = 0; i < depth; ++i) { 472 const int codePoint = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i); 473 const int rightCodePoint = 474 right->mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i); 475 if (codePoint != rightCodePoint) { 476 return rightCodePoint > codePoint; 477 } 478 } 479 // Compare pointer values here for stable comparison 480 return this > right; 481 } 482 483 private: 484 DicNodeProperties mDicNodeProperties; 485 DicNodeState mDicNodeState; 486 // TODO: Remove 487 bool mIsCachedForNextSuggestion; 488 489 AK_FORCE_INLINE int getTotalInputIndex() const { 490 int index = 0; 491 for (int i = 0; i < MAX_POINTER_COUNT_G; i++) { 492 index += mDicNodeState.mDicNodeStateInput.getInputIndex(i); 493 } 494 return index; 495 } 496 497 // Caveat: Must not be called outside Weighting 498 // This restriction is guaranteed by "friend" 499 AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost, 500 const bool doNormalization, const int inputSize, 501 const ErrorTypeUtils::ErrorType errorType) { 502 if (DEBUG_GEO_FULL) { 503 LOGI_SHOW_ADD_COST_PROP; 504 } 505 mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization, 506 inputSize, getTotalInputIndex(), errorType); 507 } 508 509 // Saves the current normalized compound distance for space-aware gestures. 510 // See getNormalizedCompoundDistanceAfterFirstWord for details. 511 AK_FORCE_INLINE void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() { 512 mDicNodeState.mDicNodeStateScoring.saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet(); 513 } 514 515 // Caveat: Must not be called outside Weighting 516 // This restriction is guaranteed by "friend" 517 AK_FORCE_INLINE void forwardInputIndex(const int pointerId, const int count, 518 const bool overwritesPrevCodePointByNodeCodePoint) { 519 if (count == 0) { 520 return; 521 } 522 mDicNodeState.mDicNodeStateInput.forwardInputIndex(pointerId, count); 523 if (overwritesPrevCodePointByNodeCodePoint) { 524 mDicNodeState.mDicNodeStateInput.setPrevCodePoint(0, getNodeCodePoint()); 525 } 526 } 527 528 AK_FORCE_INLINE void updateInputIndexG(const DicNode_InputStateG *const inputStateG) { 529 if (mDicNodeState.mDicNodeStateOutput.getPrevWordCount() == 1 && isFirstLetter()) { 530 mDicNodeState.mDicNodeStateOutput.setSecondWordFirstInputIndex( 531 inputStateG->mInputIndex); 532 } 533 mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId, 534 inputStateG->mInputIndex, inputStateG->mPrevCodePoint, 535 inputStateG->mTerminalDiffCost, inputStateG->mRawLength); 536 mDicNodeState.mDicNodeStateScoring.addRawLength(inputStateG->mRawLength); 537 mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(inputStateG->mDoubleLetterLevel); 538 } 539}; 540} // namespace latinime 541#endif // LATINIME_DIC_NODE_H 542