dic_node_state_scoring.h revision cafab169cdb21244c82b99c09983c98066113d87
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LATINIME_DIC_NODE_STATE_SCORING_H
18#define LATINIME_DIC_NODE_STATE_SCORING_H
19
20#include <algorithm>
21#include <cstdint>
22
23#include "defines.h"
24#include "suggest/core/dictionary/digraph_utils.h"
25#include "suggest/core/dictionary/error_type_utils.h"
26
27namespace latinime {
28
29class DicNodeStateScoring {
30 public:
31    AK_FORCE_INLINE DicNodeStateScoring()
32            : mDoubleLetterLevel(NOT_A_DOUBLE_LETTER),
33              mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
34              mEditCorrectionCount(0), mProximityCorrectionCount(0), mCompletionCount(0),
35              mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
36              mRawLength(0.0f), mContainedErrorTypes(ErrorTypeUtils::NOT_AN_ERROR),
37              mNormalizedCompoundDistanceAfterFirstWord(MAX_VALUE_FOR_WEIGHTING) {
38    }
39
40    ~DicNodeStateScoring() {}
41
42    void init() {
43        mEditCorrectionCount = 0;
44        mProximityCorrectionCount = 0;
45        mCompletionCount = 0;
46        mNormalizedCompoundDistance = 0.0f;
47        mSpatialDistance = 0.0f;
48        mLanguageDistance = 0.0f;
49        mRawLength = 0.0f;
50        mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
51        mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
52        mNormalizedCompoundDistanceAfterFirstWord = MAX_VALUE_FOR_WEIGHTING;
53        mContainedErrorTypes = ErrorTypeUtils::NOT_AN_ERROR;
54    }
55
56    AK_FORCE_INLINE void initByCopy(const DicNodeStateScoring *const scoring) {
57        mEditCorrectionCount = scoring->mEditCorrectionCount;
58        mProximityCorrectionCount = scoring->mProximityCorrectionCount;
59        mCompletionCount = scoring->mCompletionCount;
60        mNormalizedCompoundDistance = scoring->mNormalizedCompoundDistance;
61        mSpatialDistance = scoring->mSpatialDistance;
62        mLanguageDistance = scoring->mLanguageDistance;
63        mRawLength = scoring->mRawLength;
64        mDoubleLetterLevel = scoring->mDoubleLetterLevel;
65        mDigraphIndex = scoring->mDigraphIndex;
66        mContainedErrorTypes = scoring->mContainedErrorTypes;
67        mNormalizedCompoundDistanceAfterFirstWord =
68                scoring->mNormalizedCompoundDistanceAfterFirstWord;
69    }
70
71    void addCost(const float spatialCost, const float languageCost, const bool doNormalization,
72            const int inputSize, const int totalInputIndex,
73            const ErrorTypeUtils::ErrorType errorType) {
74        addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex);
75        mContainedErrorTypes = mContainedErrorTypes | errorType;
76        if (ErrorTypeUtils::isEditCorrectionError(errorType)) {
77            ++mEditCorrectionCount;
78        }
79        if (ErrorTypeUtils::isProximityCorrectionError(errorType)) {
80            ++mProximityCorrectionCount;
81        }
82        if (ErrorTypeUtils::isCompletion(errorType)) {
83            ++mCompletionCount;
84        }
85    }
86
87    // Saves the current normalized distance for space-aware gestures.
88    // See getNormalizedCompoundDistanceAfterFirstWord for details.
89    void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() {
90        // We get called here after each word. We only want to store the distance after
91        // the first word, so if we already have a distance we skip saving -- hence "IfNoneYet"
92        // in the method name.
93        if (mNormalizedCompoundDistanceAfterFirstWord >= MAX_VALUE_FOR_WEIGHTING) {
94            mNormalizedCompoundDistanceAfterFirstWord = getNormalizedCompoundDistance();
95        }
96    }
97
98    void addRawLength(const float rawLength) {
99        mRawLength += rawLength;
100    }
101
102    float getCompoundDistance() const {
103        return getCompoundDistance(1.0f);
104    }
105
106    float getCompoundDistance(const float languageWeight) const {
107        return mSpatialDistance + mLanguageDistance * languageWeight;
108    }
109
110    float getNormalizedCompoundDistance() const {
111        return mNormalizedCompoundDistance;
112    }
113
114    // For space-aware gestures, we store the normalized distance at the char index
115    // that ends the first word of the suggestion. We call this the distance after
116    // first word.
117    float getNormalizedCompoundDistanceAfterFirstWord() const {
118        return mNormalizedCompoundDistanceAfterFirstWord;
119    }
120
121    float getSpatialDistance() const {
122        return mSpatialDistance;
123    }
124
125    float getLanguageDistance() const {
126        return mLanguageDistance;
127    }
128
129    int16_t getEditCorrectionCount() const {
130        return mEditCorrectionCount;
131    }
132
133    int16_t getProximityCorrectionCount() const {
134        return mProximityCorrectionCount;
135    }
136
137    int16_t getCompletionCount() const {
138        return mCompletionCount;
139    }
140
141    float getRawLength() const {
142        return mRawLength;
143    }
144
145    DoubleLetterLevel getDoubleLetterLevel() const {
146        return mDoubleLetterLevel;
147    }
148
149    void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) {
150        switch(doubleLetterLevel) {
151            case NOT_A_DOUBLE_LETTER:
152                break;
153            case A_DOUBLE_LETTER:
154                if (mDoubleLetterLevel != A_STRONG_DOUBLE_LETTER) {
155                    mDoubleLetterLevel = doubleLetterLevel;
156                }
157                break;
158            case A_STRONG_DOUBLE_LETTER:
159                mDoubleLetterLevel = doubleLetterLevel;
160                break;
161        }
162    }
163
164    DigraphUtils::DigraphCodePointIndex getDigraphIndex() const {
165        return mDigraphIndex;
166    }
167
168    void advanceDigraphIndex() {
169        switch(mDigraphIndex) {
170            case DigraphUtils::NOT_A_DIGRAPH_INDEX:
171                mDigraphIndex = DigraphUtils::FIRST_DIGRAPH_CODEPOINT;
172                break;
173            case DigraphUtils::FIRST_DIGRAPH_CODEPOINT:
174                mDigraphIndex = DigraphUtils::SECOND_DIGRAPH_CODEPOINT;
175                break;
176            case DigraphUtils::SECOND_DIGRAPH_CODEPOINT:
177                mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
178                break;
179        }
180    }
181
182    ErrorTypeUtils::ErrorType getContainedErrorTypes() const {
183        return mContainedErrorTypes;
184    }
185
186 private:
187    DISALLOW_COPY_AND_ASSIGN(DicNodeStateScoring);
188
189    DoubleLetterLevel mDoubleLetterLevel;
190    DigraphUtils::DigraphCodePointIndex mDigraphIndex;
191
192    int16_t mEditCorrectionCount;
193    int16_t mProximityCorrectionCount;
194    int16_t mCompletionCount;
195
196    float mNormalizedCompoundDistance;
197    float mSpatialDistance;
198    float mLanguageDistance;
199    float mRawLength;
200    // All accumulated error types so far
201    ErrorTypeUtils::ErrorType mContainedErrorTypes;
202    float mNormalizedCompoundDistanceAfterFirstWord;
203
204    AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance,
205            bool doNormalization, int inputSize, int totalInputIndex) {
206        mSpatialDistance += spatialDistance;
207        mLanguageDistance += languageDistance;
208        if (!doNormalization) {
209            mNormalizedCompoundDistance = mSpatialDistance + mLanguageDistance;
210        } else {
211            mNormalizedCompoundDistance = (mSpatialDistance + mLanguageDistance)
212                    / static_cast<float>(std::max(1, totalInputIndex));
213        }
214    }
215};
216} // namespace latinime
217#endif // LATINIME_DIC_NODE_STATE_SCORING_H
218