1/*
2 * LZMADecoder
3 *
4 * Authors: Lasse Collin <lasse.collin@tukaani.org>
5 *          Igor Pavlov <http://7-zip.org/>
6 *
7 * This file has been put into the public domain.
8 * You can do whatever you want with this file.
9 */
10
11package org.tukaani.xz.lzma;
12
13import java.io.IOException;
14import org.tukaani.xz.lz.LZDecoder;
15import org.tukaani.xz.rangecoder.RangeDecoder;
16
17public final class LZMADecoder extends LZMACoder {
18    private final LZDecoder lz;
19    private final RangeDecoder rc;
20    private final LiteralDecoder literalDecoder;
21    private final LengthDecoder matchLenDecoder = new LengthDecoder();
22    private final LengthDecoder repLenDecoder = new LengthDecoder();
23
24    public LZMADecoder(LZDecoder lz, RangeDecoder rc, int lc, int lp, int pb) {
25        super(pb);
26        this.lz = lz;
27        this.rc = rc;
28        this.literalDecoder = new LiteralDecoder(lc, lp);
29        reset();
30    }
31
32    public void reset() {
33        super.reset();
34        literalDecoder.reset();
35        matchLenDecoder.reset();
36        repLenDecoder.reset();
37    }
38
39    /**
40     * Returns true if LZMA end marker was detected. It is encoded as
41     * the maximum match distance which with signed ints becomes -1. This
42     * function is needed only for LZMA1. LZMA2 doesn't use the end marker
43     * in the LZMA layer.
44     */
45    public boolean endMarkerDetected() {
46        return reps[0] == -1;
47    }
48
49    public void decode() throws IOException {
50        lz.repeatPending();
51
52        while (lz.hasSpace()) {
53            int posState = lz.getPos() & posMask;
54
55            if (rc.decodeBit(isMatch[state.get()], posState) == 0) {
56                literalDecoder.decode();
57            } else {
58                int len = rc.decodeBit(isRep, state.get()) == 0
59                          ? decodeMatch(posState)
60                          : decodeRepMatch(posState);
61
62                // NOTE: With LZMA1 streams that have the end marker,
63                // this will throw CorruptedInputException. LZMAInputStream
64                // handles it specially.
65                lz.repeat(reps[0], len);
66            }
67        }
68
69        rc.normalize();
70    }
71
72    private int decodeMatch(int posState) throws IOException {
73        state.updateMatch();
74
75        reps[3] = reps[2];
76        reps[2] = reps[1];
77        reps[1] = reps[0];
78
79        int len = matchLenDecoder.decode(posState);
80        int distSlot = rc.decodeBitTree(distSlots[getDistState(len)]);
81
82        if (distSlot < DIST_MODEL_START) {
83            reps[0] = distSlot;
84        } else {
85            int limit = (distSlot >> 1) - 1;
86            reps[0] = (2 | (distSlot & 1)) << limit;
87
88            if (distSlot < DIST_MODEL_END) {
89                reps[0] |= rc.decodeReverseBitTree(
90                        distSpecial[distSlot - DIST_MODEL_START]);
91            } else {
92                reps[0] |= rc.decodeDirectBits(limit - ALIGN_BITS)
93                           << ALIGN_BITS;
94                reps[0] |= rc.decodeReverseBitTree(distAlign);
95            }
96        }
97
98        return len;
99    }
100
101    private int decodeRepMatch(int posState) throws IOException {
102        if (rc.decodeBit(isRep0, state.get()) == 0) {
103            if (rc.decodeBit(isRep0Long[state.get()], posState) == 0) {
104                state.updateShortRep();
105                return 1;
106            }
107        } else {
108            int tmp;
109
110            if (rc.decodeBit(isRep1, state.get()) == 0) {
111                tmp = reps[1];
112            } else {
113                if (rc.decodeBit(isRep2, state.get()) == 0) {
114                    tmp = reps[2];
115                } else {
116                    tmp = reps[3];
117                    reps[3] = reps[2];
118                }
119
120                reps[2] = reps[1];
121            }
122
123            reps[1] = reps[0];
124            reps[0] = tmp;
125        }
126
127        state.updateLongRep();
128
129        return repLenDecoder.decode(posState);
130    }
131
132
133    private class LiteralDecoder extends LiteralCoder {
134        private final LiteralSubdecoder[] subdecoders;
135
136        LiteralDecoder(int lc, int lp) {
137            super(lc, lp);
138
139            subdecoders = new LiteralSubdecoder[1 << (lc + lp)];
140            for (int i = 0; i < subdecoders.length; ++i)
141                subdecoders[i] = new LiteralSubdecoder();
142        }
143
144        void reset() {
145            for (int i = 0; i < subdecoders.length; ++i)
146                subdecoders[i].reset();
147        }
148
149        void decode() throws IOException {
150            int i = getSubcoderIndex(lz.getByte(0), lz.getPos());
151            subdecoders[i].decode();
152        }
153
154
155        private class LiteralSubdecoder extends LiteralSubcoder {
156            void decode() throws IOException {
157                int symbol = 1;
158
159                if (state.isLiteral()) {
160                    do {
161                        symbol = (symbol << 1) | rc.decodeBit(probs, symbol);
162                    } while (symbol < 0x100);
163
164                } else {
165                    int matchByte = lz.getByte(reps[0]);
166                    int offset = 0x100;
167                    int matchBit;
168                    int bit;
169
170                    do {
171                        matchByte <<= 1;
172                        matchBit = matchByte & offset;
173                        bit = rc.decodeBit(probs, offset + matchBit + symbol);
174                        symbol = (symbol << 1) | bit;
175                        offset &= (0 - bit) ^ ~matchBit;
176                    } while (symbol < 0x100);
177                }
178
179                lz.putByte((byte)symbol);
180                state.updateLiteral();
181            }
182        }
183    }
184
185
186    private class LengthDecoder extends LengthCoder {
187        int decode(int posState) throws IOException {
188            if (rc.decodeBit(choice, 0) == 0)
189                return rc.decodeBitTree(low[posState]) + MATCH_LEN_MIN;
190
191            if (rc.decodeBit(choice, 1) == 0)
192                return rc.decodeBitTree(mid[posState])
193                       + MATCH_LEN_MIN + LOW_SYMBOLS;
194
195            return rc.decodeBitTree(high)
196                   + MATCH_LEN_MIN + LOW_SYMBOLS + MID_SYMBOLS;
197        }
198    }
199}
200