1/**
2 *
3 */
4package org.junit.experimental.theories;
5
6import java.lang.reflect.Field;
7import java.lang.reflect.InvocationTargetException;
8import java.lang.reflect.Modifier;
9import java.util.ArrayList;
10import java.util.List;
11
12import org.junit.Assert;
13import org.junit.experimental.theories.PotentialAssignment.CouldNotGenerateValueException;
14import org.junit.experimental.theories.internal.Assignments;
15import org.junit.experimental.theories.internal.ParameterizedAssertionError;
16import org.junit.internal.AssumptionViolatedException;
17import org.junit.runners.BlockJUnit4ClassRunner;
18import org.junit.runners.model.FrameworkMethod;
19import org.junit.runners.model.InitializationError;
20import org.junit.runners.model.Statement;
21import org.junit.runners.model.TestClass;
22
23public class Theories extends BlockJUnit4ClassRunner {
24	public Theories(Class<?> klass) throws InitializationError {
25		super(klass);
26	}
27
28	@Override
29	protected void collectInitializationErrors(List<Throwable> errors) {
30		super.collectInitializationErrors(errors);
31		validateDataPointFields(errors);
32	}
33
34	private void validateDataPointFields(List<Throwable> errors) {
35		Field[] fields= getTestClass().getJavaClass().getDeclaredFields();
36
37		for (Field each : fields)
38			if (each.getAnnotation(DataPoint.class) != null && !Modifier.isStatic(each.getModifiers()))
39				errors.add(new Error("DataPoint field " + each.getName() + " must be static"));
40	}
41
42	@Override
43	protected void validateConstructor(List<Throwable> errors) {
44		validateOnlyOneConstructor(errors);
45	}
46
47	@Override
48	protected void validateTestMethods(List<Throwable> errors) {
49		for (FrameworkMethod each : computeTestMethods())
50			if(each.getAnnotation(Theory.class) != null)
51				each.validatePublicVoid(false, errors);
52			else
53				each.validatePublicVoidNoArg(false, errors);
54	}
55
56	@Override
57	protected List<FrameworkMethod> computeTestMethods() {
58		List<FrameworkMethod> testMethods= super.computeTestMethods();
59		List<FrameworkMethod> theoryMethods= getTestClass().getAnnotatedMethods(Theory.class);
60		testMethods.removeAll(theoryMethods);
61		testMethods.addAll(theoryMethods);
62		return testMethods;
63	}
64
65	@Override
66	public Statement methodBlock(final FrameworkMethod method) {
67		return new TheoryAnchor(method, getTestClass());
68	}
69
70	public static class TheoryAnchor extends Statement {
71		private int successes= 0;
72
73		private FrameworkMethod fTestMethod;
74        private TestClass fTestClass;
75
76		private List<AssumptionViolatedException> fInvalidParameters= new ArrayList<AssumptionViolatedException>();
77
78		public TheoryAnchor(FrameworkMethod method, TestClass testClass) {
79			fTestMethod= method;
80            fTestClass= testClass;
81		}
82
83        private TestClass getTestClass() {
84            return fTestClass;
85        }
86
87		@Override
88		public void evaluate() throws Throwable {
89			runWithAssignment(Assignments.allUnassigned(
90					fTestMethod.getMethod(), getTestClass()));
91
92			if (successes == 0)
93				Assert
94						.fail("Never found parameters that satisfied method assumptions.  Violated assumptions: "
95								+ fInvalidParameters);
96		}
97
98		protected void runWithAssignment(Assignments parameterAssignment)
99				throws Throwable {
100			if (!parameterAssignment.isComplete()) {
101				runWithIncompleteAssignment(parameterAssignment);
102			} else {
103				runWithCompleteAssignment(parameterAssignment);
104			}
105		}
106
107		protected void runWithIncompleteAssignment(Assignments incomplete)
108				throws InstantiationException, IllegalAccessException,
109				Throwable {
110			for (PotentialAssignment source : incomplete
111					.potentialsForNextUnassigned()) {
112				runWithAssignment(incomplete.assignNext(source));
113			}
114		}
115
116		protected void runWithCompleteAssignment(final Assignments complete)
117				throws InstantiationException, IllegalAccessException,
118				InvocationTargetException, NoSuchMethodException, Throwable {
119			new BlockJUnit4ClassRunner(getTestClass().getJavaClass()) {
120				@Override
121				protected void collectInitializationErrors(
122						List<Throwable> errors) {
123					// do nothing
124				}
125
126				@Override
127				public Statement methodBlock(FrameworkMethod method) {
128					final Statement statement= super.methodBlock(method);
129					return new Statement() {
130						@Override
131						public void evaluate() throws Throwable {
132							try {
133								statement.evaluate();
134								handleDataPointSuccess();
135							} catch (AssumptionViolatedException e) {
136								handleAssumptionViolation(e);
137							} catch (Throwable e) {
138								reportParameterizedError(e, complete
139										.getArgumentStrings(nullsOk()));
140							}
141						}
142
143					};
144				}
145
146				@Override
147				protected Statement methodInvoker(FrameworkMethod method, Object test) {
148					return methodCompletesWithParameters(method, complete, test);
149				}
150
151				@Override
152				public Object createTest() throws Exception {
153					return getTestClass().getOnlyConstructor().newInstance(
154							complete.getConstructorArguments(nullsOk()));
155				}
156			}.methodBlock(fTestMethod).evaluate();
157		}
158
159		private Statement methodCompletesWithParameters(
160				final FrameworkMethod method, final Assignments complete, final Object freshInstance) {
161			return new Statement() {
162				@Override
163				public void evaluate() throws Throwable {
164					try {
165						final Object[] values= complete.getMethodArguments(
166								nullsOk());
167						method.invokeExplosively(freshInstance, values);
168					} catch (CouldNotGenerateValueException e) {
169						// ignore
170					}
171				}
172			};
173		}
174
175		protected void handleAssumptionViolation(AssumptionViolatedException e) {
176			fInvalidParameters.add(e);
177		}
178
179		protected void reportParameterizedError(Throwable e, Object... params)
180				throws Throwable {
181			if (params.length == 0)
182				throw e;
183			throw new ParameterizedAssertionError(e, fTestMethod.getName(),
184					params);
185		}
186
187		private boolean nullsOk() {
188			Theory annotation= fTestMethod.getMethod().getAnnotation(
189					Theory.class);
190			if (annotation == null)
191				return false;
192			return annotation.nullsAccepted();
193		}
194
195		protected void handleDataPointSuccess() {
196			successes++;
197		}
198	}
199}
200