dic_node_state_scoring.h revision 9d618d1431ec78328bd0eecb90ade8bfcef9b025
1/* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#ifndef LATINIME_DIC_NODE_STATE_SCORING_H 18#define LATINIME_DIC_NODE_STATE_SCORING_H 19 20#include <stdint.h> 21 22#include "defines.h" 23#include "suggest/core/dictionary/digraph_utils.h" 24 25namespace latinime { 26 27class DicNodeStateScoring { 28 public: 29 AK_FORCE_INLINE DicNodeStateScoring() 30 : mDoubleLetterLevel(NOT_A_DOUBLE_LETTER), 31 mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX), 32 mEditCorrectionCount(0), mProximityCorrectionCount(0), 33 mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f), 34 mRawLength(0.0f), mExactMatch(true) { 35 } 36 37 virtual ~DicNodeStateScoring() {} 38 39 void init() { 40 mEditCorrectionCount = 0; 41 mProximityCorrectionCount = 0; 42 mNormalizedCompoundDistance = 0.0f; 43 mSpatialDistance = 0.0f; 44 mLanguageDistance = 0.0f; 45 mRawLength = 0.0f; 46 mDoubleLetterLevel = NOT_A_DOUBLE_LETTER; 47 mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; 48 mExactMatch = true; 49 } 50 51 AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) { 52 mEditCorrectionCount = scoring->mEditCorrectionCount; 53 mProximityCorrectionCount = scoring->mProximityCorrectionCount; 54 mNormalizedCompoundDistance = scoring->mNormalizedCompoundDistance; 55 mSpatialDistance = scoring->mSpatialDistance; 56 mLanguageDistance = scoring->mLanguageDistance; 57 mRawLength = scoring->mRawLength; 58 mDoubleLetterLevel = scoring->mDoubleLetterLevel; 59 mDigraphIndex = scoring->mDigraphIndex; 60 mExactMatch = scoring->mExactMatch; 61 } 62 63 void addCost(const float spatialCost, const float languageCost, const bool doNormalization, 64 const int inputSize, const int totalInputIndex, const ErrorType errorType) { 65 addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex); 66 switch (errorType) { 67 case ET_EDIT_CORRECTION: 68 ++mEditCorrectionCount; 69 mExactMatch = false; 70 break; 71 case ET_PROXIMITY_CORRECTION: 72 ++mProximityCorrectionCount; 73 mExactMatch = false; 74 break; 75 case ET_COMPLETION: 76 mExactMatch = false; 77 break; 78 case ET_NEW_WORD: 79 mExactMatch = false; 80 break; 81 case ET_INTENTIONAL_OMISSION: 82 mExactMatch = false; 83 break; 84 case ET_NOT_AN_ERROR: 85 break; 86 } 87 } 88 89 void addRawLength(const float rawLength) { 90 mRawLength += rawLength; 91 } 92 93 float getCompoundDistance() const { 94 return getCompoundDistance(1.0f); 95 } 96 97 float getCompoundDistance(const float languageWeight) const { 98 return mSpatialDistance + mLanguageDistance * languageWeight; 99 } 100 101 float getNormalizedCompoundDistance() const { 102 return mNormalizedCompoundDistance; 103 } 104 105 float getSpatialDistance() const { 106 return mSpatialDistance; 107 } 108 109 float getLanguageDistance() const { 110 return mLanguageDistance; 111 } 112 113 int16_t getEditCorrectionCount() const { 114 return mEditCorrectionCount; 115 } 116 117 int16_t getProximityCorrectionCount() const { 118 return mProximityCorrectionCount; 119 } 120 121 float getRawLength() const { 122 return mRawLength; 123 } 124 125 DoubleLetterLevel getDoubleLetterLevel() const { 126 return mDoubleLetterLevel; 127 } 128 129 void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) { 130 switch(doubleLetterLevel) { 131 case NOT_A_DOUBLE_LETTER: 132 break; 133 case A_DOUBLE_LETTER: 134 if (mDoubleLetterLevel != A_STRONG_DOUBLE_LETTER) { 135 mDoubleLetterLevel = doubleLetterLevel; 136 } 137 break; 138 case A_STRONG_DOUBLE_LETTER: 139 mDoubleLetterLevel = doubleLetterLevel; 140 break; 141 } 142 } 143 144 DigraphUtils::DigraphCodePointIndex getDigraphIndex() const { 145 return mDigraphIndex; 146 } 147 148 void advanceDigraphIndex() { 149 switch(mDigraphIndex) { 150 case DigraphUtils::NOT_A_DIGRAPH_INDEX: 151 mDigraphIndex = DigraphUtils::FIRST_DIGRAPH_CODEPOINT; 152 break; 153 case DigraphUtils::FIRST_DIGRAPH_CODEPOINT: 154 mDigraphIndex = DigraphUtils::SECOND_DIGRAPH_CODEPOINT; 155 break; 156 case DigraphUtils::SECOND_DIGRAPH_CODEPOINT: 157 mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; 158 break; 159 } 160 } 161 162 bool isExactMatch() const { 163 return mExactMatch; 164 } 165 166 private: 167 // Caution!!! 168 // Use a default copy constructor and an assign operator because shallow copies are ok 169 // for this class 170 DoubleLetterLevel mDoubleLetterLevel; 171 DigraphUtils::DigraphCodePointIndex mDigraphIndex; 172 173 int16_t mEditCorrectionCount; 174 int16_t mProximityCorrectionCount; 175 176 float mNormalizedCompoundDistance; 177 float mSpatialDistance; 178 float mLanguageDistance; 179 float mRawLength; 180 bool mExactMatch; 181 182 AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance, 183 bool doNormalization, int inputSize, int totalInputIndex) { 184 mSpatialDistance += spatialDistance; 185 mLanguageDistance += languageDistance; 186 if (!doNormalization) { 187 mNormalizedCompoundDistance = mSpatialDistance + mLanguageDistance; 188 } else { 189 mNormalizedCompoundDistance = (mSpatialDistance + mLanguageDistance) 190 / static_cast<float>(max(1, totalInputIndex)); 191 } 192 } 193}; 194} // namespace latinime 195#endif // LATINIME_DIC_NODE_STATE_SCORING_H 196