1/*
2 * Copyright (C) 2014, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "dictionary/utils/trie_map.h"
18
19#include "dictionary/utils/dict_file_writing_utils.h"
20
21namespace latinime {
22
23const int TrieMap::INVALID_INDEX = -1;
24const int TrieMap::FIELD0_SIZE = 4;
25const int TrieMap::FIELD1_SIZE = 3;
26const int TrieMap::ENTRY_SIZE = FIELD0_SIZE + FIELD1_SIZE;
27const uint32_t TrieMap::VALUE_FLAG = 0x400000;
28const uint32_t TrieMap::VALUE_MASK = 0x3FFFFF;
29const uint32_t TrieMap::INVALID_VALUE_IN_KEY_VALUE_ENTRY = VALUE_MASK;
30const uint32_t TrieMap::TERMINAL_LINK_FLAG = 0x800000;
31const uint32_t TrieMap::TERMINAL_LINK_MASK = 0x7FFFFF;
32const int TrieMap::NUM_OF_BITS_USED_FOR_ONE_LEVEL = 5;
33const uint32_t TrieMap::LABEL_MASK = 0x1F;
34const int TrieMap::MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL = 1 << NUM_OF_BITS_USED_FOR_ONE_LEVEL;
35const int TrieMap::ROOT_BITMAP_ENTRY_INDEX = 0;
36const int TrieMap::ROOT_BITMAP_ENTRY_POS = MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL * FIELD0_SIZE;
37const TrieMap::Entry TrieMap::EMPTY_BITMAP_ENTRY = TrieMap::Entry(0, 0);
38const int TrieMap::TERMINAL_LINKED_ENTRY_COUNT = 2; // Value entry and bitmap entry.
39const uint64_t TrieMap::MAX_VALUE =
40        (static_cast<uint64_t>(1) << ((FIELD0_SIZE + FIELD1_SIZE) * CHAR_BIT)) - 1;
41const int TrieMap::MAX_BUFFER_SIZE = TERMINAL_LINK_MASK * ENTRY_SIZE;
42
43TrieMap::TrieMap() : mBuffer(MAX_BUFFER_SIZE) {
44    mBuffer.extend(ROOT_BITMAP_ENTRY_POS);
45    writeEntry(EMPTY_BITMAP_ENTRY, ROOT_BITMAP_ENTRY_INDEX);
46}
47
48TrieMap::TrieMap(const ReadWriteByteArrayView buffer)
49        : mBuffer(buffer, BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE) {}
50
51void TrieMap::dump(const int from, const int to) const {
52    AKLOGI("BufSize: %d", mBuffer.getTailPosition());
53    for (int i = from; i < to; ++i) {
54        AKLOGI("Entry[%d]: %x, %x", i, readField0(i), readField1(i));
55    }
56    int unusedRegionSize = 0;
57    for (int i = 1; i <= MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL; ++i) {
58        int index = readEmptyTableLink(i);
59        while (index != ROOT_BITMAP_ENTRY_INDEX) {
60            index = readField0(index);
61            unusedRegionSize += i;
62        }
63    }
64    AKLOGI("Unused Size: %d", unusedRegionSize);
65}
66
67int TrieMap::getNextLevelBitmapEntryIndex(const int key, const int bitmapEntryIndex) {
68    const Entry bitmapEntry = readEntry(bitmapEntryIndex);
69    const uint32_t unsignedKey = static_cast<uint32_t>(key);
70    const int terminalEntryIndex = getTerminalEntryIndex(
71            unsignedKey, getBitShuffledKey(unsignedKey), bitmapEntry, 0 /* level */);
72    if (terminalEntryIndex == INVALID_INDEX) {
73        // Not found.
74        return INVALID_INDEX;
75    }
76    const Entry terminalEntry = readEntry(terminalEntryIndex);
77    if (terminalEntry.hasTerminalLink()) {
78        return terminalEntry.getValueEntryIndex() + 1;
79    }
80    // Create a value entry and a bitmap entry.
81    const int valueEntryIndex = allocateTable(TERMINAL_LINKED_ENTRY_COUNT);
82    if (valueEntryIndex == INVALID_INDEX) {
83        return INVALID_INDEX;
84    }
85    if (!writeEntry(Entry(0, terminalEntry.getValue()), valueEntryIndex)) {
86        return INVALID_INDEX;
87    }
88    if (!writeEntry(EMPTY_BITMAP_ENTRY, valueEntryIndex + 1)) {
89        return INVALID_INDEX;
90    }
91    if (!writeField1(valueEntryIndex | TERMINAL_LINK_FLAG, terminalEntryIndex)) {
92        return INVALID_INDEX;
93    }
94    return valueEntryIndex + 1;
95}
96
97const TrieMap::Result TrieMap::get(const int key, const int bitmapEntryIndex) const {
98    const uint32_t unsignedKey = static_cast<uint32_t>(key);
99    return getInternal(unsignedKey, getBitShuffledKey(unsignedKey), bitmapEntryIndex,
100            0 /* level */);
101}
102
103bool TrieMap::put(const int key, const uint64_t value, const int bitmapEntryIndex) {
104    if (value > MAX_VALUE) {
105        return false;
106    }
107    const uint32_t unsignedKey = static_cast<uint32_t>(key);
108    return putInternal(unsignedKey, value, getBitShuffledKey(unsignedKey), bitmapEntryIndex,
109            readEntry(bitmapEntryIndex), 0 /* level */);
110}
111
112bool TrieMap::save(FILE *const file) const {
113    return DictFileWritingUtils::writeBufferToFileTail(file, &mBuffer);
114}
115
116bool TrieMap::remove(const int key, const int bitmapEntryIndex) {
117    const Entry bitmapEntry = readEntry(bitmapEntryIndex);
118    const uint32_t unsignedKey = static_cast<uint32_t>(key);
119    const int terminalEntryIndex = getTerminalEntryIndex(
120            unsignedKey, getBitShuffledKey(unsignedKey), bitmapEntry, 0 /* level */);
121    if (terminalEntryIndex == INVALID_INDEX) {
122        // Not found.
123        return false;
124    }
125    const Entry terminalEntry = readEntry(terminalEntryIndex);
126    if (!writeField1(VALUE_FLAG ^ INVALID_VALUE_IN_KEY_VALUE_ENTRY , terminalEntryIndex)) {
127        return false;
128    }
129    if (terminalEntry.hasTerminalLink()) {
130        const Entry nextLevelBitmapEntry = readEntry(terminalEntry.getValueEntryIndex() + 1);
131        if (!freeTable(terminalEntry.getValueEntryIndex(), TERMINAL_LINKED_ENTRY_COUNT)) {
132            return false;
133        }
134        if (!removeInner(nextLevelBitmapEntry)){
135            return false;
136        }
137    }
138    return true;
139}
140
141/**
142 * Iterate next entry in a certain level.
143 *
144 * @param iterationState the iteration state that will be read and updated in this method.
145 * @param outKey the output key
146 * @return Result instance. mIsValid is false when all entries are iterated.
147 */
148const TrieMap::Result TrieMap::iterateNext(std::vector<TableIterationState> *const iterationState,
149        int *const outKey) const {
150    while (!iterationState->empty()) {
151        TableIterationState &state = iterationState->back();
152        if (state.mTableSize <= state.mCurrentIndex) {
153            // Move to parent.
154            iterationState->pop_back();
155        } else {
156            const int entryIndex = state.mTableIndex + state.mCurrentIndex;
157            state.mCurrentIndex += 1;
158            const Entry entry = readEntry(entryIndex);
159            if (entry.isBitmapEntry()) {
160                // Move to child.
161                iterationState->emplace_back(popCount(entry.getBitmap()), entry.getTableIndex());
162            } else if (entry.isValidTerminalEntry()) {
163                if (outKey) {
164                    *outKey = entry.getKey();
165                }
166                if (!entry.hasTerminalLink()) {
167                    return Result(entry.getValue(), true, INVALID_INDEX);
168                }
169                const int valueEntryIndex = entry.getValueEntryIndex();
170                const Entry valueEntry = readEntry(valueEntryIndex);
171                return Result(valueEntry.getValueOfValueEntry(), true, valueEntryIndex + 1);
172            }
173        }
174    }
175    // Visited all entries.
176    return Result(0, false, INVALID_INDEX);
177}
178
179/**
180 * Shuffle bits of the key in the fixed order.
181 *
182 * This method is used as a hash function. This returns different values for different inputs.
183 */
184uint32_t TrieMap::getBitShuffledKey(const uint32_t key) const {
185    uint32_t shuffledKey = 0;
186    for (int i = 0; i < 4; ++i) {
187        const uint32_t keyPiece = (key >> (i * 8)) & 0xFF;
188        shuffledKey ^= ((keyPiece ^ (keyPiece << 7) ^ (keyPiece << 14) ^ (keyPiece << 21))
189                & 0x11111111) << i;
190    }
191    return shuffledKey;
192}
193
194bool TrieMap::writeValue(const uint64_t value, const int terminalEntryIndex) {
195    if (value < VALUE_MASK) {
196        // Write value into the terminal entry.
197        return writeField1(value | VALUE_FLAG, terminalEntryIndex);
198    }
199    // Create value entry and write value.
200    const int valueEntryIndex = allocateTable(TERMINAL_LINKED_ENTRY_COUNT);
201    if (valueEntryIndex == INVALID_INDEX) {
202        return false;
203    }
204    if (!writeEntry(Entry(value >> (FIELD1_SIZE * CHAR_BIT), value), valueEntryIndex)) {
205        return false;
206    }
207    if (!writeEntry(EMPTY_BITMAP_ENTRY, valueEntryIndex + 1)) {
208        return false;
209    }
210    return writeField1(valueEntryIndex | TERMINAL_LINK_FLAG, terminalEntryIndex);
211}
212
213bool TrieMap::updateValue(const Entry &terminalEntry, const uint64_t value,
214        const int terminalEntryIndex) {
215    if (!terminalEntry.hasTerminalLink()) {
216        return writeValue(value, terminalEntryIndex);
217    }
218    const int valueEntryIndex = terminalEntry.getValueEntryIndex();
219    return writeEntry(Entry(value >> (FIELD1_SIZE * CHAR_BIT), value), valueEntryIndex);
220}
221
222bool TrieMap::freeTable(const int tableIndex, const int entryCount) {
223    if (!writeField0(readEmptyTableLink(entryCount), tableIndex)) {
224        return false;
225    }
226    return writeEmptyTableLink(tableIndex, entryCount);
227}
228
229/**
230 * Allocate table with entryCount-entries. Reuse freed table if possible.
231 */
232int TrieMap::allocateTable(const int entryCount) {
233    if (entryCount > 0 && entryCount <= MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL) {
234        const int tableIndex = readEmptyTableLink(entryCount);
235        if (tableIndex > 0) {
236            if (!writeEmptyTableLink(readField0(tableIndex), entryCount)) {
237                return INVALID_INDEX;
238            }
239            // Reuse the table.
240            return tableIndex;
241        }
242    }
243    // Allocate memory space at tail position of the buffer.
244    const int mapIndex = getTailEntryIndex();
245    if (!mBuffer.extend(entryCount * ENTRY_SIZE)) {
246        return INVALID_INDEX;
247    }
248    return mapIndex;
249}
250
251int TrieMap::getTerminalEntryIndex(const uint32_t key, const uint32_t hashedKey,
252        const Entry &bitmapEntry, const int level) const {
253    const int label = getLabel(hashedKey, level);
254    if (!exists(bitmapEntry.getBitmap(), label)) {
255        return INVALID_INDEX;
256    }
257    const int entryIndex = bitmapEntry.getTableIndex() + popCount(bitmapEntry.getBitmap(), label);
258    const Entry entry = readEntry(entryIndex);
259    if (entry.isBitmapEntry()) {
260        // Move to the next level.
261        return getTerminalEntryIndex(key, hashedKey, entry, level + 1);
262    }
263    if (!entry.isValidTerminalEntry()) {
264        return INVALID_INDEX;
265    }
266    if (entry.getKey() == key) {
267        // Terminal entry is found.
268        return entryIndex;
269    }
270    return INVALID_INDEX;
271}
272
273/**
274 * Get Result corresponding to the key.
275 *
276 * @param key the key.
277 * @param hashedKey the hashed key.
278 * @param bitmapEntryIndex the index of bitmap entry
279 * @param level current level
280 * @return Result instance corresponding to the key. mIsValid indicates whether the key is in the
281 * map.
282 */
283const TrieMap::Result TrieMap::getInternal(const uint32_t key, const uint32_t hashedKey,
284        const int bitmapEntryIndex, const int level) const {
285    const int terminalEntryIndex = getTerminalEntryIndex(key, hashedKey,
286            readEntry(bitmapEntryIndex), level);
287    if (terminalEntryIndex == INVALID_INDEX) {
288        // Not found.
289        return Result(0, false, INVALID_INDEX);
290    }
291    const Entry terminalEntry = readEntry(terminalEntryIndex);
292    if (!terminalEntry.hasTerminalLink()) {
293        return Result(terminalEntry.getValue(), true, INVALID_INDEX);
294    }
295    const int valueEntryIndex = terminalEntry.getValueEntryIndex();
296    const Entry valueEntry = readEntry(valueEntryIndex);
297    return Result(valueEntry.getValueOfValueEntry(), true, valueEntryIndex + 1);
298}
299
300/**
301 * Put key to value mapping to the map.
302 *
303 * @param key the key.
304 * @param value the value
305 * @param hashedKey the hashed key.
306 * @param bitmapEntryIndex the index of bitmap entry
307 * @param bitmapEntry the bitmap entry
308 * @param level current level
309 * @return whether the key-value has been correctly inserted to the map or not.
310 */
311bool TrieMap::putInternal(const uint32_t key, const uint64_t value, const uint32_t hashedKey,
312        const int bitmapEntryIndex, const Entry &bitmapEntry, const int level) {
313    const int label = getLabel(hashedKey, level);
314    const uint32_t bitmap = bitmapEntry.getBitmap();
315    const int mapIndex = bitmapEntry.getTableIndex();
316    if (!exists(bitmap, label)) {
317        // Current map doesn't contain the label.
318        return addNewEntryByExpandingTable(key, value, mapIndex, bitmap, bitmapEntryIndex, label);
319    }
320    const int entryIndex = mapIndex + popCount(bitmap, label);
321    const Entry entry = readEntry(entryIndex);
322    if (entry.isBitmapEntry()) {
323        // Bitmap entry is found. Go to the next level.
324        return putInternal(key, value, hashedKey, entryIndex, entry, level + 1);
325    }
326    if (!entry.isValidTerminalEntry()) {
327        // Overwrite invalid terminal entry.
328        return writeTerminalEntry(key, value, entryIndex);
329    }
330    if (entry.getKey() == key) {
331        // Terminal entry for the key is found. Update the value.
332        return updateValue(entry, value, entryIndex);
333    }
334    // Conflict with the existing key.
335    return addNewEntryByResolvingConflict(key, value, hashedKey, entry, entryIndex, level);
336}
337
338/**
339 * Resolve a conflict in the current level and add new entry.
340 *
341 * @param key the key
342 * @param value the value
343 * @param hashedKey the hashed key
344 * @param conflictedEntry the existing conflicted entry
345 * @param conflictedEntryIndex the index of existing conflicted entry
346 * @param level current level
347 * @return whether the key-value has been correctly inserted to the map or not.
348 */
349bool TrieMap::addNewEntryByResolvingConflict(const uint32_t key, const uint64_t value,
350        const uint32_t hashedKey, const Entry &conflictedEntry, const int conflictedEntryIndex,
351        const int level) {
352    const int conflictedKeyNextLabel =
353            getLabel(getBitShuffledKey(conflictedEntry.getKey()), level + 1);
354    const int nextLabel = getLabel(hashedKey, level + 1);
355    if (conflictedKeyNextLabel == nextLabel) {
356        // Conflicted again in the next level.
357        const int newTableIndex = allocateTable(1 /* entryCount */);
358        if (newTableIndex == INVALID_INDEX) {
359            return false;
360        }
361        if (!writeEntry(conflictedEntry, newTableIndex)) {
362            return false;
363        }
364        const Entry newBitmapEntry(setExist(0 /* bitmap */, nextLabel), newTableIndex);
365        if (!writeEntry(newBitmapEntry, conflictedEntryIndex)) {
366            return false;
367        }
368        return putInternal(key, value, hashedKey, conflictedEntryIndex, newBitmapEntry, level + 1);
369    }
370    // The conflict has been resolved. Create a table that contains 2 entries.
371    const int newTableIndex = allocateTable(2 /* entryCount */);
372    if (newTableIndex == INVALID_INDEX) {
373        return false;
374    }
375    if (nextLabel < conflictedKeyNextLabel) {
376        if (!writeTerminalEntry(key, value, newTableIndex)) {
377            return false;
378        }
379        if (!writeEntry(conflictedEntry, newTableIndex + 1)) {
380            return false;
381        }
382    } else { // nextLabel > conflictedKeyNextLabel
383        if (!writeEntry(conflictedEntry, newTableIndex)) {
384            return false;
385        }
386        if (!writeTerminalEntry(key, value, newTableIndex + 1)) {
387            return false;
388        }
389    }
390    const uint32_t updatedBitmap =
391            setExist(setExist(0 /* bitmap */, nextLabel), conflictedKeyNextLabel);
392    return writeEntry(Entry(updatedBitmap, newTableIndex), conflictedEntryIndex);
393}
394
395/**
396 * Add new entry to the existing table.
397 */
398bool TrieMap::addNewEntryByExpandingTable(const uint32_t key, const uint64_t value,
399        const int tableIndex, const uint32_t bitmap, const int bitmapEntryIndex, const int label) {
400    // Current map doesn't contain the label.
401    const int entryCount = popCount(bitmap);
402    const int newTableIndex = allocateTable(entryCount + 1);
403    if (newTableIndex == INVALID_INDEX) {
404        return false;
405    }
406    const int newEntryIndexInTable = popCount(bitmap, label);
407    // Copy from existing table to the new table.
408    for (int i = 0; i < entryCount; ++i) {
409        if (!copyEntry(tableIndex + i, newTableIndex + i + (i >= newEntryIndexInTable ? 1 : 0))) {
410            return false;
411        }
412    }
413    // Write new terminal entry.
414    if (!writeTerminalEntry(key, value, newTableIndex + newEntryIndexInTable)) {
415        return false;
416    }
417    // Update bitmap.
418    if (!writeEntry(Entry(setExist(bitmap, label), newTableIndex), bitmapEntryIndex)) {
419        return false;
420    }
421    if (entryCount > 0) {
422        return freeTable(tableIndex, entryCount);
423    }
424    return true;
425}
426
427bool TrieMap::removeInner(const Entry &bitmapEntry) {
428    const int tableSize = popCount(bitmapEntry.getBitmap());
429    if (tableSize <= 0) {
430        // The table is empty. No need to remove any entries.
431        return true;
432    }
433    for (int i = 0; i < tableSize; ++i) {
434        const int entryIndex = bitmapEntry.getTableIndex() + i;
435        const Entry entry = readEntry(entryIndex);
436        if (entry.isBitmapEntry()) {
437            // Delete next bitmap entry recursively.
438            if (!removeInner(entry)) {
439                return false;
440            }
441        } else {
442            // Invalidate terminal entry just in case.
443            if (!writeField1(VALUE_FLAG ^ INVALID_VALUE_IN_KEY_VALUE_ENTRY , entryIndex)) {
444                return false;
445            }
446            if (entry.hasTerminalLink()) {
447                const Entry nextLevelBitmapEntry = readEntry(entry.getValueEntryIndex() + 1);
448                if (!freeTable(entry.getValueEntryIndex(), TERMINAL_LINKED_ENTRY_COUNT)) {
449                    return false;
450                }
451                if (!removeInner(nextLevelBitmapEntry)) {
452                    return false;
453                }
454            }
455        }
456    }
457    return true;
458}
459
460}  // namespace latinime
461