dic_node.h revision 6da9b21191dc7d6049d96945366ec7e605e716e6
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LATINIME_DIC_NODE_H
18#define LATINIME_DIC_NODE_H
19
20#include "defines.h"
21#include "suggest/core/dicnode/dic_node_profiler.h"
22#include "suggest/core/dicnode/dic_node_utils.h"
23#include "suggest/core/dicnode/internal/dic_node_state.h"
24#include "suggest/core/dicnode/internal/dic_node_properties.h"
25#include "suggest/core/dictionary/digraph_utils.h"
26#include "suggest/core/dictionary/error_type_utils.h"
27#include "suggest/core/layout/proximity_info_state.h"
28#include "utils/char_utils.h"
29#include "utils/int_array_view.h"
30
31#if DEBUG_DICT
32#define LOGI_SHOW_ADD_COST_PROP \
33        do { \
34            char charBuf[50]; \
35            INTS_TO_CHARS(getOutputWordBuf(), getNodeCodePointCount(), charBuf, NELEMS(charBuf)); \
36            AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \
37                    __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \
38                    getInputIndex(0), getNormalizedCompoundDistance(), charBuf); \
39        } while (0)
40#define DUMP_WORD_AND_SCORE(header) \
41        do { \
42            char charBuf[50]; \
43            INTS_TO_CHARS(getOutputWordBuf(), \
44                    getNodeCodePointCount() \
45                            + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength(), \
46                    charBuf, NELEMS(charBuf)); \
47            AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %d, %5f,", header, \
48                    getSpatialDistanceForScoring(), \
49                    mDicNodeState.mDicNodeStateScoring.getLanguageDistance(), \
50                    getNormalizedCompoundDistance(), getRawLength(), charBuf, \
51                    getInputIndex(0), getNormalizedCompoundDistanceAfterFirstWord()); \
52        } while (0)
53#else
54#define LOGI_SHOW_ADD_COST_PROP
55#define DUMP_WORD_AND_SCORE(header)
56#endif
57
58namespace latinime {
59
60// This struct is purely a bucket to return values. No instances of this struct should be kept.
61struct DicNode_InputStateG {
62    DicNode_InputStateG()
63            : mNeedsToUpdateInputStateG(false), mPointerId(0), mInputIndex(0),
64              mPrevCodePoint(0), mTerminalDiffCost(0.0f), mRawLength(0.0f),
65              mDoubleLetterLevel(NOT_A_DOUBLE_LETTER) {}
66
67    bool mNeedsToUpdateInputStateG;
68    int mPointerId;
69    int16_t mInputIndex;
70    int mPrevCodePoint;
71    float mTerminalDiffCost;
72    float mRawLength;
73    DoubleLetterLevel mDoubleLetterLevel;
74};
75
76class DicNode {
77    // Caveat: We define Weighting as a friend class of DicNode to let Weighting change
78    // the distance of DicNode.
79    // Caution!!! In general, we avoid using the "friend" access modifier.
80    // This is an exception to explicitly hide DicNode::addCost() from all classes but Weighting.
81    friend class Weighting;
82
83 public:
84#if DEBUG_DICT
85    DicNodeProfiler mProfiler;
86#endif
87
88    AK_FORCE_INLINE DicNode()
89            :
90#if DEBUG_DICT
91              mProfiler(),
92#endif
93              mDicNodeProperties(), mDicNodeState(), mIsCachedForNextSuggestion(false) {}
94
95    DicNode(const DicNode &dicNode);
96    DicNode &operator=(const DicNode &dicNode);
97    ~DicNode() {}
98
99    // Init for copy
100    void initByCopy(const DicNode *const dicNode) {
101        mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
102        mDicNodeProperties.initByCopy(&dicNode->mDicNodeProperties);
103        mDicNodeState.initByCopy(&dicNode->mDicNodeState);
104        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
105    }
106
107    // Init for root with prevWordIds which is used for n-gram
108    void initAsRoot(const int rootPtNodeArrayPos, const WordIdArrayView prevWordIds) {
109        mIsCachedForNextSuggestion = false;
110        mDicNodeProperties.init(rootPtNodeArrayPos, prevWordIds);
111        mDicNodeState.init();
112        PROF_NODE_RESET(mProfiler);
113    }
114
115    // Init for root with previous word
116    void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) {
117        mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
118        WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> newPrevWordIds;
119        newPrevWordIds[0] = dicNode->mDicNodeProperties.getWordId();
120        dicNode->getPrevWordIds().limit(newPrevWordIds.size() - 1)
121                .copyToArray(&newPrevWordIds, 1 /* offset */);
122        mDicNodeProperties.init(rootPtNodeArrayPos, WordIdArrayView::fromArray(newPrevWordIds));
123        mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState,
124                dicNode->mDicNodeProperties.getDepth());
125        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
126    }
127
128    void initAsPassingChild(const DicNode *parentDicNode) {
129        mIsCachedForNextSuggestion = parentDicNode->mIsCachedForNextSuggestion;
130        const int codePoint =
131                parentDicNode->mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(
132                            parentDicNode->getNodeCodePointCount());
133        mDicNodeProperties.init(&parentDicNode->mDicNodeProperties, codePoint);
134        mDicNodeState.initByCopy(&parentDicNode->mDicNodeState);
135        PROF_NODE_COPY(&parentDicNode->mProfiler, mProfiler);
136    }
137
138    void initAsChild(const DicNode *const dicNode, const int childrenPtNodeArrayPos,
139            const int wordId, const CodePointArrayView mergedCodePoints) {
140        uint16_t newDepth = static_cast<uint16_t>(dicNode->getNodeCodePointCount() + 1);
141        mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
142        const uint16_t newLeavingDepth = static_cast<uint16_t>(
143                dicNode->mDicNodeProperties.getLeavingDepth() + mergedCodePoints.size());
144        mDicNodeProperties.init(childrenPtNodeArrayPos, mergedCodePoints[0],
145                wordId, newDepth, newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordIds());
146        mDicNodeState.init(&dicNode->mDicNodeState, mergedCodePoints.size(),
147                mergedCodePoints.data());
148        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
149    }
150
151    bool isRoot() const {
152        return getNodeCodePointCount() == 0;
153    }
154
155    bool hasChildren() const {
156        return mDicNodeProperties.hasChildren();
157    }
158
159    bool isLeavingNode() const {
160        ASSERT(getNodeCodePointCount() <= mDicNodeProperties.getLeavingDepth());
161        return getNodeCodePointCount() == mDicNodeProperties.getLeavingDepth();
162    }
163
164    AK_FORCE_INLINE bool isFirstLetter() const {
165        return getNodeCodePointCount() == 1;
166    }
167
168    bool isCached() const {
169        return mIsCachedForNextSuggestion;
170    }
171
172    void setCached() {
173        mIsCachedForNextSuggestion = true;
174    }
175
176    // Check if the current word and the previous word can be considered as a valid multiple word
177    // suggestion.
178    bool isValidMultipleWordSuggestion() const {
179        // Treat suggestion as invalid if the current and the previous word are single character
180        // words.
181        const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength()
182                - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1;
183        const int currentWordLen = getNodeCodePointCount();
184        return (prevWordLen != 1 || currentWordLen != 1);
185    }
186
187    bool isFirstCharUppercase() const {
188        const int c = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(0);
189        return CharUtils::isAsciiUpper(c);
190    }
191
192    bool isCompletion(const int inputSize) const {
193        return mDicNodeState.mDicNodeStateInput.getInputIndex(0) >= inputSize;
194    }
195
196    bool canDoLookAheadCorrection(const int inputSize) const {
197        return mDicNodeState.mDicNodeStateInput.getInputIndex(0) < inputSize - 1;
198    }
199
200    // Used to get n-gram probability in DicNodeUtils.
201    int getWordId() const {
202        return mDicNodeProperties.getWordId();
203    }
204
205    const WordIdArrayView getPrevWordIds() const {
206        return mDicNodeProperties.getPrevWordIds();
207    }
208
209    // Used in DicNodeUtils
210    int getChildrenPtNodeArrayPos() const {
211        return mDicNodeProperties.getChildrenPtNodeArrayPos();
212    }
213
214    AK_FORCE_INLINE bool isTerminalDicNode() const {
215        const bool isTerminalPtNode = mDicNodeProperties.isTerminal();
216        const int currentDicNodeDepth = getNodeCodePointCount();
217        const int terminalDicNodeDepth = mDicNodeProperties.getLeavingDepth();
218        return isTerminalPtNode && currentDicNodeDepth > 0
219                && currentDicNodeDepth == terminalDicNodeDepth;
220    }
221
222    bool shouldBeFilteredBySafetyNetForBigram() const {
223        const uint16_t currentDepth = getNodeCodePointCount();
224        const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength()
225                - mDicNodeState.mDicNodeStateOutput.getPrevWordStart() - 1;
226        return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1));
227    }
228
229    bool hasMatchedOrProximityCodePoints() const {
230        // This DicNode does not have matched or proximity code points when all code points have
231        // been handled as edit corrections or completion so far.
232        const int editCorrectionCount = mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount();
233        const int completionCount = mDicNodeState.mDicNodeStateScoring.getCompletionCount();
234        return (editCorrectionCount + completionCount) < getNodeCodePointCount();
235    }
236
237    bool isTotalInputSizeExceedingLimit() const {
238        // TODO: 3 can be 2? Needs to be investigated.
239        // TODO: Have a const variable for 3 (or 2)
240        return getTotalNodeCodePointCount() > MAX_WORD_LENGTH - 3;
241    }
242
243    void outputResult(int *dest) const {
244        memmove(dest, getOutputWordBuf(), getTotalNodeCodePointCount() * sizeof(dest[0]));
245        DUMP_WORD_AND_SCORE("OUTPUT");
246    }
247
248    // "Total" in this context (and other methods in this class) means the whole suggestion. When
249    // this represents a multi-word suggestion, the referenced PtNode (in mDicNodeState) is only
250    // the one that corresponds to the last word of the suggestion, and all the previous words
251    // are concatenated together in mDicNodeStateOutput.
252    int getTotalNodeSpaceCount() const {
253        if (!hasMultipleWords()) {
254            return 0;
255        }
256        return CharUtils::getSpaceCount(mDicNodeState.mDicNodeStateOutput.getCodePointBuf(),
257                mDicNodeState.mDicNodeStateOutput.getPrevWordsLength());
258    }
259
260    int getSecondWordFirstInputIndex(const ProximityInfoState *const pInfoState) const {
261        const int inputIndex = mDicNodeState.mDicNodeStateOutput.getSecondWordFirstInputIndex();
262        if (inputIndex == NOT_AN_INDEX) {
263            return NOT_AN_INDEX;
264        } else {
265            return pInfoState->getInputIndexOfSampledPoint(inputIndex);
266        }
267    }
268
269    bool hasMultipleWords() const {
270        return mDicNodeState.mDicNodeStateOutput.getPrevWordCount() > 0;
271    }
272
273    int getProximityCorrectionCount() const {
274        return mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount();
275    }
276
277    int getEditCorrectionCount() const {
278        return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount();
279    }
280
281    // Used to prune nodes
282    float getNormalizedCompoundDistance() const {
283        return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistance();
284    }
285
286    // Used to prune nodes
287    float getNormalizedSpatialDistance() const {
288        return mDicNodeState.mDicNodeStateScoring.getSpatialDistance()
289                / static_cast<float>(getInputIndex(0) + 1);
290    }
291
292    // Used to prune nodes
293    float getCompoundDistance() const {
294        return mDicNodeState.mDicNodeStateScoring.getCompoundDistance();
295    }
296
297    // Used to prune nodes
298    float getCompoundDistance(const float weightOfLangModelVsSpatialModel) const {
299        return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(
300                weightOfLangModelVsSpatialModel);
301    }
302
303    AK_FORCE_INLINE const int *getOutputWordBuf() const {
304        return mDicNodeState.mDicNodeStateOutput.getCodePointBuf();
305    }
306
307    int getPrevCodePointG(int pointerId) const {
308        return mDicNodeState.mDicNodeStateInput.getPrevCodePoint(pointerId);
309    }
310
311    // Whether the current codepoint can be an intentional omission, in which case the traversal
312    // algorithm will always check for a possible omission here.
313    bool canBeIntentionalOmission() const {
314        return CharUtils::isIntentionalOmissionCodePoint(getNodeCodePoint());
315    }
316
317    // Whether the omission is so frequent that it should incur zero cost.
318    bool isZeroCostOmission() const {
319        // TODO: do not hardcode and read from header
320        return (getNodeCodePoint() == KEYCODE_SINGLE_QUOTE);
321    }
322
323    // TODO: remove
324    float getTerminalDiffCostG(int path) const {
325        return mDicNodeState.mDicNodeStateInput.getTerminalDiffCost(path);
326    }
327
328    //////////////////////
329    // Temporary getter //
330    // TODO: Remove     //
331    //////////////////////
332    // TODO: Remove once touch path is merged into ProximityInfoState
333    // Note: Returned codepoint may be a digraph codepoint if the node is in a composite glyph.
334    int getNodeCodePoint() const {
335        const int codePoint = mDicNodeProperties.getDicNodeCodePoint();
336        const DigraphUtils::DigraphCodePointIndex digraphIndex =
337                mDicNodeState.mDicNodeStateScoring.getDigraphIndex();
338        if (digraphIndex == DigraphUtils::NOT_A_DIGRAPH_INDEX) {
339            return codePoint;
340        }
341        return DigraphUtils::getDigraphCodePointForIndex(codePoint, digraphIndex);
342    }
343
344    ////////////////////////////////
345    // Utils for cost calculation //
346    ////////////////////////////////
347    AK_FORCE_INLINE bool isSameNodeCodePoint(const DicNode *const dicNode) const {
348        return mDicNodeProperties.getDicNodeCodePoint()
349                == dicNode->mDicNodeProperties.getDicNodeCodePoint();
350    }
351
352    // TODO: remove
353    // TODO: rename getNextInputIndex
354    int16_t getInputIndex(int pointerId) const {
355        return mDicNodeState.mDicNodeStateInput.getInputIndex(pointerId);
356    }
357
358    ////////////////////////////////////
359    // Getter of features for scoring //
360    ////////////////////////////////////
361    float getSpatialDistanceForScoring() const {
362        return mDicNodeState.mDicNodeStateScoring.getSpatialDistance();
363    }
364
365    // For space-aware gestures, we store the normalized distance at the char index
366    // that ends the first word of the suggestion. We call this the distance after
367    // first word.
368    float getNormalizedCompoundDistanceAfterFirstWord() const {
369        return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistanceAfterFirstWord();
370    }
371
372    float getRawLength() const {
373        return mDicNodeState.mDicNodeStateScoring.getRawLength();
374    }
375
376    DoubleLetterLevel getDoubleLetterLevel() const {
377        return mDicNodeState.mDicNodeStateScoring.getDoubleLetterLevel();
378    }
379
380    void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) {
381        mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel);
382    }
383
384    bool isInDigraph() const {
385        return mDicNodeState.mDicNodeStateScoring.getDigraphIndex()
386                != DigraphUtils::NOT_A_DIGRAPH_INDEX;
387    }
388
389    void advanceDigraphIndex() {
390        mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex();
391    }
392
393    ErrorTypeUtils::ErrorType getContainedErrorTypes() const {
394        return mDicNodeState.mDicNodeStateScoring.getContainedErrorTypes();
395    }
396
397    inline uint16_t getNodeCodePointCount() const {
398        return mDicNodeProperties.getDepth();
399    }
400
401    // Returns code point count including spaces
402    inline uint16_t getTotalNodeCodePointCount() const {
403        return getNodeCodePointCount() + mDicNodeState.mDicNodeStateOutput.getPrevWordsLength();
404    }
405
406    AK_FORCE_INLINE void dump(const char *tag) const {
407#if DEBUG_DICT
408        DUMP_WORD_AND_SCORE(tag);
409#if DEBUG_DUMP_ERROR
410        mProfiler.dump();
411#endif
412#endif
413    }
414
415    AK_FORCE_INLINE bool compare(const DicNode *right) const {
416        // Promote exact matches to prevent them from being pruned.
417        const bool leftExactMatch = ErrorTypeUtils::isExactMatch(getContainedErrorTypes());
418        const bool rightExactMatch = ErrorTypeUtils::isExactMatch(right->getContainedErrorTypes());
419        if (leftExactMatch != rightExactMatch) {
420            return leftExactMatch;
421        }
422        const float diff =
423                right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance();
424        static const float MIN_DIFF = 0.000001f;
425        if (diff > MIN_DIFF) {
426            return true;
427        } else if (diff < -MIN_DIFF) {
428            return false;
429        }
430        const int depth = getNodeCodePointCount();
431        const int depthDiff = right->getNodeCodePointCount() - depth;
432        if (depthDiff != 0) {
433            return depthDiff > 0;
434        }
435        for (int i = 0; i < depth; ++i) {
436            const int codePoint = mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i);
437            const int rightCodePoint =
438                    right->mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(i);
439            if (codePoint != rightCodePoint) {
440                return rightCodePoint > codePoint;
441            }
442        }
443        // Compare pointer values here for stable comparison
444        return this > right;
445    }
446
447 private:
448    DicNodeProperties mDicNodeProperties;
449    DicNodeState mDicNodeState;
450    // TODO: Remove
451    bool mIsCachedForNextSuggestion;
452
453    AK_FORCE_INLINE int getTotalInputIndex() const {
454        int index = 0;
455        for (int i = 0; i < MAX_POINTER_COUNT_G; i++) {
456            index += mDicNodeState.mDicNodeStateInput.getInputIndex(i);
457        }
458        return index;
459    }
460
461    // Caveat: Must not be called outside Weighting
462    // This restriction is guaranteed by "friend"
463    AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost,
464            const bool doNormalization, const int inputSize,
465            const ErrorTypeUtils::ErrorType errorType) {
466        if (DEBUG_GEO_FULL) {
467            LOGI_SHOW_ADD_COST_PROP;
468        }
469        mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization,
470                inputSize, getTotalInputIndex(), errorType);
471    }
472
473    // Saves the current normalized compound distance for space-aware gestures.
474    // See getNormalizedCompoundDistanceAfterFirstWord for details.
475    AK_FORCE_INLINE void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() {
476        mDicNodeState.mDicNodeStateScoring.saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet();
477    }
478
479    // Caveat: Must not be called outside Weighting
480    // This restriction is guaranteed by "friend"
481    AK_FORCE_INLINE void forwardInputIndex(const int pointerId, const int count,
482            const bool overwritesPrevCodePointByNodeCodePoint) {
483        if (count == 0) {
484            return;
485        }
486        mDicNodeState.mDicNodeStateInput.forwardInputIndex(pointerId, count);
487        if (overwritesPrevCodePointByNodeCodePoint) {
488            mDicNodeState.mDicNodeStateInput.setPrevCodePoint(0, getNodeCodePoint());
489        }
490    }
491
492    AK_FORCE_INLINE void updateInputIndexG(const DicNode_InputStateG *const inputStateG) {
493        if (mDicNodeState.mDicNodeStateOutput.getPrevWordCount() == 1 && isFirstLetter()) {
494            mDicNodeState.mDicNodeStateOutput.setSecondWordFirstInputIndex(
495                    inputStateG->mInputIndex);
496        }
497        mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId,
498                inputStateG->mInputIndex, inputStateG->mPrevCodePoint,
499                inputStateG->mTerminalDiffCost, inputStateG->mRawLength);
500        mDicNodeState.mDicNodeStateScoring.addRawLength(inputStateG->mRawLength);
501        mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(inputStateG->mDoubleLetterLevel);
502    }
503};
504} // namespace latinime
505#endif // LATINIME_DIC_NODE_H
506