1# Copyright (c) 2015, Google Inc.
2#
3# Permission to use, copy, modify, and/or distribute this software for any
4# purpose with or without fee is hereby granted, provided that the above
5# copyright notice and this permission notice appear in all copies.
6#
7# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
10# SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12# OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14
15"""Extracts archives."""
16
17
18import optparse
19import os
20import os.path
21import tarfile
22import shutil
23import sys
24import zipfile
25
26
27def CheckedJoin(output, path):
28  """
29  CheckedJoin returns os.path.join(output, path). It does sanity checks to
30  ensure the resulting path is under output, but shouldn't be used on untrusted
31  input.
32  """
33  path = os.path.normpath(path)
34  if os.path.isabs(path) or path.startswith('.'):
35    raise ValueError(path)
36  return os.path.join(output, path)
37
38
39def IterateZip(path):
40  """
41  IterateZip opens the zip file at path and returns a generator of
42  (filename, mode, fileobj) tuples for each file in it.
43  """
44  with zipfile.ZipFile(path, 'r') as zip_file:
45    for info in zip_file.infolist():
46      if info.filename.endswith('/'):
47        continue
48      yield (info.filename, None, zip_file.open(info))
49
50
51def IterateTar(path):
52  """
53  IterateTar opens the tar.gz file at path and returns a generator of
54  (filename, mode, fileobj) tuples for each file in it.
55  """
56  with tarfile.open(path, 'r:gz') as tar_file:
57    for info in tar_file:
58      if info.isdir():
59        continue
60      if not info.isfile():
61        raise ValueError('Unknown entry type "%s"' % (info.name, ))
62      yield (info.name, info.mode, tar_file.extractfile(info))
63
64
65def main(args):
66  parser = optparse.OptionParser(usage='Usage: %prog ARCHIVE OUTPUT')
67  parser.add_option('--no-prefix', dest='no_prefix', action='store_true',
68                    help='Do not remove a prefix from paths in the archive.')
69  options, args = parser.parse_args(args)
70
71  if len(args) != 2:
72    parser.print_help()
73    return 1
74
75  archive, output = args
76
77  if not os.path.exists(archive):
78    # Skip archives that weren't downloaded.
79    return 0
80
81  if archive.endswith('.zip'):
82    entries = IterateZip(archive)
83  elif archive.endswith('.tar.gz'):
84    entries = IterateTar(archive)
85  else:
86    raise ValueError(archive)
87
88  try:
89    if os.path.exists(output):
90      print "Removing %s" % (output, )
91      shutil.rmtree(output)
92
93    print "Extracting %s to %s" % (archive, output)
94    prefix = None
95    num_extracted = 0
96    for path, mode, inp in entries:
97      # Even on Windows, zip files must always use forward slashes.
98      if '\\' in path or path.startswith('/'):
99        raise ValueError(path)
100
101      if not options.no_prefix:
102        new_prefix, rest = path.split('/', 1)
103
104        # Ensure the archive is consistent.
105        if prefix is None:
106          prefix = new_prefix
107        if prefix != new_prefix:
108          raise ValueError((prefix, new_prefix))
109      else:
110        rest = path
111
112      # Extract the file into the output directory.
113      fixed_path = CheckedJoin(output, rest)
114      if not os.path.isdir(os.path.dirname(fixed_path)):
115        os.makedirs(os.path.dirname(fixed_path))
116      with open(fixed_path, 'wb') as out:
117        shutil.copyfileobj(inp, out)
118
119      # Fix up permissions if needbe.
120      # TODO(davidben): To be extra tidy, this should only track the execute bit
121      # as in git.
122      if mode is not None:
123        os.chmod(fixed_path, mode)
124
125      # Print every 100 files, so bots do not time out on large archives.
126      num_extracted += 1
127      if num_extracted % 100 == 0:
128        print "Extracted %d files..." % (num_extracted,)
129  finally:
130    entries.close()
131
132  if num_extracted % 100 == 0:
133    print "Done. Extracted %d files." % (num_extracted,)
134
135  return 0
136
137
138if __name__ == '__main__':
139  sys.exit(main(sys.argv[1:]))
140