1/*
2 * Copyright (C) 2013, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LATINIME_FORGETTING_CURVE_UTILS_H
18#define LATINIME_FORGETTING_CURVE_UTILS_H
19
20#include <vector>
21
22#include "defines.h"
23#include "dictionary/property/historical_info.h"
24#include "dictionary/utils/entry_counters.h"
25
26namespace latinime {
27
28class HeaderPolicy;
29
30class ForgettingCurveUtils {
31 public:
32    static const HistoricalInfo createUpdatedHistoricalInfo(
33            const HistoricalInfo *const originalHistoricalInfo, const int newProbability,
34            const HistoricalInfo *const newHistoricalInfo, const HeaderPolicy *const headerPolicy);
35
36    static const HistoricalInfo createHistoricalInfoToSave(
37            const HistoricalInfo *const originalHistoricalInfo,
38            const HeaderPolicy *const headerPolicy);
39
40    static int decodeProbability(const HistoricalInfo *const historicalInfo,
41            const HeaderPolicy *const headerPolicy);
42
43    static bool needsToKeep(const HistoricalInfo *const historicalInfo,
44            const HeaderPolicy *const headerPolicy);
45
46    static bool needsToDecay(const bool mindsBlockByDecay, const EntryCounts &entryCounters,
47            const HeaderPolicy *const headerPolicy);
48
49    // TODO: Improve probability computation method and remove this.
50    static int getProbabilityBiasForNgram(const int n) {
51        return (n - 1) * MULTIPLIER_TWO_IN_PROBABILITY_SCALE;
52    }
53
54    AK_FORCE_INLINE static int getEntryCountHardLimit(const int maxEntryCount) {
55        return static_cast<int>(static_cast<float>(maxEntryCount)
56                * ENTRY_COUNT_HARD_LIMIT_WEIGHT);
57    }
58
59 private:
60    DISALLOW_IMPLICIT_CONSTRUCTORS(ForgettingCurveUtils);
61
62    class ProbabilityTable {
63     public:
64        ProbabilityTable();
65
66        int getProbability(const int tableId, const int level,
67                const int elapsedTimeStepCount) const {
68            return mTables[tableId][level][elapsedTimeStepCount];
69        }
70
71     private:
72        DISALLOW_COPY_AND_ASSIGN(ProbabilityTable);
73
74        static const int PROBABILITY_TABLE_COUNT;
75        static const int WEAK_PROBABILITY_TABLE_ID;
76        static const int MODEST_PROBABILITY_TABLE_ID;
77        static const int STRONG_PROBABILITY_TABLE_ID;
78        static const int AGGRESSIVE_PROBABILITY_TABLE_ID;
79
80        static const int WEAK_MAX_PROBABILITY;
81        static const int MODEST_BASE_PROBABILITY;
82        static const int STRONG_BASE_PROBABILITY;
83        static const int AGGRESSIVE_BASE_PROBABILITY;
84
85        std::vector<std::vector<std::vector<int>>> mTables;
86
87        static int getBaseProbabilityForLevel(const int tableId, const int level);
88    };
89
90    static const int MULTIPLIER_TWO_IN_PROBABILITY_SCALE;
91    static const int DECAY_INTERVAL_SECONDS;
92
93    static const int MAX_LEVEL;
94    static const int MIN_VISIBLE_LEVEL;
95    static const int MAX_ELAPSED_TIME_STEP_COUNT;
96    static const int DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD;
97    static const int OCCURRENCES_TO_RAISE_THE_LEVEL;
98    static const int DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS;
99
100    static const float ENTRY_COUNT_HARD_LIMIT_WEIGHT;
101
102    static const ProbabilityTable sProbabilityTable;
103
104    static int backoff(const int unigramProbability);
105    static int getElapsedTimeStepCount(const int timestamp, const int durationToLevelDown);
106    static int clampToVisibleEntryLevelRange(const int level);
107    static int clampToValidLevelRange(const int level);
108    static int clampToValidCountRange(const int count, const HeaderPolicy *const headerPolicy);
109    static int clampToValidTimeStepCountRange(const int timeStepCount);
110};
111} // namespace latinime
112#endif /* LATINIME_FORGETTING_CURVE_UTILS_H */
113