1/*
2 * Copyright (C) 2013, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "dictionary/structure/v4/ver4_patricia_trie_node_writer.h"
18
19#include "dictionary/header/header_policy.h"
20#include "dictionary/property/unigram_property.h"
21#include "dictionary/structure/pt_common/dynamic_pt_reading_utils.h"
22#include "dictionary/structure/pt_common/dynamic_pt_writing_utils.h"
23#include "dictionary/structure/pt_common/patricia_trie_reading_utils.h"
24#include "dictionary/structure/v4/content/probability_entry.h"
25#include "dictionary/structure/v4/shortcut/ver4_shortcut_list_policy.h"
26#include "dictionary/structure/v4/ver4_patricia_trie_node_reader.h"
27#include "dictionary/structure/v4/ver4_dict_buffers.h"
28#include "dictionary/utils/buffer_with_extendable_buffer.h"
29#include "dictionary/utils/forgetting_curve_utils.h"
30
31namespace latinime {
32
33const int Ver4PatriciaTrieNodeWriter::CHILDREN_POSITION_FIELD_SIZE = 3;
34
35bool Ver4PatriciaTrieNodeWriter::markPtNodeAsDeleted(
36        const PtNodeParams *const toBeUpdatedPtNodeParams) {
37    int pos = toBeUpdatedPtNodeParams->getHeadPos();
38    const bool usesAdditionalBuffer = mTrieBuffer->isInAdditionalBuffer(pos);
39    const uint8_t *const dictBuf = mTrieBuffer->getBuffer(usesAdditionalBuffer);
40    if (usesAdditionalBuffer) {
41        pos -= mTrieBuffer->getOriginalBufferSize();
42    }
43    // Read original flags
44    const PatriciaTrieReadingUtils::NodeFlags originalFlags =
45            PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos);
46    const PatriciaTrieReadingUtils::NodeFlags updatedFlags =
47            DynamicPtReadingUtils::updateAndGetFlags(originalFlags, false /* isMoved */,
48                    true /* isDeleted */, false /* willBecomeNonTerminal */);
49    int writingPos = toBeUpdatedPtNodeParams->getHeadPos();
50    // Update flags.
51    if (!DynamicPtWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, updatedFlags,
52            &writingPos)) {
53        return false;
54    }
55    if (toBeUpdatedPtNodeParams->isTerminal()) {
56        // The PtNode is a terminal. Delete entry from the terminal position lookup table.
57        return mBuffers->getMutableTerminalPositionLookupTable()->setTerminalPtNodePosition(
58                toBeUpdatedPtNodeParams->getTerminalId(), NOT_A_DICT_POS /* ptNodePos */);
59    } else {
60        return true;
61    }
62}
63
64// TODO: Quit using bigramLinkedNodePos.
65bool Ver4PatriciaTrieNodeWriter::markPtNodeAsMoved(
66        const PtNodeParams *const toBeUpdatedPtNodeParams,
67        const int movedPos, const int bigramLinkedNodePos) {
68    int pos = toBeUpdatedPtNodeParams->getHeadPos();
69    const bool usesAdditionalBuffer = mTrieBuffer->isInAdditionalBuffer(pos);
70    const uint8_t *const dictBuf = mTrieBuffer->getBuffer(usesAdditionalBuffer);
71    if (usesAdditionalBuffer) {
72        pos -= mTrieBuffer->getOriginalBufferSize();
73    }
74    // Read original flags
75    const PatriciaTrieReadingUtils::NodeFlags originalFlags =
76            PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos);
77    const PatriciaTrieReadingUtils::NodeFlags updatedFlags =
78            DynamicPtReadingUtils::updateAndGetFlags(originalFlags, true /* isMoved */,
79                    false /* isDeleted */, false /* willBecomeNonTerminal */);
80    int writingPos = toBeUpdatedPtNodeParams->getHeadPos();
81    // Update flags.
82    if (!DynamicPtWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, updatedFlags,
83            &writingPos)) {
84        return false;
85    }
86    // Update moved position, which is stored in the parent offset field.
87    if (!DynamicPtWritingUtils::writeParentPosOffsetAndAdvancePosition(
88            mTrieBuffer, movedPos, toBeUpdatedPtNodeParams->getHeadPos(), &writingPos)) {
89        return false;
90    }
91    if (toBeUpdatedPtNodeParams->hasChildren()) {
92        // Update children's parent position.
93        mReadingHelper.initWithPtNodeArrayPos(toBeUpdatedPtNodeParams->getChildrenPos());
94        while (!mReadingHelper.isEnd()) {
95            const PtNodeParams childPtNodeParams(mReadingHelper.getPtNodeParams());
96            int parentOffsetFieldPos = childPtNodeParams.getHeadPos()
97                    + DynamicPtWritingUtils::NODE_FLAG_FIELD_SIZE;
98            if (!DynamicPtWritingUtils::writeParentPosOffsetAndAdvancePosition(
99                    mTrieBuffer, bigramLinkedNodePos, childPtNodeParams.getHeadPos(),
100                    &parentOffsetFieldPos)) {
101                // Parent offset cannot be written because of a bug or a broken dictionary; thus,
102                // we give up to update dictionary.
103                return false;
104            }
105            mReadingHelper.readNextSiblingNode(childPtNodeParams);
106        }
107    }
108    return true;
109}
110
111bool Ver4PatriciaTrieNodeWriter::markPtNodeAsWillBecomeNonTerminal(
112        const PtNodeParams *const toBeUpdatedPtNodeParams) {
113    int pos = toBeUpdatedPtNodeParams->getHeadPos();
114    const bool usesAdditionalBuffer = mTrieBuffer->isInAdditionalBuffer(pos);
115    const uint8_t *const dictBuf = mTrieBuffer->getBuffer(usesAdditionalBuffer);
116    if (usesAdditionalBuffer) {
117        pos -= mTrieBuffer->getOriginalBufferSize();
118    }
119    // Read original flags
120    const PatriciaTrieReadingUtils::NodeFlags originalFlags =
121            PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos);
122    const PatriciaTrieReadingUtils::NodeFlags updatedFlags =
123            DynamicPtReadingUtils::updateAndGetFlags(originalFlags, false /* isMoved */,
124                    false /* isDeleted */, true /* willBecomeNonTerminal */);
125    if (!mBuffers->getMutableTerminalPositionLookupTable()->setTerminalPtNodePosition(
126            toBeUpdatedPtNodeParams->getTerminalId(), NOT_A_DICT_POS /* ptNodePos */)) {
127        AKLOGE("Cannot update terminal position lookup table. terminal id: %d",
128                toBeUpdatedPtNodeParams->getTerminalId());
129        return false;
130    }
131    // Update flags.
132    int writingPos = toBeUpdatedPtNodeParams->getHeadPos();
133    return DynamicPtWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, updatedFlags,
134            &writingPos);
135}
136
137bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty(
138        const PtNodeParams *const toBeUpdatedPtNodeParams,
139        const UnigramProperty *const unigramProperty) {
140    // Update probability and historical information.
141    // TODO: Update other information in the unigram property.
142    if (!toBeUpdatedPtNodeParams->isTerminal()) {
143        return false;
144    }
145    const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
146    return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
147            toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntryOfUnigramProperty);
148}
149
150bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC(
151        const PtNodeParams *const toBeUpdatedPtNodeParams, bool *const outNeedsToKeepPtNode) {
152    if (!toBeUpdatedPtNodeParams->isTerminal()) {
153        AKLOGE("updatePtNodeProbabilityAndGetNeedsToSaveForGC is called for non-terminal PtNode.");
154        return false;
155    }
156    const ProbabilityEntry originalProbabilityEntry =
157            mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
158                    toBeUpdatedPtNodeParams->getTerminalId());
159    if (originalProbabilityEntry.isValid()) {
160        *outNeedsToKeepPtNode = true;
161        return true;
162    }
163    if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
164        AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
165        return false;
166    }
167    *outNeedsToKeepPtNode = false;
168    return true;
169}
170
171bool Ver4PatriciaTrieNodeWriter::updateChildrenPosition(
172        const PtNodeParams *const toBeUpdatedPtNodeParams, const int newChildrenPosition) {
173    int childrenPosFieldPos = toBeUpdatedPtNodeParams->getChildrenPosFieldPos();
174    return DynamicPtWritingUtils::writeChildrenPositionAndAdvancePosition(mTrieBuffer,
175            newChildrenPosition, &childrenPosFieldPos);
176}
177
178bool Ver4PatriciaTrieNodeWriter::updateTerminalId(const PtNodeParams *const toBeUpdatedPtNodeParams,
179        const int newTerminalId) {
180    return mTrieBuffer->writeUint(newTerminalId, Ver4DictConstants::TERMINAL_ID_FIELD_SIZE,
181            toBeUpdatedPtNodeParams->getTerminalIdFieldPos());
182}
183
184bool Ver4PatriciaTrieNodeWriter::writePtNodeAndAdvancePosition(
185        const PtNodeParams *const ptNodeParams, int *const ptNodeWritingPos) {
186    return writePtNodeAndGetTerminalIdAndAdvancePosition(ptNodeParams, 0 /* outTerminalId */,
187            ptNodeWritingPos);
188}
189
190bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
191        const PtNodeParams *const ptNodeParams, const UnigramProperty *const unigramProperty,
192        int *const ptNodeWritingPos) {
193    int terminalId = Ver4DictConstants::NOT_A_TERMINAL_ID;
194    if (!writePtNodeAndGetTerminalIdAndAdvancePosition(ptNodeParams, &terminalId,
195            ptNodeWritingPos)) {
196        return false;
197    }
198    // Write probability.
199    ProbabilityEntry newProbabilityEntry;
200    const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
201    return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
202            terminalId, &probabilityEntryOfUnigramProperty);
203}
204
205// TODO: Support counting ngram entries.
206bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId,
207        const NgramProperty *const ngramProperty, bool *const outAddedNewBigram) {
208    LanguageModelDictContent *const languageModelDictContent =
209            mBuffers->getMutableLanguageModelDictContent();
210    const ProbabilityEntry probabilityEntry =
211            languageModelDictContent->getNgramProbabilityEntry(prevWordIds, wordId);
212    const ProbabilityEntry probabilityEntryOfNgramProperty(ngramProperty);
213    if (!languageModelDictContent->setNgramProbabilityEntry(
214            prevWordIds, wordId, &probabilityEntryOfNgramProperty)) {
215        AKLOGE("Cannot add new ngram entry. prevWordId[0]: %d, prevWordId.size(): %zd, wordId: %d",
216                prevWordIds[0], prevWordIds.size(), wordId);
217        return false;
218    }
219    if (!probabilityEntry.isValid() && outAddedNewBigram) {
220        *outAddedNewBigram = true;
221    }
222    return true;
223}
224
225bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds,
226        const int wordId) {
227    LanguageModelDictContent *const languageModelDictContent =
228            mBuffers->getMutableLanguageModelDictContent();
229    return languageModelDictContent->removeNgramProbabilityEntry(prevWordIds, wordId);
230}
231
232// TODO: Remove when we stop supporting v402 format.
233bool Ver4PatriciaTrieNodeWriter::updateAllBigramEntriesAndDeleteUselessEntries(
234            const PtNodeParams *const sourcePtNodeParams, int *const outBigramEntryCount) {
235    // Do nothing.
236    return true;
237}
238
239bool Ver4PatriciaTrieNodeWriter::updateAllPositionFields(
240        const PtNodeParams *const toBeUpdatedPtNodeParams,
241        const DictPositionRelocationMap *const dictPositionRelocationMap,
242        int *const outBigramEntryCount) {
243    int parentPos = toBeUpdatedPtNodeParams->getParentPos();
244    if (parentPos != NOT_A_DICT_POS) {
245        PtNodeWriter::PtNodePositionRelocationMap::const_iterator it =
246                dictPositionRelocationMap->mPtNodePositionRelocationMap.find(parentPos);
247        if (it != dictPositionRelocationMap->mPtNodePositionRelocationMap.end()) {
248            parentPos = it->second;
249        }
250    }
251    int writingPos = toBeUpdatedPtNodeParams->getHeadPos()
252            + DynamicPtWritingUtils::NODE_FLAG_FIELD_SIZE;
253    // Write updated parent offset.
254    if (!DynamicPtWritingUtils::writeParentPosOffsetAndAdvancePosition(mTrieBuffer,
255            parentPos, toBeUpdatedPtNodeParams->getHeadPos(), &writingPos)) {
256        return false;
257    }
258
259    // Updates children position.
260    int childrenPos = toBeUpdatedPtNodeParams->getChildrenPos();
261    if (childrenPos != NOT_A_DICT_POS) {
262        PtNodeWriter::PtNodeArrayPositionRelocationMap::const_iterator it =
263                dictPositionRelocationMap->mPtNodeArrayPositionRelocationMap.find(childrenPos);
264        if (it != dictPositionRelocationMap->mPtNodeArrayPositionRelocationMap.end()) {
265            childrenPos = it->second;
266        }
267    }
268    if (!updateChildrenPosition(toBeUpdatedPtNodeParams, childrenPos)) {
269        return false;
270    }
271    return true;
272}
273
274bool Ver4PatriciaTrieNodeWriter::addShortcutTarget(const PtNodeParams *const ptNodeParams,
275        const int *const targetCodePoints, const int targetCodePointCount,
276        const int shortcutProbability) {
277    if (!mShortcutPolicy->addNewShortcut(ptNodeParams->getTerminalId(),
278            targetCodePoints, targetCodePointCount, shortcutProbability)) {
279        AKLOGE("Cannot add new shortcut entry. terminalId: %d", ptNodeParams->getTerminalId());
280        return false;
281    }
282    return true;
283}
284
285bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
286        const PtNodeParams *const ptNodeParams, int *const outTerminalId,
287        int *const ptNodeWritingPos) {
288    const int nodePos = *ptNodeWritingPos;
289    // Write dummy flags. The Node flags are updated with appropriate flags at the last step of the
290    // PtNode writing.
291    if (!DynamicPtWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer,
292            0 /* nodeFlags */, ptNodeWritingPos)) {
293        return false;
294    }
295    // Calculate a parent offset and write the offset.
296    if (!DynamicPtWritingUtils::writeParentPosOffsetAndAdvancePosition(mTrieBuffer,
297            ptNodeParams->getParentPos(), nodePos, ptNodeWritingPos)) {
298        return false;
299    }
300    // Write code points
301    if (!DynamicPtWritingUtils::writeCodePointsAndAdvancePosition(mTrieBuffer,
302            ptNodeParams->getCodePoints(), ptNodeParams->getCodePointCount(), ptNodeWritingPos)) {
303        return false;
304    }
305    int terminalId = Ver4DictConstants::NOT_A_TERMINAL_ID;
306    if (!ptNodeParams->willBecomeNonTerminal()) {
307        if (ptNodeParams->getTerminalId() != Ver4DictConstants::NOT_A_TERMINAL_ID) {
308            terminalId = ptNodeParams->getTerminalId();
309        } else if (ptNodeParams->isTerminal()) {
310            // Write terminal information using a new terminal id.
311            // Get a new unused terminal id.
312            terminalId = mBuffers->getTerminalPositionLookupTable()->getNextTerminalId();
313        }
314    }
315    const int isTerminal = terminalId != Ver4DictConstants::NOT_A_TERMINAL_ID;
316    if (isTerminal) {
317        // Update the lookup table.
318        if (!mBuffers->getMutableTerminalPositionLookupTable()->setTerminalPtNodePosition(
319                terminalId, nodePos)) {
320            return false;
321        }
322        // Write terminal Id.
323        if (!mTrieBuffer->writeUintAndAdvancePosition(terminalId,
324                Ver4DictConstants::TERMINAL_ID_FIELD_SIZE, ptNodeWritingPos)) {
325            return false;
326        }
327        if (outTerminalId) {
328            *outTerminalId = terminalId;
329        }
330    }
331    // Write children position
332    if (!DynamicPtWritingUtils::writeChildrenPositionAndAdvancePosition(mTrieBuffer,
333            ptNodeParams->getChildrenPos(), ptNodeWritingPos)) {
334        return false;
335    }
336    return updatePtNodeFlags(nodePos, isTerminal,
337            ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
338}
339
340bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, const bool isTerminal,
341        const bool hasMultipleChars) {
342    // Create node flags and write them.
343    PatriciaTrieReadingUtils::NodeFlags nodeFlags =
344            PatriciaTrieReadingUtils::createAndGetFlags(false /* isNotAWord */,
345                    false /* isPossiblyOffensive */, isTerminal, false /* hasShortcutTargets */,
346                    false /* hasBigrams */, hasMultipleChars, CHILDREN_POSITION_FIELD_SIZE);
347    if (!DynamicPtWritingUtils::writeFlags(mTrieBuffer, nodeFlags, ptNodePos)) {
348        AKLOGE("Cannot write PtNode flags. flags: %x, pos: %d", nodeFlags, ptNodePos);
349        return false;
350    }
351    return true;
352}
353
354} // namespace latinime
355