dic_node.h revision 38c26dd0bf8cd5c4511e4a02d5eeae4b3553f03a
1/* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#ifndef LATINIME_DIC_NODE_H 18#define LATINIME_DIC_NODE_H 19 20#include "char_utils.h" 21#include "defines.h" 22#include "dic_node_state.h" 23#include "dic_node_profiler.h" 24#include "dic_node_properties.h" 25#include "dic_node_release_listener.h" 26 27#if DEBUG_DICT 28#define LOGI_SHOW_ADD_COST_PROP \ 29 do { char charBuf[50]; \ 30 INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \ 31 AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \ 32 __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \ 33 getInputIndex(0), getNormalizedCompoundDistance(), charBuf); } while (0) 34#define DUMP_WORD_AND_SCORE(header) \ 35 do { char charBuf[50]; char prevWordCharBuf[50]; \ 36 INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \ 37 INTS_TO_CHARS(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, \ 38 mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), prevWordCharBuf); \ 39 AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %s, %d,,", header, \ 40 getSpatialDistanceForScoring(), getLanguageDistanceForScoring(), \ 41 getNormalizedCompoundDistance(), getRawLength(), prevWordCharBuf, charBuf, \ 42 getInputIndex(0)); \ 43 } while (0) 44#else 45#define LOGI_SHOW_ADD_COST_PROP 46#define DUMP_WORD_AND_SCORE(header) 47#endif 48 49namespace latinime { 50 51// Naming convention 52// - Distance: "Weighted" edit distance -- used both for spatial and language. 53// - Compound Distance: Spatial Distance + Language Distance -- used for pruning and scoring 54// - Cost: delta/diff for Distance -- used both for spatial and language 55// - Length: "Non-weighted" -- used only for spatial 56// - Probability: "Non-weighted" -- used only for language 57 58// This struct is purely a bucket to return values. No instances of this struct should be kept. 59struct DicNode_InputStateG { 60 bool mNeedsToUpdateInputStateG; 61 int mPointerId; 62 int16_t mInputIndex; 63 int mPrevCodePoint; 64 float mTerminalDiffCost; 65 float mRawLength; 66 DoubleLetterLevel mDoubleLetterLevel; 67}; 68 69class DicNode { 70 // Caveat: We define Weighting as a friend class of DicNode to let Weighting change 71 // the distance of DicNode. 72 // Caution!!! In general, we avoid using the "friend" access modifier. 73 // This is an exception to explicitly hide DicNode::addCost() from all classes but Weighting. 74 friend class Weighting; 75 76 public: 77#if DEBUG_DICT 78 DicNodeProfiler mProfiler; 79#endif 80 ////////////////// 81 // Memory utils // 82 ////////////////// 83 AK_FORCE_INLINE static void managedDelete(DicNode *node) { 84 node->remove(); 85 } 86 // end 87 ///////////////// 88 89 AK_FORCE_INLINE DicNode() 90 : 91#if DEBUG_DICT 92 mProfiler(), 93#endif 94 mDicNodeProperties(), mDicNodeState(), mIsCachedForNextSuggestion(false), 95 mIsUsed(false), mReleaseListener(0) {} 96 97 DicNode(const DicNode &dicNode); 98 DicNode &operator=(const DicNode &dicNode); 99 virtual ~DicNode() {} 100 101 // TODO: minimize arguments by looking binary_format 102 // Init for copy 103 void initByCopy(const DicNode *dicNode) { 104 mIsUsed = true; 105 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 106 mDicNodeProperties.init(&dicNode->mDicNodeProperties); 107 mDicNodeState.init(&dicNode->mDicNodeState); 108 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 109 } 110 111 // TODO: minimize arguments by looking binary_format 112 // Init for root with prevWordNodePos which is used for bigram 113 void initAsRoot(const int pos, const int childrenPos, const int childrenCount, 114 const int prevWordNodePos) { 115 mIsUsed = true; 116 mIsCachedForNextSuggestion = false; 117 mDicNodeProperties.init( 118 pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0); 119 mDicNodeState.init(prevWordNodePos); 120 PROF_NODE_RESET(mProfiler); 121 } 122 123 void initAsPassingChild(DicNode *parentNode) { 124 mIsUsed = true; 125 mIsCachedForNextSuggestion = parentNode->mIsCachedForNextSuggestion; 126 const int c = parentNode->getNodeTypedCodePoint(); 127 mDicNodeProperties.init(&parentNode->mDicNodeProperties, c); 128 mDicNodeState.init(&parentNode->mDicNodeState); 129 PROF_NODE_COPY(&parentNode->mProfiler, mProfiler); 130 } 131 132 // TODO: minimize arguments by looking binary_format 133 // Init for root with previous word 134 void initAsRootWithPreviousWord(DicNode *dicNode, const int pos, const int childrenPos, 135 const int childrenCount) { 136 mIsUsed = true; 137 mIsCachedForNextSuggestion = false; 138 mDicNodeProperties.init( 139 pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0); 140 // TODO: Move to dicNodeState? 141 mDicNodeState.mDicNodeStateOutput.init(); // reset for next word 142 mDicNodeState.mDicNodeStateInput.init( 143 &dicNode->mDicNodeState.mDicNodeStateInput, true /* resetTerminalDiffCost */); 144 mDicNodeState.mDicNodeStateScoring.init( 145 &dicNode->mDicNodeState.mDicNodeStateScoring); 146 mDicNodeState.mDicNodeStatePrevWord.init( 147 dicNode->mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() + 1, 148 dicNode->mDicNodeProperties.getProbability(), 149 dicNode->mDicNodeProperties.getPos(), 150 dicNode->mDicNodeState.mDicNodeStatePrevWord.mPrevWord, 151 dicNode->mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), 152 dicNode->getOutputWordBuf(), 153 dicNode->mDicNodeProperties.getDepth(), 154 dicNode->mDicNodeState.mDicNodeStatePrevWord.mPrevSpacePositions, 155 mDicNodeState.mDicNodeStateInput.getInputIndex(0) /* lastInputIndex */); 156 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 157 } 158 159 // TODO: minimize arguments by looking binary_format 160 void initAsChild(DicNode *dicNode, const int pos, const uint8_t flags, const int childrenPos, 161 const int attributesPos, const int siblingPos, const int nodeCodePoint, 162 const int childrenCount, const int probability, const int bigramProbability, 163 const bool isTerminal, const bool hasMultipleChars, const bool hasChildren, 164 const uint16_t additionalSubwordLength, const int *additionalSubword) { 165 mIsUsed = true; 166 uint16_t newDepth = static_cast<uint16_t>(dicNode->getDepth() + 1); 167 mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; 168 const uint16_t newLeavingDepth = static_cast<uint16_t>( 169 dicNode->mDicNodeProperties.getLeavingDepth() + additionalSubwordLength); 170 mDicNodeProperties.init(pos, flags, childrenPos, attributesPos, siblingPos, nodeCodePoint, 171 childrenCount, probability, bigramProbability, isTerminal, hasMultipleChars, 172 hasChildren, newDepth, newLeavingDepth); 173 mDicNodeState.init(&dicNode->mDicNodeState, additionalSubwordLength, additionalSubword); 174 PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); 175 } 176 177 AK_FORCE_INLINE void remove() { 178 mIsUsed = false; 179 if (mReleaseListener) { 180 mReleaseListener->onReleased(this); 181 } 182 } 183 184 bool isUsed() const { 185 return mIsUsed; 186 } 187 188 bool isRoot() const { 189 return getDepth() == 0; 190 } 191 192 bool hasChildren() const { 193 return mDicNodeProperties.hasChildren(); 194 } 195 196 bool isLeavingNode() const { 197 ASSERT(getDepth() <= getLeavingDepth()); 198 return getDepth() == getLeavingDepth(); 199 } 200 201 AK_FORCE_INLINE bool isFirstLetter() const { 202 return getDepth() == 1; 203 } 204 205 bool isCached() const { 206 return mIsCachedForNextSuggestion; 207 } 208 209 void setCached() { 210 mIsCachedForNextSuggestion = true; 211 } 212 213 // Used to expand the node in DicNodeUtils 214 int getNodeTypedCodePoint() const { 215 return mDicNodeState.mDicNodeStateOutput.getCodePointAt(getDepth()); 216 } 217 218 bool isImpossibleBigramWord() const { 219 const int probability = mDicNodeProperties.getProbability(); 220 if (probability == 0) { 221 return true; 222 } 223 const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength() 224 - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1; 225 const int currentWordLen = getDepth(); 226 return (prevWordLen == 1 && currentWordLen == 1); 227 } 228 229 bool isCapitalized() const { 230 const int c = getOutputWordBuf()[0]; 231 return isAsciiUpper(c); 232 } 233 234 bool isFirstWord() const { 235 return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos() == NOT_VALID_WORD; 236 } 237 238 bool isCompletion(const int inputSize) const { 239 return mDicNodeState.mDicNodeStateInput.getInputIndex(0) >= inputSize; 240 } 241 242 bool canDoLookAheadCorrection(const int inputSize) const { 243 return mDicNodeState.mDicNodeStateInput.getInputIndex(0) < inputSize - 1; 244 } 245 246 // Used to get bigram probability in DicNodeUtils 247 int getPos() const { 248 return mDicNodeProperties.getPos(); 249 } 250 251 // Used to get bigram probability in DicNodeUtils 252 int getPrevWordPos() const { 253 return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos(); 254 } 255 256 // Used in DicNodeUtils 257 int getChildrenPos() const { 258 return mDicNodeProperties.getChildrenPos(); 259 } 260 261 // Used in DicNodeUtils 262 int getChildrenCount() const { 263 return mDicNodeProperties.getChildrenCount(); 264 } 265 266 // Used in DicNodeUtils 267 int getProbability() const { 268 return mDicNodeProperties.getProbability(); 269 } 270 271 AK_FORCE_INLINE bool isTerminalWordNode() const { 272 const bool isTerminalNodes = mDicNodeProperties.isTerminal(); 273 const int currentNodeDepth = getDepth(); 274 const int terminalNodeDepth = mDicNodeProperties.getLeavingDepth(); 275 return isTerminalNodes && currentNodeDepth > 0 && currentNodeDepth == terminalNodeDepth; 276 } 277 278 bool shouldBeFilterdBySafetyNetForBigram() const { 279 const uint16_t currentDepth = getDepth(); 280 const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength() 281 - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1; 282 return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1)); 283 } 284 285 uint16_t getLeavingDepth() const { 286 return mDicNodeProperties.getLeavingDepth(); 287 } 288 289 bool isTotalInputSizeExceedingLimit() const { 290 const int prevWordsLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); 291 const int currentWordDepth = getDepth(); 292 // TODO: 3 can be 2? Needs to be investigated. 293 // TODO: Have a const variable for 3 (or 2) 294 return prevWordsLen + currentWordDepth > MAX_WORD_LENGTH - 3; 295 } 296 297 // TODO: This may be defective. Needs to be revised. 298 bool truncateNode(const DicNode *const topNode, const int inputCommitPoint) { 299 const int prevWordLenOfTop = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); 300 int newPrevWordStartIndex = inputCommitPoint; 301 int charCount = 0; 302 // Find new word start index 303 for (int i = 0; i < prevWordLenOfTop; ++i) { 304 const int c = mDicNodeState.mDicNodeStatePrevWord.getPrevWordCodePointAt(i); 305 // TODO: Check other separators. 306 if (c != KEYCODE_SPACE && c != KEYCODE_SINGLE_QUOTE) { 307 if (charCount == inputCommitPoint) { 308 newPrevWordStartIndex = i; 309 break; 310 } 311 ++charCount; 312 } 313 } 314 if (!mDicNodeState.mDicNodeStatePrevWord.startsWith( 315 &topNode->mDicNodeState.mDicNodeStatePrevWord, newPrevWordStartIndex - 1)) { 316 // Node mismatch. 317 return false; 318 } 319 mDicNodeState.mDicNodeStateInput.truncate(inputCommitPoint); 320 mDicNodeState.mDicNodeStatePrevWord.truncate(newPrevWordStartIndex); 321 return true; 322 } 323 324 void outputResult(int *dest) const { 325 const uint16_t prevWordLength = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); 326 const uint16_t currentDepth = getDepth(); 327 DicNodeUtils::appendTwoWords(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, 328 prevWordLength, getOutputWordBuf(), currentDepth, dest); 329 DUMP_WORD_AND_SCORE("OUTPUT"); 330 } 331 332 void outputSpacePositionsResult(int *spaceIndices) const { 333 mDicNodeState.mDicNodeStatePrevWord.outputSpacePositions(spaceIndices); 334 } 335 336 bool hasMultipleWords() const { 337 return mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() > 0; 338 } 339 340 float getProximityCorrectionCount() const { 341 return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount()); 342 } 343 344 float getEditCorrectionCount() const { 345 return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount()); 346 } 347 348 // Used to prune nodes 349 float getNormalizedCompoundDistance() const { 350 return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistance(); 351 } 352 353 // Used to prune nodes 354 float getNormalizedSpatialDistance() const { 355 return mDicNodeState.mDicNodeStateScoring.getSpatialDistance() 356 / static_cast<float>(getInputIndex(0) + 1); 357 } 358 359 // Used to prune nodes 360 float getCompoundDistance() const { 361 return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(); 362 } 363 364 // Used to prune nodes 365 float getCompoundDistance(const float languageWeight) const { 366 return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight); 367 } 368 369 // Note that "cost" means delta for "distance" that is weighted. 370 float getTotalPrevWordsLanguageCost() const { 371 return mDicNodeState.mDicNodeStateScoring.getTotalPrevWordsLanguageCost(); 372 } 373 374 // Used to commit input partially 375 int getPrevWordNodePos() const { 376 return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos(); 377 } 378 379 AK_FORCE_INLINE const int *getOutputWordBuf() const { 380 return mDicNodeState.mDicNodeStateOutput.mWordBuf; 381 } 382 383 int getPrevCodePointG(int pointerId) const { 384 return mDicNodeState.mDicNodeStateInput.getPrevCodePoint(pointerId); 385 } 386 387 // Whether the current codepoint can be an intentional omission, in which case the traversal 388 // algorithm will always check for a possible omission here. 389 bool canBeIntentionalOmission() const { 390 return isIntentionalOmissionCodePoint(getNodeCodePoint()); 391 } 392 393 // Whether the omission is so frequent that it should incur zero cost. 394 bool isZeroCostOmission() const { 395 // TODO: do not hardcode and read from header 396 return (getNodeCodePoint() == KEYCODE_SINGLE_QUOTE); 397 } 398 399 // TODO: remove 400 float getTerminalDiffCostG(int path) const { 401 return mDicNodeState.mDicNodeStateInput.getTerminalDiffCost(path); 402 } 403 404 ////////////////////// 405 // Temporary getter // 406 // TODO: Remove // 407 ////////////////////// 408 // TODO: Remove once touch path is merged into ProximityInfoState 409 int getNodeCodePoint() const { 410 return mDicNodeProperties.getNodeCodePoint(); 411 } 412 413 //////////////////////////////// 414 // Utils for cost calculation // 415 //////////////////////////////// 416 AK_FORCE_INLINE bool isSameNodeCodePoint(const DicNode *const dicNode) const { 417 return mDicNodeProperties.getNodeCodePoint() 418 == dicNode->mDicNodeProperties.getNodeCodePoint(); 419 } 420 421 // TODO: remove 422 // TODO: rename getNextInputIndex 423 int16_t getInputIndex(int pointerId) const { 424 return mDicNodeState.mDicNodeStateInput.getInputIndex(pointerId); 425 } 426 427 //////////////////////////////////// 428 // Getter of features for scoring // 429 //////////////////////////////////// 430 float getSpatialDistanceForScoring() const { 431 return mDicNodeState.mDicNodeStateScoring.getSpatialDistance(); 432 } 433 434 float getLanguageDistanceForScoring() const { 435 return mDicNodeState.mDicNodeStateScoring.getLanguageDistance(); 436 } 437 438 float getLanguageDistanceRatePerWordForScoring() const { 439 const float langDist = getLanguageDistanceForScoring(); 440 const float totalWordCount = 441 static_cast<float>(mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() + 1); 442 return langDist / totalWordCount; 443 } 444 445 float getRawLength() const { 446 return mDicNodeState.mDicNodeStateScoring.getRawLength(); 447 } 448 449 bool isLessThanOneErrorForScoring() const { 450 return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount() 451 + mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount() <= 1; 452 } 453 454 DoubleLetterLevel getDoubleLetterLevel() const { 455 return mDicNodeState.mDicNodeStateScoring.getDoubleLetterLevel(); 456 } 457 458 void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) { 459 mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel); 460 } 461 462 uint8_t getFlags() const { 463 return mDicNodeProperties.getFlags(); 464 } 465 466 int getAttributesPos() const { 467 return mDicNodeProperties.getAttributesPos(); 468 } 469 470 inline uint16_t getDepth() const { 471 return mDicNodeProperties.getDepth(); 472 } 473 474 AK_FORCE_INLINE void dump(const char *tag) const { 475#if DEBUG_DICT 476 DUMP_WORD_AND_SCORE(tag); 477#if DEBUG_DUMP_ERROR 478 mProfiler.dump(); 479#endif 480#endif 481 } 482 483 void setReleaseListener(DicNodeReleaseListener *releaseListener) { 484 mReleaseListener = releaseListener; 485 } 486 487 AK_FORCE_INLINE bool compare(const DicNode *right) { 488 if (!isUsed() && !right->isUsed()) { 489 // Compare pointer values here for stable comparison 490 return this > right; 491 } 492 if (!isUsed()) { 493 return true; 494 } 495 if (!right->isUsed()) { 496 return false; 497 } 498 const float diff = 499 right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance(); 500 static const float MIN_DIFF = 0.000001f; 501 if (diff > MIN_DIFF) { 502 return true; 503 } else if (diff < -MIN_DIFF) { 504 return false; 505 } 506 const int depth = getDepth(); 507 const int depthDiff = right->getDepth() - depth; 508 if (depthDiff != 0) { 509 return depthDiff > 0; 510 } 511 for (int i = 0; i < depth; ++i) { 512 const int codePoint = mDicNodeState.mDicNodeStateOutput.getCodePointAt(i); 513 const int rightCodePoint = right->mDicNodeState.mDicNodeStateOutput.getCodePointAt(i); 514 if (codePoint != rightCodePoint) { 515 return rightCodePoint > codePoint; 516 } 517 } 518 // Compare pointer values here for stable comparison 519 return this > right; 520 } 521 522 private: 523 DicNodeProperties mDicNodeProperties; 524 DicNodeState mDicNodeState; 525 // TODO: Remove 526 bool mIsCachedForNextSuggestion; 527 bool mIsUsed; 528 DicNodeReleaseListener *mReleaseListener; 529 530 AK_FORCE_INLINE int getTotalInputIndex() const { 531 int index = 0; 532 for (int i = 0; i < MAX_POINTER_COUNT_G; i++) { 533 index += mDicNodeState.mDicNodeStateInput.getInputIndex(i); 534 } 535 return index; 536 } 537 538 // Caveat: Must not be called outside Weighting 539 // This restriction is guaranteed by "friend" 540 AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost, 541 const bool doNormalization, const int inputSize, const bool isEditCorrection, 542 const bool isProximityCorrection) { 543 if (DEBUG_GEO_FULL) { 544 LOGI_SHOW_ADD_COST_PROP; 545 } 546 mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization, 547 inputSize, getTotalInputIndex(), isEditCorrection, isProximityCorrection); 548 } 549 550 // Caveat: Must not be called outside Weighting 551 // This restriction is guaranteed by "friend" 552 AK_FORCE_INLINE void forwardInputIndex(const int pointerId, const int count, 553 const bool overwritesPrevCodePointByNodeCodePoint) { 554 if (count == 0) { 555 return; 556 } 557 mDicNodeState.mDicNodeStateInput.forwardInputIndex(pointerId, count); 558 if (overwritesPrevCodePointByNodeCodePoint) { 559 mDicNodeState.mDicNodeStateInput.setPrevCodePoint(0, getNodeCodePoint()); 560 } 561 } 562 563 AK_FORCE_INLINE void updateInputIndexG(DicNode_InputStateG *inputStateG) { 564 mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId, 565 inputStateG->mInputIndex, inputStateG->mPrevCodePoint, 566 inputStateG->mTerminalDiffCost, inputStateG->mRawLength); 567 mDicNodeState.mDicNodeStateScoring.addRawLength(inputStateG->mRawLength); 568 mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(inputStateG->mDoubleLetterLevel); 569 } 570}; 571} // namespace latinime 572#endif // LATINIME_DIC_NODE_H 573