patricia_trie_policy.cpp revision 1229879e7c5892e818ab53b3c2162a158cc5e177
1/*
2 * Copyright (C) 2013, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18#include "suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h"
19
20#include "defines.h"
21#include "suggest/core/dicnode/dic_node.h"
22#include "suggest/core/dicnode/dic_node_vector.h"
23#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h"
24#include "suggest/core/session/prev_words_info.h"
25#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h"
26#include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h"
27#include "suggest/policyimpl/dictionary/utils/probability_utils.h"
28#include "utils/char_utils.h"
29
30namespace latinime {
31
32void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNode,
33        DicNodeVector *const childDicNodes) const {
34    if (!dicNode->hasChildren()) {
35        return;
36    }
37    int nextPos = dicNode->getChildrenPtNodeArrayPos();
38    if (nextPos < 0 || nextPos >= mDictBufferSize) {
39        AKLOGE("Children PtNode array position is invalid. pos: %d, dict size: %d",
40                nextPos, mDictBufferSize);
41        mIsCorrupted = true;
42        ASSERT(false);
43        return;
44    }
45    const int childCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition(
46            mDictRoot, &nextPos);
47    for (int i = 0; i < childCount; i++) {
48        if (nextPos < 0 || nextPos >= mDictBufferSize) {
49            AKLOGE("Child PtNode position is invalid. pos: %d, dict size: %d, childCount: %d / %d",
50                    nextPos, mDictBufferSize, i, childCount);
51            mIsCorrupted = true;
52            ASSERT(false);
53            return;
54        }
55        nextPos = createAndGetLeavingChildNode(dicNode, nextPos, childDicNodes);
56    }
57}
58
59// This retrieves code points and the probability of the word by its terminal position.
60// Due to the fact that words are ordered in the dictionary in a strict breadth-first order,
61// it is possible to check for this with advantageous complexity. For each PtNode array, we search
62// for PtNodes with children and compare the children position with the position we look for.
63// When we shoot the position we look for, it means the word we look for is in the children
64// of the previous PtNode. The only tricky part is the fact that if we arrive at the end of a
65// PtNode array with the last PtNode's children position still less than what we are searching for,
66// we must descend the last PtNode's children (for example, if the word we are searching for starts
67// with a z, it's the last PtNode of the root array, so all children addresses will be smaller
68// than the position we look for, and we have to descend the z PtNode).
69/* Parameters :
70 * ptNodePos: the byte position of the terminal PtNode of the word we are searching for (this is
71 *   what is stored as the "bigram position" in each bigram)
72 * outCodePoints: an array to write the found word, with MAX_WORD_LENGTH size.
73 * outUnigramProbability: a pointer to an int to write the probability into.
74 * Return value : the code point count, of 0 if the word was not found.
75 */
76// TODO: Split this function to be more readable
77int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
78        const int ptNodePos, const int maxCodePointCount, int *const outCodePoints,
79        int *const outUnigramProbability) const {
80    int pos = getRootPosition();
81    int wordPos = 0;
82    // One iteration of the outer loop iterates through PtNode arrays. As stated above, we will
83    // only traverse PtNodes that are actually a part of the terminal we are searching, so each
84    // time we enter this loop we are one depth level further than last time.
85    // The only reason we count PtNodes is because we want to reduce the probability of infinite
86    // looping in case there is a bug. Since we know there is an upper bound to the depth we are
87    // supposed to traverse, it does not hurt to count iterations.
88    for (int loopCount = maxCodePointCount; loopCount > 0; --loopCount) {
89        int lastCandidatePtNodePos = 0;
90        // Let's loop through PtNodes in this PtNode array searching for either the terminal
91        // or one of its ascendants.
92        if (pos < 0 || pos >= mDictBufferSize) {
93            AKLOGE("PtNode array position is invalid. pos: %d, dict size: %d",
94                    pos, mDictBufferSize);
95            mIsCorrupted = true;
96            ASSERT(false);
97            *outUnigramProbability = NOT_A_PROBABILITY;
98            return 0;
99        }
100        for (int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition(
101                mDictRoot, &pos); ptNodeCount > 0; --ptNodeCount) {
102            const int startPos = pos;
103            if (pos < 0 || pos >= mDictBufferSize) {
104                AKLOGE("PtNode position is invalid. pos: %d, dict size: %d", pos, mDictBufferSize);
105                mIsCorrupted = true;
106                ASSERT(false);
107                *outUnigramProbability = NOT_A_PROBABILITY;
108                return 0;
109            }
110            const PatriciaTrieReadingUtils::NodeFlags flags =
111                    PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos);
112            const int character = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
113                    mDictRoot, &pos);
114            if (ptNodePos == startPos) {
115                // We found the position. Copy the rest of the code points in the buffer and return
116                // the length.
117                outCodePoints[wordPos] = character;
118                if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) {
119                    int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
120                            mDictRoot, &pos);
121                    // We count code points in order to avoid infinite loops if the file is broken
122                    // or if there is some other bug
123                    int charCount = maxCodePointCount;
124                    while (NOT_A_CODE_POINT != nextChar && --charCount > 0) {
125                        outCodePoints[++wordPos] = nextChar;
126                        nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
127                                mDictRoot, &pos);
128                    }
129                }
130                *outUnigramProbability =
131                        PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot,
132                                &pos);
133                return ++wordPos;
134            }
135            // We need to skip past this PtNode, so skip any remaining code points after the
136            // first and possibly the probability.
137            if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) {
138                PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, &pos);
139            }
140            if (PatriciaTrieReadingUtils::isTerminal(flags)) {
141                PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos);
142            }
143            // The fact that this PtNode has children is very important. Since we already know
144            // that this PtNode does not match, if it has no children we know it is irrelevant
145            // to what we are searching for.
146            const bool hasChildren = PatriciaTrieReadingUtils::hasChildrenInFlags(flags);
147            // We will write in `found' whether we have passed the children position we are
148            // searching for. For example if we search for "beer", the children of b are less
149            // than the address we are searching for and the children of c are greater. When we
150            // come here for c, we realize this is too big, and that we should descend b.
151            bool found;
152            if (hasChildren) {
153                int currentPos = pos;
154                // Here comes the tricky part. First, read the children position.
155                const int childrenPos = PatriciaTrieReadingUtils
156                        ::readChildrenPositionAndAdvancePosition(mDictRoot, flags, &currentPos);
157                if (childrenPos > ptNodePos) {
158                    // If the children pos is greater than the position, it means the previous
159                    // PtNode, which position is stored in lastCandidatePtNodePos, was the right
160                    // one.
161                    found = true;
162                } else if (1 >= ptNodeCount) {
163                    // However if we are on the LAST PtNode of this array, and we have NOT shot the
164                    // position we should descend THIS PtNode. So we trick the
165                    // lastCandidatePtNodePos so that we will descend this PtNode, not the previous
166                    // one.
167                    lastCandidatePtNodePos = startPos;
168                    found = true;
169                } else {
170                    // Else, we should continue looking.
171                    found = false;
172                }
173            } else {
174                // Even if we don't have children here, we could still be on the last PtNode of
175                // this array. If this is the case, we should descend the last PtNode that had
176                // children, and their position is already in lastCandidatePtNodePos.
177                found = (1 >= ptNodeCount);
178            }
179
180            if (found) {
181                // Okay, we found the PtNode we should descend. Its position is in
182                // the lastCandidatePtNodePos variable, so we just re-read it.
183                if (0 != lastCandidatePtNodePos) {
184                    const PatriciaTrieReadingUtils::NodeFlags lastFlags =
185                            PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(
186                                    mDictRoot, &lastCandidatePtNodePos);
187                    const int lastChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
188                            mDictRoot, &lastCandidatePtNodePos);
189                    // We copy all the characters in this PtNode to the buffer
190                    outCodePoints[wordPos] = lastChar;
191                    if (PatriciaTrieReadingUtils::hasMultipleChars(lastFlags)) {
192                        int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
193                                mDictRoot, &lastCandidatePtNodePos);
194                        int charCount = maxCodePointCount;
195                        while (-1 != nextChar && --charCount > 0) {
196                            outCodePoints[++wordPos] = nextChar;
197                            nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
198                                    mDictRoot, &lastCandidatePtNodePos);
199                        }
200                    }
201                    ++wordPos;
202                    // Now we only need to branch to the children address. Skip the probability if
203                    // it's there, read pos, and break to resume the search at pos.
204                    if (PatriciaTrieReadingUtils::isTerminal(lastFlags)) {
205                        PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot,
206                                &lastCandidatePtNodePos);
207                    }
208                    pos = PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(
209                            mDictRoot, lastFlags, &lastCandidatePtNodePos);
210                    break;
211                } else {
212                    // Here is a little tricky part: we come here if we found out that all children
213                    // addresses in this PtNode are bigger than the address we are searching for.
214                    // Should we conclude the word is not in the dictionary? No! It could still be
215                    // one of the remaining PtNodes in this array, so we have to keep looking in
216                    // this array until we find it (or we realize it's not there either, in which
217                    // case it's actually not in the dictionary). Pass the end of this PtNode,
218                    // ready to start the next one.
219                    if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) {
220                        PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(
221                                mDictRoot, flags, &pos);
222                    }
223                    if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) {
224                        mShortcutListPolicy.skipAllShortcuts(&pos);
225                    }
226                    if (PatriciaTrieReadingUtils::hasBigrams(flags)) {
227                        mBigramListPolicy.skipAllBigrams(&pos);
228                    }
229                }
230            } else {
231                // If we did not find it, we should record the last children address for the next
232                // iteration.
233                if (hasChildren) lastCandidatePtNodePos = startPos;
234                // Now skip the end of this PtNode (children pos and the attributes if any) so that
235                // our pos is after the end of this PtNode, at the start of the next one.
236                if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) {
237                    PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(
238                            mDictRoot, flags, &pos);
239                }
240                if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) {
241                    mShortcutListPolicy.skipAllShortcuts(&pos);
242                }
243                if (PatriciaTrieReadingUtils::hasBigrams(flags)) {
244                    mBigramListPolicy.skipAllBigrams(&pos);
245                }
246            }
247
248        }
249    }
250    // If we have looked through all the PtNodes and found no match, the ptNodePos is
251    // not the position of a terminal in this dictionary.
252    return 0;
253}
254
255// This function gets the position of the terminal PtNode of the exact matching word in the
256// dictionary. If no match is found, it returns NOT_A_DICT_POS.
257int PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const int *const inWord,
258        const int length, const bool forceLowerCaseSearch) const {
259    DynamicPtReadingHelper readingHelper(&mPtNodeReader, &mPtNodeArrayReader);
260    readingHelper.initWithPtNodeArrayPos(getRootPosition());
261    const int ptNodePos =
262            readingHelper.getTerminalPtNodePositionOfWord(inWord, length, forceLowerCaseSearch);
263    if (readingHelper.isError()) {
264        mIsCorrupted = true;
265        AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes().");
266    }
267    return ptNodePos;
268}
269
270int PatriciaTriePolicy::getProbability(const int unigramProbability,
271        const int bigramProbability) const {
272    // Due to space constraints, the probability for bigrams is approximate - the lower the unigram
273    // probability, the worse the precision. The theoritical maximum error in resulting probability
274    // is 8 - although in the practice it's never bigger than 3 or 4 in very bad cases. This means
275    // that sometimes, we'll see some bigrams interverted here, but it can't get too bad.
276    if (unigramProbability == NOT_A_PROBABILITY) {
277        return NOT_A_PROBABILITY;
278    } else if (bigramProbability == NOT_A_PROBABILITY) {
279        return ProbabilityUtils::backoff(unigramProbability);
280    } else {
281        return ProbabilityUtils::computeProbabilityForBigram(unigramProbability,
282                bigramProbability);
283    }
284}
285
286int PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo,
287        const int ptNodePos) const {
288    if (ptNodePos == NOT_A_DICT_POS) {
289        return NOT_A_PROBABILITY;
290    }
291    const PtNodeParams ptNodeParams =
292            mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
293    if (ptNodeParams.isNotAWord() || ptNodeParams.isBlacklisted()) {
294        // If this is not a word, or if it's a blacklisted entry, it should behave as
295        // having no probability outside of the suggestion process (where it should be used
296        // for shortcuts).
297        return NOT_A_PROBABILITY;
298    }
299    if (prevWordsInfo) {
300        BinaryDictionaryBigramsIterator bigramsIt =
301                prevWordsInfo->getBigramsIteratorForPrediction(this /* dictStructurePolicy */);
302        while (bigramsIt.hasNext()) {
303            bigramsIt.next();
304            if (bigramsIt.getBigramPos() == ptNodePos
305                    && bigramsIt.getProbability() != NOT_A_PROBABILITY) {
306                return getProbability(ptNodeParams.getProbability(), bigramsIt.getProbability());
307            }
308        }
309        return NOT_A_PROBABILITY;
310    }
311    return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
312}
313
314int PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const {
315    if (ptNodePos == NOT_A_DICT_POS) {
316        return NOT_A_DICT_POS;
317    }
318    return mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos).getShortcutPos();
319}
320
321BinaryDictionaryBigramsIterator PatriciaTriePolicy::getBigramsIteratorOfPtNode(
322        const int ptNodePos) const {
323    const int bigramsPosition = getBigramsPositionOfPtNode(ptNodePos);
324    return BinaryDictionaryBigramsIterator(&mBigramListPolicy, bigramsPosition);
325}
326
327int PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) const {
328    if (ptNodePos == NOT_A_DICT_POS) {
329        return NOT_A_DICT_POS;
330    }
331    return mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos).getBigramsPos();
332}
333
334int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNode,
335        const int ptNodePos, DicNodeVector *childDicNodes) const {
336    PatriciaTrieReadingUtils::NodeFlags flags;
337    int mergedNodeCodePointCount = 0;
338    int mergedNodeCodePoints[MAX_WORD_LENGTH];
339    int probability = NOT_A_PROBABILITY;
340    int childrenPos = NOT_A_DICT_POS;
341    int shortcutPos = NOT_A_DICT_POS;
342    int bigramPos = NOT_A_DICT_POS;
343    int siblingPos = NOT_A_DICT_POS;
344    PatriciaTrieReadingUtils::readPtNodeInfo(mDictRoot, ptNodePos, getShortcutsStructurePolicy(),
345            &mBigramListPolicy, &flags, &mergedNodeCodePointCount, mergedNodeCodePoints,
346            &probability, &childrenPos, &shortcutPos, &bigramPos, &siblingPos);
347    // Skip PtNodes don't start with Unicode code point because they represent non-word information.
348    if (CharUtils::isInUnicodeSpace(mergedNodeCodePoints[0])) {
349        childDicNodes->pushLeavingChild(dicNode, ptNodePos, childrenPos, probability,
350                PatriciaTrieReadingUtils::isTerminal(flags),
351                PatriciaTrieReadingUtils::hasChildrenInFlags(flags),
352                PatriciaTrieReadingUtils::isBlacklisted(flags)
353                        || PatriciaTrieReadingUtils::isNotAWord(flags),
354                mergedNodeCodePointCount, mergedNodeCodePoints);
355    }
356    return siblingPos;
357}
358
359const WordProperty PatriciaTriePolicy::getWordProperty(const int *const codePoints,
360        const int codePointCount) const {
361    const int ptNodePos = getTerminalPtNodePositionOfWord(codePoints, codePointCount,
362            false /* forceLowerCaseSearch */);
363    if (ptNodePos == NOT_A_DICT_POS) {
364        AKLOGE("getWordProperty was called for invalid word.");
365        return WordProperty();
366    }
367    const PtNodeParams ptNodeParams =
368            mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
369    std::vector<int> codePointVector(ptNodeParams.getCodePoints(),
370            ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount());
371    // Fetch bigram information.
372    std::vector<BigramProperty> bigrams;
373    const int bigramListPos = getBigramsPositionOfPtNode(ptNodePos);
374    int bigramWord1CodePoints[MAX_WORD_LENGTH];
375    BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramListPos);
376    while (bigramsIt.hasNext()) {
377        // Fetch the next bigram information and forward the iterator.
378        bigramsIt.next();
379        // Skip the entry if the entry has been deleted. This never happens for ver2 dicts.
380        if (bigramsIt.getBigramPos() != NOT_A_DICT_POS) {
381            int word1Probability = NOT_A_PROBABILITY;
382            const int word1CodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
383                    bigramsIt.getBigramPos(), MAX_WORD_LENGTH, bigramWord1CodePoints,
384                    &word1Probability);
385            const std::vector<int> word1(bigramWord1CodePoints,
386                    bigramWord1CodePoints + word1CodePointCount);
387            const int probability = getProbability(word1Probability, bigramsIt.getProbability());
388            bigrams.emplace_back(&word1, probability,
389                    NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */);
390        }
391    }
392    // Fetch shortcut information.
393    std::vector<UnigramProperty::ShortcutProperty> shortcuts;
394    int shortcutPos = getShortcutPositionOfPtNode(ptNodePos);
395    if (shortcutPos != NOT_A_DICT_POS) {
396        int shortcutTargetCodePoints[MAX_WORD_LENGTH];
397        ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mDictRoot, &shortcutPos);
398        bool hasNext = true;
399        while (hasNext) {
400            const ShortcutListReadingUtils::ShortcutFlags shortcutFlags =
401                    ShortcutListReadingUtils::getFlagsAndForwardPointer(mDictRoot, &shortcutPos);
402            hasNext = ShortcutListReadingUtils::hasNext(shortcutFlags);
403            const int shortcutTargetLength = ShortcutListReadingUtils::readShortcutTarget(
404                    mDictRoot, MAX_WORD_LENGTH, shortcutTargetCodePoints, &shortcutPos);
405            const std::vector<int> shortcutTarget(shortcutTargetCodePoints,
406                    shortcutTargetCodePoints + shortcutTargetLength);
407            const int shortcutProbability =
408                    ShortcutListReadingUtils::getProbabilityFromFlags(shortcutFlags);
409            shortcuts.emplace_back(&shortcutTarget, shortcutProbability);
410        }
411    }
412    const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(),
413            ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(),
414            NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts);
415    return WordProperty(&codePointVector, &unigramProperty, &bigrams);
416}
417
418int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints,
419        int *const outCodePointCount) {
420    *outCodePointCount = 0;
421    if (token == 0) {
422        // Start iterating the dictionary.
423        mTerminalPtNodePositionsForIteratingWords.clear();
424        DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy(
425                &mTerminalPtNodePositionsForIteratingWords);
426        DynamicPtReadingHelper readingHelper(&mPtNodeReader, &mPtNodeArrayReader);
427        readingHelper.initWithPtNodeArrayPos(getRootPosition());
428        readingHelper.traverseAllPtNodesInPostorderDepthFirstManner(&traversePolicy);
429    }
430    const int terminalPtNodePositionsVectorSize =
431            static_cast<int>(mTerminalPtNodePositionsForIteratingWords.size());
432    if (token < 0 || token >= terminalPtNodePositionsVectorSize) {
433        AKLOGE("Given token %d is invalid.", token);
434        return 0;
435    }
436    const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
437    int unigramProbability = NOT_A_PROBABILITY;
438    *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos,
439            MAX_WORD_LENGTH, outCodePoints, &unigramProbability);
440    const int nextToken = token + 1;
441    if (nextToken >= terminalPtNodePositionsVectorSize) {
442        // All words have been iterated.
443        mTerminalPtNodePositionsForIteratingWords.clear();
444        return 0;
445    }
446    return nextToken;
447}
448
449} // namespace latinime
450