AndroidTranslator.java revision 4ac725f9b4cebbf46805fc5e9b2f0eaf3fdd9b29
1package com.xtremelabs.robolectric.bytecode;
2
3import android.net.Uri;
4import com.xtremelabs.robolectric.internal.DoNotStrip;
5import javassist.*;
6
7import java.io.IOException;
8import java.util.ArrayList;
9import java.util.List;
10
11@SuppressWarnings({"UnusedDeclaration"})
12public class AndroidTranslator implements Translator {
13    /**
14     * IMPORTANT -- increment this number when the bytecode generated for modified classes changes
15     * so the cache file can be invalidated.
16     */
17    public static final int CACHE_VERSION = 19;
18
19    private static final List<ClassHandler> CLASS_HANDLERS = new ArrayList<ClassHandler>();
20
21    private ClassHandler classHandler;
22    private ClassCache classCache;
23
24    public AndroidTranslator(ClassHandler classHandler, ClassCache classCache) {
25        this.classHandler = classHandler;
26        this.classCache = classCache;
27    }
28
29    public static ClassHandler getClassHandler(int index) {
30        return CLASS_HANDLERS.get(index);
31    }
32
33    @Override
34    public void start(ClassPool classPool) throws NotFoundException, CannotCompileException {
35        injectClassHandlerToInstrumentedClasses(classPool);
36    }
37
38    private void injectClassHandlerToInstrumentedClasses(ClassPool classPool) throws NotFoundException, CannotCompileException {
39        int index;
40        synchronized (CLASS_HANDLERS) {
41            CLASS_HANDLERS.add(classHandler);
42            index = CLASS_HANDLERS.size() - 1;
43        }
44
45        CtClass robolectricInternalsCtClass = classPool.get(RobolectricInternals.class.getName());
46        robolectricInternalsCtClass.setModifiers(Modifier.PUBLIC);
47
48        robolectricInternalsCtClass.getClassInitializer().insertBefore("{\n" +
49                "classHandler = " + AndroidTranslator.class.getName() + ".getClassHandler(" + index + ");\n" +
50                "}");
51    }
52
53    @Override
54    public void onLoad(ClassPool classPool, String className) throws NotFoundException, CannotCompileException {
55        if (classCache.isWriting()) {
56            throw new IllegalStateException("shouldn't be modifying bytecode after we've started writing cache! class=" + className);
57        }
58
59        if (classHasFromAndroidEquivalent(className)) {
60            replaceClassWithFromAndroidEquivalent(classPool, className);
61            return;
62        }
63
64        boolean needsStripping =
65                className.startsWith("android.")
66                        || className.startsWith("com.google.android.maps")
67                        || className.equals("org.apache.http.impl.client.DefaultRequestDirector");
68
69        CtClass ctClass = classPool.get(className);
70        if (needsStripping && !ctClass.hasAnnotation(DoNotStrip.class)) {
71            int modifiers = ctClass.getModifiers();
72            if (Modifier.isFinal(modifiers)) {
73                ctClass.setModifiers(modifiers & ~Modifier.FINAL);
74            }
75
76            if (ctClass.isInterface()) return;
77
78            classHandler.instrument(ctClass);
79
80            fixConstructors(ctClass);
81            fixMethods(ctClass);
82
83            try {
84                classCache.addClass(className, ctClass.toBytecode());
85            } catch (IOException e) {
86                throw new RuntimeException(e);
87            }
88        }
89    }
90
91    private boolean classHasFromAndroidEquivalent(String className) {
92        return className.startsWith(Uri.class.getName());
93    }
94
95    private void replaceClassWithFromAndroidEquivalent(ClassPool classPool, String className) throws NotFoundException {
96        FromAndroidClassNameParts classNameParts = new FromAndroidClassNameParts(className);
97        if (classNameParts.isFromAndroid()) return;
98
99        String from = classNameParts.getNameWithFromAndroid();
100        CtClass ctClass = classPool.getAndRename(from, className);
101
102        ClassMap map = new ClassMap() {
103            @Override
104            public Object get(Object jvmClassName) {
105                FromAndroidClassNameParts classNameParts = new FromAndroidClassNameParts(jvmClassName.toString());
106                if (classNameParts.isFromAndroid()) {
107                    return classNameParts.getNameWithoutFromAndroid();
108                } else {
109                    return jvmClassName;
110                }
111            }
112        };
113        ctClass.replaceClassName(map);
114    }
115
116    class FromAndroidClassNameParts {
117        private static final String TOKEN = "__FromAndroid";
118
119        private String prefix;
120        private String suffix;
121
122        FromAndroidClassNameParts(String name) {
123            int dollarIndex = name.indexOf("$");
124            prefix = name;
125            suffix = "";
126            if (dollarIndex > -1) {
127                prefix = name.substring(0, dollarIndex);
128                suffix = name.substring(dollarIndex);
129            }
130        }
131
132        public boolean isFromAndroid() {
133            return prefix.endsWith(TOKEN);
134        }
135
136        public String getNameWithFromAndroid() {
137            return prefix + TOKEN + suffix;
138        }
139
140        public String getNameWithoutFromAndroid() {
141            return prefix.replace(TOKEN, "") + suffix;
142        }
143    }
144
145    private void addBypassShadowField(CtClass ctClass, String fieldName) {
146        try {
147            try {
148                ctClass.getField(fieldName);
149            } catch (NotFoundException e) {
150                CtField field = new CtField(CtClass.booleanType, fieldName, ctClass);
151                field.setModifiers(java.lang.reflect.Modifier.PUBLIC | java.lang.reflect.Modifier.STATIC);
152                ctClass.addField(field);
153            }
154        } catch (CannotCompileException e) {
155            throw new RuntimeException(e);
156        }
157    }
158
159    private void fixConstructors(CtClass ctClass) throws CannotCompileException, NotFoundException {
160        boolean hasDefault = false;
161
162        for (CtConstructor ctConstructor : ctClass.getConstructors()) {
163            try {
164                fixConstructor(ctClass, hasDefault, ctConstructor);
165
166                if (ctConstructor.getParameterTypes().length == 0) {
167                    hasDefault = true;
168                }
169            } catch (Exception e) {
170                throw new RuntimeException("problem instrumenting " + ctConstructor, e);
171            }
172        }
173
174        if (!hasDefault) {
175            String methodBody = generateConstructorBody(ctClass, new CtClass[0]);
176            ctClass.addConstructor(CtNewConstructor.make(new CtClass[0], new CtClass[0], "{\n" + methodBody + "}\n", ctClass));
177        }
178    }
179
180    private boolean fixConstructor(CtClass ctClass, boolean needsDefault, CtConstructor ctConstructor) throws NotFoundException, CannotCompileException {
181        String methodBody = generateConstructorBody(ctClass, ctConstructor.getParameterTypes());
182        ctConstructor.setBody("{\n" + methodBody + "}\n");
183        return needsDefault;
184    }
185
186    private String generateConstructorBody(CtClass ctClass, CtClass[] parameterTypes) throws NotFoundException {
187        return generateMethodBody(ctClass,
188                new CtMethod(CtClass.voidType, "<init>", parameterTypes, ctClass),
189                CtClass.voidType,
190                Type.VOID,
191                false,
192                false);
193    }
194
195    private void fixMethods(CtClass ctClass) throws NotFoundException, CannotCompileException {
196        for (CtMethod ctMethod : ctClass.getDeclaredMethods()) {
197            fixMethod(ctClass, ctMethod, true);
198        }
199        CtMethod equalsMethod = ctClass.getMethod("equals", "(Ljava/lang/Object;)Z");
200        CtMethod hashCodeMethod = ctClass.getMethod("hashCode", "()I");
201        CtMethod toStringMethod = ctClass.getMethod("toString", "()Ljava/lang/String;");
202
203        fixMethod(ctClass, equalsMethod, false);
204        fixMethod(ctClass, hashCodeMethod, false);
205        fixMethod(ctClass, toStringMethod, false);
206    }
207
208    private String describe(CtMethod ctMethod) throws NotFoundException {
209        return Modifier.toString(ctMethod.getModifiers()) + " " + ctMethod.getReturnType().getSimpleName() + " " + ctMethod.getLongName();
210    }
211
212    private void fixMethod(CtClass ctClass, CtMethod ctMethod, boolean wasFoundInClass) throws NotFoundException {
213        String describeBefore = describe(ctMethod);
214        try {
215            CtClass declaringClass = ctMethod.getDeclaringClass();
216            int originalModifiers = ctMethod.getModifiers();
217
218            boolean wasNative = Modifier.isNative(originalModifiers);
219            boolean wasFinal = Modifier.isFinal(originalModifiers);
220            boolean wasAbstract = Modifier.isAbstract(originalModifiers);
221            boolean wasDeclaredInClass = ctClass == declaringClass;
222
223            if (wasFinal && ctClass.isEnum()) {
224                return;
225            }
226
227            int newModifiers = originalModifiers;
228            if (wasNative) {
229                newModifiers = Modifier.clear(newModifiers, Modifier.NATIVE);
230            }
231            if (wasFinal) {
232                newModifiers = Modifier.clear(newModifiers, Modifier.FINAL);
233            }
234            if (wasFoundInClass) {
235                ctMethod.setModifiers(newModifiers);
236            }
237
238            CtClass returnCtClass = ctMethod.getReturnType();
239            Type returnType = Type.find(returnCtClass);
240
241            String methodName = ctMethod.getName();
242            CtClass[] paramTypes = ctMethod.getParameterTypes();
243
244//            if (!isAbstract) {
245//                if (methodName.startsWith("set") && paramTypes.length == 1) {
246//                    String fieldName = "__" + methodName.substring(3);
247//                    if (declareField(ctClass, fieldName, paramTypes[0])) {
248//                        methodBody = fieldName + " = $1;\n" + methodBody;
249//                    }
250//                } else if (methodName.startsWith("get") && paramTypes.length == 0) {
251//                    String fieldName = "__" + methodName.substring(3);
252//                    if (declareField(ctClass, fieldName, returnType)) {
253//                        methodBody = "return " + fieldName + ";\n";
254//                    }
255//                }
256//            }
257
258            boolean isStatic = Modifier.isStatic(originalModifiers);
259            String methodBody = generateMethodBody(ctClass, ctMethod, wasNative, wasAbstract, returnCtClass, returnType, isStatic, !wasFoundInClass);
260
261            if (!wasFoundInClass) {
262                CtMethod newMethod = makeNewMethod(ctClass, ctMethod, returnCtClass, methodName, paramTypes, "{\n" + methodBody + generateCallToSuper(methodName, paramTypes) + "\n}");
263                newMethod.setModifiers(newModifiers);
264                if (wasDeclaredInClass) {
265                    ctMethod.insertBefore("{\n" + methodBody + "}\n");
266                } else {
267                    ctClass.addMethod(newMethod);
268                }
269            } else if (wasAbstract || wasNative) {
270                CtMethod newMethod = makeNewMethod(ctClass, ctMethod, returnCtClass, methodName, paramTypes, "{\n" + methodBody + "\n}");
271                ctMethod.setBody(newMethod, null);
272            } else {
273                ctMethod.insertBefore("{\n" + methodBody + "}\n");
274            }
275        } catch (Exception e) {
276            throw new RuntimeException("problem instrumenting " + describeBefore, e);
277        }
278    }
279
280    private CtMethod makeNewMethod(CtClass ctClass, CtMethod ctMethod, CtClass returnCtClass, String methodName, CtClass[] paramTypes, String methodBody) throws CannotCompileException, NotFoundException {
281        return CtNewMethod.make(
282                ctMethod.getModifiers(),
283                returnCtClass,
284                methodName,
285                paramTypes,
286                ctMethod.getExceptionTypes(),
287                methodBody,
288                ctClass);
289    }
290
291    public String generateCallToSuper(String methodName, CtClass[] paramTypes) {
292        return "return super." + methodName + "(" + makeParameterReplacementList(paramTypes.length) + ");";
293    }
294
295    public String makeParameterReplacementList(int length) {
296        if (length == 0) {
297            return "";
298        }
299
300        String parameterReplacementList = "$1";
301        for (int i = 2; i <= length; ++i) {
302            parameterReplacementList += ", $" + i;
303        }
304        return parameterReplacementList;
305    }
306
307    private String generateMethodBody(CtClass ctClass, CtMethod ctMethod, boolean wasNative, boolean wasAbstract, CtClass returnCtClass, Type returnType, boolean aStatic, boolean shouldGenerateCallToSuper) throws NotFoundException {
308        String methodBody;
309        if (wasAbstract) {
310            methodBody = returnType.isVoid() ? "" : "return " + returnType.defaultReturnString() + ";";
311        } else {
312            methodBody = generateMethodBody(ctClass, ctMethod, returnCtClass, returnType, aStatic, shouldGenerateCallToSuper);
313        }
314
315        if (wasNative) {
316            methodBody += returnType.isVoid() ? "" : "return " + returnType.defaultReturnString() + ";";
317        }
318        return methodBody;
319    }
320
321    public String generateMethodBody(CtClass ctClass, CtMethod ctMethod, CtClass returnCtClass, Type returnType, boolean isStatic, boolean shouldGenerateCallToSuper) throws NotFoundException {
322        boolean returnsVoid = returnType.isVoid();
323        String className = ctClass.getName();
324
325        String methodBody;
326        StringBuilder buf = new StringBuilder();
327        buf.append("if (!");
328        buf.append(RobolectricInternals.class.getName());
329        buf.append(".shouldCallDirectly(");
330        buf.append(isStatic ? className + ".class" : "this");
331        buf.append(")) {\n");
332
333        if (!returnsVoid) {
334            buf.append("Object x = ");
335        }
336        buf.append(RobolectricInternals.class.getName());
337        buf.append(".methodInvoked(\n  ");
338        buf.append(className);
339        buf.append(".class, \"");
340        buf.append(ctMethod.getName());
341        buf.append("\", ");
342        if (!isStatic) {
343            buf.append("this");
344        } else {
345            buf.append("null");
346        }
347        buf.append(", ");
348
349        appendParamTypeArray(buf, ctMethod);
350        buf.append(", ");
351        appendParamArray(buf, ctMethod);
352
353        buf.append(")");
354        buf.append(";\n");
355
356        if (!returnsVoid) {
357            buf.append("if (x != null) return ((");
358            buf.append(returnType.nonPrimitiveClassName(returnCtClass));
359            buf.append(") x)");
360            buf.append(returnType.unboxString());
361            buf.append(";\n");
362            if (shouldGenerateCallToSuper) {
363                buf.append(generateCallToSuper(ctMethod.getName(), ctMethod.getParameterTypes()));
364            } else {
365                buf.append("return ");
366                buf.append(returnType.defaultReturnString());
367                buf.append(";\n");
368            }
369        } else {
370            buf.append("return;\n");
371        }
372
373        buf.append("}\n");
374
375        methodBody = buf.toString();
376        return methodBody;
377    }
378
379    private void appendParamTypeArray(StringBuilder buf, CtMethod ctMethod) throws NotFoundException {
380        CtClass[] parameterTypes = ctMethod.getParameterTypes();
381        if (parameterTypes.length == 0) {
382            buf.append("new String[0]");
383        } else {
384            buf.append("new String[] {");
385            for (int i = 0; i < parameterTypes.length; i++) {
386                if (i > 0) buf.append(", ");
387                buf.append("\"");
388                CtClass parameterType = parameterTypes[i];
389                buf.append(parameterType.getName());
390                buf.append("\"");
391            }
392            buf.append("}");
393        }
394    }
395
396    private void appendParamArray(StringBuilder buf, CtMethod ctMethod) throws NotFoundException {
397        int parameterCount = ctMethod.getParameterTypes().length;
398        if (parameterCount == 0) {
399            buf.append("new Object[0]");
400        } else {
401            buf.append("new Object[] {");
402            for (int i = 0; i < parameterCount; i++) {
403                if (i > 0) buf.append(", ");
404                buf.append(RobolectricInternals.class.getName());
405                buf.append(".autobox(");
406                buf.append("$").append(i + 1);
407                buf.append(")");
408            }
409            buf.append("}");
410        }
411    }
412
413    private boolean declareField(CtClass ctClass, String fieldName, CtClass fieldType) throws CannotCompileException, NotFoundException {
414        CtMethod ctMethod = getMethod(ctClass, "get" + fieldName, "");
415        if (ctMethod == null) {
416            return false;
417        }
418        CtClass getterFieldType = ctMethod.getReturnType();
419
420        if (!getterFieldType.equals(fieldType)) {
421            return false;
422        }
423
424        if (getField(ctClass, fieldName) == null) {
425            CtField field = new CtField(fieldType, fieldName, ctClass);
426            field.setModifiers(Modifier.PRIVATE);
427            ctClass.addField(field);
428        }
429
430        return true;
431    }
432
433    private CtField getField(CtClass ctClass, String fieldName) {
434        try {
435            return ctClass.getField(fieldName);
436        } catch (NotFoundException e) {
437            return null;
438        }
439    }
440
441    private CtMethod getMethod(CtClass ctClass, String methodName, String desc) {
442        try {
443            return ctClass.getMethod(methodName, desc);
444        } catch (NotFoundException e) {
445            return null;
446        }
447    }
448}
449