ver4_patricia_trie_policy.cpp revision 4ce480d5ce2d47f607448ce439aaf2cefba1bdd8
1/* 2 * Copyright (C) 2013, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h" 18 19#include <vector> 20 21#include "suggest/core/dicnode/dic_node.h" 22#include "suggest/core/dicnode/dic_node_vector.h" 23#include "suggest/core/dictionary/word_property.h" 24#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h" 25#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h" 26#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" 27#include "suggest/policyimpl/dictionary/utils/probability_utils.h" 28 29namespace latinime { 30 31// Note that there are corresponding definitions in Java side in BinaryDictionaryTests and 32// BinaryDictionaryDecayingTests. 33const char *const Ver4PatriciaTriePolicy::UNIGRAM_COUNT_QUERY = "UNIGRAM_COUNT"; 34const char *const Ver4PatriciaTriePolicy::BIGRAM_COUNT_QUERY = "BIGRAM_COUNT"; 35const char *const Ver4PatriciaTriePolicy::MAX_UNIGRAM_COUNT_QUERY = "MAX_UNIGRAM_COUNT"; 36const char *const Ver4PatriciaTriePolicy::MAX_BIGRAM_COUNT_QUERY = "MAX_BIGRAM_COUNT"; 37const int Ver4PatriciaTriePolicy::MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS = 1024; 38const int Ver4PatriciaTriePolicy::MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS = 39 Ver4DictConstants::MAX_DICTIONARY_SIZE - MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS; 40 41void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNode, 42 DicNodeVector *const childDicNodes) const { 43 if (!dicNode->hasChildren()) { 44 return; 45 } 46 DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); 47 readingHelper.initWithPtNodeArrayPos(dicNode->getChildrenPtNodeArrayPos()); 48 while (!readingHelper.isEnd()) { 49 const PtNodeParams ptNodeParams = readingHelper.getPtNodeParams(); 50 if (!ptNodeParams.isValid()) { 51 break; 52 } 53 bool isTerminal = ptNodeParams.isTerminal() && !ptNodeParams.isDeleted(); 54 if (isTerminal && mHeaderPolicy->isDecayingDict()) { 55 // A DecayingDict may have a terminal PtNode that has a terminal DicNode whose 56 // probability is NOT_A_PROBABILITY. In such case, we don't want to treat it as a 57 // valid terminal DicNode. 58 isTerminal = ptNodeParams.getProbability() != NOT_A_PROBABILITY; 59 } 60 childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getHeadPos(), 61 ptNodeParams.getChildrenPos(), ptNodeParams.getProbability(), isTerminal, 62 ptNodeParams.hasChildren(), 63 ptNodeParams.isBlacklisted() 64 || ptNodeParams.isNotAWord() /* isBlacklistedOrNotAWord */, 65 ptNodeParams.getCodePointCount(), ptNodeParams.getCodePoints()); 66 readingHelper.readNextSiblingNode(ptNodeParams); 67 } 68 if (readingHelper.isError()) { 69 mIsCorrupted = true; 70 AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); 71 } 72} 73 74int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( 75 const int ptNodePos, const int maxCodePointCount, int *const outCodePoints, 76 int *const outUnigramProbability) const { 77 DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); 78 readingHelper.initWithPtNodePos(ptNodePos); 79 const int codePointCount = readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount( 80 maxCodePointCount, outCodePoints, outUnigramProbability); 81 if (readingHelper.isError()) { 82 mIsCorrupted = true; 83 AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount()."); 84 } 85 return codePointCount; 86} 87 88int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const int *const inWord, 89 const int length, const bool forceLowerCaseSearch) const { 90 DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); 91 readingHelper.initWithPtNodeArrayPos(getRootPosition()); 92 const int ptNodePos = 93 readingHelper.getTerminalPtNodePositionOfWord(inWord, length, forceLowerCaseSearch); 94 if (readingHelper.isError()) { 95 mIsCorrupted = true; 96 AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); 97 } 98 return ptNodePos; 99} 100 101int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, 102 const int bigramProbability) const { 103 if (mHeaderPolicy->isDecayingDict()) { 104 // Both probabilities are encoded. Decode them and get probability. 105 return ForgettingCurveUtils::getProbability(unigramProbability, bigramProbability); 106 } else { 107 if (unigramProbability == NOT_A_PROBABILITY) { 108 return NOT_A_PROBABILITY; 109 } else if (bigramProbability == NOT_A_PROBABILITY) { 110 return ProbabilityUtils::backoff(unigramProbability); 111 } else { 112 // bigramProbability is a bigram probability delta. 113 return ProbabilityUtils::computeProbabilityForBigram(unigramProbability, 114 bigramProbability); 115 } 116 } 117} 118 119int Ver4PatriciaTriePolicy::getUnigramProbabilityOfPtNode(const int ptNodePos) const { 120 if (ptNodePos == NOT_A_DICT_POS) { 121 return NOT_A_PROBABILITY; 122 } 123 const PtNodeParams ptNodeParams(mNodeReader.fetchNodeInfoInBufferFromPtNodePos(ptNodePos)); 124 if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { 125 return NOT_A_PROBABILITY; 126 } 127 return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); 128} 129 130int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { 131 if (ptNodePos == NOT_A_DICT_POS) { 132 return NOT_A_DICT_POS; 133 } 134 const PtNodeParams ptNodeParams(mNodeReader.fetchNodeInfoInBufferFromPtNodePos(ptNodePos)); 135 if (ptNodeParams.isDeleted()) { 136 return NOT_A_DICT_POS; 137 } 138 return mBuffers->getShortcutDictContent()->getShortcutListHeadPos( 139 ptNodeParams.getTerminalId()); 140} 141 142int Ver4PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) const { 143 if (ptNodePos == NOT_A_DICT_POS) { 144 return NOT_A_DICT_POS; 145 } 146 const PtNodeParams ptNodeParams(mNodeReader.fetchNodeInfoInBufferFromPtNodePos(ptNodePos)); 147 if (ptNodeParams.isDeleted()) { 148 return NOT_A_DICT_POS; 149 } 150 return mBuffers->getBigramDictContent()->getBigramListHeadPos( 151 ptNodeParams.getTerminalId()); 152} 153 154bool Ver4PatriciaTriePolicy::addUnigramWord(const int *const word, const int length, 155 const int probability, const int *const shortcutTargetCodePoints, const int shortcutLength, 156 const int shortcutProbability, const bool isNotAWord, const bool isBlacklisted, 157 const int timestamp) { 158 if (!mBuffers->isUpdatable()) { 159 AKLOGI("Warning: addUnigramWord() is called for non-updatable dictionary."); 160 return false; 161 } 162 if (mDictBuffer->getTailPosition() >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS) { 163 AKLOGE("The dictionary is too large to dynamically update. Dictionary size: %d", 164 mDictBuffer->getTailPosition()); 165 return false; 166 } 167 if (length > MAX_WORD_LENGTH) { 168 AKLOGE("The word is too long to insert to the dictionary, length: %d", length); 169 return false; 170 } 171 if (shortcutLength > MAX_WORD_LENGTH) { 172 AKLOGE("The shortcutTarget is too long to insert to the dictionary, length: %d", 173 shortcutLength); 174 return false; 175 } 176 DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); 177 readingHelper.initWithPtNodeArrayPos(getRootPosition()); 178 bool addedNewUnigram = false; 179 if (mUpdatingHelper.addUnigramWord(&readingHelper, word, length, probability, isNotAWord, 180 isBlacklisted, timestamp, &addedNewUnigram)) { 181 if (addedNewUnigram) { 182 mUnigramCount++; 183 } 184 if (shortcutLength > 0) { 185 // Add shortcut target. 186 const int wordPos = getTerminalPtNodePositionOfWord(word, length, 187 false /* forceLowerCaseSearch */); 188 if (wordPos == NOT_A_DICT_POS) { 189 AKLOGE("Cannot find terminal PtNode position to add shortcut target."); 190 return false; 191 } 192 if (!mUpdatingHelper.addShortcutTarget(wordPos, shortcutTargetCodePoints, 193 shortcutLength, shortcutProbability)) { 194 AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %d, probability: %d", 195 wordPos, shortcutLength, shortcutProbability); 196 return false; 197 } 198 } 199 return true; 200 } else { 201 return false; 202 } 203} 204 205bool Ver4PatriciaTriePolicy::addBigramWords(const int *const word0, const int length0, 206 const int *const word1, const int length1, const int probability, 207 const int timestamp) { 208 if (!mBuffers->isUpdatable()) { 209 AKLOGI("Warning: addBigramWords() is called for non-updatable dictionary."); 210 return false; 211 } 212 if (mDictBuffer->getTailPosition() >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS) { 213 AKLOGE("The dictionary is too large to dynamically update. Dictionary size: %d", 214 mDictBuffer->getTailPosition()); 215 return false; 216 } 217 if (length0 > MAX_WORD_LENGTH || length1 > MAX_WORD_LENGTH) { 218 AKLOGE("Either src word or target word is too long to insert the bigram to the dictionary. " 219 "length0: %d, length1: %d", length0, length1); 220 return false; 221 } 222 const int word0Pos = getTerminalPtNodePositionOfWord(word0, length0, 223 false /* forceLowerCaseSearch */); 224 if (word0Pos == NOT_A_DICT_POS) { 225 return false; 226 } 227 const int word1Pos = getTerminalPtNodePositionOfWord(word1, length1, 228 false /* forceLowerCaseSearch */); 229 if (word1Pos == NOT_A_DICT_POS) { 230 return false; 231 } 232 bool addedNewBigram = false; 233 if (mUpdatingHelper.addBigramWords(word0Pos, word1Pos, probability, timestamp, 234 &addedNewBigram)) { 235 if (addedNewBigram) { 236 mBigramCount++; 237 } 238 return true; 239 } else { 240 return false; 241 } 242} 243 244bool Ver4PatriciaTriePolicy::removeBigramWords(const int *const word0, const int length0, 245 const int *const word1, const int length1) { 246 if (!mBuffers->isUpdatable()) { 247 AKLOGI("Warning: addBigramWords() is called for non-updatable dictionary."); 248 return false; 249 } 250 if (mDictBuffer->getTailPosition() >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS) { 251 AKLOGE("The dictionary is too large to dynamically update. Dictionary size: %d", 252 mDictBuffer->getTailPosition()); 253 return false; 254 } 255 if (length0 > MAX_WORD_LENGTH || length1 > MAX_WORD_LENGTH) { 256 AKLOGE("Either src word or target word is too long to remove the bigram to from the " 257 "dictionary. length0: %d, length1: %d", length0, length1); 258 return false; 259 } 260 const int word0Pos = getTerminalPtNodePositionOfWord(word0, length0, 261 false /* forceLowerCaseSearch */); 262 if (word0Pos == NOT_A_DICT_POS) { 263 return false; 264 } 265 const int word1Pos = getTerminalPtNodePositionOfWord(word1, length1, 266 false /* forceLowerCaseSearch */); 267 if (word1Pos == NOT_A_DICT_POS) { 268 return false; 269 } 270 if (mUpdatingHelper.removeBigramWords(word0Pos, word1Pos)) { 271 mBigramCount--; 272 return true; 273 } else { 274 return false; 275 } 276} 277 278void Ver4PatriciaTriePolicy::flush(const char *const filePath) { 279 if (!mBuffers->isUpdatable()) { 280 AKLOGI("Warning: flush() is called for non-updatable dictionary. filePath: %s", filePath); 281 return; 282 } 283 if (!mWritingHelper.writeToDictFile(filePath, mUnigramCount, mBigramCount)) { 284 AKLOGE("Cannot flush the dictionary to file."); 285 mIsCorrupted = true; 286 } 287} 288 289void Ver4PatriciaTriePolicy::flushWithGC(const char *const filePath) { 290 if (!mBuffers->isUpdatable()) { 291 AKLOGI("Warning: flushWithGC() is called for non-updatable dictionary."); 292 return; 293 } 294 if (!mWritingHelper.writeToDictFileWithGC(getRootPosition(), filePath)) { 295 AKLOGE("Cannot flush the dictionary to file with GC."); 296 mIsCorrupted = true; 297 } 298} 299 300bool Ver4PatriciaTriePolicy::needsToRunGC(const bool mindsBlockByGC) const { 301 if (!mBuffers->isUpdatable()) { 302 AKLOGI("Warning: needsToRunGC() is called for non-updatable dictionary."); 303 return false; 304 } 305 if (mBuffers->isNearSizeLimit()) { 306 // Additional buffer size is near the limit. 307 return true; 308 } else if (mHeaderPolicy->getExtendedRegionSize() + mDictBuffer->getUsedAdditionalBufferSize() 309 > Ver4DictConstants::MAX_DICT_EXTENDED_REGION_SIZE) { 310 // Total extended region size of the trie exceeds the limit. 311 return true; 312 } else if (mDictBuffer->getTailPosition() >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS 313 && mDictBuffer->getUsedAdditionalBufferSize() > 0) { 314 // Needs to reduce dictionary size. 315 return true; 316 } else if (mHeaderPolicy->isDecayingDict()) { 317 return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mUnigramCount, mBigramCount, 318 mHeaderPolicy); 319 } 320 return false; 321} 322 323void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int queryLength, 324 char *const outResult, const int maxResultLength) { 325 const int compareLength = queryLength + 1 /* terminator */; 326 if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) { 327 snprintf(outResult, maxResultLength, "%d", mUnigramCount); 328 } else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) { 329 snprintf(outResult, maxResultLength, "%d", mBigramCount); 330 } else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) { 331 snprintf(outResult, maxResultLength, "%d", 332 mHeaderPolicy->isDecayingDict() ? 333 ForgettingCurveUtils::getUnigramCountHardLimit( 334 mHeaderPolicy->getMaxUnigramCount()) : 335 static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE)); 336 } else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) { 337 snprintf(outResult, maxResultLength, "%d", 338 mHeaderPolicy->isDecayingDict() ? 339 ForgettingCurveUtils::getBigramCountHardLimit( 340 mHeaderPolicy->getMaxBigramCount()) : 341 static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE)); 342 } 343} 344 345const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const codePoints, 346 const int codePointCount) const { 347 const int ptNodePos = getTerminalPtNodePositionOfWord(codePoints, codePointCount, 348 false /* forceLowerCaseSearch */); 349 if (ptNodePos == NOT_A_DICT_POS) { 350 AKLOGE("getWordProperty is called for invalid word."); 351 return WordProperty(); 352 } 353 const PtNodeParams ptNodeParams = mNodeReader.fetchNodeInfoInBufferFromPtNodePos(ptNodePos); 354 std::vector<int> codePointVector(ptNodeParams.getCodePoints(), 355 ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount()); 356 const ProbabilityEntry probabilityEntry = 357 mBuffers->getProbabilityDictContent()->getProbabilityEntry( 358 ptNodeParams.getTerminalId()); 359 const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); 360 // Fetch bigram information. 361 std::vector<WordProperty::BigramProperty> bigrams; 362 const int bigramListPos = getBigramsPositionOfPtNode(ptNodePos); 363 if (bigramListPos != NOT_A_DICT_POS) { 364 int bigramWord1CodePoints[MAX_WORD_LENGTH]; 365 const BigramDictContent *const bigramDictContent = mBuffers->getBigramDictContent(); 366 const TerminalPositionLookupTable *const terminalPositionLookupTable = 367 mBuffers->getTerminalPositionLookupTable(); 368 bool hasNext = true; 369 int readingPos = bigramListPos; 370 while (hasNext) { 371 const BigramEntry bigramEntry = 372 bigramDictContent->getBigramEntryAndAdvancePosition(&readingPos); 373 hasNext = bigramEntry.hasNext(); 374 const int word1TerminalId = bigramEntry.getTargetTerminalId(); 375 const int word1TerminalPtNodePos = 376 terminalPositionLookupTable->getTerminalPtNodePosition(word1TerminalId); 377 if (word1TerminalPtNodePos == NOT_A_DICT_POS) { 378 continue; 379 } 380 // Word (unigram) probability 381 int word1Probability = NOT_A_PROBABILITY; 382 const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( 383 word1TerminalPtNodePos, MAX_WORD_LENGTH, bigramWord1CodePoints, 384 &word1Probability); 385 std::vector<int> word1(bigramWord1CodePoints, 386 bigramWord1CodePoints + codePointCount); 387 const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo(); 388 const int probability = bigramEntry.hasHistoricalInfo() ? 389 ForgettingCurveUtils::decodeProbability( 390 bigramEntry.getHistoricalInfo(), mHeaderPolicy) : 391 bigramEntry.getProbability(); 392 bigrams.push_back(WordProperty::BigramProperty(&word1, probability, 393 historicalInfo->getTimeStamp(), historicalInfo->getLevel(), 394 historicalInfo->getCount())); 395 } 396 } 397 // Fetch shortcut information. 398 std::vector<WordProperty::ShortcutProperty> shortcuts; 399 int shortcutPos = getShortcutPositionOfPtNode(ptNodePos); 400 if (shortcutPos != NOT_A_DICT_POS) { 401 int shortcutTarget[MAX_WORD_LENGTH]; 402 const ShortcutDictContent *const shortcutDictContent = 403 mBuffers->getShortcutDictContent(); 404 bool hasNext = true; 405 while (hasNext) { 406 int shortcutTargetLength = 0; 407 int shortcutProbability = NOT_A_PROBABILITY; 408 shortcutDictContent->getShortcutEntryAndAdvancePosition(MAX_WORD_LENGTH, shortcutTarget, 409 &shortcutTargetLength, &shortcutProbability, &hasNext, &shortcutPos); 410 std::vector<int> target(shortcutTarget, shortcutTarget + shortcutTargetLength); 411 shortcuts.push_back(WordProperty::ShortcutProperty(&target, shortcutProbability)); 412 } 413 } 414 return WordProperty(&codePointVector, ptNodeParams.isNotAWord(), 415 ptNodeParams.isBlacklisted(), ptNodeParams.hasBigrams(), 416 ptNodeParams.hasShortcutTargets(), ptNodeParams.getProbability(), 417 historicalInfo->getTimeStamp(), historicalInfo->getLevel(), 418 historicalInfo->getCount(), &bigrams, &shortcuts); 419} 420 421int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints) { 422 if (token == 0) { 423 mTerminalPtNodePositionsForIteratingWords.clear(); 424 DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy( 425 &mTerminalPtNodePositionsForIteratingWords); 426 DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); 427 readingHelper.initWithPtNodeArrayPos(getRootPosition()); 428 readingHelper.traverseAllPtNodesInPostorderDepthFirstManner(&traversePolicy); 429 } 430 const int terminalPtNodePositionsVectorSize = 431 static_cast<int>(mTerminalPtNodePositionsForIteratingWords.size()); 432 if (token < 0 || token >= terminalPtNodePositionsVectorSize) { 433 AKLOGE("Given token %d is invalid.", token); 434 return 0; 435 } 436 const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; 437 int unigramProbability = NOT_A_PROBABILITY; 438 getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos, MAX_WORD_LENGTH, 439 outCodePoints, &unigramProbability); 440 const int nextToken = token + 1; 441 if (nextToken >= terminalPtNodePositionsVectorSize) { 442 // All words have been iterated. 443 mTerminalPtNodePositionsForIteratingWords.clear(); 444 return 0; 445 } 446 return nextToken; 447} 448 449} // namespace latinime 450