prev_words_info.h revision b00973952f269ebee6d1d5f808fad7ca64fb9954
1/* 2 * Copyright (C) 2014 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#ifndef LATINIME_PREV_WORDS_INFO_H 18#define LATINIME_PREV_WORDS_INFO_H 19 20#include "defines.h" 21#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" 22#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" 23#include "utils/char_utils.h" 24 25namespace latinime { 26 27// TODO: Support n-gram. 28class PrevWordsInfo { 29 public: 30 // No prev word information. 31 PrevWordsInfo() { 32 clear(); 33 } 34 35 PrevWordsInfo(PrevWordsInfo &&prevWordsInfo) { 36 for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { 37 mPrevWordCodePointCount[i] = prevWordsInfo.mPrevWordCodePointCount[i]; 38 memmove(mPrevWordCodePoints[i], prevWordsInfo.mPrevWordCodePoints[i], 39 sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]); 40 mIsBeginningOfSentence[i] = prevWordsInfo.mIsBeginningOfSentence[i]; 41 } 42 } 43 44 // Construct from previous words. 45 PrevWordsInfo(const int prevWordCodePoints[][MAX_WORD_LENGTH], 46 const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence, 47 const size_t prevWordCount) { 48 clear(); 49 for (size_t i = 0; i < std::min(NELEMS(mPrevWordCodePoints), prevWordCount); ++i) { 50 if (prevWordCodePointCount[i] < 0 || prevWordCodePointCount[i] > MAX_WORD_LENGTH) { 51 continue; 52 } 53 memmove(mPrevWordCodePoints[i], prevWordCodePoints[i], 54 sizeof(mPrevWordCodePoints[i][0]) * prevWordCodePointCount[i]); 55 mPrevWordCodePointCount[i] = prevWordCodePointCount[i]; 56 mIsBeginningOfSentence[i] = isBeginningOfSentence[i]; 57 } 58 } 59 60 // Construct from a previous word. 61 PrevWordsInfo(const int *const prevWordCodePoints, const int prevWordCodePointCount, 62 const bool isBeginningOfSentence) { 63 clear(); 64 if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) { 65 return; 66 } 67 memmove(mPrevWordCodePoints[0], prevWordCodePoints, 68 sizeof(mPrevWordCodePoints[0][0]) * prevWordCodePointCount); 69 mPrevWordCodePointCount[0] = prevWordCodePointCount; 70 mIsBeginningOfSentence[0] = isBeginningOfSentence; 71 } 72 73 bool isValid() const { 74 if (mPrevWordCodePointCount[0] > 0) { 75 return true; 76 } 77 if (mIsBeginningOfSentence[0]) { 78 return true; 79 } 80 return false; 81 } 82 83 void getPrevWordsTerminalPtNodePos( 84 const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, 85 int *const outPrevWordsTerminalPtNodePos, const bool tryLowerCaseSearch) const { 86 for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { 87 outPrevWordsTerminalPtNodePos[i] = getTerminalPtNodePosOfWord(dictStructurePolicy, 88 mPrevWordCodePoints[i], mPrevWordCodePointCount[i], 89 mIsBeginningOfSentence[i], tryLowerCaseSearch); 90 } 91 } 92 93 BinaryDictionaryBigramsIterator getBigramsIteratorForPrediction( 94 const DictionaryStructureWithBufferPolicy *const dictStructurePolicy) const { 95 return getBigramsIteratorForWordWithTryingLowerCaseSearch( 96 dictStructurePolicy, mPrevWordCodePoints[0], mPrevWordCodePointCount[0], 97 mIsBeginningOfSentence[0]); 98 } 99 100 // n is 1-indexed. 101 const int *getNthPrevWordCodePoints(const int n) const { 102 if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { 103 return nullptr; 104 } 105 return mPrevWordCodePoints[n - 1]; 106 } 107 108 // n is 1-indexed. 109 int getNthPrevWordCodePointCount(const int n) const { 110 if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { 111 return 0; 112 } 113 return mPrevWordCodePointCount[n - 1]; 114 } 115 116 // n is 1-indexed. 117 bool isNthPrevWordBeginningOfSentence(const int n) const { 118 if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { 119 return false; 120 } 121 return mIsBeginningOfSentence[n - 1]; 122 } 123 124 private: 125 DISALLOW_COPY_AND_ASSIGN(PrevWordsInfo); 126 127 static int getTerminalPtNodePosOfWord( 128 const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, 129 const int *const wordCodePoints, const int wordCodePointCount, 130 const bool isBeginningOfSentence, const bool tryLowerCaseSearch) { 131 if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) { 132 return NOT_A_DICT_POS; 133 } 134 int codePoints[MAX_WORD_LENGTH]; 135 int codePointCount = wordCodePointCount; 136 memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount); 137 if (isBeginningOfSentence) { 138 codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints, 139 codePointCount, MAX_WORD_LENGTH); 140 if (codePointCount <= 0) { 141 return NOT_A_DICT_POS; 142 } 143 } 144 const int wordPtNodePos = dictStructurePolicy->getTerminalPtNodePositionOfWord( 145 codePoints, codePointCount, false /* forceLowerCaseSearch */); 146 if (wordPtNodePos != NOT_A_DICT_POS || !tryLowerCaseSearch) { 147 // Return the position when when the word was found or doesn't try lower case 148 // search. 149 return wordPtNodePos; 150 } 151 // Check bigrams for lower-cased previous word if original was not found. Useful for 152 // auto-capitalized words like "The [current_word]". 153 return dictStructurePolicy->getTerminalPtNodePositionOfWord( 154 codePoints, codePointCount, true /* forceLowerCaseSearch */); 155 } 156 157 static BinaryDictionaryBigramsIterator getBigramsIteratorForWordWithTryingLowerCaseSearch( 158 const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, 159 const int *const wordCodePoints, const int wordCodePointCount, 160 const bool isBeginningOfSentence) { 161 if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) { 162 return BinaryDictionaryBigramsIterator(); 163 } 164 int codePoints[MAX_WORD_LENGTH]; 165 int codePointCount = wordCodePointCount; 166 memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount); 167 if (isBeginningOfSentence) { 168 codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints, 169 codePointCount, MAX_WORD_LENGTH); 170 if (codePointCount <= 0) { 171 return BinaryDictionaryBigramsIterator(); 172 } 173 } 174 BinaryDictionaryBigramsIterator bigramsIt = getBigramsIteratorForWord(dictStructurePolicy, 175 codePoints, codePointCount, false /* forceLowerCaseSearch */); 176 // getBigramsIteratorForWord returns an empty iterator if this word isn't in the dictionary 177 // or has no bigrams. 178 if (bigramsIt.hasNext()) { 179 return bigramsIt; 180 } 181 // If no bigrams for this exact word, search again in lower case. 182 return getBigramsIteratorForWord(dictStructurePolicy, codePoints, 183 codePointCount, true /* forceLowerCaseSearch */); 184 } 185 186 static BinaryDictionaryBigramsIterator getBigramsIteratorForWord( 187 const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, 188 const int *wordCodePoints, const int wordCodePointCount, 189 const bool forceLowerCaseSearch) { 190 if (!wordCodePoints || wordCodePointCount <= 0) return BinaryDictionaryBigramsIterator(); 191 const int terminalPtNodePos = dictStructurePolicy->getTerminalPtNodePositionOfWord( 192 wordCodePoints, wordCodePointCount, forceLowerCaseSearch); 193 if (NOT_A_DICT_POS == terminalPtNodePos) return BinaryDictionaryBigramsIterator(); 194 return dictStructurePolicy->getBigramsIteratorOfPtNode(terminalPtNodePos); 195 } 196 197 void clear() { 198 for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { 199 mPrevWordCodePointCount[i] = 0; 200 mIsBeginningOfSentence[i] = false; 201 } 202 } 203 204 int mPrevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH]; 205 int mPrevWordCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; 206 bool mIsBeginningOfSentence[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; 207}; 208} // namespace latinime 209#endif // LATINIME_PREV_WORDS_INFO_H 210