1/*
2 * Javassist, a Java-bytecode translator toolkit.
3 * Copyright (C) 1999-2007 Shigeru Chiba. All Rights Reserved.
4 *
5 * The contents of this file are subject to the Mozilla Public License Version
6 * 1.1 (the "License"); you may not use this file except in compliance with
7 * the License.  Alternatively, the contents of this file may be used under
8 * the terms of the GNU Lesser General Public License Version 2.1 or later.
9 *
10 * Software distributed under the License is distributed on an "AS IS" basis,
11 * WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
12 * for the specific language governing rights and limitations under the
13 * License.
14 */
15
16package javassist.bytecode;
17
18/**
19 * Utility for computing <code>max_stack</code>.
20 */
21class CodeAnalyzer implements Opcode {
22    private ConstPool constPool;
23    private CodeAttribute codeAttr;
24
25    public CodeAnalyzer(CodeAttribute ca) {
26        codeAttr = ca;
27        constPool = ca.getConstPool();
28    }
29
30    public int computeMaxStack()
31        throws BadBytecode
32    {
33        /* d = stack[i]
34         * d == 0: not visited
35         * d > 0: the depth is d - 1 after executing the bytecode at i.
36         * d < 0: not visited. the initial depth (before execution) is 1 - d.
37         */
38        CodeIterator ci = codeAttr.iterator();
39        int length = ci.getCodeLength();
40        int[] stack = new int[length];
41        constPool = codeAttr.getConstPool();
42        initStack(stack, codeAttr);
43        boolean repeat;
44        do {
45            repeat = false;
46            for (int i = 0; i < length; ++i)
47                if (stack[i] < 0) {
48                    repeat = true;
49                    visitBytecode(ci, stack, i);
50                }
51        } while (repeat);
52
53        int maxStack = 1;
54        for (int i = 0; i < length; ++i)
55            if (stack[i] > maxStack)
56                maxStack = stack[i];
57
58        return maxStack - 1;    // the base is 1.
59    }
60
61    private void initStack(int[] stack, CodeAttribute ca) {
62        stack[0] = -1;
63        ExceptionTable et = ca.getExceptionTable();
64        if (et != null) {
65            int size = et.size();
66            for (int i = 0; i < size; ++i)
67                stack[et.handlerPc(i)] = -2;    // an exception is on stack
68        }
69    }
70
71    private void visitBytecode(CodeIterator ci, int[] stack, int index)
72        throws BadBytecode
73    {
74        int codeLength = stack.length;
75        ci.move(index);
76        int stackDepth = -stack[index];
77        int[] jsrDepth = new int[1];
78        jsrDepth[0] = -1;
79        while (ci.hasNext()) {
80            index = ci.next();
81            stack[index] = stackDepth;
82            int op = ci.byteAt(index);
83            stackDepth = visitInst(op, ci, index, stackDepth);
84            if (stackDepth < 1)
85                throw new BadBytecode("stack underflow at " + index);
86
87            if (processBranch(op, ci, index, codeLength, stack, stackDepth, jsrDepth))
88                break;
89
90            if (isEnd(op))     // return, ireturn, athrow, ...
91                break;
92
93            if (op == JSR || op == JSR_W)
94                --stackDepth;
95        }
96    }
97
98    private boolean processBranch(int opcode, CodeIterator ci, int index,
99                                  int codeLength, int[] stack, int stackDepth, int[] jsrDepth)
100        throws BadBytecode
101    {
102        if ((IFEQ <= opcode && opcode <= IF_ACMPNE)
103                            || opcode == IFNULL || opcode == IFNONNULL) {
104            int target = index + ci.s16bitAt(index + 1);
105            checkTarget(index, target, codeLength, stack, stackDepth);
106        }
107        else {
108            int target, index2;
109            switch (opcode) {
110            case GOTO :
111                target = index + ci.s16bitAt(index + 1);
112                checkTarget(index, target, codeLength, stack, stackDepth);
113                return true;
114            case GOTO_W :
115                target = index + ci.s32bitAt(index + 1);
116                checkTarget(index, target, codeLength, stack, stackDepth);
117                return true;
118            case JSR :
119            case JSR_W :
120                if (opcode == JSR)
121                    target = index + ci.s16bitAt(index + 1);
122                else
123                    target = index + ci.s32bitAt(index + 1);
124
125                checkTarget(index, target, codeLength, stack, stackDepth);
126                /*
127                 * It is unknown which RET comes back to this JSR.
128                 * So we assume that if the stack depth at one JSR instruction
129                 * is N, then it is also N at other JSRs and N - 1 at all RET
130                 * instructions.  Note that STACK_GROW[JSR] is 1 since it pushes
131                 * a return address on the operand stack.
132                 */
133                if (jsrDepth[0] < 0) {
134                    jsrDepth[0] = stackDepth;
135                    return false;
136                }
137                else if (stackDepth == jsrDepth[0])
138                    return false;
139                else
140                    throw new BadBytecode(
141                        "sorry, cannot compute this data flow due to JSR: "
142                            + stackDepth + "," + jsrDepth[0]);
143            case RET :
144                if (jsrDepth[0] < 0) {
145                    jsrDepth[0] = stackDepth + 1;
146                    return false;
147                }
148                else if (stackDepth + 1 == jsrDepth[0])
149                    return true;
150                else
151                    throw new BadBytecode(
152                        "sorry, cannot compute this data flow due to RET: "
153                            + stackDepth + "," + jsrDepth[0]);
154            case LOOKUPSWITCH :
155            case TABLESWITCH :
156                index2 = (index & ~3) + 4;
157                target = index + ci.s32bitAt(index2);
158                checkTarget(index, target, codeLength, stack, stackDepth);
159                if (opcode == LOOKUPSWITCH) {
160                    int npairs = ci.s32bitAt(index2 + 4);
161                    index2 += 12;
162                    for (int i = 0; i < npairs; ++i) {
163                        target = index + ci.s32bitAt(index2);
164                        checkTarget(index, target, codeLength,
165                                    stack, stackDepth);
166                        index2 += 8;
167                    }
168                }
169                else {
170                    int low = ci.s32bitAt(index2 + 4);
171                    int high = ci.s32bitAt(index2 + 8);
172                    int n = high - low + 1;
173                    index2 += 12;
174                    for (int i = 0; i < n; ++i) {
175                        target = index + ci.s32bitAt(index2);
176                        checkTarget(index, target, codeLength,
177                                    stack, stackDepth);
178                        index2 += 4;
179                    }
180                }
181
182                return true;    // always branch.
183            }
184        }
185
186        return false;   // may not branch.
187    }
188
189    private void checkTarget(int opIndex, int target, int codeLength,
190                             int[] stack, int stackDepth)
191        throws BadBytecode
192    {
193        if (target < 0 || codeLength <= target)
194            throw new BadBytecode("bad branch offset at " + opIndex);
195
196        int d = stack[target];
197        if (d == 0)
198            stack[target] = -stackDepth;
199        else if (d != stackDepth && d != -stackDepth)
200            throw new BadBytecode("verification error (" + stackDepth +
201                                  "," + d + ") at " + opIndex);
202    }
203
204    private static boolean isEnd(int opcode) {
205        return (IRETURN <= opcode && opcode <= RETURN) || opcode == ATHROW;
206    }
207
208    /**
209     * Visits an instruction.
210     */
211    private int visitInst(int op, CodeIterator ci, int index, int stack)
212        throws BadBytecode
213    {
214        String desc;
215        switch (op) {
216        case GETFIELD :
217            stack += getFieldSize(ci, index) - 1;
218            break;
219        case PUTFIELD :
220            stack -= getFieldSize(ci, index) + 1;
221            break;
222        case GETSTATIC :
223            stack += getFieldSize(ci, index);
224            break;
225        case PUTSTATIC :
226            stack -= getFieldSize(ci, index);
227            break;
228        case INVOKEVIRTUAL :
229        case INVOKESPECIAL :
230            desc = constPool.getMethodrefType(ci.u16bitAt(index + 1));
231            stack += Descriptor.dataSize(desc) - 1;
232            break;
233        case INVOKESTATIC :
234            desc = constPool.getMethodrefType(ci.u16bitAt(index + 1));
235            stack += Descriptor.dataSize(desc);
236            break;
237        case INVOKEINTERFACE :
238            desc = constPool.getInterfaceMethodrefType(
239                                            ci.u16bitAt(index + 1));
240            stack += Descriptor.dataSize(desc) - 1;
241            break;
242        case ATHROW :
243            stack = 1;      // the stack becomes empty (1 means no values).
244            break;
245        case MULTIANEWARRAY :
246            stack += 1 - ci.byteAt(index + 3);
247            break;
248        case WIDE :
249            op = ci.byteAt(index + 1);
250            // don't break here.
251        default :
252            stack += STACK_GROW[op];
253        }
254
255        return stack;
256    }
257
258    private int getFieldSize(CodeIterator ci, int index) {
259        String desc = constPool.getFieldrefType(ci.u16bitAt(index + 1));
260        return Descriptor.dataSize(desc);
261    }
262}
263