utils.py revision cd63a21daf20993032f83ac6268d181c4c8e3457
1#!/usr/bin/python
2#
3# Copyright 2008 Google Inc. Released under the GPL v2
4
5import os, pickle, random, re, resource, select, shutil, signal, StringIO
6import socket, struct, subprocess, sys, time, textwrap, urllib, urlparse
7import warnings, smtplib, logging
8from autotest_lib.client.common_lib import error, barrier
9
10def deprecated(func):
11    """This is a decorator which can be used to mark functions as deprecated.
12    It will result in a warning being emmitted when the function is used."""
13    def new_func(*args, **dargs):
14        warnings.warn("Call to deprecated function %s." % func.__name__,
15                      category=DeprecationWarning)
16        return func(*args, **dargs)
17    new_func.__name__ = func.__name__
18    new_func.__doc__ = func.__doc__
19    new_func.__dict__.update(func.__dict__)
20    return new_func
21
22
23class BgJob(object):
24    def __init__(self, command, stdout_tee=None, stderr_tee=None, verbose=True,
25                 stdin=None):
26        self.command = command
27        self.stdout_tee = stdout_tee
28        self.stderr_tee = stderr_tee
29        self.result = CmdResult(command)
30        if verbose:
31            logging.debug("Running '%s'" % command)
32        self.sp = subprocess.Popen(command, stdout=subprocess.PIPE,
33                                   stderr=subprocess.PIPE,
34                                   preexec_fn=self._reset_sigpipe, shell=True,
35                                   executable="/bin/bash",
36                                   stdin=stdin)
37
38
39    def output_prepare(self, stdout_file=None, stderr_file=None):
40        self.stdout_file = stdout_file
41        self.stderr_file = stderr_file
42
43
44    def process_output(self, stdout=True, final_read=False):
45        """output_prepare must be called prior to calling this"""
46        if stdout:
47            pipe, buf, tee = self.sp.stdout, self.stdout_file, self.stdout_tee
48        else:
49            pipe, buf, tee = self.sp.stderr, self.stderr_file, self.stderr_tee
50
51        if final_read:
52            # read in all the data we can from pipe and then stop
53            data = []
54            while select.select([pipe], [], [], 0)[0]:
55                data.append(os.read(pipe.fileno(), 1024))
56                if len(data[-1]) == 0:
57                    break
58            data = "".join(data)
59        else:
60            # perform a single read
61            data = os.read(pipe.fileno(), 1024)
62        buf.write(data)
63        if tee:
64            tee.write(data)
65            tee.flush()
66
67
68    def cleanup(self):
69        self.sp.stdout.close()
70        self.sp.stderr.close()
71        self.result.stdout = self.stdout_file.getvalue()
72        self.result.stderr = self.stderr_file.getvalue()
73
74
75    def _reset_sigpipe(self):
76        signal.signal(signal.SIGPIPE, signal.SIG_DFL)
77
78
79def ip_to_long(ip):
80    # !L is a long in network byte order
81    return struct.unpack('!L', socket.inet_aton(ip))[0]
82
83
84def long_to_ip(number):
85    # See above comment.
86    return socket.inet_ntoa(struct.pack('!L', number))
87
88
89def create_subnet_mask(bits):
90    return (1 << 32) - (1 << 32-bits)
91
92
93def format_ip_with_mask(ip, mask_bits):
94    masked_ip = ip_to_long(ip) & create_subnet_mask(mask_bits)
95    return "%s/%s" % (long_to_ip(masked_ip), mask_bits)
96
97
98def normalize_hostname(alias):
99    ip = socket.gethostbyname(alias)
100    return socket.gethostbyaddr(ip)[0]
101
102
103def get_ip_local_port_range():
104    match = re.match(r'\s*(\d+)\s*(\d+)\s*$',
105                     read_one_line('/proc/sys/net/ipv4/ip_local_port_range'))
106    return (int(match.group(1)), int(match.group(2)))
107
108
109def set_ip_local_port_range(lower, upper):
110    write_one_line('/proc/sys/net/ipv4/ip_local_port_range',
111                   '%d %d\n' % (lower, upper))
112
113
114
115def send_email(mail_from, mail_to, subject, body):
116    """
117    Sends an email via smtp
118
119    mail_from: string with email address of sender
120    mail_to: string or list with email address(es) of recipients
121    subject: string with subject of email
122    body: (multi-line) string with body of email
123    """
124    if isinstance(mail_to, str):
125        mail_to = [mail_to]
126    msg = "From: %s\nTo: %s\nSubject: %s\n\n%s" % (mail_from, ','.join(mail_to),
127                                                   subject, body)
128    try:
129        mailer = smtplib.SMTP('localhost')
130        try:
131            mailer.sendmail(mail_from, mail_to, msg)
132        finally:
133            mailer.quit()
134    except Exception, e:
135        # Emails are non-critical, not errors, but don't raise them
136        print "Sending email failed. Reason: %s" % repr(e)
137
138
139def read_one_line(filename):
140    return open(filename, 'r').readline().rstrip('\n')
141
142
143def write_one_line(filename, line):
144    open_write_close(filename, line.rstrip('\n') + '\n')
145
146
147def open_write_close(filename, data):
148    f = open(filename, 'w')
149    try:
150        f.write(data)
151    finally:
152        f.close()
153
154
155def read_keyval(path):
156    """
157    Read a key-value pair format file into a dictionary, and return it.
158    Takes either a filename or directory name as input. If it's a
159    directory name, we assume you want the file to be called keyval.
160    """
161    if os.path.isdir(path):
162        path = os.path.join(path, 'keyval')
163    keyval = {}
164    if os.path.exists(path):
165        for line in open(path):
166            line = re.sub('#.*', '', line).rstrip()
167            if not re.search(r'^[-\.\w]+=', line):
168                raise ValueError('Invalid format line: %s' % line)
169            key, value = line.split('=', 1)
170            if re.search('^\d+$', value):
171                value = int(value)
172            elif re.search('^(\d+\.)?\d+$', value):
173                value = float(value)
174            keyval[key] = value
175    return keyval
176
177
178def write_keyval(path, dictionary, type_tag=None):
179    """
180    Write a key-value pair format file out to a file. This uses append
181    mode to open the file, so existing text will not be overwritten or
182    reparsed.
183
184    If type_tag is None, then the key must be composed of alphanumeric
185    characters (or dashes+underscores). However, if type-tag is not
186    null then the keys must also have "{type_tag}" as a suffix. At
187    the moment the only valid values of type_tag are "attr" and "perf".
188    """
189    if os.path.isdir(path):
190        path = os.path.join(path, 'keyval')
191    keyval = open(path, 'a')
192
193    if type_tag is None:
194        key_regex = re.compile(r'^[-\.\w]+$')
195    else:
196        if type_tag not in ('attr', 'perf'):
197            raise ValueError('Invalid type tag: %s' % type_tag)
198        escaped_tag = re.escape(type_tag)
199        key_regex = re.compile(r'^[-\.\w]+\{%s\}$' % escaped_tag)
200    try:
201        for key, value in dictionary.iteritems():
202            if not key_regex.search(key):
203                raise ValueError('Invalid key: %s' % key)
204            keyval.write('%s=%s\n' % (key, value))
205    finally:
206        keyval.close()
207
208
209def is_url(path):
210    """Return true if path looks like a URL"""
211    # for now, just handle http and ftp
212    url_parts = urlparse.urlparse(path)
213    return (url_parts[0] in ('http', 'ftp'))
214
215
216def urlopen(url, data=None, proxies=None, timeout=5):
217    """Wrapper to urllib.urlopen with timeout addition."""
218
219    # Save old timeout
220    old_timeout = socket.getdefaulttimeout()
221    socket.setdefaulttimeout(timeout)
222    try:
223        return urllib.urlopen(url, data=data, proxies=proxies)
224    finally:
225        socket.setdefaulttimeout(old_timeout)
226
227
228def urlretrieve(url, filename=None, reporthook=None, data=None, timeout=300):
229    """Wrapper to urllib.urlretrieve with timeout addition."""
230    old_timeout = socket.getdefaulttimeout()
231    socket.setdefaulttimeout(timeout)
232    try:
233        return urllib.urlretrieve(url, filename=filename,
234                                  reporthook=reporthook, data=data)
235    finally:
236        socket.setdefaulttimeout(old_timeout)
237
238
239def get_file(src, dest, permissions=None):
240    """Get a file from src, which can be local or a remote URL"""
241    if (src == dest):
242        return
243    if (is_url(src)):
244        print 'PWD: ' + os.getcwd()
245        print 'Fetching \n\t', src, '\n\t->', dest
246        try:
247            urllib.urlretrieve(src, dest)
248        except IOError, e:
249            raise error.AutotestError('Unable to retrieve %s (to %s)'
250                                % (src, dest), e)
251    else:
252        shutil.copyfile(src, dest)
253    if permissions:
254        os.chmod(dest, permissions)
255    return dest
256
257
258def unmap_url(srcdir, src, destdir='.'):
259    """
260    Receives either a path to a local file or a URL.
261    returns either the path to the local file, or the fetched URL
262
263    unmap_url('/usr/src', 'foo.tar', '/tmp')
264                            = '/usr/src/foo.tar'
265    unmap_url('/usr/src', 'http://site/file', '/tmp')
266                            = '/tmp/file'
267                            (after retrieving it)
268    """
269    if is_url(src):
270        url_parts = urlparse.urlparse(src)
271        filename = os.path.basename(url_parts[2])
272        dest = os.path.join(destdir, filename)
273        return get_file(src, dest)
274    else:
275        return os.path.join(srcdir, src)
276
277
278def update_version(srcdir, preserve_srcdir, new_version, install,
279                   *args, **dargs):
280    """
281    Make sure srcdir is version new_version
282
283    If not, delete it and install() the new version.
284
285    In the preserve_srcdir case, we just check it's up to date,
286    and if not, we rerun install, without removing srcdir
287    """
288    versionfile = os.path.join(srcdir, '.version')
289    install_needed = True
290
291    if os.path.exists(versionfile):
292        old_version = pickle.load(open(versionfile))
293        if old_version == new_version:
294            install_needed = False
295
296    if install_needed:
297        if not preserve_srcdir and os.path.exists(srcdir):
298            shutil.rmtree(srcdir)
299        install(*args, **dargs)
300        if os.path.exists(srcdir):
301            pickle.dump(new_version, open(versionfile, 'w'))
302
303
304def run(command, timeout=None, ignore_status=False,
305        stdout_tee=None, stderr_tee=None, verbose=True, stdin=None):
306    """
307    Run a command on the host.
308
309    Args:
310            command: the command line string
311            timeout: time limit in seconds before attempting to
312                    kill the running process. The run() function
313                    will take a few seconds longer than 'timeout'
314                    to complete if it has to kill the process.
315            ignore_status: do not raise an exception, no matter what
316                    the exit code of the command is.
317            stdout_tee: optional file-like object to which stdout data
318                        will be written as it is generated (data will still
319                        be stored in result.stdout)
320            stderr_tee: likewise for stderr
321            stdin: stdin to pass to the executed process
322
323    Returns:
324            a CmdResult object
325
326    Raises:
327            CmdError: the exit code of the command
328                    execution was not 0
329    """
330    bg_job = join_bg_jobs(
331        (BgJob(command, stdout_tee, stderr_tee, verbose, stdin=stdin),),
332        timeout)[0]
333    if not ignore_status and bg_job.result.exit_status:
334        raise error.CmdError(command, bg_job.result,
335                             "Command returned non-zero exit status")
336
337    return bg_job.result
338
339
340def run_parallel(commands, timeout=None, ignore_status=False,
341                 stdout_tee=None, stderr_tee=None):
342    """Beahves the same as run with the following exceptions:
343
344    - commands is a list of commands to run in parallel.
345    - ignore_status toggles whether or not an exception should be raised
346      on any error.
347
348    returns a list of CmdResult objects
349    """
350    bg_jobs = []
351    for command in commands:
352        bg_jobs.append(BgJob(command, stdout_tee, stderr_tee))
353
354    # Updates objects in bg_jobs list with their process information
355    join_bg_jobs(bg_jobs, timeout)
356
357    for bg_job in bg_jobs:
358        if not ignore_status and bg_job.result.exit_status:
359            raise error.CmdError(command, bg_job.result,
360                                 "Command returned non-zero exit status")
361
362    return [bg_job.result for bg_job in bg_jobs]
363
364
365@deprecated
366def run_bg(command):
367    """Function deprecated. Please use BgJob class instead."""
368    bg_job = BgJob(command)
369    return bg_job.sp, bg_job.result
370
371
372def join_bg_jobs(bg_jobs, timeout=None):
373    """Joins the bg_jobs with the current thread.
374
375    Returns the same list of bg_jobs objects that was passed in.
376    """
377    ret, timeout_error = 0, False
378    for bg_job in bg_jobs:
379        bg_job.output_prepare(StringIO.StringIO(), StringIO.StringIO())
380
381    try:
382        # We are holding ends to stdin, stdout pipes
383        # hence we need to be sure to close those fds no mater what
384        start_time = time.time()
385        timeout_error = _wait_for_commands(bg_jobs, start_time, timeout)
386
387        for bg_job in bg_jobs:
388            # Process stdout and stderr
389            bg_job.process_output(stdout=True,final_read=True)
390            bg_job.process_output(stdout=False,final_read=True)
391    finally:
392        # close our ends of the pipes to the sp no matter what
393        for bg_job in bg_jobs:
394            bg_job.cleanup()
395
396    if timeout_error:
397        # TODO: This needs to be fixed to better represent what happens when
398        # running in parallel. However this is backwards compatable, so it will
399        # do for the time being.
400        raise error.CmdError(bg_jobs[0].command, bg_jobs[0].result,
401                             "Command(s) did not complete within %d seconds"
402                             % timeout)
403
404
405    return bg_jobs
406
407
408def _wait_for_commands(bg_jobs, start_time, timeout):
409    # This returns True if it must return due to a timeout, otherwise False.
410
411    # To check for processes which terminate without producing any output
412    # a 1 second timeout is used in select.
413    SELECT_TIMEOUT = 1
414
415    select_list = []
416    reverse_dict = {}
417    for bg_job in bg_jobs:
418        select_list.append(bg_job.sp.stdout)
419        select_list.append(bg_job.sp.stderr)
420        reverse_dict[bg_job.sp.stdout] = (bg_job,True)
421        reverse_dict[bg_job.sp.stderr] = (bg_job,False)
422
423    if timeout:
424        stop_time = start_time + timeout
425        time_left = stop_time - time.time()
426    else:
427        time_left = None # so that select never times out
428    while not timeout or time_left > 0:
429        # select will return when stdout is ready (including when it is
430        # EOF, that is the process has terminated).
431        ready, _, _ = select.select(select_list, [], [], SELECT_TIMEOUT)
432
433        # os.read() has to be used instead of
434        # subproc.stdout.read() which will otherwise block
435        for fileno in ready:
436            bg_job,stdout = reverse_dict[fileno]
437            bg_job.process_output(stdout)
438
439        remaining_jobs = [x for x in bg_jobs if x.result.exit_status is None]
440        if len(remaining_jobs) == 0:
441            return False
442        for bg_job in remaining_jobs:
443            bg_job.result.exit_status = bg_job.sp.poll()
444
445        if timeout:
446            time_left = stop_time - time.time()
447
448    # Kill all processes which did not complete prior to timeout
449    for bg_job in [x for x in bg_jobs if x.result.exit_status is None]:
450        print '* Warning: run process timeout (%s) fired' % timeout
451        nuke_subprocess(bg_job.sp)
452        bg_job.result.exit_status = bg_job.sp.poll()
453
454    return True
455
456
457def nuke_subprocess(subproc):
458    # check if the subprocess is still alive, first
459    if subproc.poll() is not None:
460        return subproc.poll()
461
462    # the process has not terminated within timeout,
463    # kill it via an escalating series of signals.
464    signal_queue = [signal.SIGTERM, signal.SIGKILL]
465    for sig in signal_queue:
466        try:
467            os.kill(subproc.pid, sig)
468        # The process may have died before we could kill it.
469        except OSError:
470            pass
471
472        for i in range(5):
473            rc = subproc.poll()
474            if rc is not None:
475                return rc
476            time.sleep(1)
477
478
479def nuke_pid(pid):
480    # the process has not terminated within timeout,
481    # kill it via an escalating series of signals.
482    signal_queue = [signal.SIGTERM, signal.SIGKILL]
483    for sig in signal_queue:
484        try:
485            os.kill(pid, sig)
486
487        # The process may have died before we could kill it.
488        except OSError:
489            pass
490
491        try:
492            for i in range(5):
493                status = os.waitpid(pid, os.WNOHANG)[0]
494                if status == pid:
495                    return
496                time.sleep(1)
497
498            if status != pid:
499                raise error.AutoservRunError('Could not kill %d'
500                        % pid, None)
501
502        # the process died before we join it.
503        except OSError:
504            pass
505
506
507def system(command, timeout=None, ignore_status=False):
508    """This function returns the exit status of command."""
509    return run(command, timeout=timeout, ignore_status=ignore_status,
510               stdout_tee=sys.stdout, stderr_tee=sys.stderr).exit_status
511
512
513def system_parallel(commands, timeout=None, ignore_status=False):
514    """This function returns a list of exit statuses for the respective
515    list of commands."""
516    return [bg_jobs.exit_status for bg_jobs in
517            run_parallel(commands, timeout=timeout, ignore_status=ignore_status,
518                         stdout_tee=sys.stdout, stderr_tee=sys.stderr)]
519
520
521def system_output(command, timeout=None, ignore_status=False,
522                  retain_output=False):
523    if retain_output:
524        out = run(command, timeout=timeout, ignore_status=ignore_status,
525                  stdout_tee=sys.stdout, stderr_tee=sys.stderr).stdout
526    else:
527        out = run(command, timeout=timeout, ignore_status=ignore_status).stdout
528    if out[-1:] == '\n': out = out[:-1]
529    return out
530
531
532def system_output_parallel(commands, timeout=None, ignore_status=False,
533                           retain_output=False):
534    if retain_output:
535        out = [bg_job.stdout for bg_job in run_parallel(commands,
536                                  timeout=timeout, ignore_status=ignore_status,
537                                  stdout_tee=sys.stdout, stderr_tee=sys.stderr)]
538    else:
539        out = [bg_job.stdout for bg_job in run_parallel(commands,
540                                  timeout=timeout, ignore_status=ignore_status)]
541    for x in out:
542        if out[-1:] == '\n': out = out[:-1]
543    return out
544
545
546def strip_unicode(input):
547    if type(input) == list:
548        return [strip_unicode(i) for i in input]
549    elif type(input) == dict:
550        output = {}
551        for key in input.keys():
552            output[str(key)] = strip_unicode(input[key])
553        return output
554    elif type(input) == unicode:
555        return str(input)
556    else:
557        return input
558
559
560def get_cpu_percentage(function, *args, **dargs):
561    """Returns a tuple containing the CPU% and return value from function call.
562
563    This function calculates the usage time by taking the difference of
564    the user and system times both before and after the function call.
565    """
566    child_pre = resource.getrusage(resource.RUSAGE_CHILDREN)
567    self_pre = resource.getrusage(resource.RUSAGE_SELF)
568    start = time.time()
569    to_return = function(*args, **dargs)
570    elapsed = time.time() - start
571    self_post = resource.getrusage(resource.RUSAGE_SELF)
572    child_post = resource.getrusage(resource.RUSAGE_CHILDREN)
573
574    # Calculate CPU Percentage
575    s_user, s_system = [a - b for a, b in zip(self_post, self_pre)[:2]]
576    c_user, c_system = [a - b for a, b in zip(child_post, child_pre)[:2]]
577    cpu_percent = (s_user + c_user + s_system + c_system) / elapsed
578
579    return cpu_percent, to_return
580
581
582"""
583This function is used when there is a need to run more than one
584job simultaneously starting exactly at the same time. It basically returns
585a modified control file (containing the synchronization code prepended)
586whenever it is ready to run the control file. The synchronization
587is done using barriers to make sure that the jobs start at the same time.
588
589Here is how the synchronization is done to make sure that the tests
590start at exactly the same time on the client.
591sc_bar is a server barrier and s_bar, c_bar are the normal barriers
592
593                  Job1              Job2         ......      JobN
594 Server:   |                        sc_bar
595 Server:   |                        s_bar        ......      s_bar
596 Server:   |      at.run()         at.run()      ......      at.run()
597 ----------|------------------------------------------------------
598 Client    |      sc_bar
599 Client    |      c_bar             c_bar        ......      c_bar
600 Client    |    <run test>         <run test>    ......     <run test>
601
602
603PARAMS:
604   control_file : The control file which to which the above synchronization
605                  code would be prepended to
606   host_name    : The host name on which the job is going to run
607   host_num (non negative) : A number to identify the machine so that we have
608                  different sets of s_bar_ports for each of the machines.
609   instance     : The number of the job
610   num_jobs     : Total number of jobs that are going to run in parallel with
611                  this job starting at the same time
612   port_base    : Port number that is used to derive the actual barrier ports.
613
614RETURN VALUE:
615    The modified control file.
616
617"""
618def get_sync_control_file(control, host_name, host_num,
619                          instance, num_jobs, port_base=63100):
620    sc_bar_port = port_base
621    c_bar_port = port_base
622    if host_num < 0:
623        print "Please provide a non negative number for the host"
624        return None
625    s_bar_port = port_base + 1 + host_num # The set of s_bar_ports are
626                                          # the same for a given machine
627
628    sc_bar_timeout = 180
629    s_bar_timeout = c_bar_timeout = 120
630
631    # The barrier code snippet is prepended into the conrol file
632    # dynamically before at.run() is called finally.
633    control_new = []
634
635    # jobid is the unique name used to identify the processes
636    # trying to reach the barriers
637    jobid = "%s#%d" % (host_name, instance)
638
639    rendv = []
640    # rendvstr is a temp holder for the rendezvous list of the processes
641    for n in range(num_jobs):
642        rendv.append("'%s#%d'" % (host_name, n))
643    rendvstr = ",".join(rendv)
644
645    if instance == 0:
646        # Do the setup and wait at the server barrier
647        # Clean up the tmp and the control dirs for the first instance
648        control_new.append('if os.path.exists(job.tmpdir):')
649        control_new.append("\t system('umount -f %s > /dev/null"
650                           "2> /dev/null' % job.tmpdir,"
651                           "ignore_status=True)")
652        control_new.append("\t system('rm -rf ' + job.tmpdir)")
653        control_new.append(
654            'b0 = job.barrier("%s", "sc_bar", %d, port=%d)'
655            % (jobid, sc_bar_timeout, sc_bar_port))
656        control_new.append(
657        'b0.rendevous_servers("PARALLEL_MASTER", "%s")'
658         % jobid)
659
660    elif instance == 1:
661        # Wait at the server barrier to wait for instance=0
662        # process to complete setup
663        b0 = barrier.barrier("PARALLEL_MASTER", "sc_bar", sc_bar_timeout,
664                     port=sc_bar_port)
665        b0.rendevous_servers("PARALLEL_MASTER", jobid)
666
667        if(num_jobs > 2):
668            b1 = barrier.barrier(jobid, "s_bar", s_bar_timeout,
669                         port=s_bar_port)
670            b1.rendevous(rendvstr)
671
672    else:
673        # For the rest of the clients
674        b2 = barrier.barrier(jobid, "s_bar", s_bar_timeout, port=s_bar_port)
675        b2.rendevous(rendvstr)
676
677    # Client side barrier for all the tests to start at the same time
678    control_new.append('b1 = job.barrier("%s", "c_bar", %d, port=%d)'
679                    % (jobid, c_bar_timeout, c_bar_port))
680    control_new.append("b1.rendevous(%s)" % rendvstr)
681
682    # Stick in the rest of the control file
683    control_new.append(control)
684
685    return "\n".join(control_new)
686
687
688def get_arch(run_function=run):
689    """
690    Get the hardware architecture of the machine.
691    run_function is used to execute the commands. It defaults to
692    utils.run() but a custom method (if provided) should be of the
693    same schema as utils.run. It should return a CmdResult object and
694    throw a CmdError exception.
695    """
696    arch = run_function('/bin/uname -m').stdout.rstrip()
697    if re.match(r'i\d86$', arch):
698        arch = 'i386'
699    return arch
700
701
702def get_num_logical_cores(run_function=run):
703    """
704    Get the number of cores (including hyperthreading) per cpu.
705    run_function is used to execute the commands. It defaults to
706    utils.run() but a custom method (if provided) should be of the
707    same schema as utils.run. It should return a CmdResult object and
708    throw a CmdError exception.
709    """
710    coreinfo = run_function('grep "^siblings" /proc/cpuinfo').stdout.rstrip()
711    cores = int(re.match('^siblings.*(\d+)', coreinfo).group(1))
712    return cores
713
714
715def merge_trees(src, dest):
716    """
717    Merges a source directory tree at 'src' into a destination tree at
718    'dest'. If a path is a file in both trees than the file in the source
719    tree is APPENDED to the one in the destination tree. If a path is
720    a directory in both trees then the directories are recursively merged
721    with this function. In any other case, the function will skip the
722    paths that cannot be merged (instead of failing).
723    """
724    if not os.path.exists(src):
725        return # exists only in dest
726    elif not os.path.exists(dest):
727        if os.path.isfile(src):
728            shutil.copy2(src, dest) # file only in src
729        else:
730            shutil.copytree(src, dest, symlinks=True) # dir only in src
731        return
732    elif os.path.isfile(src) and os.path.isfile(dest):
733        # src & dest are files in both trees, append src to dest
734        destfile = open(dest, "a")
735        try:
736            srcfile = open(src)
737            try:
738                destfile.write(srcfile.read())
739            finally:
740                srcfile.close()
741        finally:
742            destfile.close()
743    elif os.path.isdir(src) and os.path.isdir(dest):
744        # src & dest are directories in both trees, so recursively merge
745        for name in os.listdir(src):
746            merge_trees(os.path.join(src, name), os.path.join(dest, name))
747    else:
748        # src & dest both exist, but are incompatible
749        return
750
751
752class CmdResult(object):
753    """
754    Command execution result.
755
756    command:     String containing the command line itself
757    exit_status: Integer exit code of the process
758    stdout:      String containing stdout of the process
759    stderr:      String containing stderr of the process
760    duration:    Elapsed wall clock time running the process
761    """
762
763
764    def __init__(self, command="", stdout="", stderr="",
765                 exit_status=None, duration=0):
766        self.command = command
767        self.exit_status = exit_status
768        self.stdout = stdout
769        self.stderr = stderr
770        self.duration = duration
771
772
773    def __repr__(self):
774        wrapper = textwrap.TextWrapper(width = 78,
775                                       initial_indent="\n    ",
776                                       subsequent_indent="    ")
777
778        stdout = self.stdout.rstrip()
779        if stdout:
780            stdout = "\nstdout:\n%s" % stdout
781
782        stderr = self.stderr.rstrip()
783        if stderr:
784            stderr = "\nstderr:\n%s" % stderr
785
786        return ("* Command: %s\n"
787                "Exit status: %s\n"
788                "Duration: %s\n"
789                "%s"
790                "%s"
791                % (wrapper.fill(self.command), self.exit_status,
792                self.duration, stdout, stderr))
793
794
795class run_randomly:
796    def __init__(self, run_sequentially=False):
797        # Run sequentially is for debugging control files
798        self.test_list = []
799        self.run_sequentially = run_sequentially
800
801
802    def add(self, *args, **dargs):
803        test = (args, dargs)
804        self.test_list.append(test)
805
806
807    def run(self, fn):
808        while self.test_list:
809            test_index = random.randint(0, len(self.test_list)-1)
810            if self.run_sequentially:
811                test_index = 0
812            (args, dargs) = self.test_list.pop(test_index)
813            fn(*args, **dargs)
814
815
816def import_site_symbol(path, module, name, dummy=None, modulefile=None):
817    """
818    Try to import site specific symbol from site specific file if it exists
819
820    @param path full filename of the source file calling this (ie __file__)
821    @param module full module name
822    @param name symbol name to be imported from the site file
823    @param dummy dummy value to return in case there is no symbol to import
824    @param modulefile module filename
825
826    @return site specific symbol or dummy
827
828    @exception ImportError if the site file exists but imports fails
829    """
830    short_module = module[module.rfind(".") + 1:]
831
832    if not modulefile:
833        modulefile = short_module + ".py"
834
835    try:
836        site_exists = os.path.getsize(os.path.join(os.path.dirname(path),
837                                                   modulefile))
838    except os.error:
839        site_exists = False
840
841    if site_exists:
842        # return the object from the imported module
843        obj = getattr(__import__(module, {}, {}, [short_module]), name)
844    else:
845        msg = "unable to import site module '%s', using non-site implementation"
846        msg %= modulefile
847        logging.info(msg)
848        obj = dummy
849
850    return obj
851
852
853def import_site_class(path, module, classname, baseclass, modulefile=None):
854    """
855    Try to import site specific class from site specific file if it exists
856
857    Args:
858        path: full filename of the source file calling this (ie __file__)
859        module: full module name
860        classname: class name to be loaded from site file
861        baseclass: base class object to return when no site file present or
862            to mixin when site class exists but is not inherited from baseclass
863        modulefile: module filename
864
865    Returns: baseclass if site specific class does not exist, the site specific
866        class if it exists and is inherited from baseclass or a mixin of the
867        site specific class and baseclass when the site specific class exists
868        and is not inherited from baseclass
869
870    Raises: ImportError if the site file exists but imports fails
871    """
872
873    res = import_site_symbol(path, module, classname, None, modulefile)
874    if res:
875        if not issubclass(res, baseclass):
876            # if not a subclass of baseclass then mix in baseclass with the
877            # site specific class object and return the result
878            res = type(classname, (res, baseclass), {})
879    else:
880        res = baseclass
881
882    return res
883
884
885def import_site_function(path, module, funcname, dummy, modulefile=None):
886    """
887    Try to import site specific function from site specific file if it exists
888
889    Args:
890        path: full filename of the source file calling this (ie __file__)
891        module: full module name
892        funcname: function name to be imported from site file
893        dummy: dummy function to return in case there is no function to import
894        modulefile: module filename
895
896    Returns: site specific function object or dummy
897
898    Raises: ImportError if the site file exists but imports fails
899    """
900
901    return import_site_symbol(path, module, funcname, dummy, modulefile)
902
903
904def write_pid(program_name):
905    """
906    Try to drop <program_name>.pid in the main autotest directory.
907
908    Args:
909      program_name: prefix for file name
910    """
911
912    my_path = os.path.dirname(__file__)
913    pid_path = os.path.abspath(os.path.join(my_path, "../.."))
914    pidf = open(os.path.join(pid_path, "%s.pid" % program_name), "w")
915    if pidf:
916      pidf.write("%s\n" % os.getpid())
917      pidf.close()
918