1package org.testng.internal;
2
3import java.io.File;
4import java.io.FileFilter;
5import java.io.IOException;
6import java.io.UnsupportedEncodingException;
7import java.lang.reflect.Method;
8import java.net.JarURLConnection;
9import java.net.URL;
10import java.net.URLConnection;
11import java.net.URLDecoder;
12import java.util.Enumeration;
13import java.util.Iterator;
14import java.util.List;
15import java.util.Vector;
16import java.util.jar.JarEntry;
17import java.util.jar.JarFile;
18import java.util.regex.Pattern;
19
20import org.testng.TestNG;
21import org.testng.collections.Lists;
22
23/**
24 * Utility class that finds all the classes in a given package.
25 *
26 * Created on Feb 24, 2006
27 * @author <a href="mailto:cedric@beust.com">Cedric Beust</a>
28 */
29public class PackageUtils {
30  private static String[] s_testClassPaths;
31
32  /** The additional class loaders to find classes in. */
33  private static final List<ClassLoader> m_classLoaders = new Vector<>();
34
35  /** Add a class loader to the searchable loaders. */
36  public static void addClassLoader(final ClassLoader loader) {
37    m_classLoaders.add(loader);
38  }
39
40  /**
41   *
42   * @param packageName
43   * @return The list of all the classes inside this package
44   * @throws IOException
45   */
46  public static String[] findClassesInPackage(String packageName,
47      List<String> included, List<String> excluded)
48    throws IOException
49  {
50    String packageOnly = packageName;
51    boolean recursive = false;
52    if (packageName.endsWith(".*")) {
53      packageOnly = packageName.substring(0, packageName.lastIndexOf(".*"));
54      recursive = true;
55    }
56
57    List<String> vResult = Lists.newArrayList();
58    String packageDirName = packageOnly.replace('.', '/') + (packageOnly.length() > 0 ? "/" : "");
59
60
61    Vector<URL> dirs = new Vector<>();
62    // go through additional class loaders
63    Vector<ClassLoader> allClassLoaders = new Vector<>();
64    ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader();
65    if (contextClassLoader != null) {
66      allClassLoaders.add(contextClassLoader);
67    }
68    if (m_classLoaders != null) {
69      allClassLoaders.addAll(m_classLoaders);
70    }
71
72    int count = 0;
73    for (ClassLoader classLoader : allClassLoaders) {
74      ++count;
75      if (null == classLoader) {
76        continue;
77      }
78      Enumeration<URL> dirEnumeration = classLoader.getResources(packageDirName);
79      while(dirEnumeration.hasMoreElements()){
80        URL dir = dirEnumeration.nextElement();
81        dirs.add(dir);
82      }
83    }
84
85    Iterator<URL> dirIterator = dirs.iterator();
86    while (dirIterator.hasNext()) {
87      URL url = dirIterator.next();
88      String protocol = url.getProtocol();
89      if(!matchTestClasspath(url, packageDirName, recursive)) {
90        continue;
91      }
92
93      if ("file".equals(protocol)) {
94        findClassesInDirPackage(packageOnly, included, excluded,
95                                URLDecoder.decode(url.getFile(), "UTF-8"),
96                                recursive,
97                                vResult);
98      }
99      else if ("jar".equals(protocol)) {
100        JarFile jar = ((JarURLConnection) url.openConnection()).getJarFile();
101        Enumeration<JarEntry> entries = jar.entries();
102        while (entries.hasMoreElements()) {
103          JarEntry entry = entries.nextElement();
104          String name = entry.getName();
105          if (name.charAt(0) == '/') {
106            name = name.substring(1);
107          }
108          if (name.startsWith(packageDirName)) {
109            int idx = name.lastIndexOf('/');
110            if (idx != -1) {
111              packageName = name.substring(0, idx).replace('/', '.');
112            }
113
114            if (recursive || packageName.equals(packageOnly)) {
115              //it's not inside a deeper dir
116              Utils.log("PackageUtils", 4, "Package name is " + packageName);
117              if (name.endsWith(".class") && !entry.isDirectory()) {
118                String className = name.substring(packageName.length() + 1, name.length() - 6);
119                Utils.log("PackageUtils", 4, "Found class " + className + ", seeing it if it's included or excluded");
120                includeOrExcludeClass(packageName, className, included, excluded, vResult);
121              }
122            }
123          }
124        }
125      }
126      else if ("bundleresource".equals(protocol)) {
127        try {
128          Class params[] = {};
129          // BundleURLConnection
130          URLConnection connection = url.openConnection();
131          Method thisMethod = url.openConnection().getClass()
132              .getDeclaredMethod("getFileURL", params);
133          Object paramsObj[] = {};
134          URL fileUrl = (URL) thisMethod.invoke(connection, paramsObj);
135          findClassesInDirPackage(packageOnly, included, excluded,
136              URLDecoder.decode(fileUrl.getFile(), "UTF-8"), recursive, vResult);
137        } catch (Exception ex) {
138          // ignore - probably not an Eclipse OSGi bundle
139        }
140      }
141    }
142
143    String[] result = vResult.toArray(new String[vResult.size()]);
144    return result;
145  }
146
147  private static String[] getTestClasspath() {
148    if (null != s_testClassPaths) {
149      return s_testClassPaths;
150    }
151
152    String testClasspath = System.getProperty(TestNG.TEST_CLASSPATH);
153    if (null == testClasspath) {
154      return null;
155    }
156
157    String[] classpathFragments= Utils.split(testClasspath, File.pathSeparator);
158    s_testClassPaths= new String[classpathFragments.length];
159
160    for(int i= 0; i < classpathFragments.length; i++)  {
161      String path= null;
162      if(classpathFragments[i].toLowerCase().endsWith(".jar") || classpathFragments[i].toLowerCase().endsWith(".zip")) {
163        path= classpathFragments[i] + "!/";
164      }
165      else {
166        if(classpathFragments[i].endsWith(File.separator)) {
167          path= classpathFragments[i];
168        }
169        else {
170          path= classpathFragments[i] + "/";
171        }
172      }
173
174      s_testClassPaths[i]= path.replace('\\', '/');
175    }
176
177    return s_testClassPaths;
178  }
179
180  private static boolean matchTestClasspath(URL url, String lastFragment, boolean recursive) {
181    String[] classpathFragments= getTestClasspath();
182    if(null == classpathFragments) {
183      return true;
184    }
185
186    String fileName= null;
187    try {
188      fileName= URLDecoder.decode(url.getFile(), "UTF-8");
189    }
190    catch(UnsupportedEncodingException ueex) {
191      ; // ignore. should never happen
192    }
193
194    for(String classpathFrag: classpathFragments) {
195      String path=  classpathFrag + lastFragment;
196      int idx= fileName.indexOf(path);
197      if((idx == -1) || (idx > 0 && fileName.charAt(idx-1) != '/')) {
198        continue;
199      }
200
201      if(fileName.endsWith(classpathFrag + lastFragment)
202          || (recursive && fileName.charAt(idx + path.length()) == '/')) {
203        return true;
204      }
205    }
206
207    return false;
208  }
209
210  private static void findClassesInDirPackage(String packageName,
211                                              List<String> included,
212                                              List<String> excluded,
213                                              String packagePath,
214                                              final boolean recursive,
215                                              List<String> classes) {
216    File dir = new File(packagePath);
217
218    if (!dir.exists() || !dir.isDirectory()) {
219      return;
220    }
221
222    File[] dirfiles = dir.listFiles(new FileFilter() {
223          @Override
224          public boolean accept(File file) {
225            return (recursive && file.isDirectory())
226              || (file.getName().endsWith(".class"))
227              || (file.getName().endsWith(".groovy"));
228          }
229        });
230
231    Utils.log("PackageUtils", 4, "Looking for test classes in the directory: " + dir);
232    for (File file : dirfiles) {
233      if (file.isDirectory()) {
234        findClassesInDirPackage(makeFullClassName(packageName, file.getName()),
235                                included,
236                                excluded,
237                                file.getAbsolutePath(),
238                                recursive,
239                                classes);
240      }
241      else {
242        String className = file.getName().substring(0, file.getName().lastIndexOf("."));
243        Utils.log("PackageUtils", 4, "Found class " + className
244            + ", seeing it if it's included or excluded");
245        includeOrExcludeClass(packageName, className, included, excluded, classes);
246      }
247    }
248  }
249
250  private static String makeFullClassName(String pkg, String cls) {
251    return pkg.length() > 0 ? pkg + "." + cls : cls;
252  }
253
254  private static void includeOrExcludeClass(String packageName, String className,
255      List<String> included, List<String> excluded, List<String> classes)
256  {
257    if (isIncluded(packageName, included, excluded)) {
258      Utils.log("PackageUtils", 4, "... Including class " + className);
259      classes.add(makeFullClassName(packageName, className));
260    }
261    else {
262      Utils.log("PackageUtils", 4, "... Excluding class " + className);
263    }
264  }
265
266  /**
267   * @return true if name should be included.
268   */
269  private static boolean isIncluded(String name,
270      List<String> included, List<String> excluded)
271  {
272    boolean result = false;
273
274    //
275    // If no includes nor excludes were specified, return true.
276    //
277    if (included.size() == 0 && excluded.size() == 0) {
278      result = true;
279    }
280    else {
281      boolean isIncluded = PackageUtils.find(name, included);
282      boolean isExcluded = PackageUtils.find(name, excluded);
283      if (isIncluded && !isExcluded) {
284        result = true;
285      }
286      else if (isExcluded) {
287        result = false;
288      }
289      else {
290        result = included.size() == 0;
291      }
292    }
293    return result;
294  }
295
296  private static boolean find(String name, List<String> list) {
297    for (String regexpStr : list) {
298      if (Pattern.matches(regexpStr, name)) {
299        return true;
300      }
301    }
302    return false;
303  }
304}
305