dic_node.h revision a107dcaeb6302981974bab8284f6b7943673cf11
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LATINIME_DIC_NODE_H
18#define LATINIME_DIC_NODE_H
19
20#include "char_utils.h"
21#include "defines.h"
22#include "dic_node_state.h"
23#include "dic_node_profiler.h"
24#include "dic_node_properties.h"
25#include "dic_node_release_listener.h"
26#include "digraph_utils.h"
27
28#if DEBUG_DICT
29#define LOGI_SHOW_ADD_COST_PROP \
30        do { char charBuf[50]; \
31        INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \
32        AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \
33                __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \
34                getInputIndex(0), getNormalizedCompoundDistance(), charBuf); } while (0)
35#define DUMP_WORD_AND_SCORE(header) \
36        do { char charBuf[50]; char prevWordCharBuf[50]; \
37        INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \
38        INTS_TO_CHARS(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, \
39                mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), prevWordCharBuf); \
40        AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %s, %d,,", header, \
41                getSpatialDistanceForScoring(), getLanguageDistanceForScoring(), \
42                getNormalizedCompoundDistance(), getRawLength(), prevWordCharBuf, charBuf, \
43                getInputIndex(0)); \
44        } while (0)
45#else
46#define LOGI_SHOW_ADD_COST_PROP
47#define DUMP_WORD_AND_SCORE(header)
48#endif
49
50namespace latinime {
51
52// This struct is purely a bucket to return values. No instances of this struct should be kept.
53struct DicNode_InputStateG {
54    bool mNeedsToUpdateInputStateG;
55    int mPointerId;
56    int16_t mInputIndex;
57    int mPrevCodePoint;
58    float mTerminalDiffCost;
59    float mRawLength;
60    DoubleLetterLevel mDoubleLetterLevel;
61};
62
63class DicNode {
64    // Caveat: We define Weighting as a friend class of DicNode to let Weighting change
65    // the distance of DicNode.
66    // Caution!!! In general, we avoid using the "friend" access modifier.
67    // This is an exception to explicitly hide DicNode::addCost() from all classes but Weighting.
68    friend class Weighting;
69
70 public:
71#if DEBUG_DICT
72    DicNodeProfiler mProfiler;
73#endif
74    //////////////////
75    // Memory utils //
76    //////////////////
77    AK_FORCE_INLINE static void managedDelete(DicNode *node) {
78        node->remove();
79    }
80    // end
81    /////////////////
82
83    AK_FORCE_INLINE DicNode()
84            :
85#if DEBUG_DICT
86              mProfiler(),
87#endif
88              mDicNodeProperties(), mDicNodeState(), mIsCachedForNextSuggestion(false),
89              mIsUsed(false), mReleaseListener(0) {}
90
91    DicNode(const DicNode &dicNode);
92    DicNode &operator=(const DicNode &dicNode);
93    virtual ~DicNode() {}
94
95    // TODO: minimize arguments by looking binary_format
96    // Init for copy
97    void initByCopy(const DicNode *dicNode) {
98        mIsUsed = true;
99        mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
100        mDicNodeProperties.init(&dicNode->mDicNodeProperties);
101        mDicNodeState.init(&dicNode->mDicNodeState);
102        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
103    }
104
105    // TODO: minimize arguments by looking binary_format
106    // Init for root with prevWordNodePos which is used for bigram
107    void initAsRoot(const int pos, const int childrenPos, const int childrenCount,
108            const int prevWordNodePos) {
109        mIsUsed = true;
110        mIsCachedForNextSuggestion = false;
111        mDicNodeProperties.init(
112                pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0);
113        mDicNodeState.init(prevWordNodePos);
114        PROF_NODE_RESET(mProfiler);
115    }
116
117    void initAsPassingChild(DicNode *parentNode) {
118        mIsUsed = true;
119        mIsCachedForNextSuggestion = parentNode->mIsCachedForNextSuggestion;
120        const int c = parentNode->getNodeTypedCodePoint();
121        mDicNodeProperties.init(&parentNode->mDicNodeProperties, c);
122        mDicNodeState.init(&parentNode->mDicNodeState);
123        PROF_NODE_COPY(&parentNode->mProfiler, mProfiler);
124    }
125
126    // TODO: minimize arguments by looking binary_format
127    // Init for root with previous word
128    void initAsRootWithPreviousWord(DicNode *dicNode, const int pos, const int childrenPos,
129            const int childrenCount) {
130        mIsUsed = true;
131        mIsCachedForNextSuggestion = false;
132        mDicNodeProperties.init(
133                pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0);
134        // TODO: Move to dicNodeState?
135        mDicNodeState.mDicNodeStateOutput.init(); // reset for next word
136        mDicNodeState.mDicNodeStateInput.init(
137                &dicNode->mDicNodeState.mDicNodeStateInput, true /* resetTerminalDiffCost */);
138        mDicNodeState.mDicNodeStateScoring.init(
139                &dicNode->mDicNodeState.mDicNodeStateScoring);
140        mDicNodeState.mDicNodeStatePrevWord.init(
141                dicNode->mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() + 1,
142                dicNode->mDicNodeProperties.getProbability(),
143                dicNode->mDicNodeProperties.getPos(),
144                dicNode->mDicNodeState.mDicNodeStatePrevWord.mPrevWord,
145                dicNode->mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(),
146                dicNode->getOutputWordBuf(),
147                dicNode->mDicNodeProperties.getDepth(),
148                dicNode->mDicNodeState.mDicNodeStatePrevWord.mPrevSpacePositions,
149                mDicNodeState.mDicNodeStateInput.getInputIndex(0) /* lastInputIndex */);
150        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
151    }
152
153    // TODO: minimize arguments by looking binary_format
154    void initAsChild(DicNode *dicNode, const int pos, const uint8_t flags, const int childrenPos,
155            const int attributesPos, const int siblingPos, const int nodeCodePoint,
156            const int childrenCount, const int probability, const int bigramProbability,
157            const bool isTerminal, const bool hasMultipleChars, const bool hasChildren,
158            const uint16_t additionalSubwordLength, const int *additionalSubword) {
159        mIsUsed = true;
160        uint16_t newDepth = static_cast<uint16_t>(dicNode->getDepth() + 1);
161        mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
162        const uint16_t newLeavingDepth = static_cast<uint16_t>(
163                dicNode->mDicNodeProperties.getLeavingDepth() + additionalSubwordLength);
164        mDicNodeProperties.init(pos, flags, childrenPos, attributesPos, siblingPos, nodeCodePoint,
165                childrenCount, probability, bigramProbability, isTerminal, hasMultipleChars,
166                hasChildren, newDepth, newLeavingDepth);
167        mDicNodeState.init(&dicNode->mDicNodeState, additionalSubwordLength, additionalSubword);
168        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
169    }
170
171    AK_FORCE_INLINE void remove() {
172        mIsUsed = false;
173        if (mReleaseListener) {
174            mReleaseListener->onReleased(this);
175        }
176    }
177
178    bool isUsed() const {
179        return mIsUsed;
180    }
181
182    bool isRoot() const {
183        return getDepth() == 0;
184    }
185
186    bool hasChildren() const {
187        return mDicNodeProperties.hasChildren();
188    }
189
190    bool isLeavingNode() const {
191        ASSERT(getDepth() <= getLeavingDepth());
192        return getDepth() == getLeavingDepth();
193    }
194
195    AK_FORCE_INLINE bool isFirstLetter() const {
196        return getDepth() == 1;
197    }
198
199    bool isCached() const {
200        return mIsCachedForNextSuggestion;
201    }
202
203    void setCached() {
204        mIsCachedForNextSuggestion = true;
205    }
206
207    // Used to expand the node in DicNodeUtils
208    int getNodeTypedCodePoint() const {
209        return mDicNodeState.mDicNodeStateOutput.getCodePointAt(getDepth());
210    }
211
212    bool isImpossibleBigramWord() const {
213        if (mDicNodeProperties.hasBlacklistedOrNotAWordFlag()) {
214            return true;
215        }
216        const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength()
217                - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1;
218        const int currentWordLen = getDepth();
219        return (prevWordLen == 1 && currentWordLen == 1);
220    }
221
222    bool isFirstCharUppercase() const {
223        const int c = getOutputWordBuf()[0];
224        return isAsciiUpper(c);
225    }
226
227    bool isFirstWord() const {
228        return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos() == NOT_VALID_WORD;
229    }
230
231    bool isCompletion(const int inputSize) const {
232        return mDicNodeState.mDicNodeStateInput.getInputIndex(0) >= inputSize;
233    }
234
235    bool canDoLookAheadCorrection(const int inputSize) const {
236        return mDicNodeState.mDicNodeStateInput.getInputIndex(0) < inputSize - 1;
237    }
238
239    // Used to get bigram probability in DicNodeUtils
240    int getPos() const {
241        return mDicNodeProperties.getPos();
242    }
243
244    // Used to get bigram probability in DicNodeUtils
245    int getPrevWordPos() const {
246        return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos();
247    }
248
249    // Used in DicNodeUtils
250    int getChildrenPos() const {
251        return mDicNodeProperties.getChildrenPos();
252    }
253
254    // Used in DicNodeUtils
255    int getChildrenCount() const {
256        return mDicNodeProperties.getChildrenCount();
257    }
258
259    // Used in DicNodeUtils
260    int getProbability() const {
261        return mDicNodeProperties.getProbability();
262    }
263
264    AK_FORCE_INLINE bool isTerminalWordNode() const {
265        const bool isTerminalNodes = mDicNodeProperties.isTerminal();
266        const int currentNodeDepth = getDepth();
267        const int terminalNodeDepth = mDicNodeProperties.getLeavingDepth();
268        return isTerminalNodes && currentNodeDepth > 0 && currentNodeDepth == terminalNodeDepth;
269    }
270
271    bool shouldBeFilterdBySafetyNetForBigram() const {
272        const uint16_t currentDepth = getDepth();
273        const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength()
274                - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1;
275        return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1));
276    }
277
278    uint16_t getLeavingDepth() const {
279        return mDicNodeProperties.getLeavingDepth();
280    }
281
282    bool isTotalInputSizeExceedingLimit() const {
283        const int prevWordsLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength();
284        const int currentWordDepth = getDepth();
285        // TODO: 3 can be 2? Needs to be investigated.
286        // TODO: Have a const variable for 3 (or 2)
287        return prevWordsLen + currentWordDepth > MAX_WORD_LENGTH - 3;
288    }
289
290    // TODO: This may be defective. Needs to be revised.
291    bool truncateNode(const DicNode *const topNode, const int inputCommitPoint) {
292        const int prevWordLenOfTop = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength();
293        int newPrevWordStartIndex = inputCommitPoint;
294        int charCount = 0;
295        // Find new word start index
296        for (int i = 0; i < prevWordLenOfTop; ++i) {
297            const int c = mDicNodeState.mDicNodeStatePrevWord.getPrevWordCodePointAt(i);
298            // TODO: Check other separators.
299            if (c != KEYCODE_SPACE && c != KEYCODE_SINGLE_QUOTE) {
300                if (charCount == inputCommitPoint) {
301                    newPrevWordStartIndex = i;
302                    break;
303                }
304                ++charCount;
305            }
306        }
307        if (!mDicNodeState.mDicNodeStatePrevWord.startsWith(
308                &topNode->mDicNodeState.mDicNodeStatePrevWord, newPrevWordStartIndex - 1)) {
309            // Node mismatch.
310            return false;
311        }
312        mDicNodeState.mDicNodeStateInput.truncate(inputCommitPoint);
313        mDicNodeState.mDicNodeStatePrevWord.truncate(newPrevWordStartIndex);
314        return true;
315    }
316
317    void outputResult(int *dest) const {
318        const uint16_t prevWordLength = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength();
319        const uint16_t currentDepth = getDepth();
320        DicNodeUtils::appendTwoWords(mDicNodeState.mDicNodeStatePrevWord.mPrevWord,
321                   prevWordLength, getOutputWordBuf(), currentDepth, dest);
322        DUMP_WORD_AND_SCORE("OUTPUT");
323    }
324
325    void outputSpacePositionsResult(int *spaceIndices) const {
326        mDicNodeState.mDicNodeStatePrevWord.outputSpacePositions(spaceIndices);
327    }
328
329    bool hasMultipleWords() const {
330        return mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() > 0;
331    }
332
333    float getProximityCorrectionCount() const {
334        return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount());
335    }
336
337    float getEditCorrectionCount() const {
338        return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount());
339    }
340
341    // Used to prune nodes
342    float getNormalizedCompoundDistance() const {
343        return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistance();
344    }
345
346    // Used to prune nodes
347    float getNormalizedSpatialDistance() const {
348        return mDicNodeState.mDicNodeStateScoring.getSpatialDistance()
349                / static_cast<float>(getInputIndex(0) + 1);
350    }
351
352    // Used to prune nodes
353    float getCompoundDistance() const {
354        return mDicNodeState.mDicNodeStateScoring.getCompoundDistance();
355    }
356
357    // Used to prune nodes
358    float getCompoundDistance(const float languageWeight) const {
359        return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight);
360    }
361
362    // Used to commit input partially
363    int getPrevWordNodePos() const {
364        return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos();
365    }
366
367    AK_FORCE_INLINE const int *getOutputWordBuf() const {
368        return mDicNodeState.mDicNodeStateOutput.mWordBuf;
369    }
370
371    int getPrevCodePointG(int pointerId) const {
372        return mDicNodeState.mDicNodeStateInput.getPrevCodePoint(pointerId);
373    }
374
375    // Whether the current codepoint can be an intentional omission, in which case the traversal
376    // algorithm will always check for a possible omission here.
377    bool canBeIntentionalOmission() const {
378        return isIntentionalOmissionCodePoint(getNodeCodePoint());
379    }
380
381    // Whether the omission is so frequent that it should incur zero cost.
382    bool isZeroCostOmission() const {
383        // TODO: do not hardcode and read from header
384        return (getNodeCodePoint() == KEYCODE_SINGLE_QUOTE);
385    }
386
387    // TODO: remove
388    float getTerminalDiffCostG(int path) const {
389        return mDicNodeState.mDicNodeStateInput.getTerminalDiffCost(path);
390    }
391
392    //////////////////////
393    // Temporary getter //
394    // TODO: Remove     //
395    //////////////////////
396    // TODO: Remove once touch path is merged into ProximityInfoState
397    // Note: Returned codepoint may be a digraph codepoint if the node is in a composite glyph.
398    int getNodeCodePoint() const {
399        const int codePoint = mDicNodeProperties.getNodeCodePoint();
400        const DigraphUtils::DigraphCodePointIndex digraphIndex =
401                mDicNodeState.mDicNodeStateScoring.getDigraphIndex();
402        if (digraphIndex == DigraphUtils::NOT_A_DIGRAPH_INDEX) {
403            return codePoint;
404        }
405        return DigraphUtils::getDigraphCodePointForIndex(codePoint, digraphIndex);
406    }
407
408    ////////////////////////////////
409    // Utils for cost calculation //
410    ////////////////////////////////
411    AK_FORCE_INLINE bool isSameNodeCodePoint(const DicNode *const dicNode) const {
412        return mDicNodeProperties.getNodeCodePoint()
413                == dicNode->mDicNodeProperties.getNodeCodePoint();
414    }
415
416    // TODO: remove
417    // TODO: rename getNextInputIndex
418    int16_t getInputIndex(int pointerId) const {
419        return mDicNodeState.mDicNodeStateInput.getInputIndex(pointerId);
420    }
421
422    ////////////////////////////////////
423    // Getter of features for scoring //
424    ////////////////////////////////////
425    float getSpatialDistanceForScoring() const {
426        return mDicNodeState.mDicNodeStateScoring.getSpatialDistance();
427    }
428
429    float getLanguageDistanceForScoring() const {
430        return mDicNodeState.mDicNodeStateScoring.getLanguageDistance();
431    }
432
433    float getLanguageDistanceRatePerWordForScoring() const {
434        const float langDist = getLanguageDistanceForScoring();
435        const float totalWordCount =
436                static_cast<float>(mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() + 1);
437        return langDist / totalWordCount;
438    }
439
440    float getRawLength() const {
441        return mDicNodeState.mDicNodeStateScoring.getRawLength();
442    }
443
444    bool isLessThanOneErrorForScoring() const {
445        return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount()
446                + mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount() <= 1;
447    }
448
449    DoubleLetterLevel getDoubleLetterLevel() const {
450        return mDicNodeState.mDicNodeStateScoring.getDoubleLetterLevel();
451    }
452
453    void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) {
454        mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel);
455    }
456
457    bool isInDigraph() const {
458        return mDicNodeState.mDicNodeStateScoring.getDigraphIndex()
459                != DigraphUtils::NOT_A_DIGRAPH_INDEX;
460    }
461
462    void advanceDigraphIndex() {
463        mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex();
464    }
465
466    bool isExactMatch() const {
467        return mDicNodeState.mDicNodeStateScoring.isExactMatch();
468    }
469
470    uint8_t getFlags() const {
471        return mDicNodeProperties.getFlags();
472    }
473
474    int getAttributesPos() const {
475        return mDicNodeProperties.getAttributesPos();
476    }
477
478    inline uint16_t getDepth() const {
479        return mDicNodeProperties.getDepth();
480    }
481
482    AK_FORCE_INLINE void dump(const char *tag) const {
483#if DEBUG_DICT
484        DUMP_WORD_AND_SCORE(tag);
485#if DEBUG_DUMP_ERROR
486        mProfiler.dump();
487#endif
488#endif
489    }
490
491    void setReleaseListener(DicNodeReleaseListener *releaseListener) {
492        mReleaseListener = releaseListener;
493    }
494
495    AK_FORCE_INLINE bool compare(const DicNode *right) {
496        if (!isUsed() && !right->isUsed()) {
497            // Compare pointer values here for stable comparison
498            return this > right;
499        }
500        if (!isUsed()) {
501            return true;
502        }
503        if (!right->isUsed()) {
504            return false;
505        }
506        const float diff =
507                right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance();
508        static const float MIN_DIFF = 0.000001f;
509        if (diff > MIN_DIFF) {
510            return true;
511        } else if (diff < -MIN_DIFF) {
512            return false;
513        }
514        const int depth = getDepth();
515        const int depthDiff = right->getDepth() - depth;
516        if (depthDiff != 0) {
517            return depthDiff > 0;
518        }
519        for (int i = 0; i < depth; ++i) {
520            const int codePoint = mDicNodeState.mDicNodeStateOutput.getCodePointAt(i);
521            const int rightCodePoint = right->mDicNodeState.mDicNodeStateOutput.getCodePointAt(i);
522            if (codePoint != rightCodePoint) {
523                return rightCodePoint > codePoint;
524            }
525        }
526        // Compare pointer values here for stable comparison
527        return this > right;
528    }
529
530 private:
531    DicNodeProperties mDicNodeProperties;
532    DicNodeState mDicNodeState;
533    // TODO: Remove
534    bool mIsCachedForNextSuggestion;
535    bool mIsUsed;
536    DicNodeReleaseListener *mReleaseListener;
537
538    AK_FORCE_INLINE int getTotalInputIndex() const {
539        int index = 0;
540        for (int i = 0; i < MAX_POINTER_COUNT_G; i++) {
541            index += mDicNodeState.mDicNodeStateInput.getInputIndex(i);
542        }
543        return index;
544    }
545
546    // Caveat: Must not be called outside Weighting
547    // This restriction is guaranteed by "friend"
548    AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost,
549            const bool doNormalization, const int inputSize, const ErrorType errorType) {
550        if (DEBUG_GEO_FULL) {
551            LOGI_SHOW_ADD_COST_PROP;
552        }
553        mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization,
554                inputSize, getTotalInputIndex(), errorType);
555    }
556
557    // Caveat: Must not be called outside Weighting
558    // This restriction is guaranteed by "friend"
559    AK_FORCE_INLINE void forwardInputIndex(const int pointerId, const int count,
560            const bool overwritesPrevCodePointByNodeCodePoint) {
561        if (count == 0) {
562            return;
563        }
564        mDicNodeState.mDicNodeStateInput.forwardInputIndex(pointerId, count);
565        if (overwritesPrevCodePointByNodeCodePoint) {
566            mDicNodeState.mDicNodeStateInput.setPrevCodePoint(0, getNodeCodePoint());
567        }
568    }
569
570    AK_FORCE_INLINE void updateInputIndexG(DicNode_InputStateG *inputStateG) {
571        mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId,
572                inputStateG->mInputIndex, inputStateG->mPrevCodePoint,
573                inputStateG->mTerminalDiffCost, inputStateG->mRawLength);
574        mDicNodeState.mDicNodeStateScoring.addRawLength(inputStateG->mRawLength);
575        mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(inputStateG->mDoubleLetterLevel);
576    }
577};
578} // namespace latinime
579#endif // LATINIME_DIC_NODE_H
580