1/*
2 * LZMAEncoder
3 *
4 * Authors: Lasse Collin <lasse.collin@tukaani.org>
5 *          Igor Pavlov <http://7-zip.org/>
6 *
7 * This file has been put into the public domain.
8 * You can do whatever you want with this file.
9 */
10
11package org.tukaani.xz.lzma;
12
13import org.tukaani.xz.lz.LZEncoder;
14import org.tukaani.xz.lz.Matches;
15import org.tukaani.xz.rangecoder.RangeEncoder;
16
17public abstract class LZMAEncoder extends LZMACoder {
18    public static final int MODE_FAST = 1;
19    public static final int MODE_NORMAL = 2;
20
21    /**
22     * LZMA2 chunk is considered full when its uncompressed size exceeds
23     * <code>LZMA2_UNCOMPRESSED_LIMIT</code>.
24     * <p>
25     * A compressed LZMA2 chunk can hold 2 MiB of uncompressed data.
26     * A single LZMA symbol may indicate up to MATCH_LEN_MAX bytes
27     * of data, so the LZMA2 chunk is considered full when there is
28     * less space than MATCH_LEN_MAX bytes.
29     */
30    private static final int LZMA2_UNCOMPRESSED_LIMIT
31            = (2 << 20) - MATCH_LEN_MAX;
32
33    /**
34     * LZMA2 chunk is considered full when its compressed size exceeds
35     * <code>LZMA2_COMPRESSED_LIMIT</code>.
36     * <p>
37     * The maximum compressed size of a LZMA2 chunk is 64 KiB.
38     * A single LZMA symbol might use 20 bytes of space even though
39     * it usually takes just one byte or so. Two more bytes are needed
40     * for LZMA2 uncompressed chunks (see LZMA2OutputStream.writeChunk).
41     * Leave a little safety margin and use 26 bytes.
42     */
43    private static final int LZMA2_COMPRESSED_LIMIT = (64 << 10) - 26;
44
45    private static final int DIST_PRICE_UPDATE_INTERVAL = FULL_DISTANCES;
46    private static final int ALIGN_PRICE_UPDATE_INTERVAL = ALIGN_SIZE;
47
48    private final RangeEncoder rc;
49    final LZEncoder lz;
50    final LiteralEncoder literalEncoder;
51    final LengthEncoder matchLenEncoder;
52    final LengthEncoder repLenEncoder;
53    final int niceLen;
54
55    private int distPriceCount = 0;
56    private int alignPriceCount = 0;
57
58    private final int distSlotPricesSize;
59    private final int[][] distSlotPrices;
60    private final int[][] fullDistPrices
61            = new int[DIST_STATES][FULL_DISTANCES];
62    private final int[] alignPrices = new int[ALIGN_SIZE];
63
64    int back = 0;
65    int readAhead = -1;
66    private int uncompressedSize = 0;
67
68    public static int getMemoryUsage(int mode, int dictSize,
69                                     int extraSizeBefore, int mf) {
70        int m = 80;
71
72        switch (mode) {
73            case MODE_FAST:
74                m += LZMAEncoderFast.getMemoryUsage(
75                        dictSize, extraSizeBefore, mf);
76                break;
77
78            case MODE_NORMAL:
79                m += LZMAEncoderNormal.getMemoryUsage(
80                        dictSize, extraSizeBefore, mf);
81                break;
82
83            default:
84                throw new IllegalArgumentException();
85        }
86
87        return m;
88    }
89
90    public static LZMAEncoder getInstance(
91                RangeEncoder rc, int lc, int lp, int pb, int mode,
92                int dictSize, int extraSizeBefore,
93                int niceLen, int mf, int depthLimit) {
94        switch (mode) {
95            case MODE_FAST:
96                return new LZMAEncoderFast(rc, lc, lp, pb,
97                                           dictSize, extraSizeBefore,
98                                           niceLen, mf, depthLimit);
99
100            case MODE_NORMAL:
101                return new LZMAEncoderNormal(rc, lc, lp, pb,
102                                             dictSize, extraSizeBefore,
103                                             niceLen, mf, depthLimit);
104        }
105
106        throw new IllegalArgumentException();
107    }
108
109    /**
110     * Gets an integer [0, 63] matching the highest two bits of an integer.
111     * This is like bit scan reverse (BSR) on x86 except that this also
112     * cares about the second highest bit.
113     */
114    public static int getDistSlot(int dist) {
115        if (dist <= DIST_MODEL_START)
116            return dist;
117
118        int n = dist;
119        int i = 31;
120
121        if ((n & 0xFFFF0000) == 0) {
122            n <<= 16;
123            i = 15;
124        }
125
126        if ((n & 0xFF000000) == 0) {
127            n <<= 8;
128            i -= 8;
129        }
130
131        if ((n & 0xF0000000) == 0) {
132            n <<= 4;
133            i -= 4;
134        }
135
136        if ((n & 0xC0000000) == 0) {
137            n <<= 2;
138            i -= 2;
139        }
140
141        if ((n & 0x80000000) == 0)
142            --i;
143
144        return (i << 1) + ((dist >>> (i - 1)) & 1);
145    }
146
147    /**
148     * Gets the next LZMA symbol.
149     * <p>
150     * There are three types of symbols: literal (a single byte),
151     * repeated match, and normal match. The symbol is indicated
152     * by the return value and by the variable <code>back</code>.
153     * <p>
154     * Literal: <code>back == -1</code> and return value is <code>1</code>.
155     * The literal itself needs to be read from <code>lz</code> separately.
156     * <p>
157     * Repeated match: <code>back</code> is in the range [0, 3] and
158     * the return value is the length of the repeated match.
159     * <p>
160     * Normal match: <code>back - REPS<code> (<code>back - 4</code>)
161     * is the distance of the match and the return value is the length
162     * of the match.
163     */
164    abstract int getNextSymbol();
165
166    LZMAEncoder(RangeEncoder rc, LZEncoder lz,
167                int lc, int lp, int pb, int dictSize, int niceLen) {
168        super(pb);
169        this.rc = rc;
170        this.lz = lz;
171        this.niceLen = niceLen;
172
173        literalEncoder = new LiteralEncoder(lc, lp);
174        matchLenEncoder = new LengthEncoder(pb, niceLen);
175        repLenEncoder = new LengthEncoder(pb, niceLen);
176
177        distSlotPricesSize = getDistSlot(dictSize - 1) + 1;
178        distSlotPrices = new int[DIST_STATES][distSlotPricesSize];
179
180        reset();
181    }
182
183    public LZEncoder getLZEncoder() {
184        return lz;
185    }
186
187    public void reset() {
188        super.reset();
189        literalEncoder.reset();
190        matchLenEncoder.reset();
191        repLenEncoder.reset();
192        distPriceCount = 0;
193        alignPriceCount = 0;
194
195        uncompressedSize += readAhead + 1;
196        readAhead = -1;
197    }
198
199    public int getUncompressedSize() {
200        return uncompressedSize;
201    }
202
203    public void resetUncompressedSize() {
204        uncompressedSize = 0;
205    }
206
207    /**
208     * Compresses for LZMA2.
209     *
210     * @return      true if the LZMA2 chunk became full, false otherwise
211     */
212    public boolean encodeForLZMA2() {
213        if (!lz.isStarted() && !encodeInit())
214            return false;
215
216        while (uncompressedSize <= LZMA2_UNCOMPRESSED_LIMIT
217                && rc.getPendingSize() <= LZMA2_COMPRESSED_LIMIT)
218            if (!encodeSymbol())
219                return false;
220
221        return true;
222    }
223
224    private boolean encodeInit() {
225        assert readAhead == -1;
226        if (!lz.hasEnoughData(0))
227            return false;
228
229        // The first symbol must be a literal unless using
230        // a preset dictionary. This code isn't run if using
231        // a preset dictionary.
232        skip(1);
233        rc.encodeBit(isMatch[state.get()], 0, 0);
234        literalEncoder.encodeInit();
235
236        --readAhead;
237        assert readAhead == -1;
238
239        ++uncompressedSize;
240        assert uncompressedSize == 1;
241
242        return true;
243    }
244
245    private boolean encodeSymbol() {
246        if (!lz.hasEnoughData(readAhead + 1))
247            return false;
248
249        int len = getNextSymbol();
250
251        assert readAhead >= 0;
252        int posState = (lz.getPos() - readAhead) & posMask;
253
254        if (back == -1) {
255            // Literal i.e. eight-bit byte
256            assert len == 1;
257            rc.encodeBit(isMatch[state.get()], posState, 0);
258            literalEncoder.encode();
259        } else {
260            // Some type of match
261            rc.encodeBit(isMatch[state.get()], posState, 1);
262            if (back < REPS) {
263                // Repeated match i.e. the same distance
264                // has been used earlier.
265                assert lz.getMatchLen(-readAhead, reps[back], len) == len;
266                rc.encodeBit(isRep, state.get(), 1);
267                encodeRepMatch(back, len, posState);
268            } else {
269                // Normal match
270                assert lz.getMatchLen(-readAhead, back - REPS, len) == len;
271                rc.encodeBit(isRep, state.get(), 0);
272                encodeMatch(back - REPS, len, posState);
273            }
274        }
275
276        readAhead -= len;
277        uncompressedSize += len;
278
279        return true;
280    }
281
282    private void encodeMatch(int dist, int len, int posState) {
283        state.updateMatch();
284        matchLenEncoder.encode(len, posState);
285
286        int distSlot = getDistSlot(dist);
287        rc.encodeBitTree(distSlots[getDistState(len)], distSlot);
288
289        if (distSlot >= DIST_MODEL_START) {
290            int footerBits = (distSlot >>> 1) - 1;
291            int base = (2 | (distSlot & 1)) << footerBits;
292            int distReduced = dist - base;
293
294            if (distSlot < DIST_MODEL_END) {
295                rc.encodeReverseBitTree(
296                        distSpecial[distSlot - DIST_MODEL_START],
297                        distReduced);
298            } else {
299                rc.encodeDirectBits(distReduced >>> ALIGN_BITS,
300                                    footerBits - ALIGN_BITS);
301                rc.encodeReverseBitTree(distAlign, distReduced & ALIGN_MASK);
302                --alignPriceCount;
303            }
304        }
305
306        reps[3] = reps[2];
307        reps[2] = reps[1];
308        reps[1] = reps[0];
309        reps[0] = dist;
310
311        --distPriceCount;
312    }
313
314    private void encodeRepMatch(int rep, int len, int posState) {
315        if (rep == 0) {
316            rc.encodeBit(isRep0, state.get(), 0);
317            rc.encodeBit(isRep0Long[state.get()], posState, len == 1 ? 0 : 1);
318        } else {
319            int dist = reps[rep];
320            rc.encodeBit(isRep0, state.get(), 1);
321
322            if (rep == 1) {
323                rc.encodeBit(isRep1, state.get(), 0);
324            } else {
325                rc.encodeBit(isRep1, state.get(), 1);
326                rc.encodeBit(isRep2, state.get(), rep - 2);
327
328                if (rep == 3)
329                    reps[3] = reps[2];
330
331                reps[2] = reps[1];
332            }
333
334            reps[1] = reps[0];
335            reps[0] = dist;
336        }
337
338        if (len == 1) {
339            state.updateShortRep();
340        } else {
341            repLenEncoder.encode(len, posState);
342            state.updateLongRep();
343        }
344    }
345
346    Matches getMatches() {
347        ++readAhead;
348        Matches matches = lz.getMatches();
349        assert lz.verifyMatches(matches);
350        return matches;
351    }
352
353    void skip(int len) {
354        readAhead += len;
355        lz.skip(len);
356    }
357
358    int getAnyMatchPrice(State state, int posState) {
359        return RangeEncoder.getBitPrice(isMatch[state.get()][posState], 1);
360    }
361
362    int getNormalMatchPrice(int anyMatchPrice, State state) {
363        return anyMatchPrice
364               + RangeEncoder.getBitPrice(isRep[state.get()], 0);
365    }
366
367    int getAnyRepPrice(int anyMatchPrice, State state) {
368        return anyMatchPrice
369               + RangeEncoder.getBitPrice(isRep[state.get()], 1);
370    }
371
372    int getShortRepPrice(int anyRepPrice, State state, int posState) {
373        return anyRepPrice
374               + RangeEncoder.getBitPrice(isRep0[state.get()], 0)
375               + RangeEncoder.getBitPrice(isRep0Long[state.get()][posState],
376                                          0);
377    }
378
379    int getLongRepPrice(int anyRepPrice, int rep, State state, int posState) {
380        int price = anyRepPrice;
381
382        if (rep == 0) {
383            price += RangeEncoder.getBitPrice(isRep0[state.get()], 0)
384                     + RangeEncoder.getBitPrice(
385                       isRep0Long[state.get()][posState], 1);
386        } else {
387            price += RangeEncoder.getBitPrice(isRep0[state.get()], 1);
388
389            if (rep == 1)
390                price += RangeEncoder.getBitPrice(isRep1[state.get()], 0);
391            else
392                price += RangeEncoder.getBitPrice(isRep1[state.get()], 1)
393                         + RangeEncoder.getBitPrice(isRep2[state.get()],
394                                                    rep - 2);
395        }
396
397        return price;
398    }
399
400    int getLongRepAndLenPrice(int rep, int len, State state, int posState) {
401        int anyMatchPrice = getAnyMatchPrice(state, posState);
402        int anyRepPrice = getAnyRepPrice(anyMatchPrice, state);
403        int longRepPrice = getLongRepPrice(anyRepPrice, rep, state, posState);
404        return longRepPrice + repLenEncoder.getPrice(len, posState);
405    }
406
407    int getMatchAndLenPrice(int normalMatchPrice,
408                            int dist, int len, int posState) {
409        int price = normalMatchPrice
410                    + matchLenEncoder.getPrice(len, posState);
411        int distState = getDistState(len);
412
413        if (dist < FULL_DISTANCES) {
414            price += fullDistPrices[distState][dist];
415        } else {
416            // Note that distSlotPrices includes also
417            // the price of direct bits.
418            int distSlot = getDistSlot(dist);
419            price += distSlotPrices[distState][distSlot]
420                     + alignPrices[dist & ALIGN_MASK];
421        }
422
423        return price;
424    }
425
426    private void updateDistPrices() {
427        distPriceCount = DIST_PRICE_UPDATE_INTERVAL;
428
429        for (int distState = 0; distState < DIST_STATES; ++distState) {
430            for (int distSlot = 0; distSlot < distSlotPricesSize; ++distSlot)
431                distSlotPrices[distState][distSlot]
432                        = RangeEncoder.getBitTreePrice(
433                          distSlots[distState], distSlot);
434
435            for (int distSlot = DIST_MODEL_END; distSlot < distSlotPricesSize;
436                    ++distSlot) {
437                int count = (distSlot >>> 1) - 1 - ALIGN_BITS;
438                distSlotPrices[distState][distSlot]
439                        += RangeEncoder.getDirectBitsPrice(count);
440            }
441
442            for (int dist = 0; dist < DIST_MODEL_START; ++dist)
443                fullDistPrices[distState][dist]
444                        = distSlotPrices[distState][dist];
445        }
446
447        int dist = DIST_MODEL_START;
448        for (int distSlot = DIST_MODEL_START; distSlot < DIST_MODEL_END;
449                ++distSlot) {
450            int footerBits = (distSlot >>> 1) - 1;
451            int base = (2 | (distSlot & 1)) << footerBits;
452
453            int limit = distSpecial[distSlot - DIST_MODEL_START].length;
454            for (int i = 0; i < limit; ++i) {
455                int distReduced = dist - base;
456                int price = RangeEncoder.getReverseBitTreePrice(
457                        distSpecial[distSlot - DIST_MODEL_START],
458                        distReduced);
459
460                for (int distState = 0; distState < DIST_STATES; ++distState)
461                    fullDistPrices[distState][dist]
462                            = distSlotPrices[distState][distSlot] + price;
463
464                ++dist;
465            }
466        }
467
468        assert dist == FULL_DISTANCES;
469    }
470
471    private void updateAlignPrices() {
472        alignPriceCount = ALIGN_PRICE_UPDATE_INTERVAL;
473
474        for (int i = 0; i < ALIGN_SIZE; ++i)
475            alignPrices[i] = RangeEncoder.getReverseBitTreePrice(distAlign,
476                                                                 i);
477    }
478
479    /**
480     * Updates the lookup tables used for calculating match distance
481     * and length prices. The updating is skipped for performance reasons
482     * if the tables haven't changed much since the previous update.
483     */
484    void updatePrices() {
485        if (distPriceCount <= 0)
486            updateDistPrices();
487
488        if (alignPriceCount <= 0)
489            updateAlignPrices();
490
491        matchLenEncoder.updatePrices();
492        repLenEncoder.updatePrices();
493    }
494
495
496    class LiteralEncoder extends LiteralCoder {
497        private final LiteralSubencoder[] subencoders;
498
499        LiteralEncoder(int lc, int lp) {
500            super(lc, lp);
501
502            subencoders = new LiteralSubencoder[1 << (lc + lp)];
503            for (int i = 0; i < subencoders.length; ++i)
504                subencoders[i] = new LiteralSubencoder();
505        }
506
507        void reset() {
508            for (int i = 0; i < subencoders.length; ++i)
509                subencoders[i].reset();
510        }
511
512        void encodeInit() {
513            // When encoding the first byte of the stream, there is
514            // no previous byte in the dictionary so the encode function
515            // wouldn't work.
516            assert readAhead >= 0;
517            subencoders[0].encode();
518        }
519
520        void encode() {
521            assert readAhead >= 0;
522            int i = getSubcoderIndex(lz.getByte(1 + readAhead),
523                                     lz.getPos() - readAhead);
524            subencoders[i].encode();
525        }
526
527        int getPrice(int curByte, int matchByte,
528                     int prevByte, int pos, State state) {
529            int price = RangeEncoder.getBitPrice(
530                    isMatch[state.get()][pos & posMask], 0);
531
532            int i = getSubcoderIndex(prevByte, pos);
533            price += state.isLiteral()
534                   ? subencoders[i].getNormalPrice(curByte)
535                   : subencoders[i].getMatchedPrice(curByte, matchByte);
536
537            return price;
538        }
539
540        private class LiteralSubencoder extends LiteralSubcoder {
541            void encode() {
542                int symbol = lz.getByte(readAhead) | 0x100;
543
544                if (state.isLiteral()) {
545                    int subencoderIndex;
546                    int bit;
547
548                    do {
549                        subencoderIndex = symbol >>> 8;
550                        bit = (symbol >>> 7) & 1;
551                        rc.encodeBit(probs, subencoderIndex, bit);
552                        symbol <<= 1;
553                    } while (symbol < 0x10000);
554
555                } else {
556                    int matchByte = lz.getByte(reps[0] + 1 + readAhead);
557                    int offset = 0x100;
558                    int subencoderIndex;
559                    int matchBit;
560                    int bit;
561
562                    do {
563                        matchByte <<= 1;
564                        matchBit = matchByte & offset;
565                        subencoderIndex = offset + matchBit + (symbol >>> 8);
566                        bit = (symbol >>> 7) & 1;
567                        rc.encodeBit(probs, subencoderIndex, bit);
568                        symbol <<= 1;
569                        offset &= ~(matchByte ^ symbol);
570                    } while (symbol < 0x10000);
571                }
572
573                state.updateLiteral();
574            }
575
576            int getNormalPrice(int symbol) {
577                int price = 0;
578                int subencoderIndex;
579                int bit;
580
581                symbol |= 0x100;
582
583                do {
584                    subencoderIndex = symbol >>> 8;
585                    bit = (symbol >>> 7) & 1;
586                    price += RangeEncoder.getBitPrice(probs[subencoderIndex],
587                                                      bit);
588                    symbol <<= 1;
589                } while (symbol < (0x100 << 8));
590
591                return price;
592            }
593
594            int getMatchedPrice(int symbol, int matchByte) {
595                int price = 0;
596                int offset = 0x100;
597                int subencoderIndex;
598                int matchBit;
599                int bit;
600
601                symbol |= 0x100;
602
603                do {
604                    matchByte <<= 1;
605                    matchBit = matchByte & offset;
606                    subencoderIndex = offset + matchBit + (symbol >>> 8);
607                    bit = (symbol >>> 7) & 1;
608                    price += RangeEncoder.getBitPrice(probs[subencoderIndex],
609                                                      bit);
610                    symbol <<= 1;
611                    offset &= ~(matchByte ^ symbol);
612                } while (symbol < (0x100 << 8));
613
614                return price;
615            }
616        }
617    }
618
619
620    class LengthEncoder extends LengthCoder {
621        /**
622         * The prices are updated after at least
623         * <code>PRICE_UPDATE_INTERVAL</code> many lengths
624         * have been encoded with the same posState.
625         */
626        private static final int PRICE_UPDATE_INTERVAL = 32; // FIXME?
627
628        private final int[] counters;
629        private final int[][] prices;
630
631        LengthEncoder(int pb, int niceLen) {
632            int posStates = 1 << pb;
633            counters = new int[posStates];
634
635            // Always allocate at least LOW_SYMBOLS + MID_SYMBOLS because
636            // it makes updatePrices slightly simpler. The prices aren't
637            // usually needed anyway if niceLen < 18.
638            int lenSymbols = Math.max(niceLen - MATCH_LEN_MIN + 1,
639                                      LOW_SYMBOLS + MID_SYMBOLS);
640            prices = new int[posStates][lenSymbols];
641        }
642
643        void reset() {
644            super.reset();
645
646            // Reset counters to zero to force price update before
647            // the prices are needed.
648            for (int i = 0; i < counters.length; ++i)
649                counters[i] = 0;
650        }
651
652        void encode(int len, int posState) {
653            len -= MATCH_LEN_MIN;
654
655            if (len < LOW_SYMBOLS) {
656                rc.encodeBit(choice, 0, 0);
657                rc.encodeBitTree(low[posState], len);
658            } else {
659                rc.encodeBit(choice, 0, 1);
660                len -= LOW_SYMBOLS;
661
662                if (len < MID_SYMBOLS) {
663                    rc.encodeBit(choice, 1, 0);
664                    rc.encodeBitTree(mid[posState], len);
665                } else {
666                    rc.encodeBit(choice, 1, 1);
667                    rc.encodeBitTree(high, len - MID_SYMBOLS);
668                }
669            }
670
671            --counters[posState];
672        }
673
674        int getPrice(int len, int posState) {
675            return prices[posState][len - MATCH_LEN_MIN];
676        }
677
678        void updatePrices() {
679            for (int posState = 0; posState < counters.length; ++posState) {
680                if (counters[posState] <= 0) {
681                    counters[posState] = PRICE_UPDATE_INTERVAL;
682                    updatePrices(posState);
683                }
684            }
685        }
686
687        private void updatePrices(int posState) {
688            int choice0Price = RangeEncoder.getBitPrice(choice[0], 0);
689
690            int i = 0;
691            for (; i < LOW_SYMBOLS; ++i)
692                prices[posState][i] = choice0Price
693                        + RangeEncoder.getBitTreePrice(low[posState], i);
694
695            choice0Price = RangeEncoder.getBitPrice(choice[0], 1);
696            int choice1Price = RangeEncoder.getBitPrice(choice[1], 0);
697
698            for (; i < LOW_SYMBOLS + MID_SYMBOLS; ++i)
699                prices[posState][i] = choice0Price + choice1Price
700                         + RangeEncoder.getBitTreePrice(mid[posState],
701                                                        i - LOW_SYMBOLS);
702
703            choice1Price = RangeEncoder.getBitPrice(choice[1], 1);
704
705            for (; i < prices[posState].length; ++i)
706                prices[posState][i] = choice0Price + choice1Price
707                         + RangeEncoder.getBitTreePrice(high, i - LOW_SYMBOLS
708                                                              - MID_SYMBOLS);
709        }
710    }
711}
712