ver4_patricia_trie_policy.cpp revision 4ce480d5ce2d47f607448ce439aaf2cefba1bdd8
1/*
2 * Copyright (C) 2013, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h"
18
19#include <vector>
20
21#include "suggest/core/dicnode/dic_node.h"
22#include "suggest/core/dicnode/dic_node_vector.h"
23#include "suggest/core/dictionary/word_property.h"
24#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h"
25#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h"
26#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
27#include "suggest/policyimpl/dictionary/utils/probability_utils.h"
28
29namespace latinime {
30
31// Note that there are corresponding definitions in Java side in BinaryDictionaryTests and
32// BinaryDictionaryDecayingTests.
33const char *const Ver4PatriciaTriePolicy::UNIGRAM_COUNT_QUERY = "UNIGRAM_COUNT";
34const char *const Ver4PatriciaTriePolicy::BIGRAM_COUNT_QUERY = "BIGRAM_COUNT";
35const char *const Ver4PatriciaTriePolicy::MAX_UNIGRAM_COUNT_QUERY = "MAX_UNIGRAM_COUNT";
36const char *const Ver4PatriciaTriePolicy::MAX_BIGRAM_COUNT_QUERY = "MAX_BIGRAM_COUNT";
37const int Ver4PatriciaTriePolicy::MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS = 1024;
38const int Ver4PatriciaTriePolicy::MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS =
39        Ver4DictConstants::MAX_DICTIONARY_SIZE - MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS;
40
41void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNode,
42        DicNodeVector *const childDicNodes) const {
43    if (!dicNode->hasChildren()) {
44        return;
45    }
46    DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
47    readingHelper.initWithPtNodeArrayPos(dicNode->getChildrenPtNodeArrayPos());
48    while (!readingHelper.isEnd()) {
49        const PtNodeParams ptNodeParams = readingHelper.getPtNodeParams();
50        if (!ptNodeParams.isValid()) {
51            break;
52        }
53        bool isTerminal = ptNodeParams.isTerminal() && !ptNodeParams.isDeleted();
54        if (isTerminal && mHeaderPolicy->isDecayingDict()) {
55            // A DecayingDict may have a terminal PtNode that has a terminal DicNode whose
56            // probability is NOT_A_PROBABILITY. In such case, we don't want to treat it as a
57            // valid terminal DicNode.
58            isTerminal = ptNodeParams.getProbability() != NOT_A_PROBABILITY;
59        }
60        childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getHeadPos(),
61                ptNodeParams.getChildrenPos(), ptNodeParams.getProbability(), isTerminal,
62                ptNodeParams.hasChildren(),
63                ptNodeParams.isBlacklisted()
64                        || ptNodeParams.isNotAWord() /* isBlacklistedOrNotAWord */,
65                ptNodeParams.getCodePointCount(), ptNodeParams.getCodePoints());
66        readingHelper.readNextSiblingNode(ptNodeParams);
67    }
68    if (readingHelper.isError()) {
69        mIsCorrupted = true;
70        AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes().");
71    }
72}
73
74int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
75        const int ptNodePos, const int maxCodePointCount, int *const outCodePoints,
76        int *const outUnigramProbability) const {
77    DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
78    readingHelper.initWithPtNodePos(ptNodePos);
79    const int codePointCount =  readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount(
80            maxCodePointCount, outCodePoints, outUnigramProbability);
81    if (readingHelper.isError()) {
82        mIsCorrupted = true;
83        AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount().");
84    }
85    return codePointCount;
86}
87
88int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const int *const inWord,
89        const int length, const bool forceLowerCaseSearch) const {
90    DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
91    readingHelper.initWithPtNodeArrayPos(getRootPosition());
92    const int ptNodePos =
93            readingHelper.getTerminalPtNodePositionOfWord(inWord, length, forceLowerCaseSearch);
94    if (readingHelper.isError()) {
95        mIsCorrupted = true;
96        AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes().");
97    }
98    return ptNodePos;
99}
100
101int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
102        const int bigramProbability) const {
103    if (mHeaderPolicy->isDecayingDict()) {
104        // Both probabilities are encoded. Decode them and get probability.
105        return ForgettingCurveUtils::getProbability(unigramProbability, bigramProbability);
106    } else {
107        if (unigramProbability == NOT_A_PROBABILITY) {
108            return NOT_A_PROBABILITY;
109        } else if (bigramProbability == NOT_A_PROBABILITY) {
110            return ProbabilityUtils::backoff(unigramProbability);
111        } else {
112            // bigramProbability is a bigram probability delta.
113            return ProbabilityUtils::computeProbabilityForBigram(unigramProbability,
114                    bigramProbability);
115        }
116    }
117}
118
119int Ver4PatriciaTriePolicy::getUnigramProbabilityOfPtNode(const int ptNodePos) const {
120    if (ptNodePos == NOT_A_DICT_POS) {
121        return NOT_A_PROBABILITY;
122    }
123    const PtNodeParams ptNodeParams(mNodeReader.fetchNodeInfoInBufferFromPtNodePos(ptNodePos));
124    if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
125        return NOT_A_PROBABILITY;
126    }
127    return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
128}
129
130int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const {
131    if (ptNodePos == NOT_A_DICT_POS) {
132        return NOT_A_DICT_POS;
133    }
134    const PtNodeParams ptNodeParams(mNodeReader.fetchNodeInfoInBufferFromPtNodePos(ptNodePos));
135    if (ptNodeParams.isDeleted()) {
136        return NOT_A_DICT_POS;
137    }
138    return mBuffers->getShortcutDictContent()->getShortcutListHeadPos(
139            ptNodeParams.getTerminalId());
140}
141
142int Ver4PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) const {
143    if (ptNodePos == NOT_A_DICT_POS) {
144        return NOT_A_DICT_POS;
145    }
146    const PtNodeParams ptNodeParams(mNodeReader.fetchNodeInfoInBufferFromPtNodePos(ptNodePos));
147    if (ptNodeParams.isDeleted()) {
148        return NOT_A_DICT_POS;
149    }
150    return mBuffers->getBigramDictContent()->getBigramListHeadPos(
151            ptNodeParams.getTerminalId());
152}
153
154bool Ver4PatriciaTriePolicy::addUnigramWord(const int *const word, const int length,
155        const int probability, const int *const shortcutTargetCodePoints, const int shortcutLength,
156        const int shortcutProbability, const bool isNotAWord, const bool isBlacklisted,
157        const int timestamp) {
158    if (!mBuffers->isUpdatable()) {
159        AKLOGI("Warning: addUnigramWord() is called for non-updatable dictionary.");
160        return false;
161    }
162    if (mDictBuffer->getTailPosition() >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS) {
163        AKLOGE("The dictionary is too large to dynamically update. Dictionary size: %d",
164                mDictBuffer->getTailPosition());
165        return false;
166    }
167    if (length > MAX_WORD_LENGTH) {
168        AKLOGE("The word is too long to insert to the dictionary, length: %d", length);
169        return false;
170    }
171    if (shortcutLength > MAX_WORD_LENGTH) {
172        AKLOGE("The shortcutTarget is too long to insert to the dictionary, length: %d",
173                shortcutLength);
174        return false;
175    }
176    DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
177    readingHelper.initWithPtNodeArrayPos(getRootPosition());
178    bool addedNewUnigram = false;
179    if (mUpdatingHelper.addUnigramWord(&readingHelper, word, length, probability, isNotAWord,
180            isBlacklisted, timestamp,  &addedNewUnigram)) {
181        if (addedNewUnigram) {
182            mUnigramCount++;
183        }
184        if (shortcutLength > 0) {
185            // Add shortcut target.
186            const int wordPos = getTerminalPtNodePositionOfWord(word, length,
187                    false /* forceLowerCaseSearch */);
188            if (wordPos == NOT_A_DICT_POS) {
189                AKLOGE("Cannot find terminal PtNode position to add shortcut target.");
190                return false;
191            }
192            if (!mUpdatingHelper.addShortcutTarget(wordPos, shortcutTargetCodePoints,
193                    shortcutLength, shortcutProbability)) {
194                AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %d, probability: %d",
195                        wordPos, shortcutLength, shortcutProbability);
196                return false;
197            }
198        }
199        return true;
200    } else {
201        return false;
202    }
203}
204
205bool Ver4PatriciaTriePolicy::addBigramWords(const int *const word0, const int length0,
206        const int *const word1, const int length1, const int probability,
207        const int timestamp) {
208    if (!mBuffers->isUpdatable()) {
209        AKLOGI("Warning: addBigramWords() is called for non-updatable dictionary.");
210        return false;
211    }
212    if (mDictBuffer->getTailPosition() >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS) {
213        AKLOGE("The dictionary is too large to dynamically update. Dictionary size: %d",
214                mDictBuffer->getTailPosition());
215        return false;
216    }
217    if (length0 > MAX_WORD_LENGTH || length1 > MAX_WORD_LENGTH) {
218        AKLOGE("Either src word or target word is too long to insert the bigram to the dictionary. "
219                "length0: %d, length1: %d", length0, length1);
220        return false;
221    }
222    const int word0Pos = getTerminalPtNodePositionOfWord(word0, length0,
223            false /* forceLowerCaseSearch */);
224    if (word0Pos == NOT_A_DICT_POS) {
225        return false;
226    }
227    const int word1Pos = getTerminalPtNodePositionOfWord(word1, length1,
228            false /* forceLowerCaseSearch */);
229    if (word1Pos == NOT_A_DICT_POS) {
230        return false;
231    }
232    bool addedNewBigram = false;
233    if (mUpdatingHelper.addBigramWords(word0Pos, word1Pos, probability, timestamp,
234            &addedNewBigram)) {
235        if (addedNewBigram) {
236            mBigramCount++;
237        }
238        return true;
239    } else {
240        return false;
241    }
242}
243
244bool Ver4PatriciaTriePolicy::removeBigramWords(const int *const word0, const int length0,
245        const int *const word1, const int length1) {
246    if (!mBuffers->isUpdatable()) {
247        AKLOGI("Warning: addBigramWords() is called for non-updatable dictionary.");
248        return false;
249    }
250    if (mDictBuffer->getTailPosition() >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS) {
251        AKLOGE("The dictionary is too large to dynamically update. Dictionary size: %d",
252                mDictBuffer->getTailPosition());
253        return false;
254    }
255    if (length0 > MAX_WORD_LENGTH || length1 > MAX_WORD_LENGTH) {
256        AKLOGE("Either src word or target word is too long to remove the bigram to from the "
257                "dictionary. length0: %d, length1: %d", length0, length1);
258        return false;
259    }
260    const int word0Pos = getTerminalPtNodePositionOfWord(word0, length0,
261            false /* forceLowerCaseSearch */);
262    if (word0Pos == NOT_A_DICT_POS) {
263        return false;
264    }
265    const int word1Pos = getTerminalPtNodePositionOfWord(word1, length1,
266            false /* forceLowerCaseSearch */);
267    if (word1Pos == NOT_A_DICT_POS) {
268        return false;
269    }
270    if (mUpdatingHelper.removeBigramWords(word0Pos, word1Pos)) {
271        mBigramCount--;
272        return true;
273    } else {
274        return false;
275    }
276}
277
278void Ver4PatriciaTriePolicy::flush(const char *const filePath) {
279    if (!mBuffers->isUpdatable()) {
280        AKLOGI("Warning: flush() is called for non-updatable dictionary. filePath: %s", filePath);
281        return;
282    }
283    if (!mWritingHelper.writeToDictFile(filePath, mUnigramCount, mBigramCount)) {
284        AKLOGE("Cannot flush the dictionary to file.");
285        mIsCorrupted = true;
286    }
287}
288
289void Ver4PatriciaTriePolicy::flushWithGC(const char *const filePath) {
290    if (!mBuffers->isUpdatable()) {
291        AKLOGI("Warning: flushWithGC() is called for non-updatable dictionary.");
292        return;
293    }
294    if (!mWritingHelper.writeToDictFileWithGC(getRootPosition(), filePath)) {
295        AKLOGE("Cannot flush the dictionary to file with GC.");
296        mIsCorrupted = true;
297    }
298}
299
300bool Ver4PatriciaTriePolicy::needsToRunGC(const bool mindsBlockByGC) const {
301    if (!mBuffers->isUpdatable()) {
302        AKLOGI("Warning: needsToRunGC() is called for non-updatable dictionary.");
303        return false;
304    }
305    if (mBuffers->isNearSizeLimit()) {
306        // Additional buffer size is near the limit.
307        return true;
308    } else if (mHeaderPolicy->getExtendedRegionSize() + mDictBuffer->getUsedAdditionalBufferSize()
309            > Ver4DictConstants::MAX_DICT_EXTENDED_REGION_SIZE) {
310        // Total extended region size of the trie exceeds the limit.
311        return true;
312    } else if (mDictBuffer->getTailPosition() >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS
313            && mDictBuffer->getUsedAdditionalBufferSize() > 0) {
314        // Needs to reduce dictionary size.
315        return true;
316    } else if (mHeaderPolicy->isDecayingDict()) {
317        return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mUnigramCount, mBigramCount,
318                mHeaderPolicy);
319    }
320    return false;
321}
322
323void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int queryLength,
324        char *const outResult, const int maxResultLength) {
325    const int compareLength = queryLength + 1 /* terminator */;
326    if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) {
327        snprintf(outResult, maxResultLength, "%d", mUnigramCount);
328    } else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) {
329        snprintf(outResult, maxResultLength, "%d", mBigramCount);
330    } else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) {
331        snprintf(outResult, maxResultLength, "%d",
332                mHeaderPolicy->isDecayingDict() ?
333                        ForgettingCurveUtils::getUnigramCountHardLimit(
334                                mHeaderPolicy->getMaxUnigramCount()) :
335                        static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
336    } else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) {
337        snprintf(outResult, maxResultLength, "%d",
338                mHeaderPolicy->isDecayingDict() ?
339                        ForgettingCurveUtils::getBigramCountHardLimit(
340                                mHeaderPolicy->getMaxBigramCount()) :
341                        static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
342    }
343}
344
345const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const codePoints,
346        const int codePointCount) const {
347    const int ptNodePos = getTerminalPtNodePositionOfWord(codePoints, codePointCount,
348            false /* forceLowerCaseSearch */);
349    if (ptNodePos == NOT_A_DICT_POS) {
350        AKLOGE("getWordProperty is called for invalid word.");
351        return WordProperty();
352    }
353    const PtNodeParams ptNodeParams = mNodeReader.fetchNodeInfoInBufferFromPtNodePos(ptNodePos);
354    std::vector<int> codePointVector(ptNodeParams.getCodePoints(),
355            ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount());
356    const ProbabilityEntry probabilityEntry =
357            mBuffers->getProbabilityDictContent()->getProbabilityEntry(
358                    ptNodeParams.getTerminalId());
359    const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
360    // Fetch bigram information.
361    std::vector<WordProperty::BigramProperty> bigrams;
362    const int bigramListPos = getBigramsPositionOfPtNode(ptNodePos);
363    if (bigramListPos != NOT_A_DICT_POS) {
364        int bigramWord1CodePoints[MAX_WORD_LENGTH];
365        const BigramDictContent *const bigramDictContent = mBuffers->getBigramDictContent();
366        const TerminalPositionLookupTable *const terminalPositionLookupTable =
367                mBuffers->getTerminalPositionLookupTable();
368        bool hasNext = true;
369        int readingPos = bigramListPos;
370        while (hasNext) {
371            const BigramEntry bigramEntry =
372                    bigramDictContent->getBigramEntryAndAdvancePosition(&readingPos);
373            hasNext = bigramEntry.hasNext();
374            const int word1TerminalId = bigramEntry.getTargetTerminalId();
375            const int word1TerminalPtNodePos =
376                    terminalPositionLookupTable->getTerminalPtNodePosition(word1TerminalId);
377            if (word1TerminalPtNodePos == NOT_A_DICT_POS) {
378                continue;
379            }
380            // Word (unigram) probability
381            int word1Probability = NOT_A_PROBABILITY;
382            const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
383                    word1TerminalPtNodePos, MAX_WORD_LENGTH, bigramWord1CodePoints,
384                    &word1Probability);
385            std::vector<int> word1(bigramWord1CodePoints,
386                    bigramWord1CodePoints + codePointCount);
387            const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo();
388            const int probability = bigramEntry.hasHistoricalInfo() ?
389                    ForgettingCurveUtils::decodeProbability(
390                            bigramEntry.getHistoricalInfo(), mHeaderPolicy) :
391                    bigramEntry.getProbability();
392            bigrams.push_back(WordProperty::BigramProperty(&word1, probability,
393                    historicalInfo->getTimeStamp(), historicalInfo->getLevel(),
394                    historicalInfo->getCount()));
395        }
396    }
397    // Fetch shortcut information.
398    std::vector<WordProperty::ShortcutProperty> shortcuts;
399    int shortcutPos = getShortcutPositionOfPtNode(ptNodePos);
400    if (shortcutPos != NOT_A_DICT_POS) {
401        int shortcutTarget[MAX_WORD_LENGTH];
402        const ShortcutDictContent *const shortcutDictContent =
403                mBuffers->getShortcutDictContent();
404        bool hasNext = true;
405        while (hasNext) {
406            int shortcutTargetLength = 0;
407            int shortcutProbability = NOT_A_PROBABILITY;
408            shortcutDictContent->getShortcutEntryAndAdvancePosition(MAX_WORD_LENGTH, shortcutTarget,
409                    &shortcutTargetLength, &shortcutProbability, &hasNext, &shortcutPos);
410            std::vector<int> target(shortcutTarget, shortcutTarget + shortcutTargetLength);
411            shortcuts.push_back(WordProperty::ShortcutProperty(&target, shortcutProbability));
412        }
413    }
414    return WordProperty(&codePointVector, ptNodeParams.isNotAWord(),
415            ptNodeParams.isBlacklisted(), ptNodeParams.hasBigrams(),
416            ptNodeParams.hasShortcutTargets(), ptNodeParams.getProbability(),
417            historicalInfo->getTimeStamp(), historicalInfo->getLevel(),
418            historicalInfo->getCount(), &bigrams, &shortcuts);
419}
420
421int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints) {
422    if (token == 0) {
423        mTerminalPtNodePositionsForIteratingWords.clear();
424        DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy(
425                &mTerminalPtNodePositionsForIteratingWords);
426        DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
427        readingHelper.initWithPtNodeArrayPos(getRootPosition());
428        readingHelper.traverseAllPtNodesInPostorderDepthFirstManner(&traversePolicy);
429    }
430    const int terminalPtNodePositionsVectorSize =
431            static_cast<int>(mTerminalPtNodePositionsForIteratingWords.size());
432    if (token < 0 || token >= terminalPtNodePositionsVectorSize) {
433        AKLOGE("Given token %d is invalid.", token);
434        return 0;
435    }
436    const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
437    int unigramProbability = NOT_A_PROBABILITY;
438    getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos, MAX_WORD_LENGTH,
439            outCodePoints, &unigramProbability);
440    const int nextToken = token + 1;
441    if (nextToken >= terminalPtNodePositionsVectorSize) {
442        // All words have been iterated.
443        mTerminalPtNodePositionsForIteratingWords.clear();
444        return 0;
445    }
446    return nextToken;
447}
448
449} // namespace latinime
450