weighting.h revision 2fa3693c264a4c150ac307d9bb7f6f8f18cc4ffc
1/*
2 * Copyright (C) 2013 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LATINIME_WEIGHTING_H
18#define LATINIME_WEIGHTING_H
19
20#include "defines.h"
21#include "suggest/core/dictionary/error_type_utils.h"
22
23namespace latinime {
24
25class DicNode;
26class DicTraverseSession;
27struct DicNode_InputStateG;
28class MultiBigramMap;
29
30class Weighting {
31 public:
32    static void addCostAndForwardInputIndex(const Weighting *const weighting,
33            const CorrectionType correctionType,
34            const DicTraverseSession *const traverseSession,
35            const DicNode *const parentDicNode, DicNode *const dicNode,
36            MultiBigramMap *const multiBigramMap);
37
38 protected:
39    virtual float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
40            const DicNode *const dicNode) const = 0;
41
42    virtual float getOmissionCost(
43         const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
44
45    virtual float getMatchedCost(
46            const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
47            DicNode_InputStateG *inputStateG) const = 0;
48
49    virtual bool isProximityDicNode(const DicTraverseSession *const traverseSession,
50            const DicNode *const dicNode) const = 0;
51
52    virtual float getTranspositionCost(
53            const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
54            const DicNode *const dicNode) const = 0;
55
56    virtual float getInsertionCost(
57            const DicTraverseSession *const traverseSession,
58            const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
59
60    virtual float getNewWordSpatialCost(const DicTraverseSession *const traverseSession,
61            const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) const = 0;
62
63    virtual float getNewWordBigramLanguageCost(
64            const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
65            MultiBigramMap *const multiBigramMap) const = 0;
66
67    virtual float getCompletionCost(
68            const DicTraverseSession *const traverseSession,
69            const DicNode *const dicNode) const = 0;
70
71    virtual float getTerminalInsertionCost(
72            const DicTraverseSession *const traverseSession,
73            const DicNode *const dicNode) const = 0;
74
75    virtual float getTerminalLanguageCost(
76            const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
77            float dicNodeLanguageImprobability) const = 0;
78
79    virtual bool needsToNormalizeCompoundDistance() const = 0;
80
81    virtual float getAdditionalProximityCost() const = 0;
82
83    virtual float getSubstitutionCost() const = 0;
84
85    virtual float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
86            const DicNode *const dicNode) const = 0;
87
88    virtual ErrorTypeUtils::ErrorType getErrorType(const CorrectionType correctionType,
89            const DicTraverseSession *const traverseSession,
90            const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
91
92    Weighting() {}
93    virtual ~Weighting() {}
94
95 private:
96    DISALLOW_COPY_AND_ASSIGN(Weighting);
97
98    static float getSpatialCost(const Weighting *const weighting,
99            const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
100            const DicNode *const parentDicNode, const DicNode *const dicNode,
101            DicNode_InputStateG *const inputStateG);
102    static float getLanguageCost(const Weighting *const weighting,
103            const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
104            const DicNode *const parentDicNode, const DicNode *const dicNode,
105            MultiBigramMap *const multiBigramMap);
106    // TODO: Move to TypingWeighting and GestureWeighting?
107    static int getForwardInputCount(const CorrectionType correctionType);
108};
109} // namespace latinime
110#endif // LATINIME_WEIGHTING_H
111