dic_node.h revision 25e8eda9afb5c36703bd50b263ab0dd3a3b38d31
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LATINIME_DIC_NODE_H
18#define LATINIME_DIC_NODE_H
19
20#include "char_utils.h"
21#include "defines.h"
22#include "dic_node_state.h"
23#include "dic_node_profiler.h"
24#include "dic_node_properties.h"
25#include "dic_node_release_listener.h"
26#include "digraph_utils.h"
27
28#if DEBUG_DICT
29#define LOGI_SHOW_ADD_COST_PROP \
30        do { char charBuf[50]; \
31        INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \
32        AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \
33                __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \
34                getInputIndex(0), getNormalizedCompoundDistance(), charBuf); } while (0)
35#define DUMP_WORD_AND_SCORE(header) \
36        do { char charBuf[50]; char prevWordCharBuf[50]; \
37        INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \
38        INTS_TO_CHARS(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, \
39                mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), prevWordCharBuf); \
40        AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %s, %d,,", header, \
41                getSpatialDistanceForScoring(), getLanguageDistanceForScoring(), \
42                getNormalizedCompoundDistance(), getRawLength(), prevWordCharBuf, charBuf, \
43                getInputIndex(0)); \
44        } while (0)
45#else
46#define LOGI_SHOW_ADD_COST_PROP
47#define DUMP_WORD_AND_SCORE(header)
48#endif
49
50namespace latinime {
51
52// This struct is purely a bucket to return values. No instances of this struct should be kept.
53struct DicNode_InputStateG {
54    bool mNeedsToUpdateInputStateG;
55    int mPointerId;
56    int16_t mInputIndex;
57    int mPrevCodePoint;
58    float mTerminalDiffCost;
59    float mRawLength;
60    DoubleLetterLevel mDoubleLetterLevel;
61};
62
63class DicNode {
64    // Caveat: We define Weighting as a friend class of DicNode to let Weighting change
65    // the distance of DicNode.
66    // Caution!!! In general, we avoid using the "friend" access modifier.
67    // This is an exception to explicitly hide DicNode::addCost() from all classes but Weighting.
68    friend class Weighting;
69
70 public:
71#if DEBUG_DICT
72    DicNodeProfiler mProfiler;
73#endif
74    //////////////////
75    // Memory utils //
76    //////////////////
77    AK_FORCE_INLINE static void managedDelete(DicNode *node) {
78        node->remove();
79    }
80    // end
81    /////////////////
82
83    AK_FORCE_INLINE DicNode()
84            :
85#if DEBUG_DICT
86              mProfiler(),
87#endif
88              mDicNodeProperties(), mDicNodeState(), mIsCachedForNextSuggestion(false),
89              mIsUsed(false), mReleaseListener(0) {}
90
91    DicNode(const DicNode &dicNode);
92    DicNode &operator=(const DicNode &dicNode);
93    virtual ~DicNode() {}
94
95    // TODO: minimize arguments by looking binary_format
96    // Init for copy
97    void initByCopy(const DicNode *dicNode) {
98        mIsUsed = true;
99        mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
100        mDicNodeProperties.init(&dicNode->mDicNodeProperties);
101        mDicNodeState.init(&dicNode->mDicNodeState);
102        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
103    }
104
105    // TODO: minimize arguments by looking binary_format
106    // Init for root with prevWordNodePos which is used for bigram
107    void initAsRoot(const int pos, const int childrenPos, const int childrenCount,
108            const int prevWordNodePos) {
109        mIsUsed = true;
110        mIsCachedForNextSuggestion = false;
111        mDicNodeProperties.init(
112                pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0);
113        mDicNodeState.init(prevWordNodePos);
114        PROF_NODE_RESET(mProfiler);
115    }
116
117    void initAsPassingChild(DicNode *parentNode) {
118        mIsUsed = true;
119        mIsCachedForNextSuggestion = parentNode->mIsCachedForNextSuggestion;
120        const int c = parentNode->getNodeTypedCodePoint();
121        mDicNodeProperties.init(&parentNode->mDicNodeProperties, c);
122        mDicNodeState.init(&parentNode->mDicNodeState);
123        PROF_NODE_COPY(&parentNode->mProfiler, mProfiler);
124    }
125
126    // TODO: minimize arguments by looking binary_format
127    // Init for root with previous word
128    void initAsRootWithPreviousWord(DicNode *dicNode, const int pos, const int childrenPos,
129            const int childrenCount) {
130        mIsUsed = true;
131        mIsCachedForNextSuggestion = false;
132        mDicNodeProperties.init(
133                pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0);
134        // TODO: Move to dicNodeState?
135        mDicNodeState.mDicNodeStateOutput.init(); // reset for next word
136        mDicNodeState.mDicNodeStateInput.init(
137                &dicNode->mDicNodeState.mDicNodeStateInput, true /* resetTerminalDiffCost */);
138        mDicNodeState.mDicNodeStateScoring.init(
139                &dicNode->mDicNodeState.mDicNodeStateScoring);
140        mDicNodeState.mDicNodeStatePrevWord.init(
141                dicNode->mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() + 1,
142                dicNode->mDicNodeProperties.getProbability(),
143                dicNode->mDicNodeProperties.getPos(),
144                dicNode->mDicNodeState.mDicNodeStatePrevWord.mPrevWord,
145                dicNode->mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(),
146                dicNode->getOutputWordBuf(),
147                dicNode->mDicNodeProperties.getDepth(),
148                dicNode->mDicNodeState.mDicNodeStatePrevWord.mPrevSpacePositions,
149                mDicNodeState.mDicNodeStateInput.getInputIndex(0) /* lastInputIndex */);
150        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
151    }
152
153    // TODO: minimize arguments by looking binary_format
154    void initAsChild(DicNode *dicNode, const int pos, const uint8_t flags, const int childrenPos,
155            const int attributesPos, const int siblingPos, const int nodeCodePoint,
156            const int childrenCount, const int probability, const int bigramProbability,
157            const bool isTerminal, const bool hasMultipleChars, const bool hasChildren,
158            const uint16_t additionalSubwordLength, const int *additionalSubword) {
159        mIsUsed = true;
160        uint16_t newDepth = static_cast<uint16_t>(dicNode->getDepth() + 1);
161        mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
162        const uint16_t newLeavingDepth = static_cast<uint16_t>(
163                dicNode->mDicNodeProperties.getLeavingDepth() + additionalSubwordLength);
164        mDicNodeProperties.init(pos, flags, childrenPos, attributesPos, siblingPos, nodeCodePoint,
165                childrenCount, probability, bigramProbability, isTerminal, hasMultipleChars,
166                hasChildren, newDepth, newLeavingDepth);
167        mDicNodeState.init(&dicNode->mDicNodeState, additionalSubwordLength, additionalSubword);
168        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
169    }
170
171    AK_FORCE_INLINE void remove() {
172        mIsUsed = false;
173        if (mReleaseListener) {
174            mReleaseListener->onReleased(this);
175        }
176    }
177
178    bool isUsed() const {
179        return mIsUsed;
180    }
181
182    bool isRoot() const {
183        return getDepth() == 0;
184    }
185
186    bool hasChildren() const {
187        return mDicNodeProperties.hasChildren();
188    }
189
190    bool isLeavingNode() const {
191        ASSERT(getDepth() <= getLeavingDepth());
192        return getDepth() == getLeavingDepth();
193    }
194
195    AK_FORCE_INLINE bool isFirstLetter() const {
196        return getDepth() == 1;
197    }
198
199    bool isCached() const {
200        return mIsCachedForNextSuggestion;
201    }
202
203    void setCached() {
204        mIsCachedForNextSuggestion = true;
205    }
206
207    // Used to expand the node in DicNodeUtils
208    int getNodeTypedCodePoint() const {
209        return mDicNodeState.mDicNodeStateOutput.getCodePointAt(getDepth());
210    }
211
212    bool isImpossibleBigramWord() const {
213        const int probability = mDicNodeProperties.getProbability();
214        if (probability == 0) {
215            return true;
216        }
217        const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength()
218                - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1;
219        const int currentWordLen = getDepth();
220        return (prevWordLen == 1 && currentWordLen == 1);
221    }
222
223    bool isCapitalized() const {
224        const int c = getOutputWordBuf()[0];
225        return isAsciiUpper(c);
226    }
227
228    bool isFirstWord() const {
229        return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos() == NOT_VALID_WORD;
230    }
231
232    bool isCompletion(const int inputSize) const {
233        return mDicNodeState.mDicNodeStateInput.getInputIndex(0) >= inputSize;
234    }
235
236    bool canDoLookAheadCorrection(const int inputSize) const {
237        return mDicNodeState.mDicNodeStateInput.getInputIndex(0) < inputSize - 1;
238    }
239
240    // Used to get bigram probability in DicNodeUtils
241    int getPos() const {
242        return mDicNodeProperties.getPos();
243    }
244
245    // Used to get bigram probability in DicNodeUtils
246    int getPrevWordPos() const {
247        return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos();
248    }
249
250    // Used in DicNodeUtils
251    int getChildrenPos() const {
252        return mDicNodeProperties.getChildrenPos();
253    }
254
255    // Used in DicNodeUtils
256    int getChildrenCount() const {
257        return mDicNodeProperties.getChildrenCount();
258    }
259
260    // Used in DicNodeUtils
261    int getProbability() const {
262        return mDicNodeProperties.getProbability();
263    }
264
265    AK_FORCE_INLINE bool isTerminalWordNode() const {
266        const bool isTerminalNodes = mDicNodeProperties.isTerminal();
267        const int currentNodeDepth = getDepth();
268        const int terminalNodeDepth = mDicNodeProperties.getLeavingDepth();
269        return isTerminalNodes && currentNodeDepth > 0 && currentNodeDepth == terminalNodeDepth;
270    }
271
272    bool shouldBeFilterdBySafetyNetForBigram() const {
273        const uint16_t currentDepth = getDepth();
274        const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength()
275                - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1;
276        return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1));
277    }
278
279    uint16_t getLeavingDepth() const {
280        return mDicNodeProperties.getLeavingDepth();
281    }
282
283    bool isTotalInputSizeExceedingLimit() const {
284        const int prevWordsLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength();
285        const int currentWordDepth = getDepth();
286        // TODO: 3 can be 2? Needs to be investigated.
287        // TODO: Have a const variable for 3 (or 2)
288        return prevWordsLen + currentWordDepth > MAX_WORD_LENGTH - 3;
289    }
290
291    // TODO: This may be defective. Needs to be revised.
292    bool truncateNode(const DicNode *const topNode, const int inputCommitPoint) {
293        const int prevWordLenOfTop = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength();
294        int newPrevWordStartIndex = inputCommitPoint;
295        int charCount = 0;
296        // Find new word start index
297        for (int i = 0; i < prevWordLenOfTop; ++i) {
298            const int c = mDicNodeState.mDicNodeStatePrevWord.getPrevWordCodePointAt(i);
299            // TODO: Check other separators.
300            if (c != KEYCODE_SPACE && c != KEYCODE_SINGLE_QUOTE) {
301                if (charCount == inputCommitPoint) {
302                    newPrevWordStartIndex = i;
303                    break;
304                }
305                ++charCount;
306            }
307        }
308        if (!mDicNodeState.mDicNodeStatePrevWord.startsWith(
309                &topNode->mDicNodeState.mDicNodeStatePrevWord, newPrevWordStartIndex - 1)) {
310            // Node mismatch.
311            return false;
312        }
313        mDicNodeState.mDicNodeStateInput.truncate(inputCommitPoint);
314        mDicNodeState.mDicNodeStatePrevWord.truncate(newPrevWordStartIndex);
315        return true;
316    }
317
318    void outputResult(int *dest) const {
319        const uint16_t prevWordLength = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength();
320        const uint16_t currentDepth = getDepth();
321        DicNodeUtils::appendTwoWords(mDicNodeState.mDicNodeStatePrevWord.mPrevWord,
322                   prevWordLength, getOutputWordBuf(), currentDepth, dest);
323        DUMP_WORD_AND_SCORE("OUTPUT");
324    }
325
326    void outputSpacePositionsResult(int *spaceIndices) const {
327        mDicNodeState.mDicNodeStatePrevWord.outputSpacePositions(spaceIndices);
328    }
329
330    bool hasMultipleWords() const {
331        return mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() > 0;
332    }
333
334    float getProximityCorrectionCount() const {
335        return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount());
336    }
337
338    float getEditCorrectionCount() const {
339        return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount());
340    }
341
342    // Used to prune nodes
343    float getNormalizedCompoundDistance() const {
344        return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistance();
345    }
346
347    // Used to prune nodes
348    float getNormalizedSpatialDistance() const {
349        return mDicNodeState.mDicNodeStateScoring.getSpatialDistance()
350                / static_cast<float>(getInputIndex(0) + 1);
351    }
352
353    // Used to prune nodes
354    float getCompoundDistance() const {
355        return mDicNodeState.mDicNodeStateScoring.getCompoundDistance();
356    }
357
358    // Used to prune nodes
359    float getCompoundDistance(const float languageWeight) const {
360        return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight);
361    }
362
363    // Note that "cost" means delta for "distance" that is weighted.
364    float getTotalPrevWordsLanguageCost() const {
365        return mDicNodeState.mDicNodeStateScoring.getTotalPrevWordsLanguageCost();
366    }
367
368    // Used to commit input partially
369    int getPrevWordNodePos() const {
370        return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos();
371    }
372
373    AK_FORCE_INLINE const int *getOutputWordBuf() const {
374        return mDicNodeState.mDicNodeStateOutput.mWordBuf;
375    }
376
377    int getPrevCodePointG(int pointerId) const {
378        return mDicNodeState.mDicNodeStateInput.getPrevCodePoint(pointerId);
379    }
380
381    // Whether the current codepoint can be an intentional omission, in which case the traversal
382    // algorithm will always check for a possible omission here.
383    bool canBeIntentionalOmission() const {
384        return isIntentionalOmissionCodePoint(getNodeCodePoint());
385    }
386
387    // Whether the omission is so frequent that it should incur zero cost.
388    bool isZeroCostOmission() const {
389        // TODO: do not hardcode and read from header
390        return (getNodeCodePoint() == KEYCODE_SINGLE_QUOTE);
391    }
392
393    // TODO: remove
394    float getTerminalDiffCostG(int path) const {
395        return mDicNodeState.mDicNodeStateInput.getTerminalDiffCost(path);
396    }
397
398    //////////////////////
399    // Temporary getter //
400    // TODO: Remove     //
401    //////////////////////
402    // TODO: Remove once touch path is merged into ProximityInfoState
403    // Note: Returned codepoint may be a digraph codepoint if the node is in a composite glyph.
404    int getNodeCodePoint() const {
405        const int codePoint = mDicNodeProperties.getNodeCodePoint();
406        const DigraphUtils::DigraphCodePointIndex digraphIndex =
407                mDicNodeState.mDicNodeStateScoring.getDigraphIndex();
408        if (digraphIndex == DigraphUtils::NOT_A_DIGRAPH_INDEX) {
409            return codePoint;
410        }
411        return DigraphUtils::getDigraphCodePointForIndex(codePoint, digraphIndex);
412    }
413
414    ////////////////////////////////
415    // Utils for cost calculation //
416    ////////////////////////////////
417    AK_FORCE_INLINE bool isSameNodeCodePoint(const DicNode *const dicNode) const {
418        return mDicNodeProperties.getNodeCodePoint()
419                == dicNode->mDicNodeProperties.getNodeCodePoint();
420    }
421
422    // TODO: remove
423    // TODO: rename getNextInputIndex
424    int16_t getInputIndex(int pointerId) const {
425        return mDicNodeState.mDicNodeStateInput.getInputIndex(pointerId);
426    }
427
428    ////////////////////////////////////
429    // Getter of features for scoring //
430    ////////////////////////////////////
431    float getSpatialDistanceForScoring() const {
432        return mDicNodeState.mDicNodeStateScoring.getSpatialDistance();
433    }
434
435    float getLanguageDistanceForScoring() const {
436        return mDicNodeState.mDicNodeStateScoring.getLanguageDistance();
437    }
438
439    float getLanguageDistanceRatePerWordForScoring() const {
440        const float langDist = getLanguageDistanceForScoring();
441        const float totalWordCount =
442                static_cast<float>(mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() + 1);
443        return langDist / totalWordCount;
444    }
445
446    float getRawLength() const {
447        return mDicNodeState.mDicNodeStateScoring.getRawLength();
448    }
449
450    bool isLessThanOneErrorForScoring() const {
451        return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount()
452                + mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount() <= 1;
453    }
454
455    DoubleLetterLevel getDoubleLetterLevel() const {
456        return mDicNodeState.mDicNodeStateScoring.getDoubleLetterLevel();
457    }
458
459    void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) {
460        mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel);
461    }
462
463    bool isInDigraph() const {
464        return mDicNodeState.mDicNodeStateScoring.getDigraphIndex()
465                != DigraphUtils::NOT_A_DIGRAPH_INDEX;
466    }
467
468    void advanceDigraphIndex() {
469        mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex();
470    }
471
472    uint8_t getFlags() const {
473        return mDicNodeProperties.getFlags();
474    }
475
476    int getAttributesPos() const {
477        return mDicNodeProperties.getAttributesPos();
478    }
479
480    inline uint16_t getDepth() const {
481        return mDicNodeProperties.getDepth();
482    }
483
484    AK_FORCE_INLINE void dump(const char *tag) const {
485#if DEBUG_DICT
486        DUMP_WORD_AND_SCORE(tag);
487#if DEBUG_DUMP_ERROR
488        mProfiler.dump();
489#endif
490#endif
491    }
492
493    void setReleaseListener(DicNodeReleaseListener *releaseListener) {
494        mReleaseListener = releaseListener;
495    }
496
497    AK_FORCE_INLINE bool compare(const DicNode *right) {
498        if (!isUsed() && !right->isUsed()) {
499            // Compare pointer values here for stable comparison
500            return this > right;
501        }
502        if (!isUsed()) {
503            return true;
504        }
505        if (!right->isUsed()) {
506            return false;
507        }
508        const float diff =
509                right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance();
510        static const float MIN_DIFF = 0.000001f;
511        if (diff > MIN_DIFF) {
512            return true;
513        } else if (diff < -MIN_DIFF) {
514            return false;
515        }
516        const int depth = getDepth();
517        const int depthDiff = right->getDepth() - depth;
518        if (depthDiff != 0) {
519            return depthDiff > 0;
520        }
521        for (int i = 0; i < depth; ++i) {
522            const int codePoint = mDicNodeState.mDicNodeStateOutput.getCodePointAt(i);
523            const int rightCodePoint = right->mDicNodeState.mDicNodeStateOutput.getCodePointAt(i);
524            if (codePoint != rightCodePoint) {
525                return rightCodePoint > codePoint;
526            }
527        }
528        // Compare pointer values here for stable comparison
529        return this > right;
530    }
531
532 private:
533    DicNodeProperties mDicNodeProperties;
534    DicNodeState mDicNodeState;
535    // TODO: Remove
536    bool mIsCachedForNextSuggestion;
537    bool mIsUsed;
538    DicNodeReleaseListener *mReleaseListener;
539
540    AK_FORCE_INLINE int getTotalInputIndex() const {
541        int index = 0;
542        for (int i = 0; i < MAX_POINTER_COUNT_G; i++) {
543            index += mDicNodeState.mDicNodeStateInput.getInputIndex(i);
544        }
545        return index;
546    }
547
548    // Caveat: Must not be called outside Weighting
549    // This restriction is guaranteed by "friend"
550    AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost,
551            const bool doNormalization, const int inputSize, const bool isEditCorrection,
552            const bool isProximityCorrection) {
553        if (DEBUG_GEO_FULL) {
554            LOGI_SHOW_ADD_COST_PROP;
555        }
556        mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization,
557                inputSize, getTotalInputIndex(), isEditCorrection, isProximityCorrection);
558    }
559
560    // Caveat: Must not be called outside Weighting
561    // This restriction is guaranteed by "friend"
562    AK_FORCE_INLINE void forwardInputIndex(const int pointerId, const int count,
563            const bool overwritesPrevCodePointByNodeCodePoint) {
564        if (count == 0) {
565            return;
566        }
567        mDicNodeState.mDicNodeStateInput.forwardInputIndex(pointerId, count);
568        if (overwritesPrevCodePointByNodeCodePoint) {
569            mDicNodeState.mDicNodeStateInput.setPrevCodePoint(0, getNodeCodePoint());
570        }
571    }
572
573    AK_FORCE_INLINE void updateInputIndexG(DicNode_InputStateG *inputStateG) {
574        mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId,
575                inputStateG->mInputIndex, inputStateG->mPrevCodePoint,
576                inputStateG->mTerminalDiffCost, inputStateG->mRawLength);
577        mDicNodeState.mDicNodeStateScoring.addRawLength(inputStateG->mRawLength);
578        mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(inputStateG->mDoubleLetterLevel);
579    }
580};
581} // namespace latinime
582#endif // LATINIME_DIC_NODE_H
583