1/*
2 * Copyright (C) 2014, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "dictionary/structure/v4/content/language_model_dict_content.h"
18
19#include <algorithm>
20#include <cstring>
21
22#include "dictionary/structure/v4/content/dynamic_language_model_probability_utils.h"
23#include "dictionary/utils/probability_utils.h"
24#include "utils/ngram_utils.h"
25
26namespace latinime {
27
28const int LanguageModelDictContent::TRIE_MAP_BUFFER_INDEX = 0;
29const int LanguageModelDictContent::GLOBAL_COUNTERS_BUFFER_INDEX = 1;
30
31bool LanguageModelDictContent::save(FILE *const file) const {
32    return mTrieMap.save(file) && mGlobalCounters.save(file);
33}
34
35bool LanguageModelDictContent::runGC(
36        const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
37        const LanguageModelDictContent *const originalContent) {
38    return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(),
39            0 /* nextLevelBitmapEntryIndex */);
40}
41
42const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds,
43        const int wordId, const bool mustMatchAllPrevWords,
44        const HeaderPolicy *const headerPolicy) const {
45    int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
46    bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
47    int maxPrevWordCount = 0;
48    for (size_t i = 0; i < prevWordIds.size(); ++i) {
49        const int nextBitmapEntryIndex =
50                mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
51        if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
52            break;
53        }
54        maxPrevWordCount = i + 1;
55        bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
56    }
57
58    const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId);
59    if (mHasHistoricalInfo && unigramProbabilityEntry.getHistoricalInfo()->getCount() == 0) {
60        // The word should be treated as a invalid word.
61        return WordAttributes();
62    }
63    for (int i = maxPrevWordCount; i >= 0; --i) {
64        if (mustMatchAllPrevWords && prevWordIds.size() > static_cast<size_t>(i)) {
65            break;
66        }
67        const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
68        if (!result.mIsValid) {
69            continue;
70        }
71        const ProbabilityEntry probabilityEntry =
72                ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
73        int probability = NOT_A_PROBABILITY;
74        if (mHasHistoricalInfo) {
75            const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
76            int contextCount = 0;
77            if (i == 0) {
78                // unigram
79                contextCount = mGlobalCounters.getTotalCount();
80            } else {
81                const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry(
82                        prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]);
83                if (!prevWordProbabilityEntry.isValid()) {
84                    continue;
85                }
86                if (prevWordProbabilityEntry.representsBeginningOfSentence()
87                        && historicalInfo->getCount() == 1) {
88                    // BoS ngram requires multiple contextCount.
89                    continue;
90                }
91                contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount();
92            }
93            const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(i + 1);
94            const float rawProbability =
95                    DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts(
96                            historicalInfo->getCount(), contextCount, ngramType);
97            const int encodedRawProbability =
98                    ProbabilityUtils::encodeRawProbability(rawProbability);
99            const int decayedProbability =
100                    DynamicLanguageModelProbabilityUtils::getDecayedProbability(
101                            encodedRawProbability, *historicalInfo);
102            probability = DynamicLanguageModelProbabilityUtils::backoff(
103                    decayedProbability, ngramType);
104        } else {
105            probability = probabilityEntry.getProbability();
106        }
107        // TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in
108        // probabilityEntry.
109        return WordAttributes(probability, unigramProbabilityEntry.isBlacklisted(),
110                unigramProbabilityEntry.isNotAWord(),
111                unigramProbabilityEntry.isPossiblyOffensive());
112    }
113    // Cannot find the word.
114    return WordAttributes();
115}
116
117ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
118        const WordIdArrayView prevWordIds, const int wordId) const {
119    const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
120    if (bitmapEntryIndex == TrieMap::INVALID_INDEX) {
121        return ProbabilityEntry();
122    }
123    const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndex);
124    if (!result.mIsValid) {
125        // Not found.
126        return ProbabilityEntry();
127    }
128    return ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
129}
130
131bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds,
132        const int wordId, const ProbabilityEntry *const probabilityEntry) {
133    if (wordId == Ver4DictConstants::NOT_A_TERMINAL_ID) {
134        return false;
135    }
136    const int bitmapEntryIndex = createAndGetBitmapEntryIndex(prevWordIds);
137    if (bitmapEntryIndex == TrieMap::INVALID_INDEX) {
138        return false;
139    }
140    return mTrieMap.put(wordId, probabilityEntry->encode(mHasHistoricalInfo), bitmapEntryIndex);
141}
142
143bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView prevWordIds,
144        const int wordId) {
145    const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
146    if (bitmapEntryIndex == TrieMap::INVALID_INDEX) {
147        // Cannot find bitmap entry for the probability entry. The entry doesn't exist.
148        return false;
149    }
150    return mTrieMap.remove(wordId, bitmapEntryIndex);
151}
152
153LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEntries(
154        const WordIdArrayView prevWordIds) const {
155    const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
156    return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
157}
158
159std::vector<LanguageModelDictContent::DumppedFullEntryInfo>
160        LanguageModelDictContent::exportAllNgramEntriesRelatedToWord(
161                const HeaderPolicy *const headerPolicy, const int wordId) const {
162    const TrieMap::Result result = mTrieMap.getRoot(wordId);
163    if (!result.mIsValid || result.mNextLevelBitmapEntryIndex == TrieMap::INVALID_INDEX) {
164        // The word doesn't have any related ngram entries.
165        return std::vector<DumppedFullEntryInfo>();
166    }
167    std::vector<int> prevWordIds = { wordId };
168    std::vector<DumppedFullEntryInfo> entries;
169    exportAllNgramEntriesRelatedToWordInner(headerPolicy, result.mNextLevelBitmapEntryIndex,
170            &prevWordIds, &entries);
171    return entries;
172}
173
174void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner(
175        const HeaderPolicy *const headerPolicy, const int bitmapEntryIndex,
176        std::vector<int> *const prevWordIds,
177        std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const {
178    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
179        const int wordId = entry.key();
180        const ProbabilityEntry probabilityEntry =
181                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
182        if (probabilityEntry.isValid()) {
183            const WordAttributes wordAttributes = getWordAttributes(
184                    WordIdArrayView(*prevWordIds), wordId, true /* mustMatchAllPrevWords */,
185                    headerPolicy);
186            outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId,
187                    wordAttributes, probabilityEntry);
188        }
189        if (entry.hasNextLevelMap()) {
190            prevWordIds->push_back(wordId);
191            exportAllNgramEntriesRelatedToWordInner(headerPolicy,
192                    entry.getNextLevelBitmapEntryIndex(), prevWordIds, outBummpedFullEntryInfo);
193            prevWordIds->pop_back();
194        }
195    }
196}
197
198bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCounts,
199        const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy,
200        MutableEntryCounters *const outEntryCounters) {
201    for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) {
202        const int totalWordCount = prevWordCount + 1;
203        const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(totalWordCount);
204        if (currentEntryCounts.getNgramCount(ngramType)
205                <= maxEntryCounts.getNgramCount(ngramType)) {
206            outEntryCounters->setNgramCount(ngramType,
207                    currentEntryCounts.getNgramCount(ngramType));
208            continue;
209        }
210        int entryCount = 0;
211        if (!turncateEntriesInSpecifiedLevel(headerPolicy,
212                maxEntryCounts.getNgramCount(ngramType), prevWordCount, &entryCount)) {
213            return false;
214        }
215        outEntryCounters->setNgramCount(ngramType, entryCount);
216    }
217    return true;
218}
219
220bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds,
221        const int wordId, const bool isValid, const HistoricalInfo historicalInfo,
222        const HeaderPolicy *const headerPolicy, MutableEntryCounters *const entryCountersToUpdate) {
223    if (!mHasHistoricalInfo) {
224        AKLOGE("updateAllEntriesOnInputWord is called for dictionary without historical info.");
225        return false;
226    }
227    const ProbabilityEntry originalUnigramProbabilityEntry = getProbabilityEntry(wordId);
228    const ProbabilityEntry updatedUnigramProbabilityEntry = createUpdatedEntryFrom(
229            originalUnigramProbabilityEntry, isValid, historicalInfo, headerPolicy);
230    if (!setProbabilityEntry(wordId, &updatedUnigramProbabilityEntry)) {
231        return false;
232    }
233    mGlobalCounters.incrementTotalCount();
234    mGlobalCounters.updateMaxValueOfCounters(
235            updatedUnigramProbabilityEntry.getHistoricalInfo()->getCount());
236    for (size_t i = 0; i < prevWordIds.size(); ++i) {
237        if (prevWordIds[i] == NOT_A_WORD_ID) {
238            break;
239        }
240        // TODO: Optimize this code.
241        const WordIdArrayView limitedPrevWordIds = prevWordIds.limit(i + 1);
242        const ProbabilityEntry originalNgramProbabilityEntry = getNgramProbabilityEntry(
243                limitedPrevWordIds, wordId);
244        const ProbabilityEntry updatedNgramProbabilityEntry = createUpdatedEntryFrom(
245                originalNgramProbabilityEntry, isValid, historicalInfo, headerPolicy);
246        if (!setNgramProbabilityEntry(limitedPrevWordIds, wordId, &updatedNgramProbabilityEntry)) {
247            return false;
248        }
249        mGlobalCounters.updateMaxValueOfCounters(
250                updatedNgramProbabilityEntry.getHistoricalInfo()->getCount());
251        if (!originalNgramProbabilityEntry.isValid()) {
252            // (i + 2) words are used in total because the prevWords consists of (i + 1) words when
253            // looking at its i-th element.
254            entryCountersToUpdate->incrementNgramCount(
255                    NgramUtils::getNgramTypeFromWordCount(i + 2));
256        }
257    }
258    return true;
259}
260
261const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(
262        const ProbabilityEntry &originalProbabilityEntry, const bool isValid,
263        const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const {
264    const HistoricalInfo updatedHistoricalInfo = HistoricalInfo(historicalInfo.getTimestamp(),
265            0 /* level */, originalProbabilityEntry.getHistoricalInfo()->getCount()
266                    + historicalInfo.getCount());
267    if (originalProbabilityEntry.isValid()) {
268        return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo);
269    } else {
270        return ProbabilityEntry(0 /* flags */, &updatedHistoricalInfo);
271    }
272}
273
274bool LanguageModelDictContent::runGCInner(
275        const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
276        const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex) {
277    for (auto &entry : trieMapRange) {
278        const auto it = terminalIdMap->find(entry.key());
279        if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) {
280            // The word has been removed.
281            continue;
282        }
283        if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) {
284            return false;
285        }
286        if (entry.hasNextLevelMap()) {
287            if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(),
288                    mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex))) {
289                return false;
290            }
291        }
292    }
293    return true;
294}
295
296int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
297    int lastBitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
298    for (const int wordId : prevWordIds) {
299        const TrieMap::Result result = mTrieMap.get(wordId, lastBitmapEntryIndex);
300        if (result.mIsValid && result.mNextLevelBitmapEntryIndex != TrieMap::INVALID_INDEX) {
301            lastBitmapEntryIndex = result.mNextLevelBitmapEntryIndex;
302            continue;
303        }
304        if (!result.mIsValid) {
305            if (!mTrieMap.put(wordId, ProbabilityEntry().encode(mHasHistoricalInfo),
306                    lastBitmapEntryIndex)) {
307                AKLOGE("Failed to update trie map. wordId: %d, lastBitmapEntryIndex %d", wordId,
308                        lastBitmapEntryIndex);
309                return TrieMap::INVALID_INDEX;
310            }
311        }
312        lastBitmapEntryIndex = mTrieMap.getNextLevelBitmapEntryIndex(wordId,
313                lastBitmapEntryIndex);
314    }
315    return lastBitmapEntryIndex;
316}
317
318int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
319    int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
320    for (const int wordId : prevWordIds) {
321        const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndex);
322        if (!result.mIsValid) {
323            return TrieMap::INVALID_INDEX;
324        }
325        bitmapEntryIndex = result.mNextLevelBitmapEntryIndex;
326    }
327    return bitmapEntryIndex;
328}
329
330bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
331        const int prevWordCount, const HeaderPolicy *const headerPolicy,
332        const bool needsToHalveCounters, MutableEntryCounters *const outEntryCounters) {
333    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
334        if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
335            AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
336                    prevWordCount, MAX_PREV_WORD_COUNT_FOR_N_GRAM);
337            return false;
338        }
339        const ProbabilityEntry probabilityEntry =
340                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
341        if (prevWordCount > 0 && probabilityEntry.isValid()
342                && !mTrieMap.getRoot(entry.key()).mIsValid) {
343            // The entry is related to a word that has been removed. Remove the entry.
344            if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
345                return false;
346            }
347            continue;
348        }
349        if (mHasHistoricalInfo && probabilityEntry.isValid()) {
350            const HistoricalInfo *originalHistoricalInfo = probabilityEntry.getHistoricalInfo();
351            if (DynamicLanguageModelProbabilityUtils::shouldRemoveEntryDuringGC(
352                    *originalHistoricalInfo)) {
353                // Remove the entry.
354                if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
355                    return false;
356                }
357                continue;
358            }
359            if (needsToHalveCounters) {
360                const int updatedCount = originalHistoricalInfo->getCount() / 2;
361                if (updatedCount == 0) {
362                    // Remove the entry.
363                    if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
364                        return false;
365                    }
366                    continue;
367                }
368                const HistoricalInfo historicalInfoToSave(originalHistoricalInfo->getTimestamp(),
369                        originalHistoricalInfo->getLevel(), updatedCount);
370                const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(),
371                        &historicalInfoToSave);
372                if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
373                        bitmapEntryIndex)) {
374                    return false;
375                }
376            }
377        }
378        outEntryCounters->incrementNgramCount(
379                NgramUtils::getNgramTypeFromWordCount(prevWordCount + 1));
380        if (!entry.hasNextLevelMap()) {
381            continue;
382        }
383        if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(),
384                prevWordCount + 1, headerPolicy, needsToHalveCounters, outEntryCounters)) {
385            return false;
386        }
387    }
388    return true;
389}
390
391bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
392        const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel,
393        int *const outEntryCount) {
394    std::vector<int> prevWordIds;
395    std::vector<EntryInfoToTurncate> entryInfoVector;
396    if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(),
397            &prevWordIds, &entryInfoVector)) {
398        return false;
399    }
400    if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) {
401        *outEntryCount = static_cast<int>(entryInfoVector.size());
402        return true;
403    }
404    *outEntryCount = maxEntryCount;
405    const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount;
406    std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove,
407            entryInfoVector.end(),
408            EntryInfoToTurncate::Comparator());
409    for (int i = 0; i < entryCountToRemove; ++i) {
410        const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
411        if (!removeNgramProbabilityEntry(
412                WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount),
413                entryInfo.mKey)) {
414            return false;
415        }
416    }
417    return true;
418}
419
420bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy,
421        const int targetLevel, const int bitmapEntryIndex,  std::vector<int> *const prevWordIds,
422        std::vector<EntryInfoToTurncate> *const outEntryInfo) const {
423    const int prevWordCount = prevWordIds->size();
424    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
425        if (prevWordCount < targetLevel) {
426            if (!entry.hasNextLevelMap()) {
427                continue;
428            }
429            prevWordIds->push_back(entry.key());
430            if (!getEntryInfo(headerPolicy, targetLevel, entry.getNextLevelBitmapEntryIndex(),
431                    prevWordIds, outEntryInfo)) {
432                return false;
433            }
434            prevWordIds->pop_back();
435            continue;
436        }
437        const ProbabilityEntry probabilityEntry =
438                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
439        const int priority = mHasHistoricalInfo
440                ? DynamicLanguageModelProbabilityUtils::getPriorityToPreventFromEviction(
441                        *probabilityEntry.getHistoricalInfo())
442                : probabilityEntry.getProbability();
443        outEntryInfo->emplace_back(priority, probabilityEntry.getHistoricalInfo()->getCount(),
444                entry.key(), targetLevel, prevWordIds->data());
445    }
446    return true;
447}
448
449bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
450        const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const {
451    if (left.mPriority != right.mPriority) {
452        return left.mPriority < right.mPriority;
453    }
454    if (left.mCount != right.mCount) {
455        return left.mCount < right.mCount;
456    }
457    if (left.mKey != right.mKey) {
458        return left.mKey < right.mKey;
459    }
460    if (left.mPrevWordCount != right.mPrevWordCount) {
461        return left.mPrevWordCount > right.mPrevWordCount;
462    }
463    for (int i = 0; i < left.mPrevWordCount; ++i) {
464        if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) {
465            return left.mPrevWordIds[i] < right.mPrevWordIds[i];
466        }
467    }
468    // left and rigth represent the same entry.
469    return false;
470}
471
472LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int priority,
473        const int count, const int key, const int prevWordCount, const int *const prevWordIds)
474        : mPriority(priority), mCount(count), mKey(key), mPrevWordCount(prevWordCount) {
475    memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0]));
476}
477
478} // namespace latinime
479