dic_node.h revision 67ff21f3217c9f2ff81beac6f29f5c35a83da228
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LATINIME_DIC_NODE_H
18#define LATINIME_DIC_NODE_H
19
20#include "defines.h"
21#include "suggest/core/dicnode/dic_node_profiler.h"
22#include "suggest/core/dicnode/dic_node_utils.h"
23#include "suggest/core/dicnode/internal/dic_node_state.h"
24#include "suggest/core/dicnode/internal/dic_node_properties.h"
25#include "suggest/core/dictionary/digraph_utils.h"
26#include "suggest/core/dictionary/error_type_utils.h"
27#include "suggest/core/layout/proximity_info_state.h"
28#include "utils/char_utils.h"
29
30#if DEBUG_DICT
31#define LOGI_SHOW_ADD_COST_PROP \
32        do { \
33            char charBuf[50]; \
34            INTS_TO_CHARS(getOutputWordBuf(), getNodeCodePointCount(), charBuf, NELEMS(charBuf)); \
35            AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \
36                    __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \
37                    getInputIndex(0), getNormalizedCompoundDistance(), charBuf); \
38        } while (0)
39#define DUMP_WORD_AND_SCORE(header) \
40        do { \
41            char charBuf[50]; \
42            INTS_TO_CHARS(getOutputWordBuf(), \
43                    getNodeCodePointCount() \
44                            + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength(), \
45                    charBuf, NELEMS(charBuf)); \
46            AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %d, %5f,", header, \
47                    getSpatialDistanceForScoring(), getLanguageDistanceForScoring(), \
48                    getNormalizedCompoundDistance(), getRawLength(), charBuf, \
49                    getInputIndex(0), getNormalizedCompoundDistanceAfterFirstWord()); \
50        } while (0)
51#else
52#define LOGI_SHOW_ADD_COST_PROP
53#define DUMP_WORD_AND_SCORE(header)
54#endif
55
56namespace latinime {
57
58// This struct is purely a bucket to return values. No instances of this struct should be kept.
59struct DicNode_InputStateG {
60    DicNode_InputStateG()
61            : mNeedsToUpdateInputStateG(false), mPointerId(0), mInputIndex(0),
62              mPrevCodePoint(0), mTerminalDiffCost(0.0f), mRawLength(0.0f),
63              mDoubleLetterLevel(NOT_A_DOUBLE_LETTER) {}
64
65    bool mNeedsToUpdateInputStateG;
66    int mPointerId;
67    int16_t mInputIndex;
68    int mPrevCodePoint;
69    float mTerminalDiffCost;
70    float mRawLength;
71    DoubleLetterLevel mDoubleLetterLevel;
72};
73
74class DicNode {
75    // Caveat: We define Weighting as a friend class of DicNode to let Weighting change
76    // the distance of DicNode.
77    // Caution!!! In general, we avoid using the "friend" access modifier.
78    // This is an exception to explicitly hide DicNode::addCost() from all classes but Weighting.
79    friend class Weighting;
80
81 public:
82#if DEBUG_DICT
83    DicNodeProfiler mProfiler;
84#endif
85
86    AK_FORCE_INLINE DicNode()
87            :
88#if DEBUG_DICT
89              mProfiler(),
90#endif
91              mDicNodeProperties(), mDicNodeState(), mIsCachedForNextSuggestion(false) {}
92
93    DicNode(const DicNode &dicNode);
94    DicNode &operator=(const DicNode &dicNode);
95    ~DicNode() {}
96
97    // Init for copy
98    void initByCopy(const DicNode *const dicNode) {
99        mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
100        mDicNodeProperties.initByCopy(&dicNode->mDicNodeProperties);
101        mDicNodeState.initByCopy(&dicNode->mDicNodeState);
102        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
103    }
104
105    // Init for root with prevWordPtNodePos which is used for bigram
106    void initAsRoot(const int rootPtNodeArrayPos, const int prevWordPtNodePos) {
107        mIsCachedForNextSuggestion = false;
108        mDicNodeProperties.init(rootPtNodeArrayPos, prevWordPtNodePos);
109        mDicNodeState.init();
110        PROF_NODE_RESET(mProfiler);
111    }
112
113    // Init for root with previous word
114    void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) {
115        mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
116        mDicNodeProperties.init(rootPtNodeArrayPos, dicNode->mDicNodeProperties.getPtNodePos());
117        mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState,
118                dicNode->mDicNodeProperties.getDepth());
119        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
120    }
121
122    void initAsPassingChild(DicNode *parentDicNode) {
123        mIsCachedForNextSuggestion = parentDicNode->mIsCachedForNextSuggestion;
124        const int parentCodePoint = parentDicNode->getNodeTypedCodePoint();
125        mDicNodeProperties.init(&parentDicNode->mDicNodeProperties, parentCodePoint);
126        mDicNodeState.initByCopy(&parentDicNode->mDicNodeState);
127        PROF_NODE_COPY(&parentDicNode->mProfiler, mProfiler);
128    }
129
130    void initAsChild(const DicNode *const dicNode, const int ptNodePos,
131            const int childrenPtNodeArrayPos, const int probability, const bool isTerminal,
132            const bool hasChildren, const bool isBlacklistedOrNotAWord,
133            const uint16_t mergedNodeCodePointCount, const int *const mergedNodeCodePoints) {
134        uint16_t newDepth = static_cast<uint16_t>(dicNode->getNodeCodePointCount() + 1);
135        mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
136        const uint16_t newLeavingDepth = static_cast<uint16_t>(
137                dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount);
138        mDicNodeProperties.init(ptNodePos, childrenPtNodeArrayPos, mergedNodeCodePoints[0],
139                probability, isTerminal, hasChildren, isBlacklistedOrNotAWord, newDepth,
140                newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordTerminalPtNodePos());
141        mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount,
142                mergedNodeCodePoints);
143        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
144    }
145
146    bool isRoot() const {
147        return getNodeCodePointCount() == 0;
148    }
149
150    bool hasChildren() const {
151        return mDicNodeProperties.hasChildren();
152    }
153
154    bool isLeavingNode() const {
155        ASSERT(getNodeCodePointCount() <= mDicNodeProperties.getLeavingDepth());
156        return getNodeCodePointCount() == mDicNodeProperties.getLeavingDepth();
157    }
158
159    AK_FORCE_INLINE bool isFirstLetter() const {
160        return getNodeCodePointCount() == 1;
161    }
162
163    bool isCached() const {
164        return mIsCachedForNextSuggestion;
165    }
166
167    void setCached() {
168        mIsCachedForNextSuggestion = true;
169    }
170
171    // Used to expand the node in DicNodeUtils
172    int getNodeTypedCodePoint() const {
173        return mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(getNodeCodePointCount());
174    }
175
176    // Check if the current word and the previous word can be considered as a valid multiple word
177    // suggestion.
178    bool isValidMultipleWordSuggestion() const {
179        if (isBlacklistedOrNotAWord()) {
180            return false;
181        }
182        // Treat suggestion as invalid if the current and the previous word are single character
183        // words.
184        const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength()
185                - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1;
186        const int currentWordLen = getNodeCodePointCount();
187        return (prevWordLen != 1 || currentWordLen != 1);
188    }
189
190    bool isFirstCharUppercase() const {
191        const int c = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(0);
192        return CharUtils::isAsciiUpper(c);
193    }
194
195    bool isCompletion(const int inputSize) const {
196        return mDicNodeState.mDicNodeStateInput.getInputIndex(0) >= inputSize;
197    }
198
199    bool canDoLookAheadCorrection(const int inputSize) const {
200        return mDicNodeState.mDicNodeStateInput.getInputIndex(0) < inputSize - 1;
201    }
202
203    // Used to get bigram probability in DicNodeUtils
204    int getPtNodePos() const {
205        return mDicNodeProperties.getPtNodePos();
206    }
207
208    // Used to get bigram probability in DicNodeUtils
209    int getPrevWordTerminalPtNodePos() const {
210        return mDicNodeProperties.getPrevWordTerminalPtNodePos();
211    }
212
213    // Used in DicNodeUtils
214    int getChildrenPtNodeArrayPos() const {
215        return mDicNodeProperties.getChildrenPtNodeArrayPos();
216    }
217
218    int getProbability() const {
219        return mDicNodeProperties.getProbability();
220    }
221
222    AK_FORCE_INLINE bool isTerminalDicNode() const {
223        const bool isTerminalPtNode = mDicNodeProperties.isTerminal();
224        const int currentDicNodeDepth = getNodeCodePointCount();
225        const int terminalDicNodeDepth = mDicNodeProperties.getLeavingDepth();
226        return isTerminalPtNode && currentDicNodeDepth > 0
227                && currentDicNodeDepth == terminalDicNodeDepth;
228    }
229
230    bool shouldBeFilteredBySafetyNetForBigram() const {
231        const uint16_t currentDepth = getNodeCodePointCount();
232        const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength()
233                - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1;
234        return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1));
235    }
236
237    bool hasMatchedOrProximityCodePoints() const {
238        // This DicNode does not have matched or proximity code points when all code points have
239        // been handled as edit corrections or completion so far.
240        const int editCorrectionCount = mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount();
241        const int completionCount = mDicNodeState.mDicNodeStateScoring.getCompletionCount();
242        return (editCorrectionCount + completionCount) < getNodeCodePointCount();
243    }
244
245    bool isTotalInputSizeExceedingLimit() const {
246        const int prevWordsLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength();
247        const int currentWordDepth = getNodeCodePointCount();
248        // TODO: 3 can be 2? Needs to be investigated.
249        // TODO: Have a const variable for 3 (or 2)
250        return prevWordsLen + currentWordDepth > MAX_WORD_LENGTH - 3;
251    }
252
253    void outputResult(int *dest) const {
254        const uint16_t prevWordLength = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength();
255        const uint16_t currentDepth = getNodeCodePointCount();
256        memmove(dest, getOutputWordBuf(), (prevWordLength + currentDepth) * sizeof(dest[0]));
257        DUMP_WORD_AND_SCORE("OUTPUT");
258    }
259
260    // "Total" in this context (and other methods in this class) means the whole suggestion. When
261    // this represents a multi-word suggestion, the referenced PtNode (in mDicNodeState) is only
262    // the one that corresponds to the last word of the suggestion, and all the previous words
263    // are concatenated together in mDicNodeStateOutput.
264    int getTotalNodeSpaceCount() const {
265        if (!hasMultipleWords()) {
266            return 0;
267        }
268        return CharUtils::getSpaceCount(mDicNodeState.mDicNodeStateOutput.getCodePointBuf(),
269                mDicNodeState.mDicNodeStateOutput.getPrevWordsLength());
270    }
271
272    int getSecondWordFirstInputIndex(const ProximityInfoState *const pInfoState) const {
273        const int inputIndex = mDicNodeState.mDicNodeStateOutput.getSecondWordFirstInputIndex();
274        if (inputIndex == NOT_AN_INDEX) {
275            return NOT_AN_INDEX;
276        } else {
277            return pInfoState->getInputIndexOfSampledPoint(inputIndex);
278        }
279    }
280
281    bool hasMultipleWords() const {
282        return mDicNodeState.mDicNodeStateOutput.getPrevWordCount() > 0;
283    }
284
285    int getProximityCorrectionCount() const {
286        return mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount();
287    }
288
289    int getEditCorrectionCount() const {
290        return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount();
291    }
292
293    // Used to prune nodes
294    float getNormalizedCompoundDistance() const {
295        return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistance();
296    }
297
298    // Used to prune nodes
299    float getNormalizedSpatialDistance() const {
300        return mDicNodeState.mDicNodeStateScoring.getSpatialDistance()
301                / static_cast<float>(getInputIndex(0) + 1);
302    }
303
304    // Used to prune nodes
305    float getCompoundDistance() const {
306        return mDicNodeState.mDicNodeStateScoring.getCompoundDistance();
307    }
308
309    // Used to prune nodes
310    float getCompoundDistance(const float languageWeight) const {
311        return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight);
312    }
313
314    // Used to commit input partially
315    int getPrevWordPtNodePos() const {
316        return mDicNodeProperties.getPrevWordTerminalPtNodePos();
317    }
318
319    AK_FORCE_INLINE const int *getOutputWordBuf() const {
320        return mDicNodeState.mDicNodeStateOutput.getCodePointBuf();
321    }
322
323    int getPrevCodePointG(int pointerId) const {
324        return mDicNodeState.mDicNodeStateInput.getPrevCodePoint(pointerId);
325    }
326
327    // Whether the current codepoint can be an intentional omission, in which case the traversal
328    // algorithm will always check for a possible omission here.
329    bool canBeIntentionalOmission() const {
330        return CharUtils::isIntentionalOmissionCodePoint(getNodeCodePoint());
331    }
332
333    // Whether the omission is so frequent that it should incur zero cost.
334    bool isZeroCostOmission() const {
335        // TODO: do not hardcode and read from header
336        return (getNodeCodePoint() == KEYCODE_SINGLE_QUOTE);
337    }
338
339    // TODO: remove
340    float getTerminalDiffCostG(int path) const {
341        return mDicNodeState.mDicNodeStateInput.getTerminalDiffCost(path);
342    }
343
344    //////////////////////
345    // Temporary getter //
346    // TODO: Remove     //
347    //////////////////////
348    // TODO: Remove once touch path is merged into ProximityInfoState
349    // Note: Returned codepoint may be a digraph codepoint if the node is in a composite glyph.
350    int getNodeCodePoint() const {
351        const int codePoint = mDicNodeProperties.getDicNodeCodePoint();
352        const DigraphUtils::DigraphCodePointIndex digraphIndex =
353                mDicNodeState.mDicNodeStateScoring.getDigraphIndex();
354        if (digraphIndex == DigraphUtils::NOT_A_DIGRAPH_INDEX) {
355            return codePoint;
356        }
357        return DigraphUtils::getDigraphCodePointForIndex(codePoint, digraphIndex);
358    }
359
360    ////////////////////////////////
361    // Utils for cost calculation //
362    ////////////////////////////////
363    AK_FORCE_INLINE bool isSameNodeCodePoint(const DicNode *const dicNode) const {
364        return mDicNodeProperties.getDicNodeCodePoint()
365                == dicNode->mDicNodeProperties.getDicNodeCodePoint();
366    }
367
368    // TODO: remove
369    // TODO: rename getNextInputIndex
370    int16_t getInputIndex(int pointerId) const {
371        return mDicNodeState.mDicNodeStateInput.getInputIndex(pointerId);
372    }
373
374    ////////////////////////////////////
375    // Getter of features for scoring //
376    ////////////////////////////////////
377    float getSpatialDistanceForScoring() const {
378        return mDicNodeState.mDicNodeStateScoring.getSpatialDistance();
379    }
380
381    float getLanguageDistanceForScoring() const {
382        return mDicNodeState.mDicNodeStateScoring.getLanguageDistance();
383    }
384
385    // For space-aware gestures, we store the normalized distance at the char index
386    // that ends the first word of the suggestion. We call this the distance after
387    // first word.
388    float getNormalizedCompoundDistanceAfterFirstWord() const {
389        return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistanceAfterFirstWord();
390    }
391
392    float getLanguageDistanceRatePerWordForScoring() const {
393        const float langDist = getLanguageDistanceForScoring();
394        const float totalWordCount =
395                static_cast<float>(mDicNodeState.mDicNodeStateOutput.getPrevWordCount() + 1);
396        return langDist / totalWordCount;
397    }
398
399    float getRawLength() const {
400        return mDicNodeState.mDicNodeStateScoring.getRawLength();
401    }
402
403    bool isLessThanOneErrorForScoring() const {
404        return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount()
405                + mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount() <= 1;
406    }
407
408    DoubleLetterLevel getDoubleLetterLevel() const {
409        return mDicNodeState.mDicNodeStateScoring.getDoubleLetterLevel();
410    }
411
412    void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) {
413        mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel);
414    }
415
416    bool isInDigraph() const {
417        return mDicNodeState.mDicNodeStateScoring.getDigraphIndex()
418                != DigraphUtils::NOT_A_DIGRAPH_INDEX;
419    }
420
421    void advanceDigraphIndex() {
422        mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex();
423    }
424
425    ErrorTypeUtils::ErrorType getContainedErrorTypes() const {
426        return mDicNodeState.mDicNodeStateScoring.getContainedErrorTypes();
427    }
428
429    bool isBlacklistedOrNotAWord() const {
430        return mDicNodeProperties.isBlacklistedOrNotAWord();
431    }
432
433    inline uint16_t getNodeCodePointCount() const {
434        return mDicNodeProperties.getDepth();
435    }
436
437    // Returns code point count including spaces
438    inline uint16_t getTotalNodeCodePointCount() const {
439        return getNodeCodePointCount() + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength();
440    }
441
442    AK_FORCE_INLINE void dump(const char *tag) const {
443#if DEBUG_DICT
444        DUMP_WORD_AND_SCORE(tag);
445#if DEBUG_DUMP_ERROR
446        mProfiler.dump();
447#endif
448#endif
449    }
450
451    AK_FORCE_INLINE bool compare(const DicNode *right) const {
452        // Promote exact matches to prevent them from being pruned.
453        const bool leftExactMatch = ErrorTypeUtils::isExactMatch(getContainedErrorTypes());
454        const bool rightExactMatch = ErrorTypeUtils::isExactMatch(right->getContainedErrorTypes());
455        if (leftExactMatch != rightExactMatch) {
456            return leftExactMatch;
457        }
458        const float diff =
459                right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance();
460        static const float MIN_DIFF = 0.000001f;
461        if (diff > MIN_DIFF) {
462            return true;
463        } else if (diff < -MIN_DIFF) {
464            return false;
465        }
466        const int depth = getNodeCodePointCount();
467        const int depthDiff = right->getNodeCodePointCount() - depth;
468        if (depthDiff != 0) {
469            return depthDiff > 0;
470        }
471        for (int i = 0; i < depth; ++i) {
472            const int codePoint = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i);
473            const int rightCodePoint =
474                    right->mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i);
475            if (codePoint != rightCodePoint) {
476                return rightCodePoint > codePoint;
477            }
478        }
479        // Compare pointer values here for stable comparison
480        return this > right;
481    }
482
483 private:
484    DicNodeProperties mDicNodeProperties;
485    DicNodeState mDicNodeState;
486    // TODO: Remove
487    bool mIsCachedForNextSuggestion;
488
489    AK_FORCE_INLINE int getTotalInputIndex() const {
490        int index = 0;
491        for (int i = 0; i < MAX_POINTER_COUNT_G; i++) {
492            index += mDicNodeState.mDicNodeStateInput.getInputIndex(i);
493        }
494        return index;
495    }
496
497    // Caveat: Must not be called outside Weighting
498    // This restriction is guaranteed by "friend"
499    AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost,
500            const bool doNormalization, const int inputSize,
501            const ErrorTypeUtils::ErrorType errorType) {
502        if (DEBUG_GEO_FULL) {
503            LOGI_SHOW_ADD_COST_PROP;
504        }
505        mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization,
506                inputSize, getTotalInputIndex(), errorType);
507    }
508
509    // Saves the current normalized compound distance for space-aware gestures.
510    // See getNormalizedCompoundDistanceAfterFirstWord for details.
511    AK_FORCE_INLINE void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() {
512        mDicNodeState.mDicNodeStateScoring.saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet();
513    }
514
515    // Caveat: Must not be called outside Weighting
516    // This restriction is guaranteed by "friend"
517    AK_FORCE_INLINE void forwardInputIndex(const int pointerId, const int count,
518            const bool overwritesPrevCodePointByNodeCodePoint) {
519        if (count == 0) {
520            return;
521        }
522        mDicNodeState.mDicNodeStateInput.forwardInputIndex(pointerId, count);
523        if (overwritesPrevCodePointByNodeCodePoint) {
524            mDicNodeState.mDicNodeStateInput.setPrevCodePoint(0, getNodeCodePoint());
525        }
526    }
527
528    AK_FORCE_INLINE void updateInputIndexG(const DicNode_InputStateG *const inputStateG) {
529        if (mDicNodeState.mDicNodeStateOutput.getPrevWordCount() == 1 && isFirstLetter()) {
530            mDicNodeState.mDicNodeStateOutput.setSecondWordFirstInputIndex(
531                    inputStateG->mInputIndex);
532        }
533        mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId,
534                inputStateG->mInputIndex, inputStateG->mPrevCodePoint,
535                inputStateG->mTerminalDiffCost, inputStateG->mRawLength);
536        mDicNodeState.mDicNodeStateScoring.addRawLength(inputStateG->mRawLength);
537        mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(inputStateG->mDoubleLetterLevel);
538    }
539};
540} // namespace latinime
541#endif // LATINIME_DIC_NODE_H
542