1package org.junit.experimental;
2
3import java.util.ArrayList;
4import java.util.List;
5import java.util.concurrent.Callable;
6import java.util.concurrent.ExecutorService;
7import java.util.concurrent.Executors;
8import java.util.concurrent.Future;
9
10import org.junit.runner.Computer;
11import org.junit.runner.Runner;
12import org.junit.runners.ParentRunner;
13import org.junit.runners.model.InitializationError;
14import org.junit.runners.model.RunnerBuilder;
15import org.junit.runners.model.RunnerScheduler;
16
17public class ParallelComputer extends Computer {
18	private final boolean fClasses;
19
20	private final boolean fMethods;
21
22	public ParallelComputer(boolean classes, boolean methods) {
23		fClasses= classes;
24		fMethods= methods;
25	}
26
27	public static Computer classes() {
28		return new ParallelComputer(true, false);
29	}
30
31	public static Computer methods() {
32		return new ParallelComputer(false, true);
33	}
34
35	private static <T> Runner parallelize(Runner runner) {
36		if (runner instanceof ParentRunner<?>) {
37			((ParentRunner<?>) runner).setScheduler(new RunnerScheduler() {
38				private final List<Future<Object>> fResults= new ArrayList<Future<Object>>();
39
40				private final ExecutorService fService= Executors
41						.newCachedThreadPool();
42
43				public void schedule(final Runnable childStatement) {
44					fResults.add(fService.submit(new Callable<Object>() {
45						public Object call() throws Exception {
46							childStatement.run();
47							return null;
48						}
49					}));
50				}
51
52				public void finished() {
53					for (Future<Object> each : fResults)
54						try {
55							each.get();
56						} catch (Exception e) {
57							e.printStackTrace();
58						}
59				}
60			});
61		}
62		return runner;
63	}
64
65	@Override
66	public Runner getSuite(RunnerBuilder builder, java.lang.Class<?>[] classes)
67			throws InitializationError {
68		Runner suite= super.getSuite(builder, classes);
69		return fClasses ? parallelize(suite) : suite;
70	}
71
72	@Override
73	protected Runner getRunner(RunnerBuilder builder, Class<?> testClass)
74			throws Throwable {
75		Runner runner= super.getRunner(builder, testClass);
76		return fMethods ? parallelize(runner) : runner;
77	}
78}
79