1package test.javassist.bytecode.analysis;
2
3import java.io.Serializable;
4import java.util.ArrayList;
5import java.util.Iterator;
6
7import javassist.ClassPool;
8import javassist.CtClass;
9import javassist.CtMethod;
10import javassist.bytecode.AccessFlag;
11import javassist.bytecode.BadBytecode;
12import javassist.bytecode.Bytecode;
13import javassist.bytecode.CodeIterator;
14import javassist.bytecode.MethodInfo;
15import javassist.bytecode.Opcode;
16import javassist.bytecode.analysis.Analyzer;
17import javassist.bytecode.analysis.Frame;
18import javassist.bytecode.analysis.Type;
19import junit.framework.TestCase;
20
21/**
22 * Tests Analyzer
23 *
24 * @author Jason T. Greene
25 */
26public class AnalyzerTest extends TestCase {
27
28    public void testCommonSupperArray() throws Exception {
29        ClassPool pool = ClassPool.getDefault();
30        CtClass clazz = pool.get(getClass().getName() + "$Dummy");
31        CtMethod method = clazz.getDeclaredMethod("commonSuperArray");
32        verifyArrayLoad(clazz, method, "java.lang.Number");
33    }
34
35    public  void testCommonInterfaceArray() throws Exception {
36        ClassPool pool = ClassPool.getDefault();
37        CtClass clazz = pool.get(getClass().getName() + "$Dummy");
38        CtMethod method = clazz.getDeclaredMethod("commonInterfaceArray");
39        verifyArrayLoad(clazz, method, "java.io.Serializable");
40    }
41
42    public  void testSharedInterfaceAndSuperClass() throws Exception {
43        CtMethod method = ClassPool.getDefault().getMethod(
44                getClass().getName() + "$Dummy", "sharedInterfaceAndSuperClass");
45        verifyReturn(method, "java.io.Serializable");
46
47        method = ClassPool.getDefault().getMethod(
48                getClass().getName() + "$Dummy", "sharedOffsetInterfaceAndSuperClass");
49        verifyReturn(method, "java.io.Serializable");
50
51        method = ClassPool.getDefault().getMethod(
52                getClass().getName() + "$Dummy", "sharedSuperWithSharedInterface");
53        verifyReturn(method, getClass().getName() + "$Dummy$A");
54    }
55
56    public  void testArrayDifferentDims() throws Exception {
57        CtMethod method = ClassPool.getDefault().getMethod(
58                getClass().getName() + "$Dummy", "arrayDifferentDimensions1");
59        verifyReturn(method, "java.lang.Cloneable[]");
60
61        method = ClassPool.getDefault().getMethod(
62                getClass().getName() + "$Dummy", "arrayDifferentDimensions2");
63        verifyReturn(method, "java.lang.Object[][]");
64    }
65
66    public  void testReusedLocalMerge() throws Exception {
67        CtMethod method = ClassPool.getDefault().getMethod(
68                getClass().getName() + "$Dummy", "reusedLocalMerge");
69
70        MethodInfo info = method.getMethodInfo2();
71        Analyzer analyzer = new Analyzer();
72        Frame[] frames = analyzer.analyze(method.getDeclaringClass(), info);
73        assertNotNull(frames);
74        int pos = findOpcode(info, Opcode.RETURN);
75        Frame frame = frames[pos];
76        assertEquals("java.lang.Object", frame.getLocal(2).getCtClass().getName());
77    }
78
79    private static int findOpcode(MethodInfo info, int opcode) throws BadBytecode {
80        CodeIterator iter = info.getCodeAttribute().iterator();
81
82        // find return
83        int pos = 0;
84        while (iter.hasNext()) {
85            pos = iter.next();
86            if (iter.byteAt(pos) == opcode)
87                break;
88        }
89        return pos;
90    }
91
92
93    private static void verifyReturn(CtMethod method, String expected) throws BadBytecode {
94        MethodInfo info = method.getMethodInfo2();
95        CodeIterator iter = info.getCodeAttribute().iterator();
96
97        // find areturn
98        int pos = 0;
99        while (iter.hasNext()) {
100            pos = iter.next();
101            if (iter.byteAt(pos) == Opcode.ARETURN)
102                break;
103        }
104
105        Analyzer analyzer = new Analyzer();
106        Frame[] frames = analyzer.analyze(method.getDeclaringClass(), info);
107        assertNotNull(frames);
108        Frame frame = frames[pos];
109        assertEquals(expected, frame.peek().getCtClass().getName());
110    }
111
112    private static void verifyArrayLoad(CtClass clazz, CtMethod method, String component)
113            throws BadBytecode {
114        MethodInfo info = method.getMethodInfo2();
115        CodeIterator iter = info.getCodeAttribute().iterator();
116
117        // find aaload
118        int pos = 0;
119        while (iter.hasNext()) {
120            pos = iter.next();
121            if (iter.byteAt(pos) == Opcode.AALOAD)
122                break;
123        }
124
125        Analyzer analyzer = new Analyzer();
126        Frame[] frames = analyzer.analyze(clazz, info);
127        assertNotNull(frames);
128        Frame frame = frames[pos];
129        assertNotNull(frame);
130
131        Type type = frame.getStack(frame.getTopIndex() - 1);
132        assertEquals(component + "[]", type.getCtClass().getName());
133
134        pos = iter.next();
135        frame = frames[pos];
136        assertNotNull(frame);
137
138        type = frame.getStack(frame.getTopIndex());
139        assertEquals(component, type.getCtClass().getName());
140    }
141
142    private static void addJump(Bytecode code, int opcode, int pos) {
143        int current = code.currentPc();
144        code.addOpcode(opcode);
145        code.addIndex(pos - current);
146    }
147
148    public void testDeadCode() throws Exception {
149        CtMethod method = generateDeadCode(ClassPool.getDefault());
150        Analyzer analyzer = new Analyzer();
151        Frame[] frames = analyzer.analyze(method.getDeclaringClass(), method.getMethodInfo2());
152        assertNotNull(frames);
153        assertNull(frames[4]);
154        assertNotNull(frames[5]);
155        verifyReturn(method, "java.lang.String");
156    }
157
158    public void testInvalidCode() throws Exception {
159        CtMethod method = generateInvalidCode(ClassPool.getDefault());
160        Analyzer analyzer = new Analyzer();
161        try {
162            analyzer.analyze(method.getDeclaringClass(), method.getMethodInfo2());
163        } catch (BadBytecode e) {
164            return;
165        }
166
167        fail("Invalid code should have triggered a BadBytecode exception");
168    }
169
170    public void testCodeFalloff() throws Exception {
171        CtMethod method = generateCodeFalloff(ClassPool.getDefault());
172        Analyzer analyzer = new Analyzer();
173        try {
174            analyzer.analyze(method.getDeclaringClass(), method.getMethodInfo2());
175        } catch (BadBytecode e) {
176            return;
177        }
178
179        fail("Code falloff should have triggered a BadBytecode exception");
180    }
181
182    public void testJsrMerge() throws Exception {
183        CtMethod method = generateJsrMerge(ClassPool.getDefault());
184        Analyzer analyzer = new Analyzer();
185        analyzer.analyze(method.getDeclaringClass(), method.getMethodInfo2());
186        verifyReturn(method, "java.lang.String");
187    }
188
189    public void testJsrMerge2() throws Exception {
190        CtMethod method = generateJsrMerge2(ClassPool.getDefault());
191        Analyzer analyzer = new Analyzer();
192        analyzer.analyze(method.getDeclaringClass(), method.getMethodInfo2());
193        verifyReturn(method, "java.lang.String");
194    }
195
196    private CtMethod generateDeadCode(ClassPool pool) throws Exception {
197        CtClass clazz = pool.makeClass(getClass().getName() + "$Generated0");
198        CtClass stringClass = pool.get("java.lang.String");
199        CtMethod method = new CtMethod(stringClass, "foo", new CtClass[0], clazz);
200        MethodInfo info = method.getMethodInfo2();
201        info.setAccessFlags(AccessFlag.PUBLIC | AccessFlag.STATIC);
202        Bytecode code = new Bytecode(info.getConstPool(), 1, 2);
203        /* 0 */ code.addIconst(1);
204        /* 1 */ addJump(code, Opcode.GOTO, 5);
205        /* 4 */ code.addIconst(0); // DEAD
206        /* 5 */ code.addIconst(1);
207        /* 6 */ code.addInvokestatic(stringClass, "valueOf", stringClass, new CtClass[]{CtClass.intType});
208        /* 9 */ code.addOpcode(Opcode.ARETURN);
209        info.setCodeAttribute(code.toCodeAttribute());
210        clazz.addMethod(method);
211
212        return method;
213    }
214
215    private CtMethod generateInvalidCode(ClassPool pool) throws Exception {
216        CtClass clazz = pool.makeClass(getClass().getName() + "$Generated4");
217        CtClass intClass = pool.get("java.lang.Integer");
218        CtClass stringClass = pool.get("java.lang.String");
219        CtMethod method = new CtMethod(stringClass, "foo", new CtClass[0], clazz);
220        MethodInfo info = method.getMethodInfo2();
221        info.setAccessFlags(AccessFlag.PUBLIC | AccessFlag.STATIC);
222        Bytecode code = new Bytecode(info.getConstPool(), 1, 2);
223        /* 0 */ code.addIconst(1);
224        /* 1 */ code.addInvokestatic(intClass, "valueOf", intClass, new CtClass[]{CtClass.intType});
225        /* 4 */ code.addOpcode(Opcode.ARETURN);
226        info.setCodeAttribute(code.toCodeAttribute());
227        clazz.addMethod(method);
228
229        return method;
230    }
231
232
233    private CtMethod generateCodeFalloff(ClassPool pool) throws Exception {
234        CtClass clazz = pool.makeClass(getClass().getName() + "$Generated3");
235        CtClass stringClass = pool.get("java.lang.String");
236        CtMethod method = new CtMethod(stringClass, "foo", new CtClass[0], clazz);
237        MethodInfo info = method.getMethodInfo2();
238        info.setAccessFlags(AccessFlag.PUBLIC | AccessFlag.STATIC);
239        Bytecode code = new Bytecode(info.getConstPool(), 1, 2);
240        /* 0 */ code.addIconst(1);
241        /* 1 */ code.addInvokestatic(stringClass, "valueOf", stringClass, new CtClass[]{CtClass.intType});
242        info.setCodeAttribute(code.toCodeAttribute());
243        clazz.addMethod(method);
244
245        return method;
246    }
247
248    private CtMethod generateJsrMerge(ClassPool pool) throws Exception {
249        CtClass clazz = pool.makeClass(getClass().getName() + "$Generated1");
250        CtClass stringClass = pool.get("java.lang.String");
251        CtMethod method = new CtMethod(stringClass, "foo", new CtClass[0], clazz);
252        MethodInfo info = method.getMethodInfo2();
253        info.setAccessFlags(AccessFlag.PUBLIC | AccessFlag.STATIC);
254        Bytecode code = new Bytecode(info.getConstPool(), 1, 2);
255        /* 0 */ code.addIconst(5);
256        /* 1 */ code.addIstore(0);
257        /* 2 */ addJump(code, Opcode.JSR, 7);
258        /* 5 */ code.addAload(0);
259        /* 6 */ code.addOpcode(Opcode.ARETURN);
260        /* 7 */ code.addAstore(1);
261        /* 8 */ code.addIconst(3);
262        /* 9 */ code.addInvokestatic(stringClass, "valueOf", stringClass, new CtClass[]{CtClass.intType});
263        /* 12 */ code.addAstore(0);
264        /* 12 */ code.addRet(1);
265        info.setCodeAttribute(code.toCodeAttribute());
266        clazz.addMethod(method);
267        //System.out.println(clazz.toClass().getMethod("foo", new Class[0]).invoke(null, new Object[0]));
268
269        return method;
270    }
271
272    private CtMethod generateJsrMerge2(ClassPool pool) throws Exception {
273        CtClass clazz = pool.makeClass(getClass().getName() + "$Generated2");
274        CtClass stringClass = pool.get("java.lang.String");
275        CtMethod method = new CtMethod(stringClass, "foo", new CtClass[0], clazz);
276        MethodInfo info = method.getMethodInfo2();
277        info.setAccessFlags(AccessFlag.PUBLIC | AccessFlag.STATIC);
278        Bytecode code = new Bytecode(info.getConstPool(), 1, 2);
279        /* 0 */ addJump(code, Opcode.JSR, 5);
280        /* 3 */ code.addAload(0);
281        /* 4 */ code.addOpcode(Opcode.ARETURN);
282        /* 5 */ code.addAstore(1);
283        /* 6 */ code.addIconst(4);
284        /* 7 */ code.addInvokestatic(stringClass, "valueOf", stringClass, new CtClass[]{CtClass.intType});
285        /* 10 */ code.addAstore(0);
286        /* 11 */ code.addRet(1);
287        info.setCodeAttribute(code.toCodeAttribute());
288        clazz.addMethod(method);
289
290        return method;
291    }
292
293    public static class Dummy {
294        public Serializable commonSuperArray(int x) {
295            Number[] n;
296
297            if (x > 5) {
298                n = new Long[10];
299            } else {
300                n = new Double[5];
301            }
302
303            return n[x];
304        }
305
306        public Serializable commonInterfaceArray(int x) {
307            Serializable[] n;
308
309            if (x > 5) {
310                n = new Long[10];
311            } else if (x > 3) {
312                n = new Double[5];
313            } else {
314                n = new String[3];
315            }
316
317            return n[x];
318        }
319
320
321        public static class A {};
322        public static class B1 extends A implements Serializable {};
323        public static class B2 extends A implements Serializable {};
324        public static class A2 implements Serializable, Cloneable {};
325        public static class A3 implements Serializable, Cloneable {};
326
327        public static class B3 extends A {};
328        public static class C31 extends B3 implements Serializable {};
329
330
331        public void dummy(Serializable s) {}
332
333        public Object sharedInterfaceAndSuperClass(int x) {
334            Serializable s;
335
336            if (x > 5) {
337                s = new B1();
338            } else {
339                s = new B2();
340            }
341
342            dummy(s);
343
344            return s;
345        }
346
347        public A sharedSuperWithSharedInterface(int x) {
348            A a;
349
350            if (x > 5) {
351                a = new B1();
352            } else if (x > 3) {
353                a = new B2();
354            } else {
355                a = new C31();
356            }
357
358            return a;
359        }
360
361
362        public void reusedLocalMerge() {
363             ArrayList list = new ArrayList();
364             try {
365               Iterator i = list.iterator();
366               i.hasNext();
367             } catch (Exception e) {
368             }
369        }
370
371        public Object sharedOffsetInterfaceAndSuperClass(int x) {
372            Serializable s;
373
374            if (x > 5) {
375                s = new B1();
376            } else {
377                s = new C31();
378            }
379
380            dummy(s);
381
382            return s;
383        }
384
385
386        public Object arrayDifferentDimensions1(int x) {
387            Object[] n;
388
389            if ( x > 5) {
390                n = new Number[1][1];
391            } else {
392                n = new Cloneable[1];
393            }
394
395
396            return n;
397        }
398
399        public Object arrayDifferentDimensions2(int x) {
400            Object[] n;
401
402            if ( x> 5) {
403                n = new String[1][1];
404            } else {
405                n = new Number[1][1][1][1];
406            }
407
408            return n;
409        }
410    }
411}
412