1# Copyright 2008 Google Inc, Martin J. Bligh <mbligh@google.com>,
2#                Benjamin Poirier, Ryan Stutsman
3# Released under the GPL v2
4"""
5Miscellaneous small functions.
6
7DO NOT import this file directly - it is mixed in by server/utils.py,
8import that instead
9"""
10
11import atexit, os, re, shutil, textwrap, sys, tempfile, types
12
13from autotest_lib.client.common_lib import barrier, utils
14from autotest_lib.server import subcommand
15
16
17# A dictionary of pid and a list of tmpdirs for that pid
18__tmp_dirs = {}
19
20
21def scp_remote_escape(filename):
22    """
23    Escape special characters from a filename so that it can be passed
24    to scp (within double quotes) as a remote file.
25
26    Bis-quoting has to be used with scp for remote files, "bis-quoting"
27    as in quoting x 2
28    scp does not support a newline in the filename
29
30    Args:
31            filename: the filename string to escape.
32
33    Returns:
34            The escaped filename string. The required englobing double
35            quotes are NOT added and so should be added at some point by
36            the caller.
37    """
38    escape_chars= r' !"$&' "'" r'()*,:;<=>?[\]^`{|}'
39
40    new_name= []
41    for char in filename:
42        if char in escape_chars:
43            new_name.append("\\%s" % (char,))
44        else:
45            new_name.append(char)
46
47    return utils.sh_escape("".join(new_name))
48
49
50def get(location, local_copy = False):
51    """Get a file or directory to a local temporary directory.
52
53    Args:
54            location: the source of the material to get. This source may
55                    be one of:
56                    * a local file or directory
57                    * a URL (http or ftp)
58                    * a python file-like object
59
60    Returns:
61            The location of the file or directory where the requested
62            content was saved. This will be contained in a temporary
63            directory on the local host. If the material to get was a
64            directory, the location will contain a trailing '/'
65    """
66    tmpdir = get_tmp_dir()
67
68    # location is a file-like object
69    if hasattr(location, "read"):
70        tmpfile = os.path.join(tmpdir, "file")
71        tmpfileobj = file(tmpfile, 'w')
72        shutil.copyfileobj(location, tmpfileobj)
73        tmpfileobj.close()
74        return tmpfile
75
76    if isinstance(location, types.StringTypes):
77        # location is a URL
78        if location.startswith('http') or location.startswith('ftp'):
79            tmpfile = os.path.join(tmpdir, os.path.basename(location))
80            utils.urlretrieve(location, tmpfile)
81            return tmpfile
82        # location is a local path
83        elif os.path.exists(os.path.abspath(location)):
84            if not local_copy:
85                if os.path.isdir(location):
86                    return location.rstrip('/') + '/'
87                else:
88                    return location
89            tmpfile = os.path.join(tmpdir, os.path.basename(location))
90            if os.path.isdir(location):
91                tmpfile += '/'
92                shutil.copytree(location, tmpfile, symlinks=True)
93                return tmpfile
94            shutil.copyfile(location, tmpfile)
95            return tmpfile
96        # location is just a string, dump it to a file
97        else:
98            tmpfd, tmpfile = tempfile.mkstemp(dir=tmpdir)
99            tmpfileobj = os.fdopen(tmpfd, 'w')
100            tmpfileobj.write(location)
101            tmpfileobj.close()
102            return tmpfile
103
104
105def get_tmp_dir():
106    """Return the pathname of a directory on the host suitable
107    for temporary file storage.
108
109    The directory and its content will be deleted automatically
110    at the end of the program execution if they are still present.
111    """
112    dir_name = tempfile.mkdtemp(prefix="autoserv-")
113    pid = os.getpid()
114    if not pid in __tmp_dirs:
115        __tmp_dirs[pid] = []
116    __tmp_dirs[pid].append(dir_name)
117    return dir_name
118
119
120def __clean_tmp_dirs():
121    """Erase temporary directories that were created by the get_tmp_dir()
122    function and that are still present.
123    """
124    pid = os.getpid()
125    if pid not in __tmp_dirs:
126        return
127    for dir in __tmp_dirs[pid]:
128        try:
129            shutil.rmtree(dir)
130        except OSError, e:
131            if e.errno == 2:
132                pass
133    __tmp_dirs[pid] = []
134atexit.register(__clean_tmp_dirs)
135subcommand.subcommand.register_join_hook(lambda _: __clean_tmp_dirs())
136
137
138def unarchive(host, source_material):
139    """Uncompress and untar an archive on a host.
140
141    If the "source_material" is compresses (according to the file
142    extension) it will be uncompressed. Supported compression formats
143    are gzip and bzip2. Afterwards, if the source_material is a tar
144    archive, it will be untarred.
145
146    Args:
147            host: the host object on which the archive is located
148            source_material: the path of the archive on the host
149
150    Returns:
151            The file or directory name of the unarchived source material.
152            If the material is a tar archive, it will be extracted in the
153            directory where it is and the path returned will be the first
154            entry in the archive, assuming it is the topmost directory.
155            If the material is not an archive, nothing will be done so this
156            function is "harmless" when it is "useless".
157    """
158    # uncompress
159    if (source_material.endswith(".gz") or
160            source_material.endswith(".gzip")):
161        host.run('gunzip "%s"' % (utils.sh_escape(source_material)))
162        source_material= ".".join(source_material.split(".")[:-1])
163    elif source_material.endswith("bz2"):
164        host.run('bunzip2 "%s"' % (utils.sh_escape(source_material)))
165        source_material= ".".join(source_material.split(".")[:-1])
166
167    # untar
168    if source_material.endswith(".tar"):
169        retval= host.run('tar -C "%s" -xvf "%s"' % (
170                utils.sh_escape(os.path.dirname(source_material)),
171                utils.sh_escape(source_material),))
172        source_material= os.path.join(os.path.dirname(source_material),
173                retval.stdout.split()[0])
174
175    return source_material
176
177
178def get_server_dir():
179    path = os.path.dirname(sys.modules['autotest_lib.server.utils'].__file__)
180    return os.path.abspath(path)
181
182
183def find_pid(command):
184    for line in utils.system_output('ps -eo pid,cmd').rstrip().split('\n'):
185        (pid, cmd) = line.split(None, 1)
186        if re.search(command, cmd):
187            return int(pid)
188    return None
189
190
191def nohup(command, stdout='/dev/null', stderr='/dev/null', background=True,
192                                                                env = {}):
193    cmd = ' '.join(key+'='+val for key, val in env.iteritems())
194    cmd += ' nohup ' + command
195    cmd += ' > %s' % stdout
196    if stdout == stderr:
197        cmd += ' 2>&1'
198    else:
199        cmd += ' 2> %s' % stderr
200    if background:
201        cmd += ' &'
202    utils.system(cmd)
203
204
205def default_mappings(machines):
206    """
207    Returns a simple mapping in which all machines are assigned to the
208    same key.  Provides the default behavior for
209    form_ntuples_from_machines. """
210    mappings = {}
211    failures = []
212
213    mach = machines[0]
214    mappings['ident'] = [mach]
215    if len(machines) > 1:
216        machines = machines[1:]
217        for machine in machines:
218            mappings['ident'].append(machine)
219
220    return (mappings, failures)
221
222
223def form_ntuples_from_machines(machines, n=2, mapping_func=default_mappings):
224    """Returns a set of ntuples from machines where the machines in an
225       ntuple are in the same mapping, and a set of failures which are
226       (machine name, reason) tuples."""
227    ntuples = []
228    (mappings, failures) = mapping_func(machines)
229
230    # now run through the mappings and create n-tuples.
231    # throw out the odd guys out
232    for key in mappings:
233        key_machines = mappings[key]
234        total_machines = len(key_machines)
235
236        # form n-tuples
237        while len(key_machines) >= n:
238            ntuples.append(key_machines[0:n])
239            key_machines = key_machines[n:]
240
241        for mach in key_machines:
242            failures.append((mach, "machine can not be tupled"))
243
244    return (ntuples, failures)
245
246
247def parse_machine(machine, user='root', password='', port=22):
248    """
249    Parse the machine string user:pass@host:port and return it separately,
250    if the machine string is not complete, use the default parameters
251    when appropriate.
252    """
253
254    if '@' in machine:
255        user, machine = machine.split('@', 1)
256
257    if ':' in user:
258        user, password = user.split(':', 1)
259
260    # Brackets are required to protect an IPv6 address whenever a
261    # [xx::xx]:port number (or a file [xx::xx]:/path/) is appended to
262    # it. Do not attempt to extract a (non-existent) port number from
263    # an unprotected/bare IPv6 address "xx::xx".
264    # In the Python >= 3.3 future, 'import ipaddress' will parse
265    # addresses; and maybe more.
266    bare_ipv6 = '[' != machine[0] and re.search(r':.*:', machine)
267
268    # Extract trailing :port number if any.
269    if not bare_ipv6 and re.search(r':\d*$', machine):
270        machine, port = machine.rsplit(':', 1)
271        port = int(port)
272
273    # Strip any IPv6 brackets (ssh does not support them).
274    # We'll add them back later for rsync, scp, etc.
275    if machine[0] == '[' and machine[-1] == ']':
276        machine = machine[1:-1]
277
278    if not machine or not user:
279        raise ValueError
280
281    return machine, user, password, port
282
283
284def get_public_key():
285    """
286    Return a valid string ssh public key for the user executing autoserv or
287    autotest. If there's no DSA or RSA public key, create a DSA keypair with
288    ssh-keygen and return it.
289    """
290
291    ssh_conf_path = os.path.expanduser('~/.ssh')
292
293    dsa_public_key_path = os.path.join(ssh_conf_path, 'id_dsa.pub')
294    dsa_private_key_path = os.path.join(ssh_conf_path, 'id_dsa')
295
296    rsa_public_key_path = os.path.join(ssh_conf_path, 'id_rsa.pub')
297    rsa_private_key_path = os.path.join(ssh_conf_path, 'id_rsa')
298
299    has_dsa_keypair = os.path.isfile(dsa_public_key_path) and \
300        os.path.isfile(dsa_private_key_path)
301    has_rsa_keypair = os.path.isfile(rsa_public_key_path) and \
302        os.path.isfile(rsa_private_key_path)
303
304    if has_dsa_keypair:
305        print 'DSA keypair found, using it'
306        public_key_path = dsa_public_key_path
307
308    elif has_rsa_keypair:
309        print 'RSA keypair found, using it'
310        public_key_path = rsa_public_key_path
311
312    else:
313        print 'Neither RSA nor DSA keypair found, creating DSA ssh key pair'
314        utils.system('ssh-keygen -t dsa -q -N "" -f %s' % dsa_private_key_path)
315        public_key_path = dsa_public_key_path
316
317    public_key = open(public_key_path, 'r')
318    public_key_str = public_key.read()
319    public_key.close()
320
321    return public_key_str
322
323
324def get_sync_control_file(control, host_name, host_num,
325                          instance, num_jobs, port_base=63100):
326    """
327    This function is used when there is a need to run more than one
328    job simultaneously starting exactly at the same time. It basically returns
329    a modified control file (containing the synchronization code prepended)
330    whenever it is ready to run the control file. The synchronization
331    is done using barriers to make sure that the jobs start at the same time.
332
333    Here is how the synchronization is done to make sure that the tests
334    start at exactly the same time on the client.
335    sc_bar is a server barrier and s_bar, c_bar are the normal barriers
336
337                      Job1              Job2         ......      JobN
338     Server:   |                        sc_bar
339     Server:   |                        s_bar        ......      s_bar
340     Server:   |      at.run()         at.run()      ......      at.run()
341     ----------|------------------------------------------------------
342     Client    |      sc_bar
343     Client    |      c_bar             c_bar        ......      c_bar
344     Client    |    <run test>         <run test>    ......     <run test>
345
346    @param control: The control file which to which the above synchronization
347            code will be prepended.
348    @param host_name: The host name on which the job is going to run.
349    @param host_num: (non negative) A number to identify the machine so that
350            we have different sets of s_bar_ports for each of the machines.
351    @param instance: The number of the job
352    @param num_jobs: Total number of jobs that are going to run in parallel
353            with this job starting at the same time.
354    @param port_base: Port number that is used to derive the actual barrier
355            ports.
356
357    @returns The modified control file.
358    """
359    sc_bar_port = port_base
360    c_bar_port = port_base
361    if host_num < 0:
362        print "Please provide a non negative number for the host"
363        return None
364    s_bar_port = port_base + 1 + host_num # The set of s_bar_ports are
365                                          # the same for a given machine
366
367    sc_bar_timeout = 180
368    s_bar_timeout = c_bar_timeout = 120
369
370    # The barrier code snippet is prepended into the conrol file
371    # dynamically before at.run() is called finally.
372    control_new = []
373
374    # jobid is the unique name used to identify the processes
375    # trying to reach the barriers
376    jobid = "%s#%d" % (host_name, instance)
377
378    rendv = []
379    # rendvstr is a temp holder for the rendezvous list of the processes
380    for n in range(num_jobs):
381        rendv.append("'%s#%d'" % (host_name, n))
382    rendvstr = ",".join(rendv)
383
384    if instance == 0:
385        # Do the setup and wait at the server barrier
386        # Clean up the tmp and the control dirs for the first instance
387        control_new.append('if os.path.exists(job.tmpdir):')
388        control_new.append("\t system('umount -f %s > /dev/null"
389                           "2> /dev/null' % job.tmpdir,"
390                           "ignore_status=True)")
391        control_new.append("\t system('rm -rf ' + job.tmpdir)")
392        control_new.append(
393            'b0 = job.barrier("%s", "sc_bar", %d, port=%d)'
394            % (jobid, sc_bar_timeout, sc_bar_port))
395        control_new.append(
396        'b0.rendezvous_servers("PARALLEL_MASTER", "%s")'
397         % jobid)
398
399    elif instance == 1:
400        # Wait at the server barrier to wait for instance=0
401        # process to complete setup
402        b0 = barrier.barrier("PARALLEL_MASTER", "sc_bar", sc_bar_timeout,
403                     port=sc_bar_port)
404        b0.rendezvous_servers("PARALLEL_MASTER", jobid)
405
406        if(num_jobs > 2):
407            b1 = barrier.barrier(jobid, "s_bar", s_bar_timeout,
408                         port=s_bar_port)
409            b1.rendezvous(rendvstr)
410
411    else:
412        # For the rest of the clients
413        b2 = barrier.barrier(jobid, "s_bar", s_bar_timeout, port=s_bar_port)
414        b2.rendezvous(rendvstr)
415
416    # Client side barrier for all the tests to start at the same time
417    control_new.append('b1 = job.barrier("%s", "c_bar", %d, port=%d)'
418                    % (jobid, c_bar_timeout, c_bar_port))
419    control_new.append("b1.rendezvous(%s)" % rendvstr)
420
421    # Stick in the rest of the control file
422    control_new.append(control)
423
424    return "\n".join(control_new)
425