1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16package com.android.test.runner.junit4;
17
18import android.app.Instrumentation;
19import android.content.Context;
20import android.os.Bundle;
21import android.util.Log;
22
23import com.android.test.InjectBundle;
24import com.android.test.InjectContext;
25import com.android.test.InjectInstrumentation;
26
27import org.junit.runners.BlockJUnit4ClassRunner;
28import org.junit.runners.model.FrameworkField;
29import org.junit.runners.model.InitializationError;
30
31import java.lang.reflect.Field;
32import java.util.List;
33
34/**
35 * A specialized {@link BlockJUnit4ClassRunner} that can handle {@link InjectContext} and
36 * {@link InjectInstrumentation}.
37 */
38class AndroidJUnit4ClassRunner extends BlockJUnit4ClassRunner {
39
40    private static final String LOG_TAG = "AndroidJUnit4ClassRunner";
41    private final Instrumentation mInstr;
42    private final Bundle mBundle;
43
44    @SuppressWarnings("serial")
45    private static class InvalidInjectException extends Exception {
46        InvalidInjectException(String message) {
47            super(message);
48        }
49    }
50
51    public AndroidJUnit4ClassRunner(Class<?> klass, Instrumentation instr, Bundle bundle)
52            throws InitializationError {
53        super(klass);
54        mInstr = instr;
55        mBundle = bundle;
56    }
57
58    @Override
59    protected Object createTest() throws Exception {
60        Object test = super.createTest();
61        inject(test);
62        return test;
63    }
64
65    @Override
66    protected void collectInitializationErrors(List<Throwable> errors) {
67        super.collectInitializationErrors(errors);
68
69        validateInjectFields(errors);
70    }
71
72    private void validateInjectFields(List<Throwable> errors) {
73        List<FrameworkField> instrFields = getTestClass().getAnnotatedFields(
74                InjectInstrumentation.class);
75        for (FrameworkField instrField : instrFields) {
76            validateInjectField(errors, instrField, Instrumentation.class);
77        }
78        List<FrameworkField> contextFields = getTestClass().getAnnotatedFields(
79                InjectContext.class);
80        for (FrameworkField contextField : contextFields) {
81            validateInjectField(errors, contextField, Context.class);
82        }
83        List<FrameworkField> bundleFields = getTestClass().getAnnotatedFields(
84                InjectBundle.class);
85        for (FrameworkField bundleField : bundleFields) {
86            validateInjectField(errors, bundleField, Bundle.class);
87        }
88    }
89
90    private void validateInjectField(List<Throwable> errors, FrameworkField instrField,
91            Class<?> expectedType) {
92        if (!instrField.isPublic()) {
93            errors.add(new InvalidInjectException(String.format(
94                    "field %s in class %s has an InjectInstrumentation annotation," +
95                    " but is not public", instrField.getName(), getTestClass().getName())));
96        }
97        if (!expectedType.isAssignableFrom(instrField.getType())) {
98            errors.add(new InvalidInjectException(String.format(
99                    "field %s in class %s has an InjectInstrumentation annotation," +
100                    " but its not of %s type", instrField.getName(),
101                    getTestClass().getName(), expectedType.getName())));
102        }
103    }
104
105    private void inject(Object test) {
106        List<FrameworkField> instrFields = getTestClass().getAnnotatedFields(
107                InjectInstrumentation.class);
108        for (FrameworkField instrField : instrFields) {
109            setFieldValue(test, instrField.getField(), mInstr);
110        }
111        List<FrameworkField> contextFields = getTestClass().getAnnotatedFields(
112                InjectContext.class);
113        for (FrameworkField contextField : contextFields) {
114            setFieldValue(test, contextField.getField(), mInstr.getTargetContext());
115        }
116        List<FrameworkField> bundleFields = getTestClass().getAnnotatedFields(
117                InjectBundle.class);
118        for (FrameworkField bundleField : bundleFields) {
119            setFieldValue(test, bundleField.getField(), mBundle);
120        }
121    }
122
123    private void setFieldValue(Object test, Field field, Object value) {
124        try {
125            field.set(test, value);
126        } catch (IllegalArgumentException e) {
127            Log.e(LOG_TAG, String.format(
128                    "Failed to inject value for field %s in class %s", field.getName(),
129                    test.getClass().getName()), e);
130        } catch (IllegalAccessException e) {
131            Log.e(LOG_TAG, String.format(
132                    "Failed to inject value for field %s in class %s", field.getName(),
133                    test.getClass().getName()), e);
134        }
135    }
136}
137