ver4_patricia_trie_writing_helper.cpp revision 198be3a6c5c53e63de5ed3a6a1ce618ca36ff98c
1/*
2 * Copyright (C) 2013, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h"
18
19#include <cstring>
20#include <queue>
21
22#include "suggest/policyimpl/dictionary/header/header_policy.h"
23#include "suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h"
24#include "suggest/policyimpl/dictionary/structure/v4/shortcut/ver4_shortcut_list_policy.h"
25#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h"
26#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
27#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h"
28#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h"
29#include "suggest/policyimpl/dictionary/structure/v4/ver4_pt_node_array_reader.h"
30#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h"
31#include "suggest/policyimpl/dictionary/utils/file_utils.h"
32#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
33
34namespace latinime {
35
36bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPath,
37        const int unigramCount, const int bigramCount) const {
38    const HeaderPolicy *const headerPolicy = mBuffers->getHeaderPolicy();
39    BufferWithExtendableBuffer headerBuffer(
40            BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
41    const int extendedRegionSize = headerPolicy->getExtendedRegionSize()
42            + mBuffers->getTrieBuffer()->getUsedAdditionalBufferSize();
43    if (!headerPolicy->fillInAndWriteHeaderToBuffer(false /* updatesLastDecayedTime */,
44            unigramCount, bigramCount, extendedRegionSize, &headerBuffer)) {
45        AKLOGE("Cannot write header structure to buffer. "
46                "updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, "
47                "extendedRegionSize: %d", false, unigramCount, bigramCount,
48                extendedRegionSize);
49        return false;
50    }
51    return mBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
52}
53
54bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeArrayPos,
55        const char *const dictDirPath) {
56    const HeaderPolicy *const headerPolicy = mBuffers->getHeaderPolicy();
57    Ver4DictBuffers::Ver4DictBuffersPtr dictBuffers(
58            Ver4DictBuffers::createVer4DictBuffers(headerPolicy,
59                    Ver4DictConstants::MAX_DICTIONARY_SIZE));
60    int unigramCount = 0;
61    int bigramCount = 0;
62    if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &unigramCount, &bigramCount)) {
63        return false;
64    }
65    BufferWithExtendableBuffer headerBuffer(
66            BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
67    if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
68            unigramCount, bigramCount, 0 /* extendedRegionSize */, &headerBuffer)) {
69        return false;
70    }
71    return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
72}
73
74bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
75        const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite,
76        int *const outUnigramCount, int *const outBigramCount) {
77    Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer(),
78            mBuffers->getProbabilityDictContent(), headerPolicy);
79    Ver4PtNodeArrayReader ptNodeArrayReader(mBuffers->getTrieBuffer());
80    Ver4BigramListPolicy bigramPolicy(mBuffers->getMutableBigramDictContent(),
81            mBuffers->getTerminalPositionLookupTable(), headerPolicy);
82    Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(),
83            mBuffers->getTerminalPositionLookupTable());
84    Ver4PatriciaTrieNodeWriter ptNodeWriter(mBuffers->getWritableTrieBuffer(),
85            mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy,
86            &shortcutPolicy);
87
88    DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
89    readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
90    DynamicPtGcEventListeners
91            ::TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted
92                    traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted(
93                            &ptNodeWriter);
94    if (!readingHelper.traverseAllPtNodesInPostorderDepthFirstManner(
95            &traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted)) {
96        return false;
97    }
98    const int unigramCount = traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted
99            .getValidUnigramCount();
100    const int maxUnigramCount = headerPolicy->getMaxUnigramCount();
101    if (headerPolicy->isDecayingDict() && unigramCount > maxUnigramCount) {
102        if (!truncateUnigrams(&ptNodeReader, &ptNodeWriter, maxUnigramCount)) {
103            AKLOGE("Cannot remove unigrams. current: %d, max: %d", unigramCount,
104                    maxUnigramCount);
105            return false;
106        }
107    }
108
109    readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
110    DynamicPtGcEventListeners::TraversePolicyToUpdateBigramProbability
111            traversePolicyToUpdateBigramProbability(&ptNodeWriter);
112    if (!readingHelper.traverseAllPtNodesInPostorderDepthFirstManner(
113            &traversePolicyToUpdateBigramProbability)) {
114        return false;
115    }
116    const int bigramCount = traversePolicyToUpdateBigramProbability.getValidBigramEntryCount();
117    const int maxBigramCount = headerPolicy->getMaxBigramCount();
118    if (headerPolicy->isDecayingDict() && bigramCount > maxBigramCount) {
119        if (!truncateBigrams(maxBigramCount)) {
120            AKLOGE("Cannot remove bigrams. current: %d, max: %d", bigramCount, maxBigramCount);
121            return false;
122        }
123    }
124
125    // Mapping from positions in mBuffer to positions in bufferToWrite.
126    PtNodeWriter::DictPositionRelocationMap dictPositionRelocationMap;
127    readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
128    Ver4PatriciaTrieNodeWriter ptNodeWriterForNewBuffers(buffersToWrite->getWritableTrieBuffer(),
129            buffersToWrite, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy,
130            &shortcutPolicy);
131    DynamicPtGcEventListeners::TraversePolicyToPlaceAndWriteValidPtNodesToBuffer
132            traversePolicyToPlaceAndWriteValidPtNodesToBuffer(&ptNodeWriterForNewBuffers,
133                    buffersToWrite->getWritableTrieBuffer(), &dictPositionRelocationMap);
134    if (!readingHelper.traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner(
135            &traversePolicyToPlaceAndWriteValidPtNodesToBuffer)) {
136        return false;
137    }
138
139    // Create policy instances for the GCed dictionary.
140    Ver4PatriciaTrieNodeReader newPtNodeReader(buffersToWrite->getTrieBuffer(),
141            buffersToWrite->getProbabilityDictContent(), headerPolicy);
142    Ver4PtNodeArrayReader newPtNodeArrayreader(buffersToWrite->getTrieBuffer());
143    Ver4BigramListPolicy newBigramPolicy(buffersToWrite->getMutableBigramDictContent(),
144            buffersToWrite->getTerminalPositionLookupTable(), headerPolicy);
145    Ver4ShortcutListPolicy newShortcutPolicy(buffersToWrite->getMutableShortcutDictContent(),
146            buffersToWrite->getTerminalPositionLookupTable());
147    Ver4PatriciaTrieNodeWriter newPtNodeWriter(buffersToWrite->getWritableTrieBuffer(),
148            buffersToWrite, headerPolicy, &newPtNodeReader, &newPtNodeArrayreader, &newBigramPolicy,
149            &newShortcutPolicy);
150    // Re-assign terminal IDs for valid terminal PtNodes.
151    TerminalPositionLookupTable::TerminalIdMap terminalIdMap;
152    if(!buffersToWrite->getMutableTerminalPositionLookupTable()->runGCTerminalIds(
153            &terminalIdMap)) {
154        return false;
155    }
156    // Run GC for probability dict content.
157    if (!buffersToWrite->getMutableProbabilityDictContent()->runGC(&terminalIdMap,
158            mBuffers->getProbabilityDictContent())) {
159        return false;
160    }
161    // Run GC for bigram dict content.
162    if(!buffersToWrite->getMutableBigramDictContent()->runGC(&terminalIdMap,
163            mBuffers->getBigramDictContent(), outBigramCount)) {
164        return false;
165    }
166    // Run GC for shortcut dict content.
167    if(!buffersToWrite->getMutableShortcutDictContent()->runGC(&terminalIdMap,
168            mBuffers->getShortcutDictContent())) {
169        return false;
170    }
171    DynamicPtReadingHelper newDictReadingHelper(&newPtNodeReader, &newPtNodeArrayreader);
172    newDictReadingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
173    DynamicPtGcEventListeners::TraversePolicyToUpdateAllPositionFields
174            traversePolicyToUpdateAllPositionFields(&newPtNodeWriter, &dictPositionRelocationMap);
175    if (!newDictReadingHelper.traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner(
176            &traversePolicyToUpdateAllPositionFields)) {
177        return false;
178    }
179    newDictReadingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
180    TraversePolicyToUpdateAllPtNodeFlagsAndTerminalIds
181            traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds(&newPtNodeWriter, &terminalIdMap);
182    if (!newDictReadingHelper.traverseAllPtNodesInPostorderDepthFirstManner(
183            &traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds)) {
184        return false;
185    }
186    *outUnigramCount = traversePolicyToUpdateAllPositionFields.getUnigramCount();
187    return true;
188}
189
190bool Ver4PatriciaTrieWritingHelper::truncateUnigrams(
191        const Ver4PatriciaTrieNodeReader *const ptNodeReader,
192        Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) {
193    const TerminalPositionLookupTable *const terminalPosLookupTable =
194            mBuffers->getTerminalPositionLookupTable();
195    const int nextTerminalId = terminalPosLookupTable->getNextTerminalId();
196    std::priority_queue<DictProbability, std::vector<DictProbability>, DictProbabilityComparator>
197            priorityQueue;
198    for (int i = 0; i < nextTerminalId; ++i) {
199        const int terminalPos = terminalPosLookupTable->getTerminalPtNodePosition(i);
200        if (terminalPos == NOT_A_DICT_POS) {
201            continue;
202        }
203        const ProbabilityEntry probabilityEntry =
204                mBuffers->getProbabilityDictContent()->getProbabilityEntry(i);
205        const int probability = probabilityEntry.hasHistoricalInfo() ?
206                ForgettingCurveUtils::decodeProbability(
207                        probabilityEntry.getHistoricalInfo(), mBuffers->getHeaderPolicy()) :
208                probabilityEntry.getProbability();
209        priorityQueue.push(DictProbability(terminalPos, probability,
210                probabilityEntry.getHistoricalInfo()->getTimeStamp()));
211    }
212
213    // Delete unigrams.
214    while (static_cast<int>(priorityQueue.size()) > maxUnigramCount) {
215        const int ptNodePos = priorityQueue.top().getDictPos();
216        priorityQueue.pop();
217        const PtNodeParams ptNodeParams =
218                ptNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
219        if (ptNodeParams.representsNonWordInfo()) {
220            continue;
221        }
222        if (!ptNodeWriter->markPtNodeAsWillBecomeNonTerminal(&ptNodeParams)) {
223            AKLOGE("Cannot mark PtNode as willBecomeNonterminal. PtNode pos: %d", ptNodePos);
224            return false;
225        }
226    }
227    return true;
228}
229
230bool Ver4PatriciaTrieWritingHelper::truncateBigrams(const int maxBigramCount) {
231    const TerminalPositionLookupTable *const terminalPosLookupTable =
232            mBuffers->getTerminalPositionLookupTable();
233    const int nextTerminalId = terminalPosLookupTable->getNextTerminalId();
234    std::priority_queue<DictProbability, std::vector<DictProbability>, DictProbabilityComparator>
235            priorityQueue;
236    BigramDictContent *const bigramDictContent = mBuffers->getMutableBigramDictContent();
237    for (int i = 0; i < nextTerminalId; ++i) {
238        const int bigramListPos = bigramDictContent->getBigramListHeadPos(i);
239        if (bigramListPos == NOT_A_DICT_POS) {
240            continue;
241        }
242        bool hasNext = true;
243        int readingPos = bigramListPos;
244        while (hasNext) {
245            const BigramEntry bigramEntry =
246                    bigramDictContent->getBigramEntryAndAdvancePosition(&readingPos);
247            const int entryPos = readingPos - bigramDictContent->getBigramEntrySize();
248            hasNext = bigramEntry.hasNext();
249            if (!bigramEntry.isValid()) {
250                continue;
251            }
252            const int probability = bigramEntry.hasHistoricalInfo() ?
253                    ForgettingCurveUtils::decodeProbability(
254                            bigramEntry.getHistoricalInfo(), mBuffers->getHeaderPolicy()) :
255                    bigramEntry.getProbability();
256            priorityQueue.push(DictProbability(entryPos, probability,
257                    bigramEntry.getHistoricalInfo()->getTimeStamp()));
258        }
259    }
260
261    // Delete bigrams.
262    while (static_cast<int>(priorityQueue.size()) > maxBigramCount) {
263        const int entryPos = priorityQueue.top().getDictPos();
264        const BigramEntry bigramEntry = bigramDictContent->getBigramEntry(entryPos);
265        const BigramEntry invalidatedBigramEntry = bigramEntry.getInvalidatedEntry();
266        if (!bigramDictContent->writeBigramEntry(&invalidatedBigramEntry, entryPos)) {
267            AKLOGE("Cannot write bigram entry to remove. pos: %d", entryPos);
268            return false;
269        }
270        priorityQueue.pop();
271    }
272    return true;
273}
274
275bool Ver4PatriciaTrieWritingHelper::TraversePolicyToUpdateAllPtNodeFlagsAndTerminalIds
276        ::onVisitingPtNode(const PtNodeParams *const ptNodeParams) {
277    if (!ptNodeParams->isTerminal()) {
278        return true;
279    }
280    TerminalPositionLookupTable::TerminalIdMap::const_iterator it =
281            mTerminalIdMap->find(ptNodeParams->getTerminalId());
282    if (it == mTerminalIdMap->end()) {
283        AKLOGE("terminal Id %d is not in the terminal position map. map size: %zd",
284                ptNodeParams->getTerminalId(), mTerminalIdMap->size());
285        return false;
286    }
287    if (!mPtNodeWriter->updateTerminalId(ptNodeParams, it->second)) {
288        AKLOGE("Cannot update terminal id. %d -> %d", it->first, it->second);
289        return false;
290    }
291    return true;
292}
293
294} // namespace latinime
295