patricia_trie_policy.cpp revision 1229879e7c5892e818ab53b3c2162a158cc5e177
1/* 2 * Copyright (C) 2013, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 18#include "suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h" 19 20#include "defines.h" 21#include "suggest/core/dicnode/dic_node.h" 22#include "suggest/core/dicnode/dic_node_vector.h" 23#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" 24#include "suggest/core/session/prev_words_info.h" 25#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h" 26#include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h" 27#include "suggest/policyimpl/dictionary/utils/probability_utils.h" 28#include "utils/char_utils.h" 29 30namespace latinime { 31 32void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNode, 33 DicNodeVector *const childDicNodes) const { 34 if (!dicNode->hasChildren()) { 35 return; 36 } 37 int nextPos = dicNode->getChildrenPtNodeArrayPos(); 38 if (nextPos < 0 || nextPos >= mDictBufferSize) { 39 AKLOGE("Children PtNode array position is invalid. pos: %d, dict size: %d", 40 nextPos, mDictBufferSize); 41 mIsCorrupted = true; 42 ASSERT(false); 43 return; 44 } 45 const int childCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( 46 mDictRoot, &nextPos); 47 for (int i = 0; i < childCount; i++) { 48 if (nextPos < 0 || nextPos >= mDictBufferSize) { 49 AKLOGE("Child PtNode position is invalid. pos: %d, dict size: %d, childCount: %d / %d", 50 nextPos, mDictBufferSize, i, childCount); 51 mIsCorrupted = true; 52 ASSERT(false); 53 return; 54 } 55 nextPos = createAndGetLeavingChildNode(dicNode, nextPos, childDicNodes); 56 } 57} 58 59// This retrieves code points and the probability of the word by its terminal position. 60// Due to the fact that words are ordered in the dictionary in a strict breadth-first order, 61// it is possible to check for this with advantageous complexity. For each PtNode array, we search 62// for PtNodes with children and compare the children position with the position we look for. 63// When we shoot the position we look for, it means the word we look for is in the children 64// of the previous PtNode. The only tricky part is the fact that if we arrive at the end of a 65// PtNode array with the last PtNode's children position still less than what we are searching for, 66// we must descend the last PtNode's children (for example, if the word we are searching for starts 67// with a z, it's the last PtNode of the root array, so all children addresses will be smaller 68// than the position we look for, and we have to descend the z PtNode). 69/* Parameters : 70 * ptNodePos: the byte position of the terminal PtNode of the word we are searching for (this is 71 * what is stored as the "bigram position" in each bigram) 72 * outCodePoints: an array to write the found word, with MAX_WORD_LENGTH size. 73 * outUnigramProbability: a pointer to an int to write the probability into. 74 * Return value : the code point count, of 0 if the word was not found. 75 */ 76// TODO: Split this function to be more readable 77int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( 78 const int ptNodePos, const int maxCodePointCount, int *const outCodePoints, 79 int *const outUnigramProbability) const { 80 int pos = getRootPosition(); 81 int wordPos = 0; 82 // One iteration of the outer loop iterates through PtNode arrays. As stated above, we will 83 // only traverse PtNodes that are actually a part of the terminal we are searching, so each 84 // time we enter this loop we are one depth level further than last time. 85 // The only reason we count PtNodes is because we want to reduce the probability of infinite 86 // looping in case there is a bug. Since we know there is an upper bound to the depth we are 87 // supposed to traverse, it does not hurt to count iterations. 88 for (int loopCount = maxCodePointCount; loopCount > 0; --loopCount) { 89 int lastCandidatePtNodePos = 0; 90 // Let's loop through PtNodes in this PtNode array searching for either the terminal 91 // or one of its ascendants. 92 if (pos < 0 || pos >= mDictBufferSize) { 93 AKLOGE("PtNode array position is invalid. pos: %d, dict size: %d", 94 pos, mDictBufferSize); 95 mIsCorrupted = true; 96 ASSERT(false); 97 *outUnigramProbability = NOT_A_PROBABILITY; 98 return 0; 99 } 100 for (int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( 101 mDictRoot, &pos); ptNodeCount > 0; --ptNodeCount) { 102 const int startPos = pos; 103 if (pos < 0 || pos >= mDictBufferSize) { 104 AKLOGE("PtNode position is invalid. pos: %d, dict size: %d", pos, mDictBufferSize); 105 mIsCorrupted = true; 106 ASSERT(false); 107 *outUnigramProbability = NOT_A_PROBABILITY; 108 return 0; 109 } 110 const PatriciaTrieReadingUtils::NodeFlags flags = 111 PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); 112 const int character = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( 113 mDictRoot, &pos); 114 if (ptNodePos == startPos) { 115 // We found the position. Copy the rest of the code points in the buffer and return 116 // the length. 117 outCodePoints[wordPos] = character; 118 if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { 119 int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( 120 mDictRoot, &pos); 121 // We count code points in order to avoid infinite loops if the file is broken 122 // or if there is some other bug 123 int charCount = maxCodePointCount; 124 while (NOT_A_CODE_POINT != nextChar && --charCount > 0) { 125 outCodePoints[++wordPos] = nextChar; 126 nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( 127 mDictRoot, &pos); 128 } 129 } 130 *outUnigramProbability = 131 PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, 132 &pos); 133 return ++wordPos; 134 } 135 // We need to skip past this PtNode, so skip any remaining code points after the 136 // first and possibly the probability. 137 if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { 138 PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, &pos); 139 } 140 if (PatriciaTrieReadingUtils::isTerminal(flags)) { 141 PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos); 142 } 143 // The fact that this PtNode has children is very important. Since we already know 144 // that this PtNode does not match, if it has no children we know it is irrelevant 145 // to what we are searching for. 146 const bool hasChildren = PatriciaTrieReadingUtils::hasChildrenInFlags(flags); 147 // We will write in `found' whether we have passed the children position we are 148 // searching for. For example if we search for "beer", the children of b are less 149 // than the address we are searching for and the children of c are greater. When we 150 // come here for c, we realize this is too big, and that we should descend b. 151 bool found; 152 if (hasChildren) { 153 int currentPos = pos; 154 // Here comes the tricky part. First, read the children position. 155 const int childrenPos = PatriciaTrieReadingUtils 156 ::readChildrenPositionAndAdvancePosition(mDictRoot, flags, ¤tPos); 157 if (childrenPos > ptNodePos) { 158 // If the children pos is greater than the position, it means the previous 159 // PtNode, which position is stored in lastCandidatePtNodePos, was the right 160 // one. 161 found = true; 162 } else if (1 >= ptNodeCount) { 163 // However if we are on the LAST PtNode of this array, and we have NOT shot the 164 // position we should descend THIS PtNode. So we trick the 165 // lastCandidatePtNodePos so that we will descend this PtNode, not the previous 166 // one. 167 lastCandidatePtNodePos = startPos; 168 found = true; 169 } else { 170 // Else, we should continue looking. 171 found = false; 172 } 173 } else { 174 // Even if we don't have children here, we could still be on the last PtNode of 175 // this array. If this is the case, we should descend the last PtNode that had 176 // children, and their position is already in lastCandidatePtNodePos. 177 found = (1 >= ptNodeCount); 178 } 179 180 if (found) { 181 // Okay, we found the PtNode we should descend. Its position is in 182 // the lastCandidatePtNodePos variable, so we just re-read it. 183 if (0 != lastCandidatePtNodePos) { 184 const PatriciaTrieReadingUtils::NodeFlags lastFlags = 185 PatriciaTrieReadingUtils::getFlagsAndAdvancePosition( 186 mDictRoot, &lastCandidatePtNodePos); 187 const int lastChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( 188 mDictRoot, &lastCandidatePtNodePos); 189 // We copy all the characters in this PtNode to the buffer 190 outCodePoints[wordPos] = lastChar; 191 if (PatriciaTrieReadingUtils::hasMultipleChars(lastFlags)) { 192 int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( 193 mDictRoot, &lastCandidatePtNodePos); 194 int charCount = maxCodePointCount; 195 while (-1 != nextChar && --charCount > 0) { 196 outCodePoints[++wordPos] = nextChar; 197 nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( 198 mDictRoot, &lastCandidatePtNodePos); 199 } 200 } 201 ++wordPos; 202 // Now we only need to branch to the children address. Skip the probability if 203 // it's there, read pos, and break to resume the search at pos. 204 if (PatriciaTrieReadingUtils::isTerminal(lastFlags)) { 205 PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, 206 &lastCandidatePtNodePos); 207 } 208 pos = PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( 209 mDictRoot, lastFlags, &lastCandidatePtNodePos); 210 break; 211 } else { 212 // Here is a little tricky part: we come here if we found out that all children 213 // addresses in this PtNode are bigger than the address we are searching for. 214 // Should we conclude the word is not in the dictionary? No! It could still be 215 // one of the remaining PtNodes in this array, so we have to keep looking in 216 // this array until we find it (or we realize it's not there either, in which 217 // case it's actually not in the dictionary). Pass the end of this PtNode, 218 // ready to start the next one. 219 if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { 220 PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( 221 mDictRoot, flags, &pos); 222 } 223 if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { 224 mShortcutListPolicy.skipAllShortcuts(&pos); 225 } 226 if (PatriciaTrieReadingUtils::hasBigrams(flags)) { 227 mBigramListPolicy.skipAllBigrams(&pos); 228 } 229 } 230 } else { 231 // If we did not find it, we should record the last children address for the next 232 // iteration. 233 if (hasChildren) lastCandidatePtNodePos = startPos; 234 // Now skip the end of this PtNode (children pos and the attributes if any) so that 235 // our pos is after the end of this PtNode, at the start of the next one. 236 if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { 237 PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( 238 mDictRoot, flags, &pos); 239 } 240 if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { 241 mShortcutListPolicy.skipAllShortcuts(&pos); 242 } 243 if (PatriciaTrieReadingUtils::hasBigrams(flags)) { 244 mBigramListPolicy.skipAllBigrams(&pos); 245 } 246 } 247 248 } 249 } 250 // If we have looked through all the PtNodes and found no match, the ptNodePos is 251 // not the position of a terminal in this dictionary. 252 return 0; 253} 254 255// This function gets the position of the terminal PtNode of the exact matching word in the 256// dictionary. If no match is found, it returns NOT_A_DICT_POS. 257int PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const int *const inWord, 258 const int length, const bool forceLowerCaseSearch) const { 259 DynamicPtReadingHelper readingHelper(&mPtNodeReader, &mPtNodeArrayReader); 260 readingHelper.initWithPtNodeArrayPos(getRootPosition()); 261 const int ptNodePos = 262 readingHelper.getTerminalPtNodePositionOfWord(inWord, length, forceLowerCaseSearch); 263 if (readingHelper.isError()) { 264 mIsCorrupted = true; 265 AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); 266 } 267 return ptNodePos; 268} 269 270int PatriciaTriePolicy::getProbability(const int unigramProbability, 271 const int bigramProbability) const { 272 // Due to space constraints, the probability for bigrams is approximate - the lower the unigram 273 // probability, the worse the precision. The theoritical maximum error in resulting probability 274 // is 8 - although in the practice it's never bigger than 3 or 4 in very bad cases. This means 275 // that sometimes, we'll see some bigrams interverted here, but it can't get too bad. 276 if (unigramProbability == NOT_A_PROBABILITY) { 277 return NOT_A_PROBABILITY; 278 } else if (bigramProbability == NOT_A_PROBABILITY) { 279 return ProbabilityUtils::backoff(unigramProbability); 280 } else { 281 return ProbabilityUtils::computeProbabilityForBigram(unigramProbability, 282 bigramProbability); 283 } 284} 285 286int PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo, 287 const int ptNodePos) const { 288 if (ptNodePos == NOT_A_DICT_POS) { 289 return NOT_A_PROBABILITY; 290 } 291 const PtNodeParams ptNodeParams = 292 mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); 293 if (ptNodeParams.isNotAWord() || ptNodeParams.isBlacklisted()) { 294 // If this is not a word, or if it's a blacklisted entry, it should behave as 295 // having no probability outside of the suggestion process (where it should be used 296 // for shortcuts). 297 return NOT_A_PROBABILITY; 298 } 299 if (prevWordsInfo) { 300 BinaryDictionaryBigramsIterator bigramsIt = 301 prevWordsInfo->getBigramsIteratorForPrediction(this /* dictStructurePolicy */); 302 while (bigramsIt.hasNext()) { 303 bigramsIt.next(); 304 if (bigramsIt.getBigramPos() == ptNodePos 305 && bigramsIt.getProbability() != NOT_A_PROBABILITY) { 306 return getProbability(ptNodeParams.getProbability(), bigramsIt.getProbability()); 307 } 308 } 309 return NOT_A_PROBABILITY; 310 } 311 return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); 312} 313 314int PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { 315 if (ptNodePos == NOT_A_DICT_POS) { 316 return NOT_A_DICT_POS; 317 } 318 return mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos).getShortcutPos(); 319} 320 321BinaryDictionaryBigramsIterator PatriciaTriePolicy::getBigramsIteratorOfPtNode( 322 const int ptNodePos) const { 323 const int bigramsPosition = getBigramsPositionOfPtNode(ptNodePos); 324 return BinaryDictionaryBigramsIterator(&mBigramListPolicy, bigramsPosition); 325} 326 327int PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) const { 328 if (ptNodePos == NOT_A_DICT_POS) { 329 return NOT_A_DICT_POS; 330 } 331 return mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos).getBigramsPos(); 332} 333 334int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNode, 335 const int ptNodePos, DicNodeVector *childDicNodes) const { 336 PatriciaTrieReadingUtils::NodeFlags flags; 337 int mergedNodeCodePointCount = 0; 338 int mergedNodeCodePoints[MAX_WORD_LENGTH]; 339 int probability = NOT_A_PROBABILITY; 340 int childrenPos = NOT_A_DICT_POS; 341 int shortcutPos = NOT_A_DICT_POS; 342 int bigramPos = NOT_A_DICT_POS; 343 int siblingPos = NOT_A_DICT_POS; 344 PatriciaTrieReadingUtils::readPtNodeInfo(mDictRoot, ptNodePos, getShortcutsStructurePolicy(), 345 &mBigramListPolicy, &flags, &mergedNodeCodePointCount, mergedNodeCodePoints, 346 &probability, &childrenPos, &shortcutPos, &bigramPos, &siblingPos); 347 // Skip PtNodes don't start with Unicode code point because they represent non-word information. 348 if (CharUtils::isInUnicodeSpace(mergedNodeCodePoints[0])) { 349 childDicNodes->pushLeavingChild(dicNode, ptNodePos, childrenPos, probability, 350 PatriciaTrieReadingUtils::isTerminal(flags), 351 PatriciaTrieReadingUtils::hasChildrenInFlags(flags), 352 PatriciaTrieReadingUtils::isBlacklisted(flags) 353 || PatriciaTrieReadingUtils::isNotAWord(flags), 354 mergedNodeCodePointCount, mergedNodeCodePoints); 355 } 356 return siblingPos; 357} 358 359const WordProperty PatriciaTriePolicy::getWordProperty(const int *const codePoints, 360 const int codePointCount) const { 361 const int ptNodePos = getTerminalPtNodePositionOfWord(codePoints, codePointCount, 362 false /* forceLowerCaseSearch */); 363 if (ptNodePos == NOT_A_DICT_POS) { 364 AKLOGE("getWordProperty was called for invalid word."); 365 return WordProperty(); 366 } 367 const PtNodeParams ptNodeParams = 368 mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); 369 std::vector<int> codePointVector(ptNodeParams.getCodePoints(), 370 ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount()); 371 // Fetch bigram information. 372 std::vector<BigramProperty> bigrams; 373 const int bigramListPos = getBigramsPositionOfPtNode(ptNodePos); 374 int bigramWord1CodePoints[MAX_WORD_LENGTH]; 375 BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramListPos); 376 while (bigramsIt.hasNext()) { 377 // Fetch the next bigram information and forward the iterator. 378 bigramsIt.next(); 379 // Skip the entry if the entry has been deleted. This never happens for ver2 dicts. 380 if (bigramsIt.getBigramPos() != NOT_A_DICT_POS) { 381 int word1Probability = NOT_A_PROBABILITY; 382 const int word1CodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( 383 bigramsIt.getBigramPos(), MAX_WORD_LENGTH, bigramWord1CodePoints, 384 &word1Probability); 385 const std::vector<int> word1(bigramWord1CodePoints, 386 bigramWord1CodePoints + word1CodePointCount); 387 const int probability = getProbability(word1Probability, bigramsIt.getProbability()); 388 bigrams.emplace_back(&word1, probability, 389 NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */); 390 } 391 } 392 // Fetch shortcut information. 393 std::vector<UnigramProperty::ShortcutProperty> shortcuts; 394 int shortcutPos = getShortcutPositionOfPtNode(ptNodePos); 395 if (shortcutPos != NOT_A_DICT_POS) { 396 int shortcutTargetCodePoints[MAX_WORD_LENGTH]; 397 ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mDictRoot, &shortcutPos); 398 bool hasNext = true; 399 while (hasNext) { 400 const ShortcutListReadingUtils::ShortcutFlags shortcutFlags = 401 ShortcutListReadingUtils::getFlagsAndForwardPointer(mDictRoot, &shortcutPos); 402 hasNext = ShortcutListReadingUtils::hasNext(shortcutFlags); 403 const int shortcutTargetLength = ShortcutListReadingUtils::readShortcutTarget( 404 mDictRoot, MAX_WORD_LENGTH, shortcutTargetCodePoints, &shortcutPos); 405 const std::vector<int> shortcutTarget(shortcutTargetCodePoints, 406 shortcutTargetCodePoints + shortcutTargetLength); 407 const int shortcutProbability = 408 ShortcutListReadingUtils::getProbabilityFromFlags(shortcutFlags); 409 shortcuts.emplace_back(&shortcutTarget, shortcutProbability); 410 } 411 } 412 const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(), 413 ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(), 414 NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts); 415 return WordProperty(&codePointVector, &unigramProperty, &bigrams); 416} 417 418int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints, 419 int *const outCodePointCount) { 420 *outCodePointCount = 0; 421 if (token == 0) { 422 // Start iterating the dictionary. 423 mTerminalPtNodePositionsForIteratingWords.clear(); 424 DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy( 425 &mTerminalPtNodePositionsForIteratingWords); 426 DynamicPtReadingHelper readingHelper(&mPtNodeReader, &mPtNodeArrayReader); 427 readingHelper.initWithPtNodeArrayPos(getRootPosition()); 428 readingHelper.traverseAllPtNodesInPostorderDepthFirstManner(&traversePolicy); 429 } 430 const int terminalPtNodePositionsVectorSize = 431 static_cast<int>(mTerminalPtNodePositionsForIteratingWords.size()); 432 if (token < 0 || token >= terminalPtNodePositionsVectorSize) { 433 AKLOGE("Given token %d is invalid.", token); 434 return 0; 435 } 436 const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; 437 int unigramProbability = NOT_A_PROBABILITY; 438 *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos, 439 MAX_WORD_LENGTH, outCodePoints, &unigramProbability); 440 const int nextToken = token + 1; 441 if (nextToken >= terminalPtNodePositionsVectorSize) { 442 // All words have been iterated. 443 mTerminalPtNodePositionsForIteratingWords.clear(); 444 return 0; 445 } 446 return nextToken; 447} 448 449} // namespace latinime 450