1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LATINIME_TYPING_WEIGHTING_H
18#define LATINIME_TYPING_WEIGHTING_H
19
20#include "defines.h"
21#include "suggest/core/dicnode/dic_node_utils.h"
22#include "suggest/core/dictionary/error_type_utils.h"
23#include "suggest/core/layout/touch_position_correction_utils.h"
24#include "suggest/core/policy/weighting.h"
25#include "suggest/core/session/dic_traverse_session.h"
26#include "suggest/policyimpl/typing/scoring_params.h"
27#include "utils/char_utils.h"
28
29namespace latinime {
30
31class DicNode;
32struct DicNode_InputStateG;
33class MultiBigramMap;
34
35class TypingWeighting : public Weighting {
36 public:
37    static const TypingWeighting *getInstance() { return &sInstance; }
38
39 protected:
40    float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
41            const DicNode *const dicNode) const {
42        float cost = 0.0f;
43        if (dicNode->hasMultipleWords()) {
44            cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST;
45        }
46        if (dicNode->getProximityCorrectionCount() > 0) {
47            cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST;
48        }
49        if (dicNode->getEditCorrectionCount() > 0) {
50            cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST;
51        }
52        return cost;
53    }
54
55    float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const {
56        const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
57        const bool isIntentionalOmission = parentDicNode->canBeIntentionalOmission();
58        const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
59        // If the traversal omitted the first letter then the dicNode should now be on the second.
60        const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2;
61        float cost = 0.0f;
62        if (isZeroCostOmission) {
63            cost = 0.0f;
64        } else if (isIntentionalOmission) {
65            cost = ScoringParams::INTENTIONAL_OMISSION_COST;
66        } else if (isFirstLetterOmission) {
67            cost = ScoringParams::OMISSION_COST_FIRST_CHAR;
68        } else {
69            cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR
70                    : ScoringParams::OMISSION_COST;
71        }
72        return cost;
73    }
74
75    float getMatchedCost(const DicTraverseSession *const traverseSession,
76            const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
77        const int pointIndex = dicNode->getInputIndex(0);
78        const float normalizedSquaredLength = traverseSession->getProximityInfoState(0)
79                ->getPointToKeyLength(pointIndex,
80                        CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
81        const float normalizedDistance = TouchPositionCorrectionUtils::getSweetSpotFactor(
82                traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength);
83        const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance;
84
85        const bool isFirstChar = pointIndex == 0;
86        const bool isProximity = isProximityDicNode(traverseSession, dicNode);
87        float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_CHAR_PROXIMITY_COST
88                : ScoringParams::PROXIMITY_COST) : 0.0f;
89        if (isProximity && dicNode->getProximityCorrectionCount() == 0) {
90            cost += ScoringParams::FIRST_PROXIMITY_COST;
91        }
92        if (dicNode->getNodeCodePointCount() == 2) {
93            // At the second character of the current word, we check if the first char is uppercase
94            // and the word is a second or later word of a multiple word suggestion. We demote it
95            // if so.
96            const bool isSecondOrLaterWordFirstCharUppercase =
97                    dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase();
98            if (isSecondOrLaterWordFirstCharUppercase) {
99                cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE;
100            }
101        }
102        return weightedDistance + cost;
103    }
104
105    bool isProximityDicNode(const DicTraverseSession *const traverseSession,
106            const DicNode *const dicNode) const {
107        const int pointIndex = dicNode->getInputIndex(0);
108        const int primaryCodePoint = CharUtils::toBaseLowerCase(
109                traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex));
110        const int dicNodeChar = CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint());
111        return primaryCodePoint != dicNodeChar;
112    }
113
114    float getTranspositionCost(const DicTraverseSession *const traverseSession,
115            const DicNode *const parentDicNode, const DicNode *const dicNode) const {
116        const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
117        const int prevCodePoint = parentDicNode->getNodeCodePoint();
118        const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
119                parentPointIndex + 1, CharUtils::toBaseLowerCase(prevCodePoint));
120        const int codePoint = dicNode->getNodeCodePoint();
121        const float distance2 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
122                parentPointIndex, CharUtils::toBaseLowerCase(codePoint));
123        const float distance = distance1 + distance2;
124        const float weightedLengthDistance =
125                distance * ScoringParams::DISTANCE_WEIGHT_LENGTH;
126        return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance;
127    }
128
129    float getInsertionCost(const DicTraverseSession *const traverseSession,
130            const DicNode *const parentDicNode, const DicNode *const dicNode) const {
131        const int16_t insertedPointIndex = parentDicNode->getInputIndex(0);
132        const int prevCodePoint = traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(
133                insertedPointIndex);
134        const int currentCodePoint = dicNode->getNodeCodePoint();
135        const bool sameCodePoint = prevCodePoint == currentCodePoint;
136        const bool existsAdjacentProximityChars = traverseSession->getProximityInfoState(0)
137                ->existsAdjacentProximityChars(insertedPointIndex);
138        const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
139                insertedPointIndex + 1, CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
140        const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH;
141        const bool singleChar = dicNode->getNodeCodePointCount() == 1;
142        float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f);
143        if (sameCodePoint) {
144            cost += ScoringParams::INSERTION_COST_SAME_CHAR;
145        } else if (existsAdjacentProximityChars) {
146            cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR;
147        } else {
148            cost += ScoringParams::INSERTION_COST;
149        }
150        return cost + weightedDistance;
151    }
152
153    float getNewWordSpatialCost(const DicTraverseSession *const traverseSession,
154            const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
155        return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier();
156    }
157
158    float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession,
159            const DicNode *const dicNode,
160            MultiBigramMap *const multiBigramMap) const {
161        return DicNodeUtils::getBigramNodeImprobability(
162                traverseSession->getDictionaryStructurePolicy(),
163                dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
164    }
165
166    float getCompletionCost(const DicTraverseSession *const traverseSession,
167            const DicNode *const dicNode) const {
168        // The auto completion starts when the input index is same as the input size
169        const bool firstCompletion = dicNode->getInputIndex(0)
170                == traverseSession->getInputSize();
171        // TODO: Change the cost for the first completion for the gesture?
172        const float cost = firstCompletion ? ScoringParams::COST_FIRST_COMPLETION
173                : ScoringParams::COST_COMPLETION;
174        return cost;
175    }
176
177    float getTerminalLanguageCost(const DicTraverseSession *const traverseSession,
178            const DicNode *const dicNode, const float dicNodeLanguageImprobability) const {
179        return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
180    }
181
182    float getTerminalInsertionCost(const DicTraverseSession *const traverseSession,
183            const DicNode *const dicNode) const {
184        const int inputIndex = dicNode->getInputIndex(0);
185        const int inputSize = traverseSession->getInputSize();
186        ASSERT(inputIndex < inputSize);
187        // TODO: Implement more efficient logic
188        return  ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex);
189    }
190
191    AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const {
192        return false;
193    }
194
195    AK_FORCE_INLINE float getAdditionalProximityCost() const {
196        return ScoringParams::ADDITIONAL_PROXIMITY_COST;
197    }
198
199    AK_FORCE_INLINE float getSubstitutionCost() const {
200        return ScoringParams::SUBSTITUTION_COST;
201    }
202
203    AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
204            const DicNode *const dicNode) const {
205        const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD;
206        return cost * traverseSession->getMultiWordCostMultiplier();
207    }
208
209    ErrorTypeUtils::ErrorType getErrorType(const CorrectionType correctionType,
210            const DicTraverseSession *const traverseSession,
211            const DicNode *const parentDicNode, const DicNode *const dicNode) const;
212
213 private:
214    DISALLOW_COPY_AND_ASSIGN(TypingWeighting);
215    static const TypingWeighting sInstance;
216
217    TypingWeighting() {}
218    ~TypingWeighting() {}
219};
220} // namespace latinime
221#endif // LATINIME_TYPING_WEIGHTING_H
222