1/*
2 * Javassist, a Java-bytecode translator toolkit.
3 * Copyright (C) 1999-2007 Shigeru Chiba. All Rights Reserved.
4 *
5 * The contents of this file are subject to the Mozilla Public License Version
6 * 1.1 (the "License"); you may not use this file except in compliance with
7 * the License.  Alternatively, the contents of this file may be used under
8 * the terms of the GNU Lesser General Public License Version 2.1 or later.
9 *
10 * Software distributed under the License is distributed on an "AS IS" basis,
11 * WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
12 * for the specific language governing rights and limitations under the
13 * License.
14 */
15
16package javassist.bytecode.stackmap;
17
18import javassist.bytecode.*;
19import java.util.HashMap;
20import java.util.ArrayList;
21
22/**
23 * A basic block is a sequence of bytecode that does not contain jump/branch
24 * instructions except at the last bytecode.
25 * Since Java6 or later does not allow JSR, this class deals with JSR as a
26 * non-branch instruction.
27 */
28public class BasicBlock {
29    public int position, length;
30    public int incoming;        // the number of incoming branches.
31    public BasicBlock[] exit;   // null if the block is a leaf.
32    public boolean stop;        // true if the block ends with an unconditional jump.
33    public Catch toCatch;
34
35    protected BasicBlock(int pos) {
36        position = pos;
37        length = 0;
38        incoming = 0;
39    }
40
41    public static BasicBlock find(BasicBlock[] blocks, int pos)
42        throws BadBytecode
43    {
44        for (int i = 0; i < blocks.length; i++) {
45            int iPos = blocks[i].position;
46            if (iPos <= pos && pos < iPos + blocks[i].length)
47                return blocks[i];
48        }
49
50        throw new BadBytecode("no basic block at " + pos);
51    }
52
53    public static class Catch {
54        Catch next;
55        BasicBlock body;
56        int typeIndex;
57        Catch(BasicBlock b, int i, Catch c) {
58            body = b;
59            typeIndex = i;
60            next = c;
61        }
62    }
63
64    public String toString() {
65        StringBuffer sbuf = new StringBuffer();
66        String cname = this.getClass().getName();
67        int i = cname.lastIndexOf('.');
68        sbuf.append(i < 0 ? cname : cname.substring(i + 1));
69        sbuf.append("[");
70        toString2(sbuf);
71        sbuf.append("]");
72        return sbuf.toString();
73    }
74
75    protected void toString2(StringBuffer sbuf) {
76        sbuf.append("pos=").append(position).append(", len=")
77            .append(length).append(", in=").append(incoming)
78            .append(", exit{");
79        if (exit != null) {
80            for (int i = 0; i < exit.length; i++)
81                sbuf.append(exit[i].position).append(", ");
82        }
83
84        sbuf.append("}, {");
85        Catch th = toCatch;
86        while (th != null) {
87            sbuf.append("(").append(th.body.position).append(", ")
88                .append(th.typeIndex).append("), ");
89            th = th.next;
90        }
91
92        sbuf.append("}");
93    }
94
95    static class Mark implements Comparable {
96        int position;
97        BasicBlock block;
98        BasicBlock[] jump;
99        boolean alwaysJmp;     // true if a unconditional branch.
100        int size;       // 0 unless the mark indicates RETURN etc.
101        Catch catcher;
102
103        Mark(int p) {
104            position = p;
105            block = null;
106            jump = null;
107            alwaysJmp = false;
108            size = 0;
109            catcher = null;
110        }
111
112        public int compareTo(Object obj) {
113            if (obj instanceof Mark) {
114                int pos = ((Mark)obj).position;
115                return position - pos;
116            }
117
118            return -1;
119        }
120
121        void setJump(BasicBlock[] bb, int s, boolean always) {
122            jump = bb;
123            size = s;
124            alwaysJmp = always;
125        }
126    }
127
128    public static class Maker {
129        /* Override these two methods if a subclass of BasicBlock must be
130         * instantiated.
131         */
132        protected BasicBlock makeBlock(int pos) {
133            return new BasicBlock(pos);
134        }
135
136        protected BasicBlock[] makeArray(int size) {
137            return new BasicBlock[size];
138        }
139
140        private BasicBlock[] makeArray(BasicBlock b) {
141            BasicBlock[] array = makeArray(1);
142            array[0] = b;
143            return array;
144        }
145
146        private BasicBlock[] makeArray(BasicBlock b1, BasicBlock b2) {
147            BasicBlock[] array = makeArray(2);
148            array[0] = b1;
149            array[1] = b2;
150            return array;
151        }
152
153        public BasicBlock[] make(MethodInfo minfo) throws BadBytecode {
154            CodeAttribute ca = minfo.getCodeAttribute();
155            if (ca == null)
156                return null;
157
158            CodeIterator ci = ca.iterator();
159            return make(ci, 0, ci.getCodeLength(), ca.getExceptionTable());
160        }
161
162        public BasicBlock[] make(CodeIterator ci, int begin, int end,
163                                 ExceptionTable et)
164            throws BadBytecode
165        {
166            HashMap marks = makeMarks(ci, begin, end, et);
167            BasicBlock[] bb = makeBlocks(marks);
168            addCatchers(bb, et);
169            return bb;
170        }
171
172        /* Branch target
173         */
174        private Mark makeMark(HashMap table, int pos) {
175            return makeMark0(table, pos, true, true);
176        }
177
178        /* Branch instruction.
179         * size > 0
180         */
181        private Mark makeMark(HashMap table, int pos, BasicBlock[] jump,
182                              int size, boolean always) {
183            Mark m = makeMark0(table, pos, false, false);
184            m.setJump(jump, size, always);
185            return m;
186        }
187
188        private Mark makeMark0(HashMap table, int pos,
189                               boolean isBlockBegin, boolean isTarget) {
190            Integer p = new Integer(pos);
191            Mark m = (Mark)table.get(p);
192            if (m == null) {
193                m = new Mark(pos);
194                table.put(p, m);
195            }
196
197            if (isBlockBegin) {
198                if (m.block == null)
199                    m.block = makeBlock(pos);
200
201                if (isTarget)
202                    m.block.incoming++;
203            }
204
205            return m;
206        }
207
208        private HashMap makeMarks(CodeIterator ci, int begin, int end,
209                                  ExceptionTable et)
210            throws BadBytecode
211        {
212            ci.begin();
213            ci.move(begin);
214            HashMap marks = new HashMap();
215            while (ci.hasNext()) {
216                int index = ci.next();
217                if (index >= end)
218                    break;
219
220                int op = ci.byteAt(index);
221                if ((Opcode.IFEQ <= op && op <= Opcode.IF_ACMPNE)
222                        || op == Opcode.IFNULL || op == Opcode.IFNONNULL) {
223                    Mark to = makeMark(marks, index + ci.s16bitAt(index + 1));
224                    Mark next = makeMark(marks, index + 3);
225                    makeMark(marks, index, makeArray(to.block, next.block), 3, false);
226                }
227                else if (Opcode.GOTO <= op && op <= Opcode.LOOKUPSWITCH)
228                    switch (op) {
229                    case Opcode.GOTO :
230                        makeGoto(marks, index, index + ci.s16bitAt(index + 1), 3);
231                        break;
232                    case Opcode.JSR :
233                        makeJsr(marks, index, index + ci.s16bitAt(index + 1), 3);
234                        break;
235                    case Opcode.RET :
236                        makeMark(marks, index, null, 2, true);
237                        break;
238                    case Opcode.TABLESWITCH : {
239                        int pos = (index & ~3) + 4;
240                        int low = ci.s32bitAt(pos + 4);
241                        int high = ci.s32bitAt(pos + 8);
242                        int ncases = high - low + 1;
243                        BasicBlock[] to = makeArray(ncases + 1);
244                        to[0] = makeMark(marks, index + ci.s32bitAt(pos)).block;   // default branch target
245                        int p = pos + 12;
246                        int n = p + ncases * 4;
247                        int k = 1;
248                        while (p < n) {
249                            to[k++] = makeMark(marks, index + ci.s32bitAt(p)).block;
250                            p += 4;
251                        }
252                        makeMark(marks, index, to, n - index, true);
253                        break; }
254                    case Opcode.LOOKUPSWITCH : {
255                        int pos = (index & ~3) + 4;
256                        int ncases = ci.s32bitAt(pos + 4);
257                        BasicBlock[] to = makeArray(ncases + 1);
258                        to[0] = makeMark(marks, index + ci.s32bitAt(pos)).block;   // default branch target
259                        int p = pos + 8 + 4;
260                        int n = p + ncases * 8 - 4;
261                        int k = 1;
262                        while (p < n) {
263                            to[k++] = makeMark(marks, index + ci.s32bitAt(p)).block;
264                            p += 8;
265                        }
266                        makeMark(marks, index, to, n - index, true);
267                        break; }
268                    }
269                else if ((Opcode.IRETURN <= op && op <= Opcode.RETURN) || op == Opcode.ATHROW)
270                    makeMark(marks, index, null, 1, true);
271                else if (op == Opcode.GOTO_W)
272                    makeGoto(marks, index, index + ci.s32bitAt(index + 1), 5);
273                else if (op == Opcode.JSR_W)
274                    makeJsr(marks, index, index + ci.s32bitAt(index + 1), 5);
275                else if (op == Opcode.WIDE && ci.byteAt(index + 1) == Opcode.RET)
276                    makeMark(marks, index, null, 1, true);
277            }
278
279            if (et != null) {
280                int i = et.size();
281                while (--i >= 0) {
282                    makeMark0(marks, et.startPc(i), true, false);
283                    makeMark(marks, et.handlerPc(i));
284                }
285            }
286
287            return marks;
288        }
289
290        private void makeGoto(HashMap marks, int pos, int target, int size) {
291            Mark to = makeMark(marks, target);
292            BasicBlock[] jumps = makeArray(to.block);
293            makeMark(marks, pos, jumps, size, true);
294        }
295
296        /**
297         * We ignore JSR since Java 6 or later does not allow it.
298         */
299        protected void makeJsr(HashMap marks, int pos, int target, int size) {
300        /*
301            Mark to = makeMark(marks, target);
302            Mark next = makeMark(marks, pos + size);
303            BasicBlock[] jumps = makeArray(to.block, next.block);
304            makeMark(marks, pos, jumps, size, false);
305        */
306        }
307
308        private BasicBlock[] makeBlocks(HashMap markTable) {
309            Mark[] marks = (Mark[])markTable.values()
310                                            .toArray(new Mark[markTable.size()]);
311            java.util.Arrays.sort(marks);
312            ArrayList blocks = new ArrayList();
313            int i = 0;
314            BasicBlock prev;
315            if (marks.length > 0 && marks[0].position == 0 && marks[0].block != null)
316                prev = getBBlock(marks[i++]);
317            else
318                prev = makeBlock(0);
319
320            blocks.add(prev);
321            while (i < marks.length) {
322                Mark m = marks[i++];
323                BasicBlock bb = getBBlock(m);
324                if (bb == null) {
325                    // the mark indicates a branch instruction
326                    if (prev.length > 0) {
327                        // the previous mark already has exits.
328                        prev = makeBlock(prev.position + prev.length);
329                        blocks.add(prev);
330                    }
331
332                    prev.length = m.position + m.size - prev.position;
333                    prev.exit = m.jump;
334                    prev.stop = m.alwaysJmp;
335                }
336                else {
337                    // the mark indicates a branch target
338                    if (prev.length == 0) {
339                        prev.length = m.position - prev.position;
340                        bb.incoming++;
341                        prev.exit = makeArray(bb);
342                    }
343                    else {
344                        // the previous mark already has exits.
345                        int prevPos = prev.position;
346                        if (prevPos + prev.length < m.position) {
347                            prev = makeBlock(prevPos + prev.length);
348                            prev.length = m.position - prevPos;
349                            // the incoming flow from dead code is not counted
350                            // bb.incoming++;
351                            prev.exit = makeArray(bb);
352                        }
353                    }
354
355                    blocks.add(bb);
356                    prev = bb;
357                }
358            }
359
360            return (BasicBlock[])blocks.toArray(makeArray(blocks.size()));
361        }
362
363        private static BasicBlock getBBlock(Mark m) {
364            BasicBlock b = m.block;
365            if (b != null && m.size > 0) {
366                b.exit = m.jump;
367                b.length = m.size;
368                b.stop = m.alwaysJmp;
369            }
370
371            return b;
372        }
373
374        private void addCatchers(BasicBlock[] blocks, ExceptionTable et)
375            throws BadBytecode
376        {
377            if (et == null)
378                return;
379
380            int i = et.size();
381            while (--i >= 0) {
382                BasicBlock handler = find(blocks, et.handlerPc(i));
383                int start = et.startPc(i);
384                int end = et.endPc(i);
385                int type = et.catchType(i);
386                handler.incoming--;
387                for (int k = 0; k < blocks.length; k++) {
388                    BasicBlock bb = blocks[k];
389                    int iPos = bb.position;
390                    if (start <= iPos && iPos < end) {
391                        bb.toCatch = new Catch(handler, type, bb.toCatch);
392                        handler.incoming++;
393                    }
394                }
395            }
396        }
397    }
398}
399