patricia_trie_policy.cpp revision 8681bef03c1ca864d3de0ae27adb5cbfb63f0fef
1/*
2 * Copyright (C) 2013, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17
18#include "suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h"
19
20#include "defines.h"
21#include "suggest/core/dicnode/dic_node.h"
22#include "suggest/core/dicnode/dic_node_vector.h"
23#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h"
24#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h"
25#include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h"
26#include "suggest/policyimpl/dictionary/utils/probability_utils.h"
27#include "utils/char_utils.h"
28
29namespace latinime {
30
31void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNode,
32        DicNodeVector *const childDicNodes) const {
33    if (!dicNode->hasChildren()) {
34        return;
35    }
36    int nextPos = dicNode->getChildrenPtNodeArrayPos();
37    if (nextPos < 0 || nextPos >= mDictBufferSize) {
38        AKLOGE("Children PtNode array position is invalid. pos: %d, dict size: %d",
39                nextPos, mDictBufferSize);
40        mIsCorrupted = true;
41        ASSERT(false);
42        return;
43    }
44    const int childCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition(
45            mDictRoot, &nextPos);
46    for (int i = 0; i < childCount; i++) {
47        if (nextPos < 0 || nextPos >= mDictBufferSize) {
48            AKLOGE("Child PtNode position is invalid. pos: %d, dict size: %d, childCount: %d / %d",
49                    nextPos, mDictBufferSize, i, childCount);
50            mIsCorrupted = true;
51            ASSERT(false);
52            return;
53        }
54        nextPos = createAndGetLeavingChildNode(dicNode, nextPos, childDicNodes);
55    }
56}
57
58// This retrieves code points and the probability of the word by its terminal position.
59// Due to the fact that words are ordered in the dictionary in a strict breadth-first order,
60// it is possible to check for this with advantageous complexity. For each PtNode array, we search
61// for PtNodes with children and compare the children position with the position we look for.
62// When we shoot the position we look for, it means the word we look for is in the children
63// of the previous PtNode. The only tricky part is the fact that if we arrive at the end of a
64// PtNode array with the last PtNode's children position still less than what we are searching for,
65// we must descend the last PtNode's children (for example, if the word we are searching for starts
66// with a z, it's the last PtNode of the root array, so all children addresses will be smaller
67// than the position we look for, and we have to descend the z PtNode).
68/* Parameters :
69 * ptNodePos: the byte position of the terminal PtNode of the word we are searching for (this is
70 *   what is stored as the "bigram position" in each bigram)
71 * outCodePoints: an array to write the found word, with MAX_WORD_LENGTH size.
72 * outUnigramProbability: a pointer to an int to write the probability into.
73 * Return value : the code point count, of 0 if the word was not found.
74 */
75// TODO: Split this function to be more readable
76int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
77        const int ptNodePos, const int maxCodePointCount, int *const outCodePoints,
78        int *const outUnigramProbability) const {
79    int pos = getRootPosition();
80    int wordPos = 0;
81    // One iteration of the outer loop iterates through PtNode arrays. As stated above, we will
82    // only traverse PtNodes that are actually a part of the terminal we are searching, so each
83    // time we enter this loop we are one depth level further than last time.
84    // The only reason we count PtNodes is because we want to reduce the probability of infinite
85    // looping in case there is a bug. Since we know there is an upper bound to the depth we are
86    // supposed to traverse, it does not hurt to count iterations.
87    for (int loopCount = maxCodePointCount; loopCount > 0; --loopCount) {
88        int lastCandidatePtNodePos = 0;
89        // Let's loop through PtNodes in this PtNode array searching for either the terminal
90        // or one of its ascendants.
91        if (pos < 0 || pos >= mDictBufferSize) {
92            AKLOGE("PtNode array position is invalid. pos: %d, dict size: %d",
93                    pos, mDictBufferSize);
94            mIsCorrupted = true;
95            ASSERT(false);
96            *outUnigramProbability = NOT_A_PROBABILITY;
97            return 0;
98        }
99        for (int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition(
100                mDictRoot, &pos); ptNodeCount > 0; --ptNodeCount) {
101            const int startPos = pos;
102            if (pos < 0 || pos >= mDictBufferSize) {
103                AKLOGE("PtNode position is invalid. pos: %d, dict size: %d", pos, mDictBufferSize);
104                mIsCorrupted = true;
105                ASSERT(false);
106                *outUnigramProbability = NOT_A_PROBABILITY;
107                return 0;
108            }
109            const PatriciaTrieReadingUtils::NodeFlags flags =
110                    PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos);
111            const int character = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
112                    mDictRoot, &pos);
113            if (ptNodePos == startPos) {
114                // We found the position. Copy the rest of the code points in the buffer and return
115                // the length.
116                outCodePoints[wordPos] = character;
117                if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) {
118                    int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
119                            mDictRoot, &pos);
120                    // We count code points in order to avoid infinite loops if the file is broken
121                    // or if there is some other bug
122                    int charCount = maxCodePointCount;
123                    while (NOT_A_CODE_POINT != nextChar && --charCount > 0) {
124                        outCodePoints[++wordPos] = nextChar;
125                        nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
126                                mDictRoot, &pos);
127                    }
128                }
129                *outUnigramProbability =
130                        PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot,
131                                &pos);
132                return ++wordPos;
133            }
134            // We need to skip past this PtNode, so skip any remaining code points after the
135            // first and possibly the probability.
136            if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) {
137                PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, &pos);
138            }
139            if (PatriciaTrieReadingUtils::isTerminal(flags)) {
140                PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos);
141            }
142            // The fact that this PtNode has children is very important. Since we already know
143            // that this PtNode does not match, if it has no children we know it is irrelevant
144            // to what we are searching for.
145            const bool hasChildren = PatriciaTrieReadingUtils::hasChildrenInFlags(flags);
146            // We will write in `found' whether we have passed the children position we are
147            // searching for. For example if we search for "beer", the children of b are less
148            // than the address we are searching for and the children of c are greater. When we
149            // come here for c, we realize this is too big, and that we should descend b.
150            bool found;
151            if (hasChildren) {
152                int currentPos = pos;
153                // Here comes the tricky part. First, read the children position.
154                const int childrenPos = PatriciaTrieReadingUtils
155                        ::readChildrenPositionAndAdvancePosition(mDictRoot, flags, &currentPos);
156                if (childrenPos > ptNodePos) {
157                    // If the children pos is greater than the position, it means the previous
158                    // PtNode, which position is stored in lastCandidatePtNodePos, was the right
159                    // one.
160                    found = true;
161                } else if (1 >= ptNodeCount) {
162                    // However if we are on the LAST PtNode of this array, and we have NOT shot the
163                    // position we should descend THIS PtNode. So we trick the
164                    // lastCandidatePtNodePos so that we will descend this PtNode, not the previous
165                    // one.
166                    lastCandidatePtNodePos = startPos;
167                    found = true;
168                } else {
169                    // Else, we should continue looking.
170                    found = false;
171                }
172            } else {
173                // Even if we don't have children here, we could still be on the last PtNode of
174                // this array. If this is the case, we should descend the last PtNode that had
175                // children, and their position is already in lastCandidatePtNodePos.
176                found = (1 >= ptNodeCount);
177            }
178
179            if (found) {
180                // Okay, we found the PtNode we should descend. Its position is in
181                // the lastCandidatePtNodePos variable, so we just re-read it.
182                if (0 != lastCandidatePtNodePos) {
183                    const PatriciaTrieReadingUtils::NodeFlags lastFlags =
184                            PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(
185                                    mDictRoot, &lastCandidatePtNodePos);
186                    const int lastChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
187                            mDictRoot, &lastCandidatePtNodePos);
188                    // We copy all the characters in this PtNode to the buffer
189                    outCodePoints[wordPos] = lastChar;
190                    if (PatriciaTrieReadingUtils::hasMultipleChars(lastFlags)) {
191                        int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
192                                mDictRoot, &lastCandidatePtNodePos);
193                        int charCount = maxCodePointCount;
194                        while (-1 != nextChar && --charCount > 0) {
195                            outCodePoints[++wordPos] = nextChar;
196                            nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(
197                                    mDictRoot, &lastCandidatePtNodePos);
198                        }
199                    }
200                    ++wordPos;
201                    // Now we only need to branch to the children address. Skip the probability if
202                    // it's there, read pos, and break to resume the search at pos.
203                    if (PatriciaTrieReadingUtils::isTerminal(lastFlags)) {
204                        PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot,
205                                &lastCandidatePtNodePos);
206                    }
207                    pos = PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(
208                            mDictRoot, lastFlags, &lastCandidatePtNodePos);
209                    break;
210                } else {
211                    // Here is a little tricky part: we come here if we found out that all children
212                    // addresses in this PtNode are bigger than the address we are searching for.
213                    // Should we conclude the word is not in the dictionary? No! It could still be
214                    // one of the remaining PtNodes in this array, so we have to keep looking in
215                    // this array until we find it (or we realize it's not there either, in which
216                    // case it's actually not in the dictionary). Pass the end of this PtNode,
217                    // ready to start the next one.
218                    if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) {
219                        PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(
220                                mDictRoot, flags, &pos);
221                    }
222                    if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) {
223                        mShortcutListPolicy.skipAllShortcuts(&pos);
224                    }
225                    if (PatriciaTrieReadingUtils::hasBigrams(flags)) {
226                        mBigramListPolicy.skipAllBigrams(&pos);
227                    }
228                }
229            } else {
230                // If we did not find it, we should record the last children address for the next
231                // iteration.
232                if (hasChildren) lastCandidatePtNodePos = startPos;
233                // Now skip the end of this PtNode (children pos and the attributes if any) so that
234                // our pos is after the end of this PtNode, at the start of the next one.
235                if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) {
236                    PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(
237                            mDictRoot, flags, &pos);
238                }
239                if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) {
240                    mShortcutListPolicy.skipAllShortcuts(&pos);
241                }
242                if (PatriciaTrieReadingUtils::hasBigrams(flags)) {
243                    mBigramListPolicy.skipAllBigrams(&pos);
244                }
245            }
246
247        }
248    }
249    // If we have looked through all the PtNodes and found no match, the ptNodePos is
250    // not the position of a terminal in this dictionary.
251    return 0;
252}
253
254// This function gets the position of the terminal PtNode of the exact matching word in the
255// dictionary. If no match is found, it returns NOT_A_DICT_POS.
256int PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const int *const inWord,
257        const int length, const bool forceLowerCaseSearch) const {
258    DynamicPtReadingHelper readingHelper(&mPtNodeReader, &mPtNodeArrayReader);
259    readingHelper.initWithPtNodeArrayPos(getRootPosition());
260    const int ptNodePos =
261            readingHelper.getTerminalPtNodePositionOfWord(inWord, length, forceLowerCaseSearch);
262    if (readingHelper.isError()) {
263        mIsCorrupted = true;
264        AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes().");
265    }
266    return ptNodePos;
267}
268
269int PatriciaTriePolicy::getProbability(const int unigramProbability,
270        const int bigramProbability) const {
271    // Due to space constraints, the probability for bigrams is approximate - the lower the unigram
272    // probability, the worse the precision. The theoritical maximum error in resulting probability
273    // is 8 - although in the practice it's never bigger than 3 or 4 in very bad cases. This means
274    // that sometimes, we'll see some bigrams interverted here, but it can't get too bad.
275    if (unigramProbability == NOT_A_PROBABILITY) {
276        return NOT_A_PROBABILITY;
277    } else if (bigramProbability == NOT_A_PROBABILITY) {
278        return ProbabilityUtils::backoff(unigramProbability);
279    } else {
280        return ProbabilityUtils::computeProbabilityForBigram(unigramProbability,
281                bigramProbability);
282    }
283}
284
285int PatriciaTriePolicy::getUnigramProbabilityOfPtNode(const int ptNodePos) const {
286    if (ptNodePos == NOT_A_DICT_POS) {
287        return NOT_A_PROBABILITY;
288    }
289    const PtNodeParams ptNodeParams =
290            mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
291    if (ptNodeParams.isNotAWord() || ptNodeParams.isBlacklisted()) {
292        // If this is not a word, or if it's a blacklisted entry, it should behave as
293        // having no probability outside of the suggestion process (where it should be used
294        // for shortcuts).
295        return NOT_A_PROBABILITY;
296    }
297    return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
298}
299
300int PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const {
301    if (ptNodePos == NOT_A_DICT_POS) {
302        return NOT_A_DICT_POS;
303    }
304    return mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos).getShortcutPos();
305}
306
307int PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) const {
308    if (ptNodePos == NOT_A_DICT_POS) {
309        return NOT_A_DICT_POS;
310    }
311    return mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos).getBigramsPos();
312}
313
314int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNode,
315        const int ptNodePos, DicNodeVector *childDicNodes) const {
316    PatriciaTrieReadingUtils::NodeFlags flags;
317    int mergedNodeCodePointCount = 0;
318    int mergedNodeCodePoints[MAX_WORD_LENGTH];
319    int probability = NOT_A_PROBABILITY;
320    int childrenPos = NOT_A_DICT_POS;
321    int shortcutPos = NOT_A_DICT_POS;
322    int bigramPos = NOT_A_DICT_POS;
323    int siblingPos = NOT_A_DICT_POS;
324    PatriciaTrieReadingUtils::readPtNodeInfo(mDictRoot, ptNodePos, getShortcutsStructurePolicy(),
325            getBigramsStructurePolicy(), &flags, &mergedNodeCodePointCount, mergedNodeCodePoints,
326            &probability, &childrenPos, &shortcutPos, &bigramPos, &siblingPos);
327    // Skip PtNodes don't start with Unicode code point because they represent non-word information.
328    if (CharUtils::isInUnicodeSpace(mergedNodeCodePoints[0])) {
329        childDicNodes->pushLeavingChild(dicNode, ptNodePos, childrenPos, probability,
330                PatriciaTrieReadingUtils::isTerminal(flags),
331                PatriciaTrieReadingUtils::hasChildrenInFlags(flags),
332                PatriciaTrieReadingUtils::isBlacklisted(flags)
333                        || PatriciaTrieReadingUtils::isNotAWord(flags),
334                mergedNodeCodePointCount, mergedNodeCodePoints);
335    }
336    return siblingPos;
337}
338
339const WordProperty PatriciaTriePolicy::getWordProperty(const int *const codePoints,
340        const int codePointCount) const {
341    const int ptNodePos = getTerminalPtNodePositionOfWord(codePoints, codePointCount,
342            false /* forceLowerCaseSearch */);
343    if (ptNodePos == NOT_A_DICT_POS) {
344        AKLOGE("getWordProperty was called for invalid word.");
345        return WordProperty();
346    }
347    const PtNodeParams ptNodeParams =
348            mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
349    std::vector<int> codePointVector(ptNodeParams.getCodePoints(),
350            ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount());
351    // Fetch bigram information.
352    std::vector<BigramProperty> bigrams;
353    const int bigramListPos = getBigramsPositionOfPtNode(ptNodePos);
354    int bigramWord1CodePoints[MAX_WORD_LENGTH];
355    BinaryDictionaryBigramsIterator bigramsIt(getBigramsStructurePolicy(), bigramListPos);
356    while (bigramsIt.hasNext()) {
357        // Fetch the next bigram information and forward the iterator.
358        bigramsIt.next();
359        // Skip the entry if the entry has been deleted. This never happens for ver2 dicts.
360        if (bigramsIt.getBigramPos() != NOT_A_DICT_POS) {
361            int word1Probability = NOT_A_PROBABILITY;
362            const int word1CodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
363                    bigramsIt.getBigramPos(), MAX_WORD_LENGTH, bigramWord1CodePoints,
364                    &word1Probability);
365            const std::vector<int> word1(bigramWord1CodePoints,
366                    bigramWord1CodePoints + word1CodePointCount);
367            const int probability = getProbability(word1Probability, bigramsIt.getProbability());
368            bigrams.emplace_back(&word1, probability,
369                    NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */);
370        }
371    }
372    // Fetch shortcut information.
373    std::vector<UnigramProperty::ShortcutProperty> shortcuts;
374    int shortcutPos = getShortcutPositionOfPtNode(ptNodePos);
375    if (shortcutPos != NOT_A_DICT_POS) {
376        int shortcutTargetCodePoints[MAX_WORD_LENGTH];
377        ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mDictRoot, &shortcutPos);
378        bool hasNext = true;
379        while (hasNext) {
380            const ShortcutListReadingUtils::ShortcutFlags shortcutFlags =
381                    ShortcutListReadingUtils::getFlagsAndForwardPointer(mDictRoot, &shortcutPos);
382            hasNext = ShortcutListReadingUtils::hasNext(shortcutFlags);
383            const int shortcutTargetLength = ShortcutListReadingUtils::readShortcutTarget(
384                    mDictRoot, MAX_WORD_LENGTH, shortcutTargetCodePoints, &shortcutPos);
385            const std::vector<int> shortcutTarget(shortcutTargetCodePoints,
386                    shortcutTargetCodePoints + shortcutTargetLength);
387            const int shortcutProbability =
388                    ShortcutListReadingUtils::getProbabilityFromFlags(shortcutFlags);
389            shortcuts.emplace_back(&shortcutTarget, shortcutProbability);
390        }
391    }
392    const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(),
393            ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(),
394            NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts);
395    return WordProperty(&codePointVector, &unigramProperty, &bigrams);
396}
397
398int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints,
399        int *const outCodePointCount) {
400    *outCodePointCount = 0;
401    if (token == 0) {
402        // Start iterating the dictionary.
403        mTerminalPtNodePositionsForIteratingWords.clear();
404        DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy(
405                &mTerminalPtNodePositionsForIteratingWords);
406        DynamicPtReadingHelper readingHelper(&mPtNodeReader, &mPtNodeArrayReader);
407        readingHelper.initWithPtNodeArrayPos(getRootPosition());
408        readingHelper.traverseAllPtNodesInPostorderDepthFirstManner(&traversePolicy);
409    }
410    const int terminalPtNodePositionsVectorSize =
411            static_cast<int>(mTerminalPtNodePositionsForIteratingWords.size());
412    if (token < 0 || token >= terminalPtNodePositionsVectorSize) {
413        AKLOGE("Given token %d is invalid.", token);
414        return 0;
415    }
416    const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
417    int unigramProbability = NOT_A_PROBABILITY;
418    *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos,
419            MAX_WORD_LENGTH, outCodePoints, &unigramProbability);
420    const int nextToken = token + 1;
421    if (nextToken >= terminalPtNodePositionsVectorSize) {
422        // All words have been iterated.
423        mTerminalPtNodePositionsForIteratingWords.clear();
424        return 0;
425    }
426    return nextToken;
427}
428
429} // namespace latinime
430