prev_words_info.h revision b00973952f269ebee6d1d5f808fad7ca64fb9954
1/*
2 * Copyright (C) 2014 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LATINIME_PREV_WORDS_INFO_H
18#define LATINIME_PREV_WORDS_INFO_H
19
20#include "defines.h"
21#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h"
22#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
23#include "utils/char_utils.h"
24
25namespace latinime {
26
27// TODO: Support n-gram.
28class PrevWordsInfo {
29 public:
30    // No prev word information.
31    PrevWordsInfo() {
32        clear();
33    }
34
35    PrevWordsInfo(PrevWordsInfo &&prevWordsInfo) {
36        for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) {
37            mPrevWordCodePointCount[i] = prevWordsInfo.mPrevWordCodePointCount[i];
38            memmove(mPrevWordCodePoints[i], prevWordsInfo.mPrevWordCodePoints[i],
39                    sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]);
40            mIsBeginningOfSentence[i] = prevWordsInfo.mIsBeginningOfSentence[i];
41        }
42    }
43
44    // Construct from previous words.
45    PrevWordsInfo(const int prevWordCodePoints[][MAX_WORD_LENGTH],
46            const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence,
47            const size_t prevWordCount) {
48        clear();
49        for (size_t i = 0; i < std::min(NELEMS(mPrevWordCodePoints), prevWordCount); ++i) {
50            if (prevWordCodePointCount[i] < 0 || prevWordCodePointCount[i] > MAX_WORD_LENGTH) {
51                continue;
52            }
53            memmove(mPrevWordCodePoints[i], prevWordCodePoints[i],
54                    sizeof(mPrevWordCodePoints[i][0]) * prevWordCodePointCount[i]);
55            mPrevWordCodePointCount[i] = prevWordCodePointCount[i];
56            mIsBeginningOfSentence[i] = isBeginningOfSentence[i];
57        }
58    }
59
60    // Construct from a previous word.
61    PrevWordsInfo(const int *const prevWordCodePoints, const int prevWordCodePointCount,
62            const bool isBeginningOfSentence) {
63        clear();
64        if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) {
65            return;
66        }
67        memmove(mPrevWordCodePoints[0], prevWordCodePoints,
68                sizeof(mPrevWordCodePoints[0][0]) * prevWordCodePointCount);
69        mPrevWordCodePointCount[0] = prevWordCodePointCount;
70        mIsBeginningOfSentence[0] = isBeginningOfSentence;
71    }
72
73    bool isValid() const {
74        if (mPrevWordCodePointCount[0] > 0) {
75            return true;
76        }
77        if (mIsBeginningOfSentence[0]) {
78            return true;
79        }
80        return false;
81    }
82
83    void getPrevWordsTerminalPtNodePos(
84            const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
85            int *const outPrevWordsTerminalPtNodePos, const bool tryLowerCaseSearch) const {
86        for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) {
87            outPrevWordsTerminalPtNodePos[i] = getTerminalPtNodePosOfWord(dictStructurePolicy,
88                    mPrevWordCodePoints[i], mPrevWordCodePointCount[i],
89                    mIsBeginningOfSentence[i], tryLowerCaseSearch);
90        }
91    }
92
93    BinaryDictionaryBigramsIterator getBigramsIteratorForPrediction(
94            const DictionaryStructureWithBufferPolicy *const dictStructurePolicy) const {
95        return getBigramsIteratorForWordWithTryingLowerCaseSearch(
96                dictStructurePolicy, mPrevWordCodePoints[0], mPrevWordCodePointCount[0],
97                mIsBeginningOfSentence[0]);
98    }
99
100    // n is 1-indexed.
101    const int *getNthPrevWordCodePoints(const int n) const {
102        if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
103            return nullptr;
104        }
105        return mPrevWordCodePoints[n - 1];
106    }
107
108    // n is 1-indexed.
109    int getNthPrevWordCodePointCount(const int n) const {
110        if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
111            return 0;
112        }
113        return mPrevWordCodePointCount[n - 1];
114    }
115
116    // n is 1-indexed.
117    bool isNthPrevWordBeginningOfSentence(const int n) const {
118        if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
119            return false;
120        }
121        return mIsBeginningOfSentence[n - 1];
122    }
123
124 private:
125    DISALLOW_COPY_AND_ASSIGN(PrevWordsInfo);
126
127    static int getTerminalPtNodePosOfWord(
128            const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
129            const int *const wordCodePoints, const int wordCodePointCount,
130            const bool isBeginningOfSentence, const bool tryLowerCaseSearch) {
131        if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) {
132            return NOT_A_DICT_POS;
133        }
134        int codePoints[MAX_WORD_LENGTH];
135        int codePointCount = wordCodePointCount;
136        memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount);
137        if (isBeginningOfSentence) {
138            codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints,
139                    codePointCount, MAX_WORD_LENGTH);
140            if (codePointCount <= 0) {
141                return NOT_A_DICT_POS;
142            }
143        }
144        const int wordPtNodePos = dictStructurePolicy->getTerminalPtNodePositionOfWord(
145                codePoints, codePointCount, false /* forceLowerCaseSearch */);
146        if (wordPtNodePos != NOT_A_DICT_POS || !tryLowerCaseSearch) {
147            // Return the position when when the word was found or doesn't try lower case
148            // search.
149            return wordPtNodePos;
150        }
151        // Check bigrams for lower-cased previous word if original was not found. Useful for
152        // auto-capitalized words like "The [current_word]".
153        return dictStructurePolicy->getTerminalPtNodePositionOfWord(
154                codePoints, codePointCount, true /* forceLowerCaseSearch */);
155    }
156
157    static BinaryDictionaryBigramsIterator getBigramsIteratorForWordWithTryingLowerCaseSearch(
158            const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
159            const int *const wordCodePoints, const int wordCodePointCount,
160            const bool isBeginningOfSentence) {
161        if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) {
162            return BinaryDictionaryBigramsIterator();
163        }
164        int codePoints[MAX_WORD_LENGTH];
165        int codePointCount = wordCodePointCount;
166        memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount);
167        if (isBeginningOfSentence) {
168            codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints,
169                    codePointCount, MAX_WORD_LENGTH);
170            if (codePointCount <= 0) {
171                return BinaryDictionaryBigramsIterator();
172            }
173        }
174        BinaryDictionaryBigramsIterator bigramsIt = getBigramsIteratorForWord(dictStructurePolicy,
175                codePoints, codePointCount, false /* forceLowerCaseSearch */);
176        // getBigramsIteratorForWord returns an empty iterator if this word isn't in the dictionary
177        // or has no bigrams.
178        if (bigramsIt.hasNext()) {
179            return bigramsIt;
180        }
181        // If no bigrams for this exact word, search again in lower case.
182        return getBigramsIteratorForWord(dictStructurePolicy, codePoints,
183                codePointCount, true /* forceLowerCaseSearch */);
184    }
185
186    static BinaryDictionaryBigramsIterator getBigramsIteratorForWord(
187            const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
188            const int *wordCodePoints, const int wordCodePointCount,
189            const bool forceLowerCaseSearch) {
190        if (!wordCodePoints || wordCodePointCount <= 0) return BinaryDictionaryBigramsIterator();
191        const int terminalPtNodePos = dictStructurePolicy->getTerminalPtNodePositionOfWord(
192                wordCodePoints, wordCodePointCount, forceLowerCaseSearch);
193        if (NOT_A_DICT_POS == terminalPtNodePos) return BinaryDictionaryBigramsIterator();
194        return dictStructurePolicy->getBigramsIteratorOfPtNode(terminalPtNodePos);
195    }
196
197    void clear() {
198        for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) {
199            mPrevWordCodePointCount[i] = 0;
200            mIsBeginningOfSentence[i] = false;
201        }
202    }
203
204    int mPrevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH];
205    int mPrevWordCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
206    bool mIsBeginningOfSentence[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
207};
208} // namespace latinime
209#endif // LATINIME_PREV_WORDS_INFO_H
210