dic_node.h revision 35c62b2cc99761e97f57060ad5e3cdfad926aea7
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LATINIME_DIC_NODE_H
18#define LATINIME_DIC_NODE_H
19
20#include "defines.h"
21#include "suggest/core/dicnode/dic_node_profiler.h"
22#include "suggest/core/dicnode/dic_node_utils.h"
23#include "suggest/core/dicnode/internal/dic_node_state.h"
24#include "suggest/core/dicnode/internal/dic_node_properties.h"
25#include "suggest/core/dictionary/digraph_utils.h"
26#include "suggest/core/dictionary/error_type_utils.h"
27#include "suggest/core/layout/proximity_info_state.h"
28#include "utils/char_utils.h"
29
30#if DEBUG_DICT
31#define LOGI_SHOW_ADD_COST_PROP \
32        do { \
33            char charBuf[50]; \
34            INTS_TO_CHARS(getOutputWordBuf(), getNodeCodePointCount(), charBuf, NELEMS(charBuf)); \
35            AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \
36                    __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \
37                    getInputIndex(0), getNormalizedCompoundDistance(), charBuf); \
38        } while (0)
39#define DUMP_WORD_AND_SCORE(header) \
40        do { \
41            char charBuf[50]; \
42            INTS_TO_CHARS(getOutputWordBuf(), \
43                    getNodeCodePointCount() \
44                            + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength(), \
45                    charBuf, NELEMS(charBuf)); \
46            AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %d, %5f,", header, \
47                    getSpatialDistanceForScoring(), \
48                    mDicNodeState.mDicNodeStateScoring.getLanguageDistance(), \
49                    getNormalizedCompoundDistance(), getRawLength(), charBuf, \
50                    getInputIndex(0), getNormalizedCompoundDistanceAfterFirstWord()); \
51        } while (0)
52#else
53#define LOGI_SHOW_ADD_COST_PROP
54#define DUMP_WORD_AND_SCORE(header)
55#endif
56
57namespace latinime {
58
59// This struct is purely a bucket to return values. No instances of this struct should be kept.
60struct DicNode_InputStateG {
61    DicNode_InputStateG()
62            : mNeedsToUpdateInputStateG(false), mPointerId(0), mInputIndex(0),
63              mPrevCodePoint(0), mTerminalDiffCost(0.0f), mRawLength(0.0f),
64              mDoubleLetterLevel(NOT_A_DOUBLE_LETTER) {}
65
66    bool mNeedsToUpdateInputStateG;
67    int mPointerId;
68    int16_t mInputIndex;
69    int mPrevCodePoint;
70    float mTerminalDiffCost;
71    float mRawLength;
72    DoubleLetterLevel mDoubleLetterLevel;
73};
74
75class DicNode {
76    // Caveat: We define Weighting as a friend class of DicNode to let Weighting change
77    // the distance of DicNode.
78    // Caution!!! In general, we avoid using the "friend" access modifier.
79    // This is an exception to explicitly hide DicNode::addCost() from all classes but Weighting.
80    friend class Weighting;
81
82 public:
83#if DEBUG_DICT
84    DicNodeProfiler mProfiler;
85#endif
86
87    AK_FORCE_INLINE DicNode()
88            :
89#if DEBUG_DICT
90              mProfiler(),
91#endif
92              mDicNodeProperties(), mDicNodeState(), mIsCachedForNextSuggestion(false) {}
93
94    DicNode(const DicNode &dicNode);
95    DicNode &operator=(const DicNode &dicNode);
96    ~DicNode() {}
97
98    // Init for copy
99    void initByCopy(const DicNode *const dicNode) {
100        mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
101        mDicNodeProperties.initByCopy(&dicNode->mDicNodeProperties);
102        mDicNodeState.initByCopy(&dicNode->mDicNodeState);
103        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
104    }
105
106    // Init for root with prevWordsPtNodePos which is used for n-gram
107    void initAsRoot(const int rootPtNodeArrayPos, const int *const prevWordsPtNodePos) {
108        mIsCachedForNextSuggestion = false;
109        mDicNodeProperties.init(rootPtNodeArrayPos, prevWordsPtNodePos);
110        mDicNodeState.init();
111        PROF_NODE_RESET(mProfiler);
112    }
113
114    // Init for root with previous word
115    void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) {
116        mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
117        int newPrevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
118        newPrevWordsPtNodePos[0] = dicNode->mDicNodeProperties.getPtNodePos();
119        for (size_t i = 1; i < NELEMS(newPrevWordsPtNodePos); ++i) {
120            newPrevWordsPtNodePos[i] = dicNode->getPrevWordsTerminalPtNodePos()[i - 1];
121        }
122        mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordsPtNodePos);
123        mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState,
124                dicNode->mDicNodeProperties.getDepth());
125        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
126    }
127
128    void initAsPassingChild(const DicNode *parentDicNode) {
129        mIsCachedForNextSuggestion = parentDicNode->mIsCachedForNextSuggestion;
130        const int codePoint =
131                parentDicNode->mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(
132                            parentDicNode->getNodeCodePointCount());
133        mDicNodeProperties.init(&parentDicNode->mDicNodeProperties, codePoint);
134        mDicNodeState.initByCopy(&parentDicNode->mDicNodeState);
135        PROF_NODE_COPY(&parentDicNode->mProfiler, mProfiler);
136    }
137
138    void initAsChild(const DicNode *const dicNode, const int ptNodePos,
139            const int childrenPtNodeArrayPos, const int probability, const bool isTerminal,
140            const bool hasChildren, const bool isBlacklistedOrNotAWord,
141            const uint16_t mergedNodeCodePointCount, const int *const mergedNodeCodePoints) {
142        uint16_t newDepth = static_cast<uint16_t>(dicNode->getNodeCodePointCount() + 1);
143        mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
144        const uint16_t newLeavingDepth = static_cast<uint16_t>(
145                dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount);
146        mDicNodeProperties.init(ptNodePos, childrenPtNodeArrayPos, mergedNodeCodePoints[0],
147                probability, isTerminal, hasChildren, isBlacklistedOrNotAWord, newDepth,
148                newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordsTerminalPtNodePos());
149        mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount,
150                mergedNodeCodePoints);
151        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
152    }
153
154    bool isRoot() const {
155        return getNodeCodePointCount() == 0;
156    }
157
158    bool hasChildren() const {
159        return mDicNodeProperties.hasChildren();
160    }
161
162    bool isLeavingNode() const {
163        ASSERT(getNodeCodePointCount() <= mDicNodeProperties.getLeavingDepth());
164        return getNodeCodePointCount() == mDicNodeProperties.getLeavingDepth();
165    }
166
167    AK_FORCE_INLINE bool isFirstLetter() const {
168        return getNodeCodePointCount() == 1;
169    }
170
171    bool isCached() const {
172        return mIsCachedForNextSuggestion;
173    }
174
175    void setCached() {
176        mIsCachedForNextSuggestion = true;
177    }
178
179    // Check if the current word and the previous word can be considered as a valid multiple word
180    // suggestion.
181    bool isValidMultipleWordSuggestion() const {
182        if (isBlacklistedOrNotAWord()) {
183            return false;
184        }
185        // Treat suggestion as invalid if the current and the previous word are single character
186        // words.
187        const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength()
188                - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1;
189        const int currentWordLen = getNodeCodePointCount();
190        return (prevWordLen != 1 || currentWordLen != 1);
191    }
192
193    bool isFirstCharUppercase() const {
194        const int c = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(0);
195        return CharUtils::isAsciiUpper(c);
196    }
197
198    bool isCompletion(const int inputSize) const {
199        return mDicNodeState.mDicNodeStateInput.getInputIndex(0) >= inputSize;
200    }
201
202    bool canDoLookAheadCorrection(const int inputSize) const {
203        return mDicNodeState.mDicNodeStateInput.getInputIndex(0) < inputSize - 1;
204    }
205
206    // Used to get n-gram probability in DicNodeUtils.
207    int getPtNodePos() const {
208        return mDicNodeProperties.getPtNodePos();
209    }
210
211    // TODO: Use view class to return PtNodePos array.
212    const int *getPrevWordsTerminalPtNodePos() const {
213        return mDicNodeProperties.getPrevWordsTerminalPtNodePos();
214    }
215
216    // Used in DicNodeUtils
217    int getChildrenPtNodeArrayPos() const {
218        return mDicNodeProperties.getChildrenPtNodeArrayPos();
219    }
220
221    int getProbability() const {
222        return mDicNodeProperties.getProbability();
223    }
224
225    AK_FORCE_INLINE bool isTerminalDicNode() const {
226        const bool isTerminalPtNode = mDicNodeProperties.isTerminal();
227        const int currentDicNodeDepth = getNodeCodePointCount();
228        const int terminalDicNodeDepth = mDicNodeProperties.getLeavingDepth();
229        return isTerminalPtNode && currentDicNodeDepth > 0
230                && currentDicNodeDepth == terminalDicNodeDepth;
231    }
232
233    bool shouldBeFilteredBySafetyNetForBigram() const {
234        const uint16_t currentDepth = getNodeCodePointCount();
235        const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength()
236                - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1;
237        return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1));
238    }
239
240    bool hasMatchedOrProximityCodePoints() const {
241        // This DicNode does not have matched or proximity code points when all code points have
242        // been handled as edit corrections or completion so far.
243        const int editCorrectionCount = mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount();
244        const int completionCount = mDicNodeState.mDicNodeStateScoring.getCompletionCount();
245        return (editCorrectionCount + completionCount) < getNodeCodePointCount();
246    }
247
248    bool isTotalInputSizeExceedingLimit() const {
249        // TODO: 3 can be 2? Needs to be investigated.
250        // TODO: Have a const variable for 3 (or 2)
251        return getTotalNodeCodePointCount() > MAX_WORD_LENGTH - 3;
252    }
253
254    void outputResult(int *dest) const {
255        memmove(dest, getOutputWordBuf(), getTotalNodeCodePointCount() * sizeof(dest[0]));
256        DUMP_WORD_AND_SCORE("OUTPUT");
257    }
258
259    // "Total" in this context (and other methods in this class) means the whole suggestion. When
260    // this represents a multi-word suggestion, the referenced PtNode (in mDicNodeState) is only
261    // the one that corresponds to the last word of the suggestion, and all the previous words
262    // are concatenated together in mDicNodeStateOutput.
263    int getTotalNodeSpaceCount() const {
264        if (!hasMultipleWords()) {
265            return 0;
266        }
267        return CharUtils::getSpaceCount(mDicNodeState.mDicNodeStateOutput.getCodePointBuf(),
268                mDicNodeState.mDicNodeStateOutput.getPrevWordsLength());
269    }
270
271    int getSecondWordFirstInputIndex(const ProximityInfoState *const pInfoState) const {
272        const int inputIndex = mDicNodeState.mDicNodeStateOutput.getSecondWordFirstInputIndex();
273        if (inputIndex == NOT_AN_INDEX) {
274            return NOT_AN_INDEX;
275        } else {
276            return pInfoState->getInputIndexOfSampledPoint(inputIndex);
277        }
278    }
279
280    bool hasMultipleWords() const {
281        return mDicNodeState.mDicNodeStateOutput.getPrevWordCount() > 0;
282    }
283
284    int getProximityCorrectionCount() const {
285        return mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount();
286    }
287
288    int getEditCorrectionCount() const {
289        return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount();
290    }
291
292    // Used to prune nodes
293    float getNormalizedCompoundDistance() const {
294        return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistance();
295    }
296
297    // Used to prune nodes
298    float getNormalizedSpatialDistance() const {
299        return mDicNodeState.mDicNodeStateScoring.getSpatialDistance()
300                / static_cast<float>(getInputIndex(0) + 1);
301    }
302
303    // Used to prune nodes
304    float getCompoundDistance() const {
305        return mDicNodeState.mDicNodeStateScoring.getCompoundDistance();
306    }
307
308    // Used to prune nodes
309    float getCompoundDistance(const float languageWeight) const {
310        return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight);
311    }
312
313    AK_FORCE_INLINE const int *getOutputWordBuf() const {
314        return mDicNodeState.mDicNodeStateOutput.getCodePointBuf();
315    }
316
317    int getPrevCodePointG(int pointerId) const {
318        return mDicNodeState.mDicNodeStateInput.getPrevCodePoint(pointerId);
319    }
320
321    // Whether the current codepoint can be an intentional omission, in which case the traversal
322    // algorithm will always check for a possible omission here.
323    bool canBeIntentionalOmission() const {
324        return CharUtils::isIntentionalOmissionCodePoint(getNodeCodePoint());
325    }
326
327    // Whether the omission is so frequent that it should incur zero cost.
328    bool isZeroCostOmission() const {
329        // TODO: do not hardcode and read from header
330        return (getNodeCodePoint() == KEYCODE_SINGLE_QUOTE);
331    }
332
333    // TODO: remove
334    float getTerminalDiffCostG(int path) const {
335        return mDicNodeState.mDicNodeStateInput.getTerminalDiffCost(path);
336    }
337
338    //////////////////////
339    // Temporary getter //
340    // TODO: Remove     //
341    //////////////////////
342    // TODO: Remove once touch path is merged into ProximityInfoState
343    // Note: Returned codepoint may be a digraph codepoint if the node is in a composite glyph.
344    int getNodeCodePoint() const {
345        const int codePoint = mDicNodeProperties.getDicNodeCodePoint();
346        const DigraphUtils::DigraphCodePointIndex digraphIndex =
347                mDicNodeState.mDicNodeStateScoring.getDigraphIndex();
348        if (digraphIndex == DigraphUtils::NOT_A_DIGRAPH_INDEX) {
349            return codePoint;
350        }
351        return DigraphUtils::getDigraphCodePointForIndex(codePoint, digraphIndex);
352    }
353
354    ////////////////////////////////
355    // Utils for cost calculation //
356    ////////////////////////////////
357    AK_FORCE_INLINE bool isSameNodeCodePoint(const DicNode *const dicNode) const {
358        return mDicNodeProperties.getDicNodeCodePoint()
359                == dicNode->mDicNodeProperties.getDicNodeCodePoint();
360    }
361
362    // TODO: remove
363    // TODO: rename getNextInputIndex
364    int16_t getInputIndex(int pointerId) const {
365        return mDicNodeState.mDicNodeStateInput.getInputIndex(pointerId);
366    }
367
368    ////////////////////////////////////
369    // Getter of features for scoring //
370    ////////////////////////////////////
371    float getSpatialDistanceForScoring() const {
372        return mDicNodeState.mDicNodeStateScoring.getSpatialDistance();
373    }
374
375    // For space-aware gestures, we store the normalized distance at the char index
376    // that ends the first word of the suggestion. We call this the distance after
377    // first word.
378    float getNormalizedCompoundDistanceAfterFirstWord() const {
379        return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistanceAfterFirstWord();
380    }
381
382    float getRawLength() const {
383        return mDicNodeState.mDicNodeStateScoring.getRawLength();
384    }
385
386    DoubleLetterLevel getDoubleLetterLevel() const {
387        return mDicNodeState.mDicNodeStateScoring.getDoubleLetterLevel();
388    }
389
390    void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) {
391        mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel);
392    }
393
394    bool isInDigraph() const {
395        return mDicNodeState.mDicNodeStateScoring.getDigraphIndex()
396                != DigraphUtils::NOT_A_DIGRAPH_INDEX;
397    }
398
399    void advanceDigraphIndex() {
400        mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex();
401    }
402
403    ErrorTypeUtils::ErrorType getContainedErrorTypes() const {
404        return mDicNodeState.mDicNodeStateScoring.getContainedErrorTypes();
405    }
406
407    bool isBlacklistedOrNotAWord() const {
408        return mDicNodeProperties.isBlacklistedOrNotAWord();
409    }
410
411    inline uint16_t getNodeCodePointCount() const {
412        return mDicNodeProperties.getDepth();
413    }
414
415    // Returns code point count including spaces
416    inline uint16_t getTotalNodeCodePointCount() const {
417        return getNodeCodePointCount() + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength();
418    }
419
420    AK_FORCE_INLINE void dump(const char *tag) const {
421#if DEBUG_DICT
422        DUMP_WORD_AND_SCORE(tag);
423#if DEBUG_DUMP_ERROR
424        mProfiler.dump();
425#endif
426#endif
427    }
428
429    AK_FORCE_INLINE bool compare(const DicNode *right) const {
430        // Promote exact matches to prevent them from being pruned.
431        const bool leftExactMatch = ErrorTypeUtils::isExactMatch(getContainedErrorTypes());
432        const bool rightExactMatch = ErrorTypeUtils::isExactMatch(right->getContainedErrorTypes());
433        if (leftExactMatch != rightExactMatch) {
434            return leftExactMatch;
435        }
436        const float diff =
437                right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance();
438        static const float MIN_DIFF = 0.000001f;
439        if (diff > MIN_DIFF) {
440            return true;
441        } else if (diff < -MIN_DIFF) {
442            return false;
443        }
444        const int depth = getNodeCodePointCount();
445        const int depthDiff = right->getNodeCodePointCount() - depth;
446        if (depthDiff != 0) {
447            return depthDiff > 0;
448        }
449        for (int i = 0; i < depth; ++i) {
450            const int codePoint = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i);
451            const int rightCodePoint =
452                    right->mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i);
453            if (codePoint != rightCodePoint) {
454                return rightCodePoint > codePoint;
455            }
456        }
457        // Compare pointer values here for stable comparison
458        return this > right;
459    }
460
461 private:
462    DicNodeProperties mDicNodeProperties;
463    DicNodeState mDicNodeState;
464    // TODO: Remove
465    bool mIsCachedForNextSuggestion;
466
467    AK_FORCE_INLINE int getTotalInputIndex() const {
468        int index = 0;
469        for (int i = 0; i < MAX_POINTER_COUNT_G; i++) {
470            index += mDicNodeState.mDicNodeStateInput.getInputIndex(i);
471        }
472        return index;
473    }
474
475    // Caveat: Must not be called outside Weighting
476    // This restriction is guaranteed by "friend"
477    AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost,
478            const bool doNormalization, const int inputSize,
479            const ErrorTypeUtils::ErrorType errorType) {
480        if (DEBUG_GEO_FULL) {
481            LOGI_SHOW_ADD_COST_PROP;
482        }
483        mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization,
484                inputSize, getTotalInputIndex(), errorType);
485    }
486
487    // Saves the current normalized compound distance for space-aware gestures.
488    // See getNormalizedCompoundDistanceAfterFirstWord for details.
489    AK_FORCE_INLINE void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() {
490        mDicNodeState.mDicNodeStateScoring.saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet();
491    }
492
493    // Caveat: Must not be called outside Weighting
494    // This restriction is guaranteed by "friend"
495    AK_FORCE_INLINE void forwardInputIndex(const int pointerId, const int count,
496            const bool overwritesPrevCodePointByNodeCodePoint) {
497        if (count == 0) {
498            return;
499        }
500        mDicNodeState.mDicNodeStateInput.forwardInputIndex(pointerId, count);
501        if (overwritesPrevCodePointByNodeCodePoint) {
502            mDicNodeState.mDicNodeStateInput.setPrevCodePoint(0, getNodeCodePoint());
503        }
504    }
505
506    AK_FORCE_INLINE void updateInputIndexG(const DicNode_InputStateG *const inputStateG) {
507        if (mDicNodeState.mDicNodeStateOutput.getPrevWordCount() == 1 && isFirstLetter()) {
508            mDicNodeState.mDicNodeStateOutput.setSecondWordFirstInputIndex(
509                    inputStateG->mInputIndex);
510        }
511        mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId,
512                inputStateG->mInputIndex, inputStateG->mPrevCodePoint,
513                inputStateG->mTerminalDiffCost, inputStateG->mRawLength);
514        mDicNodeState.mDicNodeStateScoring.addRawLength(inputStateG->mRawLength);
515        mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(inputStateG->mDoubleLetterLevel);
516    }
517};
518} // namespace latinime
519#endif // LATINIME_DIC_NODE_H
520