1/*
2 * Copyright (C) 2014, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H
18#define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H
19
20#include <cstdio>
21#include <vector>
22
23#include "defines.h"
24#include "dictionary/property/word_attributes.h"
25#include "dictionary/structure/v4/content/language_model_dict_content_global_counters.h"
26#include "dictionary/structure/v4/content/probability_entry.h"
27#include "dictionary/structure/v4/content/terminal_position_lookup_table.h"
28#include "dictionary/structure/v4/ver4_dict_constants.h"
29#include "dictionary/utils/entry_counters.h"
30#include "dictionary/utils/trie_map.h"
31#include "utils/byte_array_view.h"
32#include "utils/int_array_view.h"
33
34namespace latinime {
35
36class HeaderPolicy;
37
38/**
39 * Class representing language model.
40 *
41 * This class provides methods to get and store unigram/n-gram probability information and flags.
42 */
43class LanguageModelDictContent {
44 public:
45    // Pair of word id and probability entry used for iteration.
46    class WordIdAndProbabilityEntry {
47     public:
48        WordIdAndProbabilityEntry(const int wordId, const ProbabilityEntry &probabilityEntry)
49                : mWordId(wordId), mProbabilityEntry(probabilityEntry) {}
50
51        int getWordId() const { return mWordId; }
52        const ProbabilityEntry getProbabilityEntry() const { return mProbabilityEntry; }
53
54     private:
55        DISALLOW_DEFAULT_CONSTRUCTOR(WordIdAndProbabilityEntry);
56        DISALLOW_ASSIGNMENT_OPERATOR(WordIdAndProbabilityEntry);
57
58        const int mWordId;
59        const ProbabilityEntry mProbabilityEntry;
60    };
61
62    // Iterator.
63    class EntryIterator {
64     public:
65        EntryIterator(const TrieMap::TrieMapIterator &trieMapIterator,
66                const bool hasHistoricalInfo)
67                : mTrieMapIterator(trieMapIterator), mHasHistoricalInfo(hasHistoricalInfo) {}
68
69        const WordIdAndProbabilityEntry operator*() const {
70            const TrieMap::TrieMapIterator::IterationResult &result = *mTrieMapIterator;
71            return WordIdAndProbabilityEntry(
72                    result.key(), ProbabilityEntry::decode(result.value(), mHasHistoricalInfo));
73        }
74
75        bool operator!=(const EntryIterator &other) const {
76            return mTrieMapIterator != other.mTrieMapIterator;
77        }
78
79        const EntryIterator &operator++() {
80            ++mTrieMapIterator;
81            return *this;
82        }
83
84     private:
85        DISALLOW_DEFAULT_CONSTRUCTOR(EntryIterator);
86        DISALLOW_ASSIGNMENT_OPERATOR(EntryIterator);
87
88        TrieMap::TrieMapIterator mTrieMapIterator;
89        const bool mHasHistoricalInfo;
90    };
91
92    // Class represents range to use range base for loops.
93    class EntryRange {
94     public:
95        EntryRange(const TrieMap::TrieMapRange trieMapRange, const bool hasHistoricalInfo)
96                : mTrieMapRange(trieMapRange), mHasHistoricalInfo(hasHistoricalInfo) {}
97
98        EntryIterator begin() const {
99            return EntryIterator(mTrieMapRange.begin(), mHasHistoricalInfo);
100        }
101
102        EntryIterator end() const {
103            return EntryIterator(mTrieMapRange.end(), mHasHistoricalInfo);
104        }
105
106     private:
107        DISALLOW_DEFAULT_CONSTRUCTOR(EntryRange);
108        DISALLOW_ASSIGNMENT_OPERATOR(EntryRange);
109
110        const TrieMap::TrieMapRange mTrieMapRange;
111        const bool mHasHistoricalInfo;
112    };
113
114    class DumppedFullEntryInfo {
115     public:
116        DumppedFullEntryInfo(std::vector<int> &prevWordIds, const int targetWordId,
117                const WordAttributes &wordAttributes, const ProbabilityEntry &probabilityEntry)
118                : mPrevWordIds(prevWordIds), mTargetWordId(targetWordId),
119                  mWordAttributes(wordAttributes), mProbabilityEntry(probabilityEntry) {}
120
121        const WordIdArrayView getPrevWordIds() const { return WordIdArrayView(mPrevWordIds); }
122        int getTargetWordId() const { return mTargetWordId; }
123        const WordAttributes &getWordAttributes() const { return mWordAttributes; }
124        const ProbabilityEntry &getProbabilityEntry() const { return mProbabilityEntry; }
125
126     private:
127        DISALLOW_ASSIGNMENT_OPERATOR(DumppedFullEntryInfo);
128
129        const std::vector<int> mPrevWordIds;
130        const int mTargetWordId;
131        const WordAttributes mWordAttributes;
132        const ProbabilityEntry mProbabilityEntry;
133    };
134
135    LanguageModelDictContent(const ReadWriteByteArrayView *const buffers,
136            const bool hasHistoricalInfo)
137            : mTrieMap(buffers[TRIE_MAP_BUFFER_INDEX]),
138              mGlobalCounters(buffers[GLOBAL_COUNTERS_BUFFER_INDEX]),
139              mHasHistoricalInfo(hasHistoricalInfo) {}
140
141    explicit LanguageModelDictContent(const bool hasHistoricalInfo)
142            : mTrieMap(), mGlobalCounters(), mHasHistoricalInfo(hasHistoricalInfo) {}
143
144    bool isNearSizeLimit() const {
145        return mTrieMap.isNearSizeLimit() || mGlobalCounters.needsToHalveCounters();
146    }
147
148    bool save(FILE *const file) const;
149
150    bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
151            const LanguageModelDictContent *const originalContent);
152
153    const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId,
154            const bool mustMatchAllPrevWords, const HeaderPolicy *const headerPolicy) const;
155
156    ProbabilityEntry getProbabilityEntry(const int wordId) const {
157        return getNgramProbabilityEntry(WordIdArrayView(), wordId);
158    }
159
160    bool setProbabilityEntry(const int wordId, const ProbabilityEntry *const probabilityEntry) {
161        mGlobalCounters.addToTotalCount(probabilityEntry->getHistoricalInfo()->getCount());
162        return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry);
163    }
164
165    bool removeProbabilityEntry(const int wordId) {
166        return removeNgramProbabilityEntry(WordIdArrayView(), wordId);
167    }
168
169    ProbabilityEntry getNgramProbabilityEntry(const WordIdArrayView prevWordIds,
170            const int wordId) const;
171
172    bool setNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId,
173            const ProbabilityEntry *const probabilityEntry);
174
175    bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
176
177    EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
178
179    std::vector<DumppedFullEntryInfo> exportAllNgramEntriesRelatedToWord(
180            const HeaderPolicy *const headerPolicy, const int wordId) const;
181
182    bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
183            MutableEntryCounters *const outEntryCounters) {
184        if (!updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
185                0 /* prevWordCount */, headerPolicy, mGlobalCounters.needsToHalveCounters(),
186                outEntryCounters)) {
187            return false;
188        }
189        if (mGlobalCounters.needsToHalveCounters()) {
190            mGlobalCounters.halveCounters();
191        }
192        return true;
193    }
194
195    // entryCounts should be created by updateAllProbabilityEntries.
196    bool truncateEntries(const EntryCounts &currentEntryCounts, const EntryCounts &maxEntryCounts,
197            const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);
198
199    bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId,
200            const bool isValid, const HistoricalInfo historicalInfo,
201            const HeaderPolicy *const headerPolicy,
202            MutableEntryCounters *const entryCountersToUpdate);
203
204 private:
205    DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
206
207    class EntryInfoToTurncate {
208     public:
209        class Comparator {
210         public:
211            bool operator()(const EntryInfoToTurncate &left,
212                    const EntryInfoToTurncate &right) const;
213         private:
214            DISALLOW_ASSIGNMENT_OPERATOR(Comparator);
215        };
216
217        EntryInfoToTurncate(const int priority, const int count, const int key,
218                const int prevWordCount, const int *const prevWordIds);
219
220        int mPriority;
221        // TODO: Remove.
222        int mCount;
223        int mKey;
224        int mPrevWordCount;
225        int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
226
227     private:
228        DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
229    };
230
231    static const int TRIE_MAP_BUFFER_INDEX;
232    static const int GLOBAL_COUNTERS_BUFFER_INDEX;
233
234    TrieMap mTrieMap;
235    LanguageModelDictContentGlobalCounters mGlobalCounters;
236    const bool mHasHistoricalInfo;
237
238    bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
239            const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex);
240    int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
241    int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
242    bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount,
243            const HeaderPolicy *const headerPolicy, const bool needsToHalveCounters,
244            MutableEntryCounters *const outEntryCounters);
245    bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
246            const int maxEntryCount, const int targetLevel, int *const outEntryCount);
247    bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
248            const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
249            std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
250    const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
251            const bool isValid, const HistoricalInfo historicalInfo,
252            const HeaderPolicy *const headerPolicy) const;
253    void exportAllNgramEntriesRelatedToWordInner(const HeaderPolicy *const headerPolicy,
254            const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
255            std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const;
256};
257} // namespace latinime
258#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
259