ver4_patricia_trie_node_writer.cpp revision 5520e84e16f3886548dc11cf888977a6dbbef5f1
1/*
2 * Copyright (C) 2013, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h"
18
19#include "suggest/core/dictionary/property/unigram_property.h"
20#include "suggest/policyimpl/dictionary/header/header_policy.h"
21#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_utils.h"
22#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_writing_utils.h"
23#include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h"
24#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
25#include "suggest/policyimpl/dictionary/structure/v4/shortcut/ver4_shortcut_list_policy.h"
26#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h"
27#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h"
28#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h"
29#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
30
31namespace latinime {
32
33const int Ver4PatriciaTrieNodeWriter::CHILDREN_POSITION_FIELD_SIZE = 3;
34
35bool Ver4PatriciaTrieNodeWriter::markPtNodeAsDeleted(
36        const PtNodeParams *const toBeUpdatedPtNodeParams) {
37    int pos = toBeUpdatedPtNodeParams->getHeadPos();
38    const bool usesAdditionalBuffer = mTrieBuffer->isInAdditionalBuffer(pos);
39    const uint8_t *const dictBuf = mTrieBuffer->getBuffer(usesAdditionalBuffer);
40    if (usesAdditionalBuffer) {
41        pos -= mTrieBuffer->getOriginalBufferSize();
42    }
43    // Read original flags
44    const PatriciaTrieReadingUtils::NodeFlags originalFlags =
45            PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos);
46    const PatriciaTrieReadingUtils::NodeFlags updatedFlags =
47            DynamicPtReadingUtils::updateAndGetFlags(originalFlags, false /* isMoved */,
48                    true /* isDeleted */, false /* willBecomeNonTerminal */);
49    int writingPos = toBeUpdatedPtNodeParams->getHeadPos();
50    // Update flags.
51    if (!DynamicPtWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, updatedFlags,
52            &writingPos)) {
53        return false;
54    }
55    if (toBeUpdatedPtNodeParams->isTerminal()) {
56        // The PtNode is a terminal. Delete entry from the terminal position lookup table.
57        return mBuffers->getMutableTerminalPositionLookupTable()->setTerminalPtNodePosition(
58                toBeUpdatedPtNodeParams->getTerminalId(), NOT_A_DICT_POS /* ptNodePos */);
59    } else {
60        return true;
61    }
62}
63
64bool Ver4PatriciaTrieNodeWriter::markPtNodeAsMoved(
65        const PtNodeParams *const toBeUpdatedPtNodeParams,
66        const int movedPos, const int bigramLinkedNodePos) {
67    int pos = toBeUpdatedPtNodeParams->getHeadPos();
68    const bool usesAdditionalBuffer = mTrieBuffer->isInAdditionalBuffer(pos);
69    const uint8_t *const dictBuf = mTrieBuffer->getBuffer(usesAdditionalBuffer);
70    if (usesAdditionalBuffer) {
71        pos -= mTrieBuffer->getOriginalBufferSize();
72    }
73    // Read original flags
74    const PatriciaTrieReadingUtils::NodeFlags originalFlags =
75            PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos);
76    const PatriciaTrieReadingUtils::NodeFlags updatedFlags =
77            DynamicPtReadingUtils::updateAndGetFlags(originalFlags, true /* isMoved */,
78                    false /* isDeleted */, false /* willBecomeNonTerminal */);
79    int writingPos = toBeUpdatedPtNodeParams->getHeadPos();
80    // Update flags.
81    if (!DynamicPtWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, updatedFlags,
82            &writingPos)) {
83        return false;
84    }
85    // Update moved position, which is stored in the parent offset field.
86    if (!DynamicPtWritingUtils::writeParentPosOffsetAndAdvancePosition(
87            mTrieBuffer, movedPos, toBeUpdatedPtNodeParams->getHeadPos(), &writingPos)) {
88        return false;
89    }
90    if (toBeUpdatedPtNodeParams->hasChildren()) {
91        // Update children's parent position.
92        mReadingHelper.initWithPtNodeArrayPos(toBeUpdatedPtNodeParams->getChildrenPos());
93        while (!mReadingHelper.isEnd()) {
94            const PtNodeParams childPtNodeParams(mReadingHelper.getPtNodeParams());
95            int parentOffsetFieldPos = childPtNodeParams.getHeadPos()
96                    + DynamicPtWritingUtils::NODE_FLAG_FIELD_SIZE;
97            if (!DynamicPtWritingUtils::writeParentPosOffsetAndAdvancePosition(
98                    mTrieBuffer, bigramLinkedNodePos, childPtNodeParams.getHeadPos(),
99                    &parentOffsetFieldPos)) {
100                // Parent offset cannot be written because of a bug or a broken dictionary; thus,
101                // we give up to update dictionary.
102                return false;
103            }
104            mReadingHelper.readNextSiblingNode(childPtNodeParams);
105        }
106    }
107    return true;
108}
109
110bool Ver4PatriciaTrieNodeWriter::markPtNodeAsWillBecomeNonTerminal(
111        const PtNodeParams *const toBeUpdatedPtNodeParams) {
112    int pos = toBeUpdatedPtNodeParams->getHeadPos();
113    const bool usesAdditionalBuffer = mTrieBuffer->isInAdditionalBuffer(pos);
114    const uint8_t *const dictBuf = mTrieBuffer->getBuffer(usesAdditionalBuffer);
115    if (usesAdditionalBuffer) {
116        pos -= mTrieBuffer->getOriginalBufferSize();
117    }
118    // Read original flags
119    const PatriciaTrieReadingUtils::NodeFlags originalFlags =
120            PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos);
121    const PatriciaTrieReadingUtils::NodeFlags updatedFlags =
122            DynamicPtReadingUtils::updateAndGetFlags(originalFlags, false /* isMoved */,
123                    false /* isDeleted */, true /* willBecomeNonTerminal */);
124    if (!mBuffers->getMutableTerminalPositionLookupTable()->setTerminalPtNodePosition(
125            toBeUpdatedPtNodeParams->getTerminalId(), NOT_A_DICT_POS /* ptNodePos */)) {
126        AKLOGE("Cannot update terminal position lookup table. terminal id: %d",
127                toBeUpdatedPtNodeParams->getTerminalId());
128        return false;
129    }
130    // Update flags.
131    int writingPos = toBeUpdatedPtNodeParams->getHeadPos();
132    return DynamicPtWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, updatedFlags,
133            &writingPos);
134}
135
136bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty(
137        const PtNodeParams *const toBeUpdatedPtNodeParams,
138        const UnigramProperty *const unigramProperty) {
139    // Update probability and historical information.
140    // TODO: Update other information in the unigram property.
141    if (!toBeUpdatedPtNodeParams->isTerminal()) {
142        return false;
143    }
144    const ProbabilityEntry originalProbabilityEntry =
145            mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
146                    toBeUpdatedPtNodeParams->getTerminalId());
147    const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
148    const ProbabilityEntry updatedProbabilityEntry =
149            createUpdatedEntryFrom(&originalProbabilityEntry, &probabilityEntryOfUnigramProperty);
150    return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
151            toBeUpdatedPtNodeParams->getTerminalId(), &updatedProbabilityEntry);
152}
153
154bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC(
155        const PtNodeParams *const toBeUpdatedPtNodeParams, bool *const outNeedsToKeepPtNode) {
156    if (!toBeUpdatedPtNodeParams->isTerminal()) {
157        AKLOGE("updatePtNodeProbabilityAndGetNeedsToSaveForGC is called for non-terminal PtNode.");
158        return false;
159    }
160    const ProbabilityEntry originalProbabilityEntry =
161            mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
162                    toBeUpdatedPtNodeParams->getTerminalId());
163    if (originalProbabilityEntry.isValid()) {
164        *outNeedsToKeepPtNode = true;
165        return true;
166    }
167    if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
168        AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
169        return false;
170    }
171    *outNeedsToKeepPtNode = false;
172    return true;
173}
174
175bool Ver4PatriciaTrieNodeWriter::updateChildrenPosition(
176        const PtNodeParams *const toBeUpdatedPtNodeParams, const int newChildrenPosition) {
177    int childrenPosFieldPos = toBeUpdatedPtNodeParams->getChildrenPosFieldPos();
178    return DynamicPtWritingUtils::writeChildrenPositionAndAdvancePosition(mTrieBuffer,
179            newChildrenPosition, &childrenPosFieldPos);
180}
181
182bool Ver4PatriciaTrieNodeWriter::updateTerminalId(const PtNodeParams *const toBeUpdatedPtNodeParams,
183        const int newTerminalId) {
184    return mTrieBuffer->writeUint(newTerminalId, Ver4DictConstants::TERMINAL_ID_FIELD_SIZE,
185            toBeUpdatedPtNodeParams->getTerminalIdFieldPos());
186}
187
188bool Ver4PatriciaTrieNodeWriter::writePtNodeAndAdvancePosition(
189        const PtNodeParams *const ptNodeParams, int *const ptNodeWritingPos) {
190    return writePtNodeAndGetTerminalIdAndAdvancePosition(ptNodeParams, 0 /* outTerminalId */,
191            ptNodeWritingPos);
192}
193
194
195bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
196        const PtNodeParams *const ptNodeParams, const UnigramProperty *const unigramProperty,
197        int *const ptNodeWritingPos) {
198    int terminalId = Ver4DictConstants::NOT_A_TERMINAL_ID;
199    if (!writePtNodeAndGetTerminalIdAndAdvancePosition(ptNodeParams, &terminalId,
200            ptNodeWritingPos)) {
201        return false;
202    }
203    // Write probability.
204    ProbabilityEntry newProbabilityEntry;
205    const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
206    const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom(
207            &newProbabilityEntry, &probabilityEntryOfUnigramProperty);
208    return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
209            terminalId, &probabilityEntryToWrite);
210}
211
212bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId,
213        const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) {
214    LanguageModelDictContent *const languageModelDictContent =
215            mBuffers->getMutableLanguageModelDictContent();
216    const ProbabilityEntry probabilityEntry =
217            languageModelDictContent->getNgramProbabilityEntry(prevWordIds, wordId);
218    const ProbabilityEntry probabilityEntryOfBigramProperty(bigramProperty);
219    const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom(
220            &probabilityEntry, &probabilityEntryOfBigramProperty);
221    if (!languageModelDictContent->setNgramProbabilityEntry(
222            prevWordIds, wordId, &updatedProbabilityEntry)) {
223        AKLOGE("Cannot add new ngram entry. prevWordId[0]: %d, prevWordId.size(): %zd, wordId: %d",
224                prevWordIds[0], prevWordIds.size(), wordId);
225        return false;
226    }
227    if (!probabilityEntry.isValid() && outAddedNewBigram) {
228        *outAddedNewBigram = true;
229    }
230    return true;
231}
232
233bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds,
234        const int wordId) {
235    LanguageModelDictContent *const languageModelDictContent =
236            mBuffers->getMutableLanguageModelDictContent();
237    return languageModelDictContent->removeNgramProbabilityEntry(prevWordIds, wordId);
238}
239
240// TODO: Remove when we stop supporting v402 format.
241bool Ver4PatriciaTrieNodeWriter::updateAllBigramEntriesAndDeleteUselessEntries(
242            const PtNodeParams *const sourcePtNodeParams, int *const outBigramEntryCount) {
243    // Do nothing.
244    return true;
245}
246
247bool Ver4PatriciaTrieNodeWriter::updateAllPositionFields(
248        const PtNodeParams *const toBeUpdatedPtNodeParams,
249        const DictPositionRelocationMap *const dictPositionRelocationMap,
250        int *const outBigramEntryCount) {
251    int parentPos = toBeUpdatedPtNodeParams->getParentPos();
252    if (parentPos != NOT_A_DICT_POS) {
253        PtNodeWriter::PtNodePositionRelocationMap::const_iterator it =
254                dictPositionRelocationMap->mPtNodePositionRelocationMap.find(parentPos);
255        if (it != dictPositionRelocationMap->mPtNodePositionRelocationMap.end()) {
256            parentPos = it->second;
257        }
258    }
259    int writingPos = toBeUpdatedPtNodeParams->getHeadPos()
260            + DynamicPtWritingUtils::NODE_FLAG_FIELD_SIZE;
261    // Write updated parent offset.
262    if (!DynamicPtWritingUtils::writeParentPosOffsetAndAdvancePosition(mTrieBuffer,
263            parentPos, toBeUpdatedPtNodeParams->getHeadPos(), &writingPos)) {
264        return false;
265    }
266
267    // Updates children position.
268    int childrenPos = toBeUpdatedPtNodeParams->getChildrenPos();
269    if (childrenPos != NOT_A_DICT_POS) {
270        PtNodeWriter::PtNodeArrayPositionRelocationMap::const_iterator it =
271                dictPositionRelocationMap->mPtNodeArrayPositionRelocationMap.find(childrenPos);
272        if (it != dictPositionRelocationMap->mPtNodeArrayPositionRelocationMap.end()) {
273            childrenPos = it->second;
274        }
275    }
276    if (!updateChildrenPosition(toBeUpdatedPtNodeParams, childrenPos)) {
277        return false;
278    }
279    return true;
280}
281
282bool Ver4PatriciaTrieNodeWriter::addShortcutTarget(const PtNodeParams *const ptNodeParams,
283        const int *const targetCodePoints, const int targetCodePointCount,
284        const int shortcutProbability) {
285    if (!mShortcutPolicy->addNewShortcut(ptNodeParams->getTerminalId(),
286            targetCodePoints, targetCodePointCount, shortcutProbability)) {
287        AKLOGE("Cannot add new shortuct entry. terminalId: %d", ptNodeParams->getTerminalId());
288        return false;
289    }
290    return true;
291}
292
293bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
294        const PtNodeParams *const ptNodeParams, int *const outTerminalId,
295        int *const ptNodeWritingPos) {
296    const int nodePos = *ptNodeWritingPos;
297    // Write dummy flags. The Node flags are updated with appropriate flags at the last step of the
298    // PtNode writing.
299    if (!DynamicPtWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer,
300            0 /* nodeFlags */, ptNodeWritingPos)) {
301        return false;
302    }
303    // Calculate a parent offset and write the offset.
304    if (!DynamicPtWritingUtils::writeParentPosOffsetAndAdvancePosition(mTrieBuffer,
305            ptNodeParams->getParentPos(), nodePos, ptNodeWritingPos)) {
306        return false;
307    }
308    // Write code points
309    if (!DynamicPtWritingUtils::writeCodePointsAndAdvancePosition(mTrieBuffer,
310            ptNodeParams->getCodePoints(), ptNodeParams->getCodePointCount(), ptNodeWritingPos)) {
311        return false;
312    }
313    int terminalId = Ver4DictConstants::NOT_A_TERMINAL_ID;
314    if (!ptNodeParams->willBecomeNonTerminal()) {
315        if (ptNodeParams->getTerminalId() != Ver4DictConstants::NOT_A_TERMINAL_ID) {
316            terminalId = ptNodeParams->getTerminalId();
317        } else if (ptNodeParams->isTerminal()) {
318            // Write terminal information using a new terminal id.
319            // Get a new unused terminal id.
320            terminalId = mBuffers->getTerminalPositionLookupTable()->getNextTerminalId();
321        }
322    }
323    const int isTerminal = terminalId != Ver4DictConstants::NOT_A_TERMINAL_ID;
324    if (isTerminal) {
325        // Update the lookup table.
326        if (!mBuffers->getMutableTerminalPositionLookupTable()->setTerminalPtNodePosition(
327                terminalId, nodePos)) {
328            return false;
329        }
330        // Write terminal Id.
331        if (!mTrieBuffer->writeUintAndAdvancePosition(terminalId,
332                Ver4DictConstants::TERMINAL_ID_FIELD_SIZE, ptNodeWritingPos)) {
333            return false;
334        }
335        if (outTerminalId) {
336            *outTerminalId = terminalId;
337        }
338    }
339    // Write children position
340    if (!DynamicPtWritingUtils::writeChildrenPositionAndAdvancePosition(mTrieBuffer,
341            ptNodeParams->getChildrenPos(), ptNodeWritingPos)) {
342        return false;
343    }
344    return updatePtNodeFlags(nodePos, ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(),
345            isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
346}
347
348// TODO: Move probability handling code to LanguageModelDictContent.
349const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
350        const ProbabilityEntry *const originalProbabilityEntry,
351        const ProbabilityEntry *const probabilityEntry) const {
352    if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
353        const HistoricalInfo updatedHistoricalInfo =
354                ForgettingCurveUtils::createUpdatedHistoricalInfo(
355                        originalProbabilityEntry->getHistoricalInfo(),
356                        probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(),
357                        mHeaderPolicy);
358        return ProbabilityEntry(probabilityEntry->getFlags(), &updatedHistoricalInfo);
359    } else {
360        return *probabilityEntry;
361    }
362}
363
364bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos,
365        const bool isBlacklisted, const bool isNotAWord, const bool isTerminal,
366        const bool hasMultipleChars) {
367    // Create node flags and write them.
368    PatriciaTrieReadingUtils::NodeFlags nodeFlags =
369            PatriciaTrieReadingUtils::createAndGetFlags(isBlacklisted, isNotAWord, isTerminal,
370                    false /* hasShortcutTargets */, false /* hasBigrams */, hasMultipleChars,
371                    CHILDREN_POSITION_FIELD_SIZE);
372    if (!DynamicPtWritingUtils::writeFlags(mTrieBuffer, nodeFlags, ptNodePos)) {
373        AKLOGE("Cannot write PtNode flags. flags: %x, pos: %d", nodeFlags, ptNodePos);
374        return false;
375    }
376    return true;
377}
378
379} // namespace latinime
380