ver4_patricia_trie_node_writer.cpp revision 5520e84e16f3886548dc11cf888977a6dbbef5f1
1/* 2 * Copyright (C) 2013, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h" 18 19#include "suggest/core/dictionary/property/unigram_property.h" 20#include "suggest/policyimpl/dictionary/header/header_policy.h" 21#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_utils.h" 22#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_writing_utils.h" 23#include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h" 24#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h" 25#include "suggest/policyimpl/dictionary/structure/v4/shortcut/ver4_shortcut_list_policy.h" 26#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h" 27#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h" 28#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" 29#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" 30 31namespace latinime { 32 33const int Ver4PatriciaTrieNodeWriter::CHILDREN_POSITION_FIELD_SIZE = 3; 34 35bool Ver4PatriciaTrieNodeWriter::markPtNodeAsDeleted( 36 const PtNodeParams *const toBeUpdatedPtNodeParams) { 37 int pos = toBeUpdatedPtNodeParams->getHeadPos(); 38 const bool usesAdditionalBuffer = mTrieBuffer->isInAdditionalBuffer(pos); 39 const uint8_t *const dictBuf = mTrieBuffer->getBuffer(usesAdditionalBuffer); 40 if (usesAdditionalBuffer) { 41 pos -= mTrieBuffer->getOriginalBufferSize(); 42 } 43 // Read original flags 44 const PatriciaTrieReadingUtils::NodeFlags originalFlags = 45 PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos); 46 const PatriciaTrieReadingUtils::NodeFlags updatedFlags = 47 DynamicPtReadingUtils::updateAndGetFlags(originalFlags, false /* isMoved */, 48 true /* isDeleted */, false /* willBecomeNonTerminal */); 49 int writingPos = toBeUpdatedPtNodeParams->getHeadPos(); 50 // Update flags. 51 if (!DynamicPtWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, updatedFlags, 52 &writingPos)) { 53 return false; 54 } 55 if (toBeUpdatedPtNodeParams->isTerminal()) { 56 // The PtNode is a terminal. Delete entry from the terminal position lookup table. 57 return mBuffers->getMutableTerminalPositionLookupTable()->setTerminalPtNodePosition( 58 toBeUpdatedPtNodeParams->getTerminalId(), NOT_A_DICT_POS /* ptNodePos */); 59 } else { 60 return true; 61 } 62} 63 64bool Ver4PatriciaTrieNodeWriter::markPtNodeAsMoved( 65 const PtNodeParams *const toBeUpdatedPtNodeParams, 66 const int movedPos, const int bigramLinkedNodePos) { 67 int pos = toBeUpdatedPtNodeParams->getHeadPos(); 68 const bool usesAdditionalBuffer = mTrieBuffer->isInAdditionalBuffer(pos); 69 const uint8_t *const dictBuf = mTrieBuffer->getBuffer(usesAdditionalBuffer); 70 if (usesAdditionalBuffer) { 71 pos -= mTrieBuffer->getOriginalBufferSize(); 72 } 73 // Read original flags 74 const PatriciaTrieReadingUtils::NodeFlags originalFlags = 75 PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos); 76 const PatriciaTrieReadingUtils::NodeFlags updatedFlags = 77 DynamicPtReadingUtils::updateAndGetFlags(originalFlags, true /* isMoved */, 78 false /* isDeleted */, false /* willBecomeNonTerminal */); 79 int writingPos = toBeUpdatedPtNodeParams->getHeadPos(); 80 // Update flags. 81 if (!DynamicPtWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, updatedFlags, 82 &writingPos)) { 83 return false; 84 } 85 // Update moved position, which is stored in the parent offset field. 86 if (!DynamicPtWritingUtils::writeParentPosOffsetAndAdvancePosition( 87 mTrieBuffer, movedPos, toBeUpdatedPtNodeParams->getHeadPos(), &writingPos)) { 88 return false; 89 } 90 if (toBeUpdatedPtNodeParams->hasChildren()) { 91 // Update children's parent position. 92 mReadingHelper.initWithPtNodeArrayPos(toBeUpdatedPtNodeParams->getChildrenPos()); 93 while (!mReadingHelper.isEnd()) { 94 const PtNodeParams childPtNodeParams(mReadingHelper.getPtNodeParams()); 95 int parentOffsetFieldPos = childPtNodeParams.getHeadPos() 96 + DynamicPtWritingUtils::NODE_FLAG_FIELD_SIZE; 97 if (!DynamicPtWritingUtils::writeParentPosOffsetAndAdvancePosition( 98 mTrieBuffer, bigramLinkedNodePos, childPtNodeParams.getHeadPos(), 99 &parentOffsetFieldPos)) { 100 // Parent offset cannot be written because of a bug or a broken dictionary; thus, 101 // we give up to update dictionary. 102 return false; 103 } 104 mReadingHelper.readNextSiblingNode(childPtNodeParams); 105 } 106 } 107 return true; 108} 109 110bool Ver4PatriciaTrieNodeWriter::markPtNodeAsWillBecomeNonTerminal( 111 const PtNodeParams *const toBeUpdatedPtNodeParams) { 112 int pos = toBeUpdatedPtNodeParams->getHeadPos(); 113 const bool usesAdditionalBuffer = mTrieBuffer->isInAdditionalBuffer(pos); 114 const uint8_t *const dictBuf = mTrieBuffer->getBuffer(usesAdditionalBuffer); 115 if (usesAdditionalBuffer) { 116 pos -= mTrieBuffer->getOriginalBufferSize(); 117 } 118 // Read original flags 119 const PatriciaTrieReadingUtils::NodeFlags originalFlags = 120 PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos); 121 const PatriciaTrieReadingUtils::NodeFlags updatedFlags = 122 DynamicPtReadingUtils::updateAndGetFlags(originalFlags, false /* isMoved */, 123 false /* isDeleted */, true /* willBecomeNonTerminal */); 124 if (!mBuffers->getMutableTerminalPositionLookupTable()->setTerminalPtNodePosition( 125 toBeUpdatedPtNodeParams->getTerminalId(), NOT_A_DICT_POS /* ptNodePos */)) { 126 AKLOGE("Cannot update terminal position lookup table. terminal id: %d", 127 toBeUpdatedPtNodeParams->getTerminalId()); 128 return false; 129 } 130 // Update flags. 131 int writingPos = toBeUpdatedPtNodeParams->getHeadPos(); 132 return DynamicPtWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, updatedFlags, 133 &writingPos); 134} 135 136bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty( 137 const PtNodeParams *const toBeUpdatedPtNodeParams, 138 const UnigramProperty *const unigramProperty) { 139 // Update probability and historical information. 140 // TODO: Update other information in the unigram property. 141 if (!toBeUpdatedPtNodeParams->isTerminal()) { 142 return false; 143 } 144 const ProbabilityEntry originalProbabilityEntry = 145 mBuffers->getLanguageModelDictContent()->getProbabilityEntry( 146 toBeUpdatedPtNodeParams->getTerminalId()); 147 const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty); 148 const ProbabilityEntry updatedProbabilityEntry = 149 createUpdatedEntryFrom(&originalProbabilityEntry, &probabilityEntryOfUnigramProperty); 150 return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( 151 toBeUpdatedPtNodeParams->getTerminalId(), &updatedProbabilityEntry); 152} 153 154bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC( 155 const PtNodeParams *const toBeUpdatedPtNodeParams, bool *const outNeedsToKeepPtNode) { 156 if (!toBeUpdatedPtNodeParams->isTerminal()) { 157 AKLOGE("updatePtNodeProbabilityAndGetNeedsToSaveForGC is called for non-terminal PtNode."); 158 return false; 159 } 160 const ProbabilityEntry originalProbabilityEntry = 161 mBuffers->getLanguageModelDictContent()->getProbabilityEntry( 162 toBeUpdatedPtNodeParams->getTerminalId()); 163 if (originalProbabilityEntry.isValid()) { 164 *outNeedsToKeepPtNode = true; 165 return true; 166 } 167 if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) { 168 AKLOGE("Cannot mark PtNode as willBecomeNonTerminal."); 169 return false; 170 } 171 *outNeedsToKeepPtNode = false; 172 return true; 173} 174 175bool Ver4PatriciaTrieNodeWriter::updateChildrenPosition( 176 const PtNodeParams *const toBeUpdatedPtNodeParams, const int newChildrenPosition) { 177 int childrenPosFieldPos = toBeUpdatedPtNodeParams->getChildrenPosFieldPos(); 178 return DynamicPtWritingUtils::writeChildrenPositionAndAdvancePosition(mTrieBuffer, 179 newChildrenPosition, &childrenPosFieldPos); 180} 181 182bool Ver4PatriciaTrieNodeWriter::updateTerminalId(const PtNodeParams *const toBeUpdatedPtNodeParams, 183 const int newTerminalId) { 184 return mTrieBuffer->writeUint(newTerminalId, Ver4DictConstants::TERMINAL_ID_FIELD_SIZE, 185 toBeUpdatedPtNodeParams->getTerminalIdFieldPos()); 186} 187 188bool Ver4PatriciaTrieNodeWriter::writePtNodeAndAdvancePosition( 189 const PtNodeParams *const ptNodeParams, int *const ptNodeWritingPos) { 190 return writePtNodeAndGetTerminalIdAndAdvancePosition(ptNodeParams, 0 /* outTerminalId */, 191 ptNodeWritingPos); 192} 193 194 195bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition( 196 const PtNodeParams *const ptNodeParams, const UnigramProperty *const unigramProperty, 197 int *const ptNodeWritingPos) { 198 int terminalId = Ver4DictConstants::NOT_A_TERMINAL_ID; 199 if (!writePtNodeAndGetTerminalIdAndAdvancePosition(ptNodeParams, &terminalId, 200 ptNodeWritingPos)) { 201 return false; 202 } 203 // Write probability. 204 ProbabilityEntry newProbabilityEntry; 205 const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty); 206 const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom( 207 &newProbabilityEntry, &probabilityEntryOfUnigramProperty); 208 return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( 209 terminalId, &probabilityEntryToWrite); 210} 211 212bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, 213 const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) { 214 LanguageModelDictContent *const languageModelDictContent = 215 mBuffers->getMutableLanguageModelDictContent(); 216 const ProbabilityEntry probabilityEntry = 217 languageModelDictContent->getNgramProbabilityEntry(prevWordIds, wordId); 218 const ProbabilityEntry probabilityEntryOfBigramProperty(bigramProperty); 219 const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom( 220 &probabilityEntry, &probabilityEntryOfBigramProperty); 221 if (!languageModelDictContent->setNgramProbabilityEntry( 222 prevWordIds, wordId, &updatedProbabilityEntry)) { 223 AKLOGE("Cannot add new ngram entry. prevWordId[0]: %d, prevWordId.size(): %zd, wordId: %d", 224 prevWordIds[0], prevWordIds.size(), wordId); 225 return false; 226 } 227 if (!probabilityEntry.isValid() && outAddedNewBigram) { 228 *outAddedNewBigram = true; 229 } 230 return true; 231} 232 233bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds, 234 const int wordId) { 235 LanguageModelDictContent *const languageModelDictContent = 236 mBuffers->getMutableLanguageModelDictContent(); 237 return languageModelDictContent->removeNgramProbabilityEntry(prevWordIds, wordId); 238} 239 240// TODO: Remove when we stop supporting v402 format. 241bool Ver4PatriciaTrieNodeWriter::updateAllBigramEntriesAndDeleteUselessEntries( 242 const PtNodeParams *const sourcePtNodeParams, int *const outBigramEntryCount) { 243 // Do nothing. 244 return true; 245} 246 247bool Ver4PatriciaTrieNodeWriter::updateAllPositionFields( 248 const PtNodeParams *const toBeUpdatedPtNodeParams, 249 const DictPositionRelocationMap *const dictPositionRelocationMap, 250 int *const outBigramEntryCount) { 251 int parentPos = toBeUpdatedPtNodeParams->getParentPos(); 252 if (parentPos != NOT_A_DICT_POS) { 253 PtNodeWriter::PtNodePositionRelocationMap::const_iterator it = 254 dictPositionRelocationMap->mPtNodePositionRelocationMap.find(parentPos); 255 if (it != dictPositionRelocationMap->mPtNodePositionRelocationMap.end()) { 256 parentPos = it->second; 257 } 258 } 259 int writingPos = toBeUpdatedPtNodeParams->getHeadPos() 260 + DynamicPtWritingUtils::NODE_FLAG_FIELD_SIZE; 261 // Write updated parent offset. 262 if (!DynamicPtWritingUtils::writeParentPosOffsetAndAdvancePosition(mTrieBuffer, 263 parentPos, toBeUpdatedPtNodeParams->getHeadPos(), &writingPos)) { 264 return false; 265 } 266 267 // Updates children position. 268 int childrenPos = toBeUpdatedPtNodeParams->getChildrenPos(); 269 if (childrenPos != NOT_A_DICT_POS) { 270 PtNodeWriter::PtNodeArrayPositionRelocationMap::const_iterator it = 271 dictPositionRelocationMap->mPtNodeArrayPositionRelocationMap.find(childrenPos); 272 if (it != dictPositionRelocationMap->mPtNodeArrayPositionRelocationMap.end()) { 273 childrenPos = it->second; 274 } 275 } 276 if (!updateChildrenPosition(toBeUpdatedPtNodeParams, childrenPos)) { 277 return false; 278 } 279 return true; 280} 281 282bool Ver4PatriciaTrieNodeWriter::addShortcutTarget(const PtNodeParams *const ptNodeParams, 283 const int *const targetCodePoints, const int targetCodePointCount, 284 const int shortcutProbability) { 285 if (!mShortcutPolicy->addNewShortcut(ptNodeParams->getTerminalId(), 286 targetCodePoints, targetCodePointCount, shortcutProbability)) { 287 AKLOGE("Cannot add new shortuct entry. terminalId: %d", ptNodeParams->getTerminalId()); 288 return false; 289 } 290 return true; 291} 292 293bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition( 294 const PtNodeParams *const ptNodeParams, int *const outTerminalId, 295 int *const ptNodeWritingPos) { 296 const int nodePos = *ptNodeWritingPos; 297 // Write dummy flags. The Node flags are updated with appropriate flags at the last step of the 298 // PtNode writing. 299 if (!DynamicPtWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, 300 0 /* nodeFlags */, ptNodeWritingPos)) { 301 return false; 302 } 303 // Calculate a parent offset and write the offset. 304 if (!DynamicPtWritingUtils::writeParentPosOffsetAndAdvancePosition(mTrieBuffer, 305 ptNodeParams->getParentPos(), nodePos, ptNodeWritingPos)) { 306 return false; 307 } 308 // Write code points 309 if (!DynamicPtWritingUtils::writeCodePointsAndAdvancePosition(mTrieBuffer, 310 ptNodeParams->getCodePoints(), ptNodeParams->getCodePointCount(), ptNodeWritingPos)) { 311 return false; 312 } 313 int terminalId = Ver4DictConstants::NOT_A_TERMINAL_ID; 314 if (!ptNodeParams->willBecomeNonTerminal()) { 315 if (ptNodeParams->getTerminalId() != Ver4DictConstants::NOT_A_TERMINAL_ID) { 316 terminalId = ptNodeParams->getTerminalId(); 317 } else if (ptNodeParams->isTerminal()) { 318 // Write terminal information using a new terminal id. 319 // Get a new unused terminal id. 320 terminalId = mBuffers->getTerminalPositionLookupTable()->getNextTerminalId(); 321 } 322 } 323 const int isTerminal = terminalId != Ver4DictConstants::NOT_A_TERMINAL_ID; 324 if (isTerminal) { 325 // Update the lookup table. 326 if (!mBuffers->getMutableTerminalPositionLookupTable()->setTerminalPtNodePosition( 327 terminalId, nodePos)) { 328 return false; 329 } 330 // Write terminal Id. 331 if (!mTrieBuffer->writeUintAndAdvancePosition(terminalId, 332 Ver4DictConstants::TERMINAL_ID_FIELD_SIZE, ptNodeWritingPos)) { 333 return false; 334 } 335 if (outTerminalId) { 336 *outTerminalId = terminalId; 337 } 338 } 339 // Write children position 340 if (!DynamicPtWritingUtils::writeChildrenPositionAndAdvancePosition(mTrieBuffer, 341 ptNodeParams->getChildrenPos(), ptNodeWritingPos)) { 342 return false; 343 } 344 return updatePtNodeFlags(nodePos, ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(), 345 isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); 346} 347 348// TODO: Move probability handling code to LanguageModelDictContent. 349const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( 350 const ProbabilityEntry *const originalProbabilityEntry, 351 const ProbabilityEntry *const probabilityEntry) const { 352 if (mHeaderPolicy->hasHistoricalInfoOfWords()) { 353 const HistoricalInfo updatedHistoricalInfo = 354 ForgettingCurveUtils::createUpdatedHistoricalInfo( 355 originalProbabilityEntry->getHistoricalInfo(), 356 probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(), 357 mHeaderPolicy); 358 return ProbabilityEntry(probabilityEntry->getFlags(), &updatedHistoricalInfo); 359 } else { 360 return *probabilityEntry; 361 } 362} 363 364bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, 365 const bool isBlacklisted, const bool isNotAWord, const bool isTerminal, 366 const bool hasMultipleChars) { 367 // Create node flags and write them. 368 PatriciaTrieReadingUtils::NodeFlags nodeFlags = 369 PatriciaTrieReadingUtils::createAndGetFlags(isBlacklisted, isNotAWord, isTerminal, 370 false /* hasShortcutTargets */, false /* hasBigrams */, hasMultipleChars, 371 CHILDREN_POSITION_FIELD_SIZE); 372 if (!DynamicPtWritingUtils::writeFlags(mTrieBuffer, nodeFlags, ptNodePos)) { 373 AKLOGE("Cannot write PtNode flags. flags: %x, pos: %d", nodeFlags, ptNodePos); 374 return false; 375 } 376 return true; 377} 378 379} // namespace latinime 380