1/*
2 * Javassist, a Java-bytecode translator toolkit.
3 * Copyright (C) 1999-2007 Shigeru Chiba. All Rights Reserved.
4 *
5 * The contents of this file are subject to the Mozilla Public License Version
6 * 1.1 (the "License"); you may not use this file except in compliance with
7 * the License.  Alternatively, the contents of this file may be used under
8 * the terms of the GNU Lesser General Public License Version 2.1 or later.
9 *
10 * Software distributed under the License is distributed on an "AS IS" basis,
11 * WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License
12 * for the specific language governing rights and limitations under the
13 * License.
14 */
15
16package javassist.compiler;
17
18import javassist.*;
19import javassist.bytecode.*;
20import javassist.compiler.ast.*;
21
22/* Code generator accepting extended Java syntax for Javassist.
23 */
24
25public class JvstCodeGen extends MemberCodeGen {
26    String paramArrayName = null;
27    String paramListName = null;
28    CtClass[] paramTypeList = null;
29    private int paramVarBase = 0;       // variable index for $0 or $1.
30    private boolean useParam0 = false;  // true if $0 is used.
31    private String param0Type = null;   // JVM name
32    public static final String sigName = "$sig";
33    public static final String dollarTypeName = "$type";
34    public static final String clazzName = "$class";
35    private CtClass dollarType = null;
36    CtClass returnType = null;
37    String returnCastName = null;
38    private String returnVarName = null;        // null if $_ is not used.
39    public static final String wrapperCastName = "$w";
40    String proceedName = null;
41    public static final String cflowName = "$cflow";
42    ProceedHandler procHandler = null;  // null if not used.
43
44    public JvstCodeGen(Bytecode b, CtClass cc, ClassPool cp) {
45        super(b, cc, cp);
46        setTypeChecker(new JvstTypeChecker(cc, cp, this));
47    }
48
49    /* Index of $1.
50     */
51    private int indexOfParam1() {
52        return paramVarBase + (useParam0 ? 1 : 0);
53    }
54
55    /* Records a ProceedHandler obejct.
56     *
57     * @param name      the name of the special method call.
58     *                  it is usually $proceed.
59     */
60    public void setProceedHandler(ProceedHandler h, String name) {
61        proceedName = name;
62        procHandler = h;
63    }
64
65    /* If the type of the expression compiled last is void,
66     * add ACONST_NULL and change exprType, arrayDim, className.
67     */
68    public void addNullIfVoid() {
69        if (exprType == VOID) {
70            bytecode.addOpcode(ACONST_NULL);
71            exprType = CLASS;
72            arrayDim = 0;
73            className = jvmJavaLangObject;
74        }
75    }
76
77    /* To support $args, $sig, and $type.
78     * $args is an array of parameter list.
79     */
80    public void atMember(Member mem) throws CompileError {
81        String name = mem.get();
82        if (name.equals(paramArrayName)) {
83            compileParameterList(bytecode, paramTypeList, indexOfParam1());
84            exprType = CLASS;
85            arrayDim = 1;
86            className = jvmJavaLangObject;
87        }
88        else if (name.equals(sigName)) {
89            bytecode.addLdc(Descriptor.ofMethod(returnType, paramTypeList));
90            bytecode.addInvokestatic("javassist/runtime/Desc", "getParams",
91                                "(Ljava/lang/String;)[Ljava/lang/Class;");
92            exprType = CLASS;
93            arrayDim = 1;
94            className = "java/lang/Class";
95        }
96        else if (name.equals(dollarTypeName)) {
97            if (dollarType == null)
98                throw new CompileError(dollarTypeName + " is not available");
99
100            bytecode.addLdc(Descriptor.of(dollarType));
101            callGetType("getType");
102        }
103        else if (name.equals(clazzName)) {
104            if (param0Type == null)
105                throw new CompileError(clazzName + " is not available");
106
107            bytecode.addLdc(param0Type);
108            callGetType("getClazz");
109        }
110        else
111            super.atMember(mem);
112    }
113
114    private void callGetType(String method) {
115        bytecode.addInvokestatic("javassist/runtime/Desc", method,
116                                "(Ljava/lang/String;)Ljava/lang/Class;");
117        exprType = CLASS;
118        arrayDim = 0;
119        className = "java/lang/Class";
120    }
121
122    protected void atFieldAssign(Expr expr, int op, ASTree left,
123                        ASTree right, boolean doDup) throws CompileError
124    {
125        if (left instanceof Member
126            && ((Member)left).get().equals(paramArrayName)) {
127            if (op != '=')
128                throw new CompileError("bad operator for " + paramArrayName);
129
130            right.accept(this);
131            if (arrayDim != 1 || exprType != CLASS)
132                throw new CompileError("invalid type for " + paramArrayName);
133
134            atAssignParamList(paramTypeList, bytecode);
135            if (!doDup)
136                bytecode.addOpcode(POP);
137        }
138        else
139            super.atFieldAssign(expr, op, left, right, doDup);
140    }
141
142    protected void atAssignParamList(CtClass[] params, Bytecode code)
143        throws CompileError
144    {
145        if (params == null)
146            return;
147
148        int varNo = indexOfParam1();
149        int n = params.length;
150        for (int i = 0; i < n; ++i) {
151            code.addOpcode(DUP);
152            code.addIconst(i);
153            code.addOpcode(AALOAD);
154            compileUnwrapValue(params[i], code);
155            code.addStore(varNo, params[i]);
156            varNo += is2word(exprType, arrayDim) ? 2 : 1;
157        }
158    }
159
160    public void atCastExpr(CastExpr expr) throws CompileError {
161        ASTList classname = expr.getClassName();
162        if (classname != null && expr.getArrayDim() == 0) {
163            ASTree p = classname.head();
164            if (p instanceof Symbol && classname.tail() == null) {
165                String typename = ((Symbol)p).get();
166                if (typename.equals(returnCastName)) {
167                    atCastToRtype(expr);
168                    return;
169                }
170                else if (typename.equals(wrapperCastName)) {
171                    atCastToWrapper(expr);
172                    return;
173                }
174            }
175        }
176
177        super.atCastExpr(expr);
178    }
179
180    /**
181     * Inserts a cast operator to the return type.
182     * If the return type is void, this does nothing.
183     */
184    protected void atCastToRtype(CastExpr expr) throws CompileError {
185        expr.getOprand().accept(this);
186        if (exprType == VOID || isRefType(exprType) || arrayDim > 0)
187            compileUnwrapValue(returnType, bytecode);
188        else if (returnType instanceof CtPrimitiveType) {
189            CtPrimitiveType pt = (CtPrimitiveType)returnType;
190            int destType = MemberResolver.descToType(pt.getDescriptor());
191            atNumCastExpr(exprType, destType);
192            exprType = destType;
193            arrayDim = 0;
194            className = null;
195        }
196        else
197            throw new CompileError("invalid cast");
198    }
199
200    protected void atCastToWrapper(CastExpr expr) throws CompileError {
201        expr.getOprand().accept(this);
202        if (isRefType(exprType) || arrayDim > 0)
203            return;     // Object type.  do nothing.
204
205        CtClass clazz = resolver.lookupClass(exprType, arrayDim, className);
206        if (clazz instanceof CtPrimitiveType) {
207            CtPrimitiveType pt = (CtPrimitiveType)clazz;
208            String wrapper = pt.getWrapperName();
209            bytecode.addNew(wrapper);           // new <wrapper>
210            bytecode.addOpcode(DUP);            // dup
211            if (pt.getDataSize() > 1)
212                bytecode.addOpcode(DUP2_X2);    // dup2_x2
213            else
214                bytecode.addOpcode(DUP2_X1);    // dup2_x1
215
216            bytecode.addOpcode(POP2);           // pop2
217            bytecode.addInvokespecial(wrapper, "<init>",
218                                      "(" + pt.getDescriptor() + ")V");
219                                                // invokespecial
220            exprType = CLASS;
221            arrayDim = 0;
222            className = jvmJavaLangObject;
223        }
224    }
225
226    /* Delegates to a ProcHandler object if the method call is
227     * $proceed().  It may process $cflow().
228     */
229    public void atCallExpr(CallExpr expr) throws CompileError {
230        ASTree method = expr.oprand1();
231        if (method instanceof Member) {
232            String name = ((Member)method).get();
233            if (procHandler != null && name.equals(proceedName)) {
234                procHandler.doit(this, bytecode, (ASTList)expr.oprand2());
235                return;
236            }
237            else if (name.equals(cflowName)) {
238                atCflow((ASTList)expr.oprand2());
239                return;
240            }
241        }
242
243        super.atCallExpr(expr);
244    }
245
246    /* To support $cflow().
247     */
248    protected void atCflow(ASTList cname) throws CompileError {
249        StringBuffer sbuf = new StringBuffer();
250        if (cname == null || cname.tail() != null)
251            throw new CompileError("bad " + cflowName);
252
253        makeCflowName(sbuf, cname.head());
254        String name = sbuf.toString();
255        Object[] names = resolver.getClassPool().lookupCflow(name);
256        if (names == null)
257            throw new CompileError("no such " + cflowName + ": " + name);
258
259        bytecode.addGetstatic((String)names[0], (String)names[1],
260                              "Ljavassist/runtime/Cflow;");
261        bytecode.addInvokevirtual("javassist.runtime.Cflow",
262                                  "value", "()I");
263        exprType = INT;
264        arrayDim = 0;
265        className = null;
266    }
267
268    /* Syntax:
269     *
270     * <cflow> : $cflow '(' <cflow name> ')'
271     * <cflow name> : <identifier> ('.' <identifier>)*
272     */
273    private static void makeCflowName(StringBuffer sbuf, ASTree name)
274        throws CompileError
275    {
276        if (name instanceof Symbol) {
277            sbuf.append(((Symbol)name).get());
278            return;
279        }
280        else if (name instanceof Expr) {
281            Expr expr = (Expr)name;
282            if (expr.getOperator() == '.') {
283                makeCflowName(sbuf, expr.oprand1());
284                sbuf.append('.');
285                makeCflowName(sbuf, expr.oprand2());
286                return;
287            }
288        }
289
290        throw new CompileError("bad " + cflowName);
291    }
292
293    /* To support $$.  ($$) is equivalent to ($1, ..., $n).
294     * It can be used only as a parameter list of method call.
295     */
296    public boolean isParamListName(ASTList args) {
297        if (paramTypeList != null
298            && args != null && args.tail() == null) {
299            ASTree left = args.head();
300            return (left instanceof Member
301                    && ((Member)left).get().equals(paramListName));
302        }
303        else
304            return false;
305    }
306
307    /*
308    public int getMethodArgsLength(ASTList args) {
309        if (!isParamListName(args))
310            return super.getMethodArgsLength(args);
311
312        return paramTypeList.length;
313    }
314    */
315
316    public int getMethodArgsLength(ASTList args) {
317        String pname = paramListName;
318        int n = 0;
319        while (args != null) {
320            ASTree a = args.head();
321            if (a instanceof Member && ((Member)a).get().equals(pname)) {
322                if (paramTypeList != null)
323                    n += paramTypeList.length;
324            }
325            else
326                ++n;
327
328            args = args.tail();
329        }
330
331        return n;
332    }
333
334    public void atMethodArgs(ASTList args, int[] types, int[] dims,
335                                String[] cnames) throws CompileError {
336        CtClass[] params = paramTypeList;
337        String pname = paramListName;
338        int i = 0;
339        while (args != null) {
340            ASTree a = args.head();
341            if (a instanceof Member && ((Member)a).get().equals(pname)) {
342                if (params != null) {
343                    int n = params.length;
344                    int regno = indexOfParam1();
345                    for (int k = 0; k < n; ++k) {
346                        CtClass p = params[k];
347                        regno += bytecode.addLoad(regno, p);
348                        setType(p);
349                        types[i] = exprType;
350                        dims[i] = arrayDim;
351                        cnames[i] = className;
352                        ++i;
353                    }
354                }
355            }
356            else {
357                a.accept(this);
358                types[i] = exprType;
359                dims[i] = arrayDim;
360                cnames[i] = className;
361                ++i;
362            }
363
364            args = args.tail();
365        }
366    }
367
368    /*
369    public void atMethodArgs(ASTList args, int[] types, int[] dims,
370                                String[] cnames) throws CompileError {
371        if (!isParamListName(args)) {
372            super.atMethodArgs(args, types, dims, cnames);
373            return;
374        }
375
376        CtClass[] params = paramTypeList;
377        if (params == null)
378            return;
379
380        int n = params.length;
381        int regno = indexOfParam1();
382        for (int i = 0; i < n; ++i) {
383            CtClass p = params[i];
384            regno += bytecode.addLoad(regno, p);
385            setType(p);
386            types[i] = exprType;
387            dims[i] = arrayDim;
388            cnames[i] = className;
389        }
390    }
391    */
392
393    /* called by Javac#recordSpecialProceed().
394     */
395    void compileInvokeSpecial(ASTree target, String classname,
396                              String methodname, String descriptor,
397                              ASTList args)
398        throws CompileError
399    {
400        target.accept(this);
401        int nargs = getMethodArgsLength(args);
402        atMethodArgs(args, new int[nargs], new int[nargs],
403                     new String[nargs]);
404        bytecode.addInvokespecial(classname, methodname, descriptor);
405        setReturnType(descriptor, false, false);
406        addNullIfVoid();
407    }
408
409    /*
410     * Makes it valid to write "return <expr>;" for a void method.
411     */
412    protected void atReturnStmnt(Stmnt st) throws CompileError {
413        ASTree result = st.getLeft();
414        if (result != null && returnType == CtClass.voidType) {
415            compileExpr(result);
416            if (is2word(exprType, arrayDim))
417                bytecode.addOpcode(POP2);
418            else if (exprType != VOID)
419                bytecode.addOpcode(POP);
420
421            result = null;
422        }
423
424        atReturnStmnt2(result);
425    }
426
427    /**
428     * Makes a cast to the return type ($r) available.
429     * It also enables $_.
430     *
431     * <p>If the return type is void, ($r) does nothing.
432     * The type of $_ is java.lang.Object.
433     *
434     * @param resultName        null if $_ is not used.
435     * @return          -1 or the variable index assigned to $_.
436     */
437    public int recordReturnType(CtClass type, String castName,
438                 String resultName, SymbolTable tbl) throws CompileError
439    {
440        returnType = type;
441        returnCastName = castName;
442        returnVarName = resultName;
443        if (resultName == null)
444            return -1;
445        else {
446            int varNo = getMaxLocals();
447            int locals = varNo + recordVar(type, resultName, varNo, tbl);
448            setMaxLocals(locals);
449            return varNo;
450        }
451    }
452
453    /**
454     * Makes $type available.
455     */
456    public void recordType(CtClass t) {
457        dollarType = t;
458    }
459
460    /**
461     * Makes method parameters $0, $1, ..., $args, $$, and $class available.
462     * $0 is equivalent to THIS if the method is not static.  Otherwise,
463     * if the method is static, then $0 is not available.
464     */
465    public int recordParams(CtClass[] params, boolean isStatic,
466                             String prefix, String paramVarName,
467                             String paramsName, SymbolTable tbl)
468        throws CompileError
469    {
470        return recordParams(params, isStatic, prefix, paramVarName,
471                            paramsName, !isStatic, 0, getThisName(), tbl);
472    }
473
474    /**
475     * Makes method parameters $0, $1, ..., $args, $$, and $class available.
476     * $0 is available only if use0 is true.  It might not be equivalent
477     * to THIS.
478     *
479     * @param params    the parameter types (the types of $1, $2, ..)
480     * @param prefix    it must be "$" (the first letter of $0, $1, ...)
481     * @param paramVarName      it must be "$args"
482     * @param paramsName        it must be "$$"
483     * @param use0      true if $0 is used.
484     * @param paramBase the register number of $0 (use0 is true)
485     *                          or $1 (otherwise).
486     * @param target    the class of $0.  If use0 is false, target
487     *                  can be null.  The value of "target" is also used
488     *                  as the name of the type represented by $class.
489     * @param isStatic  true if the method in which the compiled bytecode
490     *                  is embedded is static.
491     */
492    public int recordParams(CtClass[] params, boolean isStatic,
493                            String prefix, String paramVarName,
494                            String paramsName, boolean use0,
495                            int paramBase, String target,
496                            SymbolTable tbl)
497        throws CompileError
498    {
499        int varNo;
500
501        paramTypeList = params;
502        paramArrayName = paramVarName;
503        paramListName = paramsName;
504        paramVarBase = paramBase;
505        useParam0 = use0;
506
507        if (target != null)
508            param0Type = MemberResolver.jvmToJavaName(target);
509
510        inStaticMethod = isStatic;
511        varNo = paramBase;
512        if (use0) {
513            String varName = prefix + "0";
514            Declarator decl
515                = new Declarator(CLASS, MemberResolver.javaToJvmName(target),
516                                 0, varNo++, new Symbol(varName));
517            tbl.append(varName, decl);
518        }
519
520        for (int i = 0; i < params.length; ++i)
521            varNo += recordVar(params[i], prefix + (i + 1), varNo, tbl);
522
523        if (getMaxLocals() < varNo)
524            setMaxLocals(varNo);
525
526        return varNo;
527    }
528
529    /**
530     * Makes the given variable name available.
531     *
532     * @param type      variable type
533     * @param varName   variable name
534     */
535    public int recordVariable(CtClass type, String varName, SymbolTable tbl)
536        throws CompileError
537    {
538        if (varName == null)
539            return -1;
540        else {
541            int varNo = getMaxLocals();
542            int locals = varNo + recordVar(type, varName, varNo, tbl);
543            setMaxLocals(locals);
544            return varNo;
545        }
546    }
547
548    private int recordVar(CtClass cc, String varName, int varNo,
549                          SymbolTable tbl) throws CompileError
550    {
551        if (cc == CtClass.voidType) {
552            exprType = CLASS;
553            arrayDim = 0;
554            className = jvmJavaLangObject;
555        }
556        else
557            setType(cc);
558
559        Declarator decl
560            = new Declarator(exprType, className, arrayDim,
561                             varNo, new Symbol(varName));
562        tbl.append(varName, decl);
563        return is2word(exprType, arrayDim) ? 2 : 1;
564    }
565
566    /**
567     * Makes the given variable name available.
568     *
569     * @param typeDesc  the type descriptor of the variable
570     * @param varName   variable name
571     * @param varNo     an index into the local variable array
572     */
573    public void recordVariable(String typeDesc, String varName, int varNo,
574                               SymbolTable tbl) throws CompileError
575    {
576        char c;
577        int dim = 0;
578        while ((c = typeDesc.charAt(dim)) == '[')
579            ++dim;
580
581        int type = MemberResolver.descToType(c);
582        String cname = null;
583        if (type == CLASS) {
584            if (dim == 0)
585                cname = typeDesc.substring(1, typeDesc.length() - 1);
586            else
587                cname = typeDesc.substring(dim + 1, typeDesc.length() - 1);
588        }
589
590        Declarator decl
591            = new Declarator(type, cname, dim, varNo, new Symbol(varName));
592        tbl.append(varName, decl);
593    }
594
595    /* compileParameterList() returns the stack size used
596     * by the produced code.
597     *
598     * This method correctly computes the max_stack value.
599     *
600     * @param regno     the index of the local variable in which
601     *                  the first argument is received.
602     *                  (0: static method, 1: regular method.)
603     */
604    public static int compileParameterList(Bytecode code,
605                                CtClass[] params, int regno) {
606        if (params == null) {
607            code.addIconst(0);                          // iconst_0
608            code.addAnewarray(javaLangObject);          // anewarray Object
609            return 1;
610        }
611        else {
612            CtClass[] args = new CtClass[1];
613            int n = params.length;
614            code.addIconst(n);                          // iconst_<n>
615            code.addAnewarray(javaLangObject);          // anewarray Object
616            for (int i = 0; i < n; ++i) {
617                code.addOpcode(Bytecode.DUP);           // dup
618                code.addIconst(i);                      // iconst_<i>
619                if (params[i].isPrimitive()) {
620                    CtPrimitiveType pt = (CtPrimitiveType)params[i];
621                    String wrapper = pt.getWrapperName();
622                    code.addNew(wrapper);               // new <wrapper>
623                    code.addOpcode(Bytecode.DUP);       // dup
624                    int s = code.addLoad(regno, pt);    // ?load <regno>
625                    regno += s;
626                    args[0] = pt;
627                    code.addInvokespecial(wrapper, "<init>",
628                                Descriptor.ofMethod(CtClass.voidType, args));
629                                                        // invokespecial
630                }
631                else {
632                    code.addAload(regno);               // aload <regno>
633                    ++regno;
634                }
635
636                code.addOpcode(Bytecode.AASTORE);       // aastore
637            }
638
639            return 8;
640        }
641    }
642
643    protected void compileUnwrapValue(CtClass type, Bytecode code)
644        throws CompileError
645    {
646        if (type == CtClass.voidType) {
647            addNullIfVoid();
648            return;
649        }
650
651        if (exprType == VOID)
652            throw new CompileError("invalid type for " + returnCastName);
653
654        if (type instanceof CtPrimitiveType) {
655            CtPrimitiveType pt = (CtPrimitiveType)type;
656            // pt is not voidType.
657            String wrapper = pt.getWrapperName();
658            code.addCheckcast(wrapper);
659            code.addInvokevirtual(wrapper, pt.getGetMethodName(),
660                                  pt.getGetMethodDescriptor());
661            setType(type);
662        }
663        else {
664            code.addCheckcast(type);
665            setType(type);
666        }
667    }
668
669    /* Sets exprType, arrayDim, and className;
670     * If type is void, then this method does nothing.
671     */
672    public void setType(CtClass type) throws CompileError {
673        setType(type, 0);
674    }
675
676    private void setType(CtClass type, int dim) throws CompileError {
677        if (type.isPrimitive()) {
678            CtPrimitiveType pt = (CtPrimitiveType)type;
679            exprType = MemberResolver.descToType(pt.getDescriptor());
680            arrayDim = dim;
681            className = null;
682        }
683        else if (type.isArray())
684            try {
685                setType(type.getComponentType(), dim + 1);
686            }
687            catch (NotFoundException e) {
688                throw new CompileError("undefined type: " + type.getName());
689            }
690        else {
691            exprType = CLASS;
692            arrayDim = dim;
693            className = MemberResolver.javaToJvmName(type.getName());
694        }
695    }
696
697    /* Performs implicit coercion from exprType to type.
698     */
699    public void doNumCast(CtClass type) throws CompileError {
700        if (arrayDim == 0 && !isRefType(exprType))
701            if (type instanceof CtPrimitiveType) {
702                CtPrimitiveType pt = (CtPrimitiveType)type;
703                atNumCastExpr(exprType,
704                              MemberResolver.descToType(pt.getDescriptor()));
705            }
706            else
707                throw new CompileError("type mismatch");
708    }
709}
710