dic_node_state_scoring.h revision 9d618d1431ec78328bd0eecb90ade8bfcef9b025
1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LATINIME_DIC_NODE_STATE_SCORING_H
18#define LATINIME_DIC_NODE_STATE_SCORING_H
19
20#include <stdint.h>
21
22#include "defines.h"
23#include "suggest/core/dictionary/digraph_utils.h"
24
25namespace latinime {
26
27class DicNodeStateScoring {
28 public:
29    AK_FORCE_INLINE DicNodeStateScoring()
30            : mDoubleLetterLevel(NOT_A_DOUBLE_LETTER),
31              mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
32              mEditCorrectionCount(0), mProximityCorrectionCount(0),
33              mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
34              mRawLength(0.0f), mExactMatch(true) {
35    }
36
37    virtual ~DicNodeStateScoring() {}
38
39    void init() {
40        mEditCorrectionCount = 0;
41        mProximityCorrectionCount = 0;
42        mNormalizedCompoundDistance = 0.0f;
43        mSpatialDistance = 0.0f;
44        mLanguageDistance = 0.0f;
45        mRawLength = 0.0f;
46        mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
47        mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
48        mExactMatch = true;
49    }
50
51    AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) {
52        mEditCorrectionCount = scoring->mEditCorrectionCount;
53        mProximityCorrectionCount = scoring->mProximityCorrectionCount;
54        mNormalizedCompoundDistance = scoring->mNormalizedCompoundDistance;
55        mSpatialDistance = scoring->mSpatialDistance;
56        mLanguageDistance = scoring->mLanguageDistance;
57        mRawLength = scoring->mRawLength;
58        mDoubleLetterLevel = scoring->mDoubleLetterLevel;
59        mDigraphIndex = scoring->mDigraphIndex;
60        mExactMatch = scoring->mExactMatch;
61    }
62
63    void addCost(const float spatialCost, const float languageCost, const bool doNormalization,
64            const int inputSize, const int totalInputIndex, const ErrorType errorType) {
65        addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex);
66        switch (errorType) {
67            case ET_EDIT_CORRECTION:
68                ++mEditCorrectionCount;
69                mExactMatch = false;
70                break;
71            case ET_PROXIMITY_CORRECTION:
72                ++mProximityCorrectionCount;
73                mExactMatch = false;
74                break;
75            case ET_COMPLETION:
76                mExactMatch = false;
77                break;
78            case ET_NEW_WORD:
79                mExactMatch = false;
80                break;
81            case ET_INTENTIONAL_OMISSION:
82                mExactMatch = false;
83                break;
84            case ET_NOT_AN_ERROR:
85                break;
86        }
87    }
88
89    void addRawLength(const float rawLength) {
90        mRawLength += rawLength;
91    }
92
93    float getCompoundDistance() const {
94        return getCompoundDistance(1.0f);
95    }
96
97    float getCompoundDistance(const float languageWeight) const {
98        return mSpatialDistance + mLanguageDistance * languageWeight;
99    }
100
101    float getNormalizedCompoundDistance() const {
102        return mNormalizedCompoundDistance;
103    }
104
105    float getSpatialDistance() const {
106        return mSpatialDistance;
107    }
108
109    float getLanguageDistance() const {
110        return mLanguageDistance;
111    }
112
113    int16_t getEditCorrectionCount() const {
114        return mEditCorrectionCount;
115    }
116
117    int16_t getProximityCorrectionCount() const {
118        return mProximityCorrectionCount;
119    }
120
121    float getRawLength() const {
122        return mRawLength;
123    }
124
125    DoubleLetterLevel getDoubleLetterLevel() const {
126        return mDoubleLetterLevel;
127    }
128
129    void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) {
130        switch(doubleLetterLevel) {
131            case NOT_A_DOUBLE_LETTER:
132                break;
133            case A_DOUBLE_LETTER:
134                if (mDoubleLetterLevel != A_STRONG_DOUBLE_LETTER) {
135                    mDoubleLetterLevel = doubleLetterLevel;
136                }
137                break;
138            case A_STRONG_DOUBLE_LETTER:
139                mDoubleLetterLevel = doubleLetterLevel;
140                break;
141        }
142    }
143
144    DigraphUtils::DigraphCodePointIndex getDigraphIndex() const {
145        return mDigraphIndex;
146    }
147
148    void advanceDigraphIndex() {
149        switch(mDigraphIndex) {
150            case DigraphUtils::NOT_A_DIGRAPH_INDEX:
151                mDigraphIndex = DigraphUtils::FIRST_DIGRAPH_CODEPOINT;
152                break;
153            case DigraphUtils::FIRST_DIGRAPH_CODEPOINT:
154                mDigraphIndex = DigraphUtils::SECOND_DIGRAPH_CODEPOINT;
155                break;
156            case DigraphUtils::SECOND_DIGRAPH_CODEPOINT:
157                mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
158                break;
159        }
160    }
161
162    bool isExactMatch() const {
163        return mExactMatch;
164    }
165
166 private:
167    // Caution!!!
168    // Use a default copy constructor and an assign operator because shallow copies are ok
169    // for this class
170    DoubleLetterLevel mDoubleLetterLevel;
171    DigraphUtils::DigraphCodePointIndex mDigraphIndex;
172
173    int16_t mEditCorrectionCount;
174    int16_t mProximityCorrectionCount;
175
176    float mNormalizedCompoundDistance;
177    float mSpatialDistance;
178    float mLanguageDistance;
179    float mRawLength;
180    bool mExactMatch;
181
182    AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance,
183            bool doNormalization, int inputSize, int totalInputIndex) {
184        mSpatialDistance += spatialDistance;
185        mLanguageDistance += languageDistance;
186        if (!doNormalization) {
187            mNormalizedCompoundDistance = mSpatialDistance + mLanguageDistance;
188        } else {
189            mNormalizedCompoundDistance = (mSpatialDistance + mLanguageDistance)
190                    / static_cast<float>(max(1, totalInputIndex));
191        }
192    }
193};
194} // namespace latinime
195#endif // LATINIME_DIC_NODE_STATE_SCORING_H
196