1package com.xtremelabs.robolectric.bytecode;
2
3import android.net.Uri;
4import com.xtremelabs.robolectric.internal.DoNotInstrument;
5import com.xtremelabs.robolectric.internal.Instrument;
6import javassist.*;
7
8import java.io.IOException;
9import java.util.ArrayList;
10import java.util.List;
11
12@SuppressWarnings({"UnusedDeclaration"})
13public class AndroidTranslator implements Translator {
14    /**
15     * IMPORTANT -- increment this number when the bytecode generated for modified classes changes
16     * so the cache file can be invalidated.
17     */
18    public static final int CACHE_VERSION = 21;
19
20    private static final List<ClassHandler> CLASS_HANDLERS = new ArrayList<ClassHandler>();
21
22    private ClassHandler classHandler;
23    private ClassCache classCache;
24    private final List<String> instrumentingList = new ArrayList<String>();
25    private final List<String> instrumentingExcludeList = new ArrayList<String>();
26
27    public AndroidTranslator(ClassHandler classHandler, ClassCache classCache) {
28        this.classHandler = classHandler;
29        this.classCache = classCache;
30
31        // Initialize lists
32        instrumentingList.add("android.");
33        instrumentingList.add("com.google.android.maps");
34        instrumentingList.add("org.apache.http.impl.client.DefaultRequestDirector");
35
36        instrumentingExcludeList.add("android.support.v4.app.NotificationCompat");
37        instrumentingExcludeList.add("android.support.v4.content.LocalBroadcastManager");
38        instrumentingExcludeList.add("android.support.v4.util.LruCache");
39    }
40
41    public AndroidTranslator(ClassHandler classHandler, ClassCache classCache, List<String> customShadowClassNames) {
42        this(classHandler, classCache);
43        if (customShadowClassNames != null && !customShadowClassNames.isEmpty()) {
44            instrumentingList.addAll(customShadowClassNames);
45        }
46    }
47
48    public void addCustomShadowClass(String customShadowClassName) {
49        if (!instrumentingList.contains(customShadowClassName)) {
50            instrumentingList.add(customShadowClassName);
51        }
52    }
53
54    public static ClassHandler getClassHandler(int index) {
55        return CLASS_HANDLERS.get(index);
56    }
57
58    @Override
59    public void start(ClassPool classPool) throws NotFoundException, CannotCompileException {
60        injectClassHandlerToInstrumentedClasses(classPool);
61    }
62
63    private void injectClassHandlerToInstrumentedClasses(ClassPool classPool) throws NotFoundException, CannotCompileException {
64        int index;
65        synchronized (CLASS_HANDLERS) {
66            CLASS_HANDLERS.add(classHandler);
67            index = CLASS_HANDLERS.size() - 1;
68        }
69
70        CtClass robolectricInternalsCtClass = classPool.get(RobolectricInternals.class.getName());
71        robolectricInternalsCtClass.setModifiers(Modifier.PUBLIC);
72
73        robolectricInternalsCtClass.getClassInitializer().insertBefore("{\n" +
74                "classHandler = " + AndroidTranslator.class.getName() + ".getClassHandler(" + index + ");\n" +
75                "}");
76    }
77
78    @Override
79    public void onLoad(ClassPool classPool, String className) throws NotFoundException, CannotCompileException {
80        if (classCache.isWriting()) {
81            throw new IllegalStateException("shouldn't be modifying bytecode after we've started writing cache! class=" + className);
82        }
83
84        if (classHasFromAndroidEquivalent(className)) {
85            replaceClassWithFromAndroidEquivalent(classPool, className);
86            return;
87        }
88
89        CtClass ctClass;
90        try {
91            ctClass = classPool.get(className);
92        } catch (NotFoundException e) {
93            throw new IgnorableClassNotFoundException(e);
94        }
95
96        if (shouldInstrument(ctClass)) {
97            int modifiers = ctClass.getModifiers();
98            if (Modifier.isFinal(modifiers)) {
99                ctClass.setModifiers(modifiers & ~Modifier.FINAL);
100            }
101
102            classHandler.instrument(ctClass);
103
104            fixConstructors(ctClass);
105            fixMethods(ctClass);
106
107            try {
108                classCache.addClass(className, ctClass.toBytecode());
109            } catch (IOException e) {
110                throw new RuntimeException(e);
111            }
112        }
113    }
114
115    /* package */ boolean shouldInstrument(CtClass ctClass) {
116        if (ctClass.hasAnnotation(Instrument.class)) {
117            return true;
118        } else if (ctClass.isInterface() || ctClass.hasAnnotation(DoNotInstrument.class)) {
119            return false;
120        } else {
121            for (String klassName : instrumentingExcludeList) {
122                if (ctClass.getName().startsWith(klassName)) {
123                    return false;
124                }
125            }
126            for (String klassName : instrumentingList) {
127                if (ctClass.getName().startsWith(klassName)) {
128                    return true;
129                }
130            }
131            return false;
132        }
133    }
134
135    private boolean classHasFromAndroidEquivalent(String className) {
136        return className.startsWith(Uri.class.getName());
137    }
138
139    private void replaceClassWithFromAndroidEquivalent(ClassPool classPool, String className) throws NotFoundException {
140        FromAndroidClassNameParts classNameParts = new FromAndroidClassNameParts(className);
141        if (classNameParts.isFromAndroid()) return;
142
143        String from = classNameParts.getNameWithFromAndroid();
144        CtClass ctClass = classPool.getAndRename(from, className);
145
146        ClassMap map = new ClassMap() {
147            @Override
148            public Object get(Object jvmClassName) {
149                FromAndroidClassNameParts classNameParts = new FromAndroidClassNameParts(jvmClassName.toString());
150                if (classNameParts.isFromAndroid()) {
151                    return classNameParts.getNameWithoutFromAndroid();
152                } else {
153                    return jvmClassName;
154                }
155            }
156        };
157        ctClass.replaceClassName(map);
158    }
159
160    class FromAndroidClassNameParts {
161        private static final String TOKEN = "__FromAndroid";
162
163        private String prefix;
164        private String suffix;
165
166        FromAndroidClassNameParts(String name) {
167            int dollarIndex = name.indexOf("$");
168            prefix = name;
169            suffix = "";
170            if (dollarIndex > -1) {
171                prefix = name.substring(0, dollarIndex);
172                suffix = name.substring(dollarIndex);
173            }
174        }
175
176        public boolean isFromAndroid() {
177            return prefix.endsWith(TOKEN);
178        }
179
180        public String getNameWithFromAndroid() {
181            return prefix + TOKEN + suffix;
182        }
183
184        public String getNameWithoutFromAndroid() {
185            return prefix.replace(TOKEN, "") + suffix;
186        }
187    }
188
189    private void addBypassShadowField(CtClass ctClass, String fieldName) {
190        try {
191            try {
192                ctClass.getField(fieldName);
193            } catch (NotFoundException e) {
194                CtField field = new CtField(CtClass.booleanType, fieldName, ctClass);
195                field.setModifiers(java.lang.reflect.Modifier.PUBLIC | java.lang.reflect.Modifier.STATIC);
196                ctClass.addField(field);
197            }
198        } catch (CannotCompileException e) {
199            throw new RuntimeException(e);
200        }
201    }
202
203    private void fixConstructors(CtClass ctClass) throws CannotCompileException, NotFoundException {
204
205        if (ctClass.isEnum()) {
206            // skip enum constructors because they are not stubs in android.jar
207            return;
208        }
209
210        boolean hasDefault = false;
211
212        for (CtConstructor ctConstructor : ctClass.getDeclaredConstructors()) {
213            try {
214                fixConstructor(ctClass, hasDefault, ctConstructor);
215
216                if (ctConstructor.getParameterTypes().length == 0) {
217                    hasDefault = true;
218                }
219            } catch (Exception e) {
220                throw new RuntimeException("problem instrumenting " + ctConstructor, e);
221            }
222        }
223
224        if (!hasDefault) {
225            String methodBody = generateConstructorBody(ctClass, new CtClass[0]);
226            ctClass.addConstructor(CtNewConstructor.make(new CtClass[0], new CtClass[0], "{\n" + methodBody + "}\n", ctClass));
227        }
228    }
229
230    private boolean fixConstructor(CtClass ctClass, boolean needsDefault, CtConstructor ctConstructor) throws NotFoundException, CannotCompileException {
231        String methodBody = generateConstructorBody(ctClass, ctConstructor.getParameterTypes());
232        ctConstructor.setBody("{\n" + methodBody + "}\n");
233        return needsDefault;
234    }
235
236    private String generateConstructorBody(CtClass ctClass, CtClass[] parameterTypes) throws NotFoundException {
237        return generateMethodBody(ctClass,
238                new CtMethod(CtClass.voidType, "<init>", parameterTypes, ctClass),
239                CtClass.voidType,
240                Type.VOID,
241                false,
242                false);
243    }
244
245    private void fixMethods(CtClass ctClass) throws NotFoundException, CannotCompileException {
246        for (CtMethod ctMethod : ctClass.getDeclaredMethods()) {
247            fixMethod(ctClass, ctMethod, true);
248        }
249        CtMethod equalsMethod = ctClass.getMethod("equals", "(Ljava/lang/Object;)Z");
250        CtMethod hashCodeMethod = ctClass.getMethod("hashCode", "()I");
251        CtMethod toStringMethod = ctClass.getMethod("toString", "()Ljava/lang/String;");
252
253        fixMethod(ctClass, equalsMethod, false);
254        fixMethod(ctClass, hashCodeMethod, false);
255        fixMethod(ctClass, toStringMethod, false);
256    }
257
258    private String describe(CtMethod ctMethod) throws NotFoundException {
259        return Modifier.toString(ctMethod.getModifiers()) + " " + ctMethod.getReturnType().getSimpleName() + " " + ctMethod.getLongName();
260    }
261
262    private void fixMethod(CtClass ctClass, CtMethod ctMethod, boolean wasFoundInClass) throws NotFoundException {
263        String describeBefore = describe(ctMethod);
264        try {
265            CtClass declaringClass = ctMethod.getDeclaringClass();
266            int originalModifiers = ctMethod.getModifiers();
267
268            boolean wasNative = Modifier.isNative(originalModifiers);
269            boolean wasFinal = Modifier.isFinal(originalModifiers);
270            boolean wasAbstract = Modifier.isAbstract(originalModifiers);
271            boolean wasDeclaredInClass = ctClass == declaringClass;
272
273            if (wasFinal && ctClass.isEnum()) {
274                return;
275            }
276
277            int newModifiers = originalModifiers;
278            if (wasNative) {
279                newModifiers = Modifier.clear(newModifiers, Modifier.NATIVE);
280            }
281            if (wasFinal) {
282                newModifiers = Modifier.clear(newModifiers, Modifier.FINAL);
283            }
284            if (wasFoundInClass) {
285                ctMethod.setModifiers(newModifiers);
286            }
287
288            CtClass returnCtClass = ctMethod.getReturnType();
289            Type returnType = Type.find(returnCtClass);
290
291            String methodName = ctMethod.getName();
292            CtClass[] paramTypes = ctMethod.getParameterTypes();
293
294//            if (!isAbstract) {
295//                if (methodName.startsWith("set") && paramTypes.length == 1) {
296//                    String fieldName = "__" + methodName.substring(3);
297//                    if (declareField(ctClass, fieldName, paramTypes[0])) {
298//                        methodBody = fieldName + " = $1;\n" + methodBody;
299//                    }
300//                } else if (methodName.startsWith("get") && paramTypes.length == 0) {
301//                    String fieldName = "__" + methodName.substring(3);
302//                    if (declareField(ctClass, fieldName, returnType)) {
303//                        methodBody = "return " + fieldName + ";\n";
304//                    }
305//                }
306//            }
307
308            boolean isStatic = Modifier.isStatic(originalModifiers);
309            String methodBody = generateMethodBody(ctClass, ctMethod, wasNative, wasAbstract, returnCtClass, returnType, isStatic, !wasFoundInClass);
310
311            if (!wasFoundInClass) {
312                CtMethod newMethod = makeNewMethod(ctClass, ctMethod, returnCtClass, methodName, paramTypes, "{\n" + methodBody + generateCallToSuper(methodName, paramTypes) + "\n}");
313                newMethod.setModifiers(newModifiers);
314                if (wasDeclaredInClass) {
315                    ctMethod.insertBefore("{\n" + methodBody + "}\n");
316                } else {
317                    ctClass.addMethod(newMethod);
318                }
319            } else if (wasAbstract || wasNative) {
320                CtMethod newMethod = makeNewMethod(ctClass, ctMethod, returnCtClass, methodName, paramTypes, "{\n" + methodBody + "\n}");
321                ctMethod.setBody(newMethod, null);
322            } else {
323                ctMethod.insertBefore("{\n" + methodBody + "}\n");
324            }
325        } catch (Exception e) {
326            throw new RuntimeException("problem instrumenting " + describeBefore, e);
327        }
328    }
329
330    private CtMethod makeNewMethod(CtClass ctClass, CtMethod ctMethod, CtClass returnCtClass, String methodName, CtClass[] paramTypes, String methodBody) throws CannotCompileException, NotFoundException {
331        return CtNewMethod.make(
332                ctMethod.getModifiers(),
333                returnCtClass,
334                methodName,
335                paramTypes,
336                ctMethod.getExceptionTypes(),
337                methodBody,
338                ctClass);
339    }
340
341    public String generateCallToSuper(String methodName, CtClass[] paramTypes) {
342        return "return super." + methodName + "(" + makeParameterReplacementList(paramTypes.length) + ");";
343    }
344
345    public String makeParameterReplacementList(int length) {
346        if (length == 0) {
347            return "";
348        }
349
350        String parameterReplacementList = "$1";
351        for (int i = 2; i <= length; ++i) {
352            parameterReplacementList += ", $" + i;
353        }
354        return parameterReplacementList;
355    }
356
357    private String generateMethodBody(CtClass ctClass, CtMethod ctMethod, boolean wasNative, boolean wasAbstract, CtClass returnCtClass, Type returnType, boolean aStatic, boolean shouldGenerateCallToSuper) throws NotFoundException {
358        String methodBody;
359        if (wasAbstract) {
360            methodBody = returnType.isVoid() ? "" : "return " + returnType.defaultReturnString() + ";";
361        } else {
362            methodBody = generateMethodBody(ctClass, ctMethod, returnCtClass, returnType, aStatic, shouldGenerateCallToSuper);
363        }
364
365        if (wasNative) {
366            methodBody += returnType.isVoid() ? "" : "return " + returnType.defaultReturnString() + ";";
367        }
368        return methodBody;
369    }
370
371    public String generateMethodBody(CtClass ctClass, CtMethod ctMethod, CtClass returnCtClass, Type returnType, boolean isStatic, boolean shouldGenerateCallToSuper) throws NotFoundException {
372        boolean returnsVoid = returnType.isVoid();
373        String className = ctClass.getName();
374
375        String methodBody;
376        StringBuilder buf = new StringBuilder();
377        buf.append("if (!");
378        buf.append(RobolectricInternals.class.getName());
379        buf.append(".shouldCallDirectly(");
380        buf.append(isStatic ? className + ".class" : "this");
381        buf.append(")) {\n");
382
383        if (!returnsVoid) {
384            buf.append("Object x = ");
385        }
386        buf.append(RobolectricInternals.class.getName());
387        buf.append(".methodInvoked(\n  ");
388        buf.append(className);
389        buf.append(".class, \"");
390        buf.append(ctMethod.getName());
391        buf.append("\", ");
392        if (!isStatic) {
393            buf.append("this");
394        } else {
395            buf.append("null");
396        }
397        buf.append(", ");
398
399        appendParamTypeArray(buf, ctMethod);
400        buf.append(", ");
401        appendParamArray(buf, ctMethod);
402
403        buf.append(")");
404        buf.append(";\n");
405
406        if (!returnsVoid) {
407            buf.append("if (x != null) return ((");
408            buf.append(returnType.nonPrimitiveClassName(returnCtClass));
409            buf.append(") x)");
410            buf.append(returnType.unboxString());
411            buf.append(";\n");
412            if (shouldGenerateCallToSuper) {
413                buf.append(generateCallToSuper(ctMethod.getName(), ctMethod.getParameterTypes()));
414            } else {
415                buf.append("return ");
416                buf.append(returnType.defaultReturnString());
417                buf.append(";\n");
418            }
419        } else {
420            buf.append("return;\n");
421        }
422
423        buf.append("}\n");
424
425        methodBody = buf.toString();
426        return methodBody;
427    }
428
429    private void appendParamTypeArray(StringBuilder buf, CtMethod ctMethod) throws NotFoundException {
430        CtClass[] parameterTypes = ctMethod.getParameterTypes();
431        if (parameterTypes.length == 0) {
432            buf.append("new String[0]");
433        } else {
434            buf.append("new String[] {");
435            for (int i = 0; i < parameterTypes.length; i++) {
436                if (i > 0) buf.append(", ");
437                buf.append("\"");
438                CtClass parameterType = parameterTypes[i];
439                buf.append(parameterType.getName());
440                buf.append("\"");
441            }
442            buf.append("}");
443        }
444    }
445
446    private void appendParamArray(StringBuilder buf, CtMethod ctMethod) throws NotFoundException {
447        int parameterCount = ctMethod.getParameterTypes().length;
448        if (parameterCount == 0) {
449            buf.append("new Object[0]");
450        } else {
451            buf.append("new Object[] {");
452            for (int i = 0; i < parameterCount; i++) {
453                if (i > 0) buf.append(", ");
454                buf.append(RobolectricInternals.class.getName());
455                buf.append(".autobox(");
456                buf.append("$").append(i + 1);
457                buf.append(")");
458            }
459            buf.append("}");
460        }
461    }
462
463    private boolean declareField(CtClass ctClass, String fieldName, CtClass fieldType) throws CannotCompileException, NotFoundException {
464        CtMethod ctMethod = getMethod(ctClass, "get" + fieldName, "");
465        if (ctMethod == null) {
466            return false;
467        }
468        CtClass getterFieldType = ctMethod.getReturnType();
469
470        if (!getterFieldType.equals(fieldType)) {
471            return false;
472        }
473
474        if (getField(ctClass, fieldName) == null) {
475            CtField field = new CtField(fieldType, fieldName, ctClass);
476            field.setModifiers(Modifier.PRIVATE);
477            ctClass.addField(field);
478        }
479
480        return true;
481    }
482
483    private CtField getField(CtClass ctClass, String fieldName) {
484        try {
485            return ctClass.getField(fieldName);
486        } catch (NotFoundException e) {
487            return null;
488        }
489    }
490
491    private CtMethod getMethod(CtClass ctClass, String methodName, String desc) {
492        try {
493            return ctClass.getMethod(methodName, desc);
494        } catch (NotFoundException e) {
495            return null;
496        }
497    }
498
499}
500