1package junitparams.internal.parameters;
2
3import java.lang.reflect.InvocationTargetException;
4import java.lang.reflect.Method;
5import java.util.ArrayList;
6import java.util.Iterator;
7import java.util.List;
8
9import org.junit.runners.model.FrameworkMethod;
10
11import junitparams.Parameters;
12
13class ParamsFromMethodCommon {
14    private FrameworkMethod frameworkMethod;
15
16    ParamsFromMethodCommon(FrameworkMethod frameworkMethod) {
17        this.frameworkMethod = frameworkMethod;
18    }
19
20    Object[] paramsFromMethod(Class<?> sourceClass) {
21        String methodAnnotation = frameworkMethod.getAnnotation(Parameters.class).method();
22
23        if (methodAnnotation.isEmpty()) {
24            return invokeMethodWithParams(defaultMethodName(), sourceClass);
25        }
26
27        List<Object> result = new ArrayList<Object>();
28        for (String methodName : methodAnnotation.split(",")) {
29            for (Object param : invokeMethodWithParams(methodName.trim(), sourceClass))
30                result.add(param);
31        }
32
33        return result.toArray();
34    }
35
36    Object[] getDataFromMethod(Method providerMethod) throws IllegalAccessException, InvocationTargetException {
37        return encapsulateParamsIntoArrayIfSingleParamsetPassed((Object[]) providerMethod.invoke(null));
38    }
39
40    boolean containsDefaultParametersProvidingMethod(Class<?> sourceClass) {
41        return findMethodInTestClassHierarchy(defaultMethodName(), sourceClass) != null;
42    }
43
44    private String defaultMethodName() {
45        return "parametersFor" + frameworkMethod.getName().substring(0, 1).toUpperCase()
46                + this.frameworkMethod.getName().substring(1);
47    }
48
49    private Object[] invokeMethodWithParams(String methodName, Class<?> sourceClass) {
50        Method providerMethod = findMethodInTestClassHierarchy(methodName, sourceClass);
51        if (providerMethod == null) {
52            throw new RuntimeException("Could not find method: " + methodName + " so no params were used.");
53        }
54
55        return invokeParamsProvidingMethod(providerMethod, sourceClass);
56    }
57
58    @SuppressWarnings("unchecked")
59    private Object[] invokeParamsProvidingMethod(Method provideMethod, Class<?> sourceClass) {
60        try {
61            Object testObject = sourceClass.newInstance();
62            provideMethod.setAccessible(true);
63            Object result = provideMethod.invoke(testObject);
64
65            if (Object[].class.isAssignableFrom(result.getClass())) {
66                Object[] params = (Object[]) result;
67                return encapsulateParamsIntoArrayIfSingleParamsetPassed(params);
68            }
69
70            if (Iterable.class.isAssignableFrom(result.getClass())) {
71                try {
72                    ArrayList<Object[]> res = new ArrayList<Object[]>();
73                    for (Object[] paramSet : (Iterable<Object[]>) result)
74                        res.add(paramSet);
75                    return res.toArray();
76                } catch (ClassCastException e1) {
77                    // Iterable with consecutive paramsets, each of one param
78                    ArrayList<Object> res = new ArrayList<Object>();
79                    for (Object param : (Iterable<?>) result)
80                        res.add(new Object[]{param});
81                    return res.toArray();
82                }
83            }
84
85            if (Iterator.class.isAssignableFrom(result.getClass())) {
86                Object iteratedElement = null;
87                try {
88                    ArrayList<Object[]> res = new ArrayList<Object[]>();
89                    Iterator<Object[]> iterator = (Iterator<Object[]>) result;
90                    while (iterator.hasNext()) {
91                        iteratedElement = iterator.next();
92                        // ClassCastException will occur in the following line
93                        // if the iterator is actually Iterator<Object> in Java 7
94                        res.add((Object[]) iteratedElement);
95                    }
96                    return res.toArray();
97                } catch (ClassCastException e1) {
98                    // Iterator with consecutive paramsets, each of one param
99                    ArrayList<Object> res = new ArrayList<Object>();
100                    Iterator<?> iterator = (Iterator<?>) result;
101                    // The first element is already stored in iteratedElement
102                    res.add(iteratedElement);
103                    while (iterator.hasNext()) {
104                        res.add(new Object[]{iterator.next()});
105                    }
106                    return res.toArray();
107                }
108            }
109
110            throw new ClassCastException();
111
112        } catch (ClassCastException e) {
113            throw new RuntimeException("The return type of: " + provideMethod.getName() + " defined in class " +
114                    sourceClass + " is not Object[][] nor Iterable<Object[]>. Fix it!", e);
115        } catch (Exception e) {
116            throw new RuntimeException("Could not invoke method: " + provideMethod.getName() + " defined in class " +
117                    sourceClass + " so no params were used.", e);
118        }
119    }
120
121    private Method findMethodInTestClassHierarchy(String methodName, Class<?> sourceClass) {
122        Class<?> declaringClass = sourceClass;
123        while (declaringClass.getSuperclass() != null) {
124            try {
125                return declaringClass.getDeclaredMethod(methodName);
126            } catch (Exception ignore) {
127            }
128            declaringClass = declaringClass.getSuperclass();
129        }
130        return null;
131    }
132
133    private Object[] encapsulateParamsIntoArrayIfSingleParamsetPassed(Object[] params) {
134        if (frameworkMethod.getMethod().getParameterTypes().length != params.length) {
135            return params;
136        }
137
138        if (params.length == 0) {
139            return params;
140        }
141
142        Object param = params[0];
143        if (param == null || !param.getClass().isArray()) {
144            return new Object[]{params};
145        }
146
147        return params;
148    }
149
150}
151