dic_node.h revision 38c26dd0bf8cd5c4511e4a02d5eeae4b3553f03a
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LATINIME_DIC_NODE_H
18#define LATINIME_DIC_NODE_H
19
20#include "char_utils.h"
21#include "defines.h"
22#include "dic_node_state.h"
23#include "dic_node_profiler.h"
24#include "dic_node_properties.h"
25#include "dic_node_release_listener.h"
26
27#if DEBUG_DICT
28#define LOGI_SHOW_ADD_COST_PROP \
29        do { char charBuf[50]; \
30        INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \
31        AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \
32                __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \
33                getInputIndex(0), getNormalizedCompoundDistance(), charBuf); } while (0)
34#define DUMP_WORD_AND_SCORE(header) \
35        do { char charBuf[50]; char prevWordCharBuf[50]; \
36        INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \
37        INTS_TO_CHARS(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, \
38                mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), prevWordCharBuf); \
39        AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %s, %d,,", header, \
40                getSpatialDistanceForScoring(), getLanguageDistanceForScoring(), \
41                getNormalizedCompoundDistance(), getRawLength(), prevWordCharBuf, charBuf, \
42                getInputIndex(0)); \
43        } while (0)
44#else
45#define LOGI_SHOW_ADD_COST_PROP
46#define DUMP_WORD_AND_SCORE(header)
47#endif
48
49namespace latinime {
50
51// Naming convention
52// - Distance: "Weighted" edit distance -- used both for spatial and language.
53// - Compound Distance: Spatial Distance + Language Distance -- used for pruning and scoring
54// - Cost: delta/diff for Distance -- used both for spatial and language
55// - Length: "Non-weighted" -- used only for spatial
56// - Probability: "Non-weighted" -- used only for language
57
58// This struct is purely a bucket to return values. No instances of this struct should be kept.
59struct DicNode_InputStateG {
60    bool mNeedsToUpdateInputStateG;
61    int mPointerId;
62    int16_t mInputIndex;
63    int mPrevCodePoint;
64    float mTerminalDiffCost;
65    float mRawLength;
66    DoubleLetterLevel mDoubleLetterLevel;
67};
68
69class DicNode {
70    // Caveat: We define Weighting as a friend class of DicNode to let Weighting change
71    // the distance of DicNode.
72    // Caution!!! In general, we avoid using the "friend" access modifier.
73    // This is an exception to explicitly hide DicNode::addCost() from all classes but Weighting.
74    friend class Weighting;
75
76 public:
77#if DEBUG_DICT
78    DicNodeProfiler mProfiler;
79#endif
80    //////////////////
81    // Memory utils //
82    //////////////////
83    AK_FORCE_INLINE static void managedDelete(DicNode *node) {
84        node->remove();
85    }
86    // end
87    /////////////////
88
89    AK_FORCE_INLINE DicNode()
90            :
91#if DEBUG_DICT
92              mProfiler(),
93#endif
94              mDicNodeProperties(), mDicNodeState(), mIsCachedForNextSuggestion(false),
95              mIsUsed(false), mReleaseListener(0) {}
96
97    DicNode(const DicNode &dicNode);
98    DicNode &operator=(const DicNode &dicNode);
99    virtual ~DicNode() {}
100
101    // TODO: minimize arguments by looking binary_format
102    // Init for copy
103    void initByCopy(const DicNode *dicNode) {
104        mIsUsed = true;
105        mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
106        mDicNodeProperties.init(&dicNode->mDicNodeProperties);
107        mDicNodeState.init(&dicNode->mDicNodeState);
108        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
109    }
110
111    // TODO: minimize arguments by looking binary_format
112    // Init for root with prevWordNodePos which is used for bigram
113    void initAsRoot(const int pos, const int childrenPos, const int childrenCount,
114            const int prevWordNodePos) {
115        mIsUsed = true;
116        mIsCachedForNextSuggestion = false;
117        mDicNodeProperties.init(
118                pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0);
119        mDicNodeState.init(prevWordNodePos);
120        PROF_NODE_RESET(mProfiler);
121    }
122
123    void initAsPassingChild(DicNode *parentNode) {
124        mIsUsed = true;
125        mIsCachedForNextSuggestion = parentNode->mIsCachedForNextSuggestion;
126        const int c = parentNode->getNodeTypedCodePoint();
127        mDicNodeProperties.init(&parentNode->mDicNodeProperties, c);
128        mDicNodeState.init(&parentNode->mDicNodeState);
129        PROF_NODE_COPY(&parentNode->mProfiler, mProfiler);
130    }
131
132    // TODO: minimize arguments by looking binary_format
133    // Init for root with previous word
134    void initAsRootWithPreviousWord(DicNode *dicNode, const int pos, const int childrenPos,
135            const int childrenCount) {
136        mIsUsed = true;
137        mIsCachedForNextSuggestion = false;
138        mDicNodeProperties.init(
139                pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0);
140        // TODO: Move to dicNodeState?
141        mDicNodeState.mDicNodeStateOutput.init(); // reset for next word
142        mDicNodeState.mDicNodeStateInput.init(
143                &dicNode->mDicNodeState.mDicNodeStateInput, true /* resetTerminalDiffCost */);
144        mDicNodeState.mDicNodeStateScoring.init(
145                &dicNode->mDicNodeState.mDicNodeStateScoring);
146        mDicNodeState.mDicNodeStatePrevWord.init(
147                dicNode->mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() + 1,
148                dicNode->mDicNodeProperties.getProbability(),
149                dicNode->mDicNodeProperties.getPos(),
150                dicNode->mDicNodeState.mDicNodeStatePrevWord.mPrevWord,
151                dicNode->mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(),
152                dicNode->getOutputWordBuf(),
153                dicNode->mDicNodeProperties.getDepth(),
154                dicNode->mDicNodeState.mDicNodeStatePrevWord.mPrevSpacePositions,
155                mDicNodeState.mDicNodeStateInput.getInputIndex(0) /* lastInputIndex */);
156        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
157    }
158
159    // TODO: minimize arguments by looking binary_format
160    void initAsChild(DicNode *dicNode, const int pos, const uint8_t flags, const int childrenPos,
161            const int attributesPos, const int siblingPos, const int nodeCodePoint,
162            const int childrenCount, const int probability, const int bigramProbability,
163            const bool isTerminal, const bool hasMultipleChars, const bool hasChildren,
164            const uint16_t additionalSubwordLength, const int *additionalSubword) {
165        mIsUsed = true;
166        uint16_t newDepth = static_cast<uint16_t>(dicNode->getDepth() + 1);
167        mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
168        const uint16_t newLeavingDepth = static_cast<uint16_t>(
169                dicNode->mDicNodeProperties.getLeavingDepth() + additionalSubwordLength);
170        mDicNodeProperties.init(pos, flags, childrenPos, attributesPos, siblingPos, nodeCodePoint,
171                childrenCount, probability, bigramProbability, isTerminal, hasMultipleChars,
172                hasChildren, newDepth, newLeavingDepth);
173        mDicNodeState.init(&dicNode->mDicNodeState, additionalSubwordLength, additionalSubword);
174        PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
175    }
176
177    AK_FORCE_INLINE void remove() {
178        mIsUsed = false;
179        if (mReleaseListener) {
180            mReleaseListener->onReleased(this);
181        }
182    }
183
184    bool isUsed() const {
185        return mIsUsed;
186    }
187
188    bool isRoot() const {
189        return getDepth() == 0;
190    }
191
192    bool hasChildren() const {
193        return mDicNodeProperties.hasChildren();
194    }
195
196    bool isLeavingNode() const {
197        ASSERT(getDepth() <= getLeavingDepth());
198        return getDepth() == getLeavingDepth();
199    }
200
201    AK_FORCE_INLINE bool isFirstLetter() const {
202        return getDepth() == 1;
203    }
204
205    bool isCached() const {
206        return mIsCachedForNextSuggestion;
207    }
208
209    void setCached() {
210        mIsCachedForNextSuggestion = true;
211    }
212
213    // Used to expand the node in DicNodeUtils
214    int getNodeTypedCodePoint() const {
215        return mDicNodeState.mDicNodeStateOutput.getCodePointAt(getDepth());
216    }
217
218    bool isImpossibleBigramWord() const {
219        const int probability = mDicNodeProperties.getProbability();
220        if (probability == 0) {
221            return true;
222        }
223        const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength()
224                - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1;
225        const int currentWordLen = getDepth();
226        return (prevWordLen == 1 && currentWordLen == 1);
227    }
228
229    bool isCapitalized() const {
230        const int c = getOutputWordBuf()[0];
231        return isAsciiUpper(c);
232    }
233
234    bool isFirstWord() const {
235        return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos() == NOT_VALID_WORD;
236    }
237
238    bool isCompletion(const int inputSize) const {
239        return mDicNodeState.mDicNodeStateInput.getInputIndex(0) >= inputSize;
240    }
241
242    bool canDoLookAheadCorrection(const int inputSize) const {
243        return mDicNodeState.mDicNodeStateInput.getInputIndex(0) < inputSize - 1;
244    }
245
246    // Used to get bigram probability in DicNodeUtils
247    int getPos() const {
248        return mDicNodeProperties.getPos();
249    }
250
251    // Used to get bigram probability in DicNodeUtils
252    int getPrevWordPos() const {
253        return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos();
254    }
255
256    // Used in DicNodeUtils
257    int getChildrenPos() const {
258        return mDicNodeProperties.getChildrenPos();
259    }
260
261    // Used in DicNodeUtils
262    int getChildrenCount() const {
263        return mDicNodeProperties.getChildrenCount();
264    }
265
266    // Used in DicNodeUtils
267    int getProbability() const {
268        return mDicNodeProperties.getProbability();
269    }
270
271    AK_FORCE_INLINE bool isTerminalWordNode() const {
272        const bool isTerminalNodes = mDicNodeProperties.isTerminal();
273        const int currentNodeDepth = getDepth();
274        const int terminalNodeDepth = mDicNodeProperties.getLeavingDepth();
275        return isTerminalNodes && currentNodeDepth > 0 && currentNodeDepth == terminalNodeDepth;
276    }
277
278    bool shouldBeFilterdBySafetyNetForBigram() const {
279        const uint16_t currentDepth = getDepth();
280        const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength()
281                - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1;
282        return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1));
283    }
284
285    uint16_t getLeavingDepth() const {
286        return mDicNodeProperties.getLeavingDepth();
287    }
288
289    bool isTotalInputSizeExceedingLimit() const {
290        const int prevWordsLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength();
291        const int currentWordDepth = getDepth();
292        // TODO: 3 can be 2? Needs to be investigated.
293        // TODO: Have a const variable for 3 (or 2)
294        return prevWordsLen + currentWordDepth > MAX_WORD_LENGTH - 3;
295    }
296
297    // TODO: This may be defective. Needs to be revised.
298    bool truncateNode(const DicNode *const topNode, const int inputCommitPoint) {
299        const int prevWordLenOfTop = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength();
300        int newPrevWordStartIndex = inputCommitPoint;
301        int charCount = 0;
302        // Find new word start index
303        for (int i = 0; i < prevWordLenOfTop; ++i) {
304            const int c = mDicNodeState.mDicNodeStatePrevWord.getPrevWordCodePointAt(i);
305            // TODO: Check other separators.
306            if (c != KEYCODE_SPACE && c != KEYCODE_SINGLE_QUOTE) {
307                if (charCount == inputCommitPoint) {
308                    newPrevWordStartIndex = i;
309                    break;
310                }
311                ++charCount;
312            }
313        }
314        if (!mDicNodeState.mDicNodeStatePrevWord.startsWith(
315                &topNode->mDicNodeState.mDicNodeStatePrevWord, newPrevWordStartIndex - 1)) {
316            // Node mismatch.
317            return false;
318        }
319        mDicNodeState.mDicNodeStateInput.truncate(inputCommitPoint);
320        mDicNodeState.mDicNodeStatePrevWord.truncate(newPrevWordStartIndex);
321        return true;
322    }
323
324    void outputResult(int *dest) const {
325        const uint16_t prevWordLength = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength();
326        const uint16_t currentDepth = getDepth();
327        DicNodeUtils::appendTwoWords(mDicNodeState.mDicNodeStatePrevWord.mPrevWord,
328                   prevWordLength, getOutputWordBuf(), currentDepth, dest);
329        DUMP_WORD_AND_SCORE("OUTPUT");
330    }
331
332    void outputSpacePositionsResult(int *spaceIndices) const {
333        mDicNodeState.mDicNodeStatePrevWord.outputSpacePositions(spaceIndices);
334    }
335
336    bool hasMultipleWords() const {
337        return mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() > 0;
338    }
339
340    float getProximityCorrectionCount() const {
341        return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount());
342    }
343
344    float getEditCorrectionCount() const {
345        return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount());
346    }
347
348    // Used to prune nodes
349    float getNormalizedCompoundDistance() const {
350        return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistance();
351    }
352
353    // Used to prune nodes
354    float getNormalizedSpatialDistance() const {
355        return mDicNodeState.mDicNodeStateScoring.getSpatialDistance()
356                / static_cast<float>(getInputIndex(0) + 1);
357    }
358
359    // Used to prune nodes
360    float getCompoundDistance() const {
361        return mDicNodeState.mDicNodeStateScoring.getCompoundDistance();
362    }
363
364    // Used to prune nodes
365    float getCompoundDistance(const float languageWeight) const {
366        return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight);
367    }
368
369    // Note that "cost" means delta for "distance" that is weighted.
370    float getTotalPrevWordsLanguageCost() const {
371        return mDicNodeState.mDicNodeStateScoring.getTotalPrevWordsLanguageCost();
372    }
373
374    // Used to commit input partially
375    int getPrevWordNodePos() const {
376        return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos();
377    }
378
379    AK_FORCE_INLINE const int *getOutputWordBuf() const {
380        return mDicNodeState.mDicNodeStateOutput.mWordBuf;
381    }
382
383    int getPrevCodePointG(int pointerId) const {
384        return mDicNodeState.mDicNodeStateInput.getPrevCodePoint(pointerId);
385    }
386
387    // Whether the current codepoint can be an intentional omission, in which case the traversal
388    // algorithm will always check for a possible omission here.
389    bool canBeIntentionalOmission() const {
390        return isIntentionalOmissionCodePoint(getNodeCodePoint());
391    }
392
393    // Whether the omission is so frequent that it should incur zero cost.
394    bool isZeroCostOmission() const {
395        // TODO: do not hardcode and read from header
396        return (getNodeCodePoint() == KEYCODE_SINGLE_QUOTE);
397    }
398
399    // TODO: remove
400    float getTerminalDiffCostG(int path) const {
401        return mDicNodeState.mDicNodeStateInput.getTerminalDiffCost(path);
402    }
403
404    //////////////////////
405    // Temporary getter //
406    // TODO: Remove     //
407    //////////////////////
408    // TODO: Remove once touch path is merged into ProximityInfoState
409    int getNodeCodePoint() const {
410        return mDicNodeProperties.getNodeCodePoint();
411    }
412
413    ////////////////////////////////
414    // Utils for cost calculation //
415    ////////////////////////////////
416    AK_FORCE_INLINE bool isSameNodeCodePoint(const DicNode *const dicNode) const {
417        return mDicNodeProperties.getNodeCodePoint()
418                == dicNode->mDicNodeProperties.getNodeCodePoint();
419    }
420
421    // TODO: remove
422    // TODO: rename getNextInputIndex
423    int16_t getInputIndex(int pointerId) const {
424        return mDicNodeState.mDicNodeStateInput.getInputIndex(pointerId);
425    }
426
427    ////////////////////////////////////
428    // Getter of features for scoring //
429    ////////////////////////////////////
430    float getSpatialDistanceForScoring() const {
431        return mDicNodeState.mDicNodeStateScoring.getSpatialDistance();
432    }
433
434    float getLanguageDistanceForScoring() const {
435        return mDicNodeState.mDicNodeStateScoring.getLanguageDistance();
436    }
437
438    float getLanguageDistanceRatePerWordForScoring() const {
439        const float langDist = getLanguageDistanceForScoring();
440        const float totalWordCount =
441                static_cast<float>(mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() + 1);
442        return langDist / totalWordCount;
443    }
444
445    float getRawLength() const {
446        return mDicNodeState.mDicNodeStateScoring.getRawLength();
447    }
448
449    bool isLessThanOneErrorForScoring() const {
450        return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount()
451                + mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount() <= 1;
452    }
453
454    DoubleLetterLevel getDoubleLetterLevel() const {
455        return mDicNodeState.mDicNodeStateScoring.getDoubleLetterLevel();
456    }
457
458    void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) {
459        mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel);
460    }
461
462    uint8_t getFlags() const {
463        return mDicNodeProperties.getFlags();
464    }
465
466    int getAttributesPos() const {
467        return mDicNodeProperties.getAttributesPos();
468    }
469
470    inline uint16_t getDepth() const {
471        return mDicNodeProperties.getDepth();
472    }
473
474    AK_FORCE_INLINE void dump(const char *tag) const {
475#if DEBUG_DICT
476        DUMP_WORD_AND_SCORE(tag);
477#if DEBUG_DUMP_ERROR
478        mProfiler.dump();
479#endif
480#endif
481    }
482
483    void setReleaseListener(DicNodeReleaseListener *releaseListener) {
484        mReleaseListener = releaseListener;
485    }
486
487    AK_FORCE_INLINE bool compare(const DicNode *right) {
488        if (!isUsed() && !right->isUsed()) {
489            // Compare pointer values here for stable comparison
490            return this > right;
491        }
492        if (!isUsed()) {
493            return true;
494        }
495        if (!right->isUsed()) {
496            return false;
497        }
498        const float diff =
499                right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance();
500        static const float MIN_DIFF = 0.000001f;
501        if (diff > MIN_DIFF) {
502            return true;
503        } else if (diff < -MIN_DIFF) {
504            return false;
505        }
506        const int depth = getDepth();
507        const int depthDiff = right->getDepth() - depth;
508        if (depthDiff != 0) {
509            return depthDiff > 0;
510        }
511        for (int i = 0; i < depth; ++i) {
512            const int codePoint = mDicNodeState.mDicNodeStateOutput.getCodePointAt(i);
513            const int rightCodePoint = right->mDicNodeState.mDicNodeStateOutput.getCodePointAt(i);
514            if (codePoint != rightCodePoint) {
515                return rightCodePoint > codePoint;
516            }
517        }
518        // Compare pointer values here for stable comparison
519        return this > right;
520    }
521
522 private:
523    DicNodeProperties mDicNodeProperties;
524    DicNodeState mDicNodeState;
525    // TODO: Remove
526    bool mIsCachedForNextSuggestion;
527    bool mIsUsed;
528    DicNodeReleaseListener *mReleaseListener;
529
530    AK_FORCE_INLINE int getTotalInputIndex() const {
531        int index = 0;
532        for (int i = 0; i < MAX_POINTER_COUNT_G; i++) {
533            index += mDicNodeState.mDicNodeStateInput.getInputIndex(i);
534        }
535        return index;
536    }
537
538    // Caveat: Must not be called outside Weighting
539    // This restriction is guaranteed by "friend"
540    AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost,
541            const bool doNormalization, const int inputSize, const bool isEditCorrection,
542            const bool isProximityCorrection) {
543        if (DEBUG_GEO_FULL) {
544            LOGI_SHOW_ADD_COST_PROP;
545        }
546        mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization,
547                inputSize, getTotalInputIndex(), isEditCorrection, isProximityCorrection);
548    }
549
550    // Caveat: Must not be called outside Weighting
551    // This restriction is guaranteed by "friend"
552    AK_FORCE_INLINE void forwardInputIndex(const int pointerId, const int count,
553            const bool overwritesPrevCodePointByNodeCodePoint) {
554        if (count == 0) {
555            return;
556        }
557        mDicNodeState.mDicNodeStateInput.forwardInputIndex(pointerId, count);
558        if (overwritesPrevCodePointByNodeCodePoint) {
559            mDicNodeState.mDicNodeStateInput.setPrevCodePoint(0, getNodeCodePoint());
560        }
561    }
562
563    AK_FORCE_INLINE void updateInputIndexG(DicNode_InputStateG *inputStateG) {
564        mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId,
565                inputStateG->mInputIndex, inputStateG->mPrevCodePoint,
566                inputStateG->mTerminalDiffCost, inputStateG->mRawLength);
567        mDicNodeState.mDicNodeStateScoring.addRawLength(inputStateG->mRawLength);
568        mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(inputStateG->mDoubleLetterLevel);
569    }
570};
571} // namespace latinime
572#endif // LATINIME_DIC_NODE_H
573