1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "suggest/core/suggest.h"
18
19#include "dictionary/interface/dictionary_structure_with_buffer_policy.h"
20#include "dictionary/property/word_attributes.h"
21#include "suggest/core/dicnode/dic_node.h"
22#include "suggest/core/dicnode/dic_node_priority_queue.h"
23#include "suggest/core/dicnode/dic_node_vector.h"
24#include "suggest/core/dictionary/dictionary.h"
25#include "suggest/core/dictionary/digraph_utils.h"
26#include "suggest/core/layout/proximity_info.h"
27#include "suggest/core/policy/traversal.h"
28#include "suggest/core/policy/weighting.h"
29#include "suggest/core/result/suggestions_output_utils.h"
30#include "suggest/core/session/dic_traverse_session.h"
31#include "suggest/core/suggest_options.h"
32#include "utils/profiler.h"
33
34namespace latinime {
35
36// Initialization of class constants.
37const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2;
38
39/**
40 * Returns a set of suggestions for the given input touch points. The commitPoint argument indicates
41 * whether to prematurely commit the suggested words up to the given point for sentence-level
42 * suggestion.
43 *
44 * Note: Currently does not support concurrent calls across threads. Continuous suggestion is
45 * automatically activated for sequential calls that share the same starting input.
46 * TODO: Stop detecting continuous suggestion. Start using traverseSession instead.
47 */
48void Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession,
49        int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints,
50        int inputSize, const float weightOfLangModelVsSpatialModel,
51        SuggestionResults *const outSuggestionResults) const {
52    PROF_INIT;
53    PROF_TIMER_START(0);
54    const float maxSpatialDistance = TRAVERSAL->getMaxSpatialDistance();
55    DicTraverseSession *tSession = static_cast<DicTraverseSession *>(traverseSession);
56    tSession->setupForGetSuggestions(pInfo, inputCodePoints, inputSize, inputXs, inputYs, times,
57            pointerIds, maxSpatialDistance, TRAVERSAL->getMaxPointerCount());
58    // TODO: Add the way to evaluate cache
59
60    initializeSearch(tSession);
61    PROF_TIMER_END(0);
62    PROF_TIMER_START(1);
63
64    // keep expanding search dicNodes until all have terminated.
65    while (tSession->getDicTraverseCache()->activeSize() > 0) {
66        expandCurrentDicNodes(tSession);
67        tSession->getDicTraverseCache()->advanceActiveDicNodes();
68        tSession->getDicTraverseCache()->advanceInputIndex(inputSize);
69    }
70    PROF_TIMER_END(1);
71    PROF_TIMER_START(2);
72    SuggestionsOutputUtils::outputSuggestions(
73            SCORING, tSession, weightOfLangModelVsSpatialModel, outSuggestionResults);
74    PROF_TIMER_END(2);
75}
76
77/**
78 * Initializes the search at the root of the lexicon trie. Note that when possible the search will
79 * continue suggestion from where it left off during the last call.
80 */
81void Suggest::initializeSearch(DicTraverseSession *traverseSession) const {
82    if (!traverseSession->getProximityInfoState(0)->isUsed()) {
83        return;
84    }
85
86    if (traverseSession->getInputSize() > MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE
87            && traverseSession->isContinuousSuggestionPossible()) {
88        // Continue suggestion
89        traverseSession->getDicTraverseCache()->continueSearch();
90    } else {
91        // Restart recognition at the root.
92        traverseSession->resetCache(TRAVERSAL->getMaxCacheSize(traverseSession->getInputSize(),
93                traverseSession->getSuggestOptions()->weightForLocale()),
94                TRAVERSAL->getTerminalCacheSize());
95        // Create a new dic node here
96        DicNode rootNode;
97        DicNodeUtils::initAsRoot(traverseSession->getDictionaryStructurePolicy(),
98                traverseSession->getPrevWordIds(), &rootNode);
99        traverseSession->getDicTraverseCache()->copyPushActive(&rootNode);
100    }
101}
102
103/**
104 * Expands the dicNodes in the current search priority queue by advancing to the possible child
105 * nodes based on the next touch point(s) (or no touch points for lookahead)
106 */
107void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
108    const int inputSize = traverseSession->getInputSize();
109    DicNodeVector childDicNodes(TRAVERSAL->getDefaultExpandDicNodeSize());
110    DicNode correctionDicNode;
111
112    // TODO: Find more efficient caching
113    const bool shouldDepthLevelCache = TRAVERSAL->shouldDepthLevelCache(traverseSession);
114    if (shouldDepthLevelCache) {
115        traverseSession->getDicTraverseCache()->updateLastCachedInputIndex();
116    }
117    if (DEBUG_CACHE) {
118        AKLOGI("expandCurrentDicNodes depth level cache = %d, inputSize = %d",
119                shouldDepthLevelCache, inputSize);
120    }
121    while (traverseSession->getDicTraverseCache()->activeSize() > 0) {
122        DicNode dicNode;
123        traverseSession->getDicTraverseCache()->popActive(&dicNode);
124        if (dicNode.isTotalInputSizeExceedingLimit()) {
125            return;
126        }
127        childDicNodes.clear();
128        const int point0Index = dicNode.getInputIndex(0);
129        const bool canDoLookAheadCorrection =
130                TRAVERSAL->canDoLookAheadCorrection(traverseSession, &dicNode);
131        const bool isLookAheadCorrection = canDoLookAheadCorrection
132                && traverseSession->getDicTraverseCache()->
133                        isLookAheadCorrectionInputIndex(static_cast<int>(point0Index));
134        const bool isCompletion = dicNode.isCompletion(inputSize);
135
136        const bool shouldNodeLevelCache =
137                TRAVERSAL->shouldNodeLevelCache(traverseSession, &dicNode);
138        if (shouldDepthLevelCache || shouldNodeLevelCache) {
139            if (DEBUG_CACHE) {
140                dicNode.dump("PUSH_CACHE");
141            }
142            traverseSession->getDicTraverseCache()->copyPushContinue(&dicNode);
143            dicNode.setCached();
144        }
145
146        if (dicNode.isInDigraph()) {
147            // Finish digraph handling if the node is in the middle of a digraph expansion.
148            processDicNodeAsDigraph(traverseSession, &dicNode);
149        } else if (isLookAheadCorrection) {
150            // The algorithm maintains a small set of "deferred" nodes that have not consumed the
151            // latest touch point yet. These are needed to apply look-ahead correction operations
152            // that require special handling of the latest touch point. For example, with insertions
153            // (e.g., "thiis" -> "this") the latest touch point should not be consumed at all.
154            processDicNodeAsTransposition(traverseSession, &dicNode);
155            processDicNodeAsInsertion(traverseSession, &dicNode);
156        } else { // !isLookAheadCorrection
157            // Only consider typing error corrections if the normalized compound distance is
158            // below a spatial distance threshold.
159            // NOTE: the threshold may need to be updated if scoring model changes.
160            // TODO: Remove. Do not prune node here.
161            const bool allowsErrorCorrections = TRAVERSAL->allowsErrorCorrections(&dicNode);
162            // Process for handling space substitution (e.g., hevis => he is)
163            if (TRAVERSAL->isSpaceSubstitutionTerminal(traverseSession, &dicNode)) {
164                createNextWordDicNode(traverseSession, &dicNode, true /* spaceSubstitution */);
165            }
166
167            DicNodeUtils::getAllChildDicNodes(
168                    &dicNode, traverseSession->getDictionaryStructurePolicy(), &childDicNodes);
169
170            const int childDicNodesSize = childDicNodes.getSizeAndLock();
171            for (int i = 0; i < childDicNodesSize; ++i) {
172                DicNode *const childDicNode = childDicNodes[i];
173                if (isCompletion) {
174                    // Handle forward lookahead when the lexicon letter exceeds the input size.
175                    processDicNodeAsMatch(traverseSession, childDicNode);
176                    continue;
177                }
178                if (DigraphUtils::hasDigraphForCodePoint(
179                        traverseSession->getDictionaryStructurePolicy()
180                                ->getHeaderStructurePolicy(),
181                        childDicNode->getNodeCodePoint())) {
182                    correctionDicNode.initByCopy(childDicNode);
183                    correctionDicNode.advanceDigraphIndex();
184                    processDicNodeAsDigraph(traverseSession, &correctionDicNode);
185                }
186                if (TRAVERSAL->isOmission(traverseSession, &dicNode, childDicNode,
187                        allowsErrorCorrections)) {
188                    // TODO: (Gesture) Change weight between omission and substitution errors
189                    // TODO: (Gesture) Terminal node should not be handled as omission
190                    correctionDicNode.initByCopy(childDicNode);
191                    processDicNodeAsOmission(traverseSession, &correctionDicNode);
192                }
193                const ProximityType proximityType = TRAVERSAL->getProximityType(
194                        traverseSession, &dicNode, childDicNode);
195                switch (proximityType) {
196                    // TODO: Consider the difference of proximityType here
197                    case MATCH_CHAR:
198                    case PROXIMITY_CHAR:
199                        processDicNodeAsMatch(traverseSession, childDicNode);
200                        break;
201                    case ADDITIONAL_PROXIMITY_CHAR:
202                        if (allowsErrorCorrections) {
203                            processDicNodeAsAdditionalProximityChar(traverseSession, &dicNode,
204                                    childDicNode);
205                        }
206                        break;
207                    case SUBSTITUTION_CHAR:
208                        if (allowsErrorCorrections) {
209                            processDicNodeAsSubstitution(traverseSession, &dicNode, childDicNode);
210                        }
211                        break;
212                    case UNRELATED_CHAR:
213                        // Just drop this dicNode and do nothing.
214                        break;
215                    default:
216                        // Just drop this dicNode and do nothing.
217                        break;
218                }
219            }
220
221            // Push the dicNode for look-ahead correction
222            if (allowsErrorCorrections && canDoLookAheadCorrection) {
223                traverseSession->getDicTraverseCache()->copyPushNextActive(&dicNode);
224            }
225        }
226    }
227}
228
229void Suggest::processTerminalDicNode(
230        DicTraverseSession *traverseSession, DicNode *dicNode) const {
231    if (dicNode->getCompoundDistance() >= static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) {
232        return;
233    }
234    if (!dicNode->isTerminalDicNode()) {
235        return;
236    }
237    if (dicNode->shouldBeFilteredBySafetyNetForBigram()) {
238        return;
239    }
240    if (!dicNode->hasMatchedOrProximityCodePoints()) {
241        return;
242    }
243    // Create a non-cached node here.
244    DicNode terminalDicNode(*dicNode);
245    if (TRAVERSAL->needsToTraverseAllUserInput()
246            && dicNode->getInputIndex(0) < traverseSession->getInputSize()) {
247        Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL_INSERTION, traverseSession, 0,
248                &terminalDicNode, traverseSession->getMultiBigramMap());
249    }
250    Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0,
251            &terminalDicNode, traverseSession->getMultiBigramMap());
252    traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode);
253}
254
255/**
256 * Adds the expanded dicNode to the next search priority queue. Also creates an additional next word
257 * (by the space omission error correction) search path if input dicNode is on a terminal.
258 */
259void Suggest::processExpandedDicNode(
260        DicTraverseSession *traverseSession, DicNode *dicNode) const {
261    processTerminalDicNode(traverseSession, dicNode);
262    if (dicNode->getCompoundDistance() < static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) {
263        if (TRAVERSAL->isSpaceOmissionTerminal(traverseSession, dicNode)) {
264            createNextWordDicNode(traverseSession, dicNode, false /* spaceSubstitution */);
265        }
266        const int allowsLookAhead = !(dicNode->hasMultipleWords()
267                && dicNode->isCompletion(traverseSession->getInputSize()));
268        if (dicNode->hasChildren() && allowsLookAhead) {
269            traverseSession->getDicTraverseCache()->copyPushNextActive(dicNode);
270        }
271    }
272}
273
274void Suggest::processDicNodeAsMatch(DicTraverseSession *traverseSession,
275        DicNode *childDicNode) const {
276    weightChildNode(traverseSession, childDicNode);
277    processExpandedDicNode(traverseSession, childDicNode);
278}
279
280void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traverseSession,
281        DicNode *dicNode, DicNode *childDicNode) const {
282    // Note: Most types of corrections don't need to look up the bigram information since they do
283    // not treat the node as a terminal. There is no need to pass the bigram map in these cases.
284    Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_ADDITIONAL_PROXIMITY,
285            traverseSession, dicNode, childDicNode, 0 /* multiBigramMap */);
286    processExpandedDicNode(traverseSession, childDicNode);
287}
288
289void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession,
290        DicNode *dicNode, DicNode *childDicNode) const {
291    Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SUBSTITUTION, traverseSession,
292            dicNode, childDicNode, 0 /* multiBigramMap */);
293    processExpandedDicNode(traverseSession, childDicNode);
294}
295
296// Process the DicNode codepoint as a digraph. This means that composite glyphs like the German
297// u-umlaut is expanded to the transliteration "ue". Note that this happens in parallel with
298// the normal non-digraph traversal, so both "uber" and "ueber" can be corrected to "[u-umlaut]ber".
299void Suggest::processDicNodeAsDigraph(DicTraverseSession *traverseSession,
300        DicNode *childDicNode) const {
301    weightChildNode(traverseSession, childDicNode);
302    childDicNode->advanceDigraphIndex();
303    processExpandedDicNode(traverseSession, childDicNode);
304}
305
306/**
307 * Handle the dicNode as an omission error (e.g., ths => this). Skip the current letter and consider
308 * matches for all possible next letters. Note that just skipping the current letter without any
309 * other conditions tends to flood the search DicNodes cache with omission DicNodes. Instead, check
310 * the possible *next* letters after the omission to better limit search to plausible omissions.
311 * Note that apostrophes are handled as omissions.
312 */
313void Suggest::processDicNodeAsOmission(
314        DicTraverseSession *traverseSession, DicNode *dicNode) const {
315    DicNodeVector childDicNodes;
316    DicNodeUtils::getAllChildDicNodes(
317            dicNode, traverseSession->getDictionaryStructurePolicy(), &childDicNodes);
318
319    const int size = childDicNodes.getSizeAndLock();
320    for (int i = 0; i < size; i++) {
321        DicNode *const childDicNode = childDicNodes[i];
322        // Treat this word as omission
323        Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession,
324                dicNode, childDicNode, 0 /* multiBigramMap */);
325        weightChildNode(traverseSession, childDicNode);
326        if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) {
327            continue;
328        }
329        processExpandedDicNode(traverseSession, childDicNode);
330    }
331}
332
333/**
334 * Handle the dicNode as an insertion error (e.g., thiis => this). Skip the current touch point and
335 * consider matches for the next touch point.
336 */
337void Suggest::processDicNodeAsInsertion(DicTraverseSession *traverseSession,
338        DicNode *dicNode) const {
339    const int16_t pointIndex = dicNode->getInputIndex(0);
340    DicNodeVector childDicNodes;
341    DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getDictionaryStructurePolicy(),
342            &childDicNodes);
343    const int size = childDicNodes.getSizeAndLock();
344    for (int i = 0; i < size; i++) {
345        if (traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex + 1)
346                != childDicNodes[i]->getNodeCodePoint()) {
347            continue;
348        }
349        DicNode *const childDicNode = childDicNodes[i];
350        Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_INSERTION, traverseSession,
351                dicNode, childDicNode, 0 /* multiBigramMap */);
352        processExpandedDicNode(traverseSession, childDicNode);
353    }
354}
355
356/**
357 * Handle the dicNode as a transposition error (e.g., thsi => this). Swap the next two touch points.
358 */
359void Suggest::processDicNodeAsTransposition(DicTraverseSession *traverseSession,
360        DicNode *dicNode) const {
361    const int16_t pointIndex = dicNode->getInputIndex(0);
362    DicNodeVector childDicNodes1;
363    DicNodeVector childDicNodes2;
364    DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getDictionaryStructurePolicy(),
365            &childDicNodes1);
366    const int childSize1 = childDicNodes1.getSizeAndLock();
367    for (int i = 0; i < childSize1; i++) {
368        const ProximityType matchedId1 = traverseSession->getProximityInfoState(0)
369                ->getProximityType(pointIndex + 1, childDicNodes1[i]->getNodeCodePoint(),
370                        true /* checkProximityChars */);
371        if (!ProximityInfoUtils::isMatchOrProximityChar(matchedId1)) {
372            continue;
373        }
374        if (childDicNodes1[i]->hasChildren()) {
375            childDicNodes2.clear();
376            DicNodeUtils::getAllChildDicNodes(childDicNodes1[i],
377                    traverseSession->getDictionaryStructurePolicy(), &childDicNodes2);
378            const int childSize2 = childDicNodes2.getSizeAndLock();
379            for (int j = 0; j < childSize2; j++) {
380                DicNode *const childDicNode2 = childDicNodes2[j];
381                const ProximityType matchedId2 = traverseSession->getProximityInfoState(0)
382                        ->getProximityType(pointIndex, childDicNode2->getNodeCodePoint(),
383                                true /* checkProximityChars */);
384                if (!ProximityInfoUtils::isMatchOrProximityChar(matchedId2)) {
385                    continue;
386                }
387                Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TRANSPOSITION,
388                        traverseSession, childDicNodes1[i], childDicNode2, 0 /* multiBigramMap */);
389                processExpandedDicNode(traverseSession, childDicNode2);
390            }
391        }
392    }
393}
394
395/**
396 * Weight child dicNode by aligning it to the key
397 */
398void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicNode) const {
399    const int inputSize = traverseSession->getInputSize();
400    if (dicNode->isCompletion(inputSize)) {
401        Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_COMPLETION, traverseSession,
402                0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */);
403    } else {
404        Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_MATCH, traverseSession,
405                0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */);
406    }
407}
408
409/**
410 * Creates a new dicNode that represents a space insertion at the end of the input dicNode. Also
411 * incorporates the unigram / bigram score for the ending word into the new dicNode.
412 */
413void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode *dicNode,
414        const bool spaceSubstitution) const {
415    const WordAttributes wordAttributes =
416            traverseSession->getDictionaryStructurePolicy()->getWordAttributesInContext(
417                    dicNode->getPrevWordIds(), dicNode->getWordId(),
418                    traverseSession->getMultiBigramMap());
419    if (SuggestionsOutputUtils::shouldBlockWord(traverseSession->getSuggestOptions(),
420            dicNode, wordAttributes, false /* isLastWord */)) {
421        return;
422    }
423
424    if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode, wordAttributes.getProbability())) {
425        return;
426    }
427
428    // Create a non-cached node here.
429    DicNode newDicNode;
430    DicNodeUtils::initAsRootWithPreviousWord(
431            traverseSession->getDictionaryStructurePolicy(), dicNode, &newDicNode);
432    const CorrectionType correctionType = spaceSubstitution ?
433            CT_NEW_WORD_SPACE_SUBSTITUTION : CT_NEW_WORD_SPACE_OMISSION;
434    Weighting::addCostAndForwardInputIndex(WEIGHTING, correctionType, traverseSession, dicNode,
435            &newDicNode, traverseSession->getMultiBigramMap());
436    if (newDicNode.getCompoundDistance() < static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) {
437        // newDicNode is worth continuing to traverse.
438        // CAVEAT: This pruning is important for speed. Remove this when we can afford not to prune
439        // here because here is not the right place to do pruning. Pruning should take place only
440        // in DicNodePriorityQueue.
441        traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode);
442    }
443}
444} // namespace latinime
445