abstract_ssh.py revision 6b05f5b1c50bee17a532b01643dc0a27a84c939d
1import os, time, types, socket, shutil, glob, logging, traceback
2from autotest_lib.client.common_lib import autotemp, error, logging_manager
3from autotest_lib.server import utils, autotest
4from autotest_lib.server.hosts import remote
5from autotest_lib.client.common_lib.global_config import global_config
6
7
8get_value = global_config.get_config_value
9enable_master_ssh = get_value('AUTOSERV', 'enable_master_ssh', type=bool,
10                              default=False)
11
12
13def _make_ssh_cmd_default(user="root", port=22, opts='', hosts_file='/dev/null',
14                          connect_timeout=30, alive_interval=300):
15    base_command = ("/usr/bin/ssh -a -x %s -o StrictHostKeyChecking=no "
16                    "-o UserKnownHostsFile=%s -o BatchMode=yes "
17                    "-o ConnectTimeout=%d -o ServerAliveInterval=%d "
18                    "-l %s -p %d")
19    assert isinstance(connect_timeout, (int, long))
20    assert connect_timeout > 0 # can't disable the timeout
21    return base_command % (opts, hosts_file, connect_timeout,
22                           alive_interval, user, port)
23
24
25make_ssh_command = utils.import_site_function(
26    __file__, "autotest_lib.server.hosts.site_host", "make_ssh_command",
27    _make_ssh_cmd_default)
28
29
30# import site specific Host class
31SiteHost = utils.import_site_class(
32    __file__, "autotest_lib.server.hosts.site_host", "SiteHost",
33    remote.RemoteHost)
34
35
36class AbstractSSHHost(SiteHost):
37    """
38    This class represents a generic implementation of most of the
39    framework necessary for controlling a host via ssh. It implements
40    almost all of the abstract Host methods, except for the core
41    Host.run method.
42    """
43
44    def _initialize(self, hostname, user="root", port=22, password="",
45                    *args, **dargs):
46        super(AbstractSSHHost, self)._initialize(hostname=hostname,
47                                                 *args, **dargs)
48        self.ip = socket.getaddrinfo(self.hostname, None)[0][4][0]
49        self.user = user
50        self.port = port
51        self.password = password
52        self._use_rsync = None
53        self.known_hosts_file = os.tmpfile()
54        known_hosts_fd = self.known_hosts_file.fileno()
55        self.known_hosts_fd = '/dev/fd/%s' % known_hosts_fd
56
57        """
58        Master SSH connection background job, socket temp directory and socket
59        control path option. If master-SSH is enabled, these fields will be
60        initialized by start_master_ssh when a new SSH connection is initiated.
61        """
62        self.master_ssh_job = None
63        self.master_ssh_tempdir = None
64        self.master_ssh_option = ''
65
66
67    def use_rsync(self):
68        if self._use_rsync is not None:
69            return self._use_rsync
70
71        # Check if rsync is available on the remote host. If it's not,
72        # don't try to use it for any future file transfers.
73        self._use_rsync = self._check_rsync()
74        if not self._use_rsync:
75            logging.warn("rsync not available on remote host %s -- disabled",
76                         self.hostname)
77        return self._use_rsync
78
79
80    def _check_rsync(self):
81        """
82        Check if rsync is available on the remote host.
83        """
84        try:
85            self.run("rsync --version", stdout_tee=None, stderr_tee=None)
86        except error.AutoservRunError:
87            return False
88        return True
89
90
91    def _encode_remote_paths(self, paths, escape=True):
92        """
93        Given a list of file paths, encodes it as a single remote path, in
94        the style used by rsync and scp.
95        """
96        if escape:
97            paths = [utils.scp_remote_escape(path) for path in paths]
98        return '%s@%s:"%s"' % (self.user, self.hostname, " ".join(paths))
99
100
101    def _make_rsync_cmd(self, sources, dest, delete_dest, preserve_symlinks):
102        """
103        Given a list of source paths and a destination path, produces the
104        appropriate rsync command for copying them. Remote paths must be
105        pre-encoded.
106        """
107        ssh_cmd = make_ssh_command(user=self.user, port=self.port,
108                                   opts=self.master_ssh_option,
109                                   hosts_file=self.known_hosts_fd)
110        if delete_dest:
111            delete_flag = "--delete"
112        else:
113            delete_flag = ""
114        if preserve_symlinks:
115            symlink_flag = ""
116        else:
117            symlink_flag = "-L"
118        command = "rsync %s %s --timeout=1800 --rsh='%s' -az %s %s"
119        return command % (symlink_flag, delete_flag, ssh_cmd,
120                          " ".join(sources), dest)
121
122
123    def _make_ssh_cmd(self, cmd):
124        """
125        Create a base ssh command string for the host which can be used
126        to run commands directly on the machine
127        """
128        base_cmd = make_ssh_command(user=self.user, port=self.port,
129                                    opts=self.master_ssh_option,
130                                    hosts_file=self.known_hosts_fd)
131
132        return '%s %s "%s"' % (base_cmd, self.hostname, utils.sh_escape(cmd))
133
134    def _make_scp_cmd(self, sources, dest):
135        """
136        Given a list of source paths and a destination path, produces the
137        appropriate scp command for encoding it. Remote paths must be
138        pre-encoded.
139        """
140        command = ("scp -rq %s -o StrictHostKeyChecking=no "
141                   "-o UserKnownHostsFile=%s -P %d %s '%s'")
142        return command % (self.master_ssh_option, self.known_hosts_fd,
143                          self.port, " ".join(sources), dest)
144
145
146    def _make_rsync_compatible_globs(self, path, is_local):
147        """
148        Given an rsync-style path, returns a list of globbed paths
149        that will hopefully provide equivalent behaviour for scp. Does not
150        support the full range of rsync pattern matching behaviour, only that
151        exposed in the get/send_file interface (trailing slashes).
152
153        The is_local param is flag indicating if the paths should be
154        interpreted as local or remote paths.
155        """
156
157        # non-trailing slash paths should just work
158        if len(path) == 0 or path[-1] != "/":
159            return [path]
160
161        # make a function to test if a pattern matches any files
162        if is_local:
163            def glob_matches_files(path, pattern):
164                return len(glob.glob(path + pattern)) > 0
165        else:
166            def glob_matches_files(path, pattern):
167                result = self.run("ls \"%s\"%s" % (utils.sh_escape(path),
168                                                   pattern),
169                                  stdout_tee=None, ignore_status=True)
170                return result.exit_status == 0
171
172        # take a set of globs that cover all files, and see which are needed
173        patterns = ["*", ".[!.]*"]
174        patterns = [p for p in patterns if glob_matches_files(path, p)]
175
176        # convert them into a set of paths suitable for the commandline
177        if is_local:
178            return ["\"%s\"%s" % (utils.sh_escape(path), pattern)
179                    for pattern in patterns]
180        else:
181            return [utils.scp_remote_escape(path) + pattern
182                    for pattern in patterns]
183
184
185    def _make_rsync_compatible_source(self, source, is_local):
186        """
187        Applies the same logic as _make_rsync_compatible_globs, but
188        applies it to an entire list of sources, producing a new list of
189        sources, properly quoted.
190        """
191        return sum((self._make_rsync_compatible_globs(path, is_local)
192                    for path in source), [])
193
194
195    def _set_umask_perms(self, dest):
196        """
197        Given a destination file/dir (recursively) set the permissions on
198        all the files and directories to the max allowed by running umask.
199        """
200
201        # now this looks strange but I haven't found a way in Python to _just_
202        # get the umask, apparently the only option is to try to set it
203        umask = os.umask(0)
204        os.umask(umask)
205
206        max_privs = 0777 & ~umask
207
208        def set_file_privs(filename):
209            """Sets mode of |filename|.  Assumes |filename| exists."""
210            file_stat = os.stat(filename)
211
212            file_privs = max_privs
213            # if the original file permissions do not have at least one
214            # executable bit then do not set it anywhere
215            if not file_stat.st_mode & 0111:
216                file_privs &= ~0111
217
218            os.chmod(filename, file_privs)
219
220        # try a bottom-up walk so changes on directory permissions won't cut
221        # our access to the files/directories inside it
222        for root, dirs, files in os.walk(dest, topdown=False):
223            # when setting the privileges we emulate the chmod "X" behaviour
224            # that sets to execute only if it is a directory or any of the
225            # owner/group/other already has execute right
226            for dirname in dirs:
227                os.chmod(os.path.join(root, dirname), max_privs)
228
229            # Filter out broken symlinks as we go.
230            for filename in filter(os.path.exists, files):
231                set_file_privs(os.path.join(root, filename))
232
233
234        # now set privs for the dest itself
235        if os.path.isdir(dest):
236            os.chmod(dest, max_privs)
237        else:
238            set_file_privs(dest)
239
240
241    def get_file(self, source, dest, delete_dest=False, preserve_perm=True,
242                 preserve_symlinks=False):
243        """
244        Copy files from the remote host to a local path.
245
246        Directories will be copied recursively.
247        If a source component is a directory with a trailing slash,
248        the content of the directory will be copied, otherwise, the
249        directory itself and its content will be copied. This
250        behavior is similar to that of the program 'rsync'.
251
252        Args:
253                source: either
254                        1) a single file or directory, as a string
255                        2) a list of one or more (possibly mixed)
256                                files or directories
257                dest: a file or a directory (if source contains a
258                        directory or more than one element, you must
259                        supply a directory dest)
260                delete_dest: if this is true, the command will also clear
261                             out any old files at dest that are not in the
262                             source
263                preserve_perm: tells get_file() to try to preserve the sources
264                               permissions on files and dirs
265                preserve_symlinks: try to preserve symlinks instead of
266                                   transforming them into files/dirs on copy
267
268        Raises:
269                AutoservRunError: the scp command failed
270        """
271
272        # Start a master SSH connection if necessary.
273        self.start_master_ssh()
274
275        if isinstance(source, basestring):
276            source = [source]
277        dest = os.path.abspath(dest)
278
279        # If rsync is disabled or fails, try scp.
280        try_scp = True
281        if self.use_rsync():
282            try:
283                remote_source = self._encode_remote_paths(source)
284                local_dest = utils.sh_escape(dest)
285                rsync = self._make_rsync_cmd([remote_source], local_dest,
286                                             delete_dest, preserve_symlinks)
287                utils.run(rsync)
288                try_scp = False
289            except error.CmdError, e:
290                logging.warn("trying scp, rsync failed: %s" % e)
291
292        if try_scp:
293            # scp has no equivalent to --delete, just drop the entire dest dir
294            if delete_dest and os.path.isdir(dest):
295                shutil.rmtree(dest)
296                os.mkdir(dest)
297
298            remote_source = self._make_rsync_compatible_source(source, False)
299            if remote_source:
300                # _make_rsync_compatible_source() already did the escaping
301                remote_source = self._encode_remote_paths(remote_source,
302                                                          escape=False)
303                local_dest = utils.sh_escape(dest)
304                scp = self._make_scp_cmd([remote_source], local_dest)
305                try:
306                    utils.run(scp)
307                except error.CmdError, e:
308                    raise error.AutoservRunError(e.args[0], e.args[1])
309
310        if not preserve_perm:
311            # we have no way to tell scp to not try to preserve the
312            # permissions so set them after copy instead.
313            # for rsync we could use "--no-p --chmod=ugo=rwX" but those
314            # options are only in very recent rsync versions
315            self._set_umask_perms(dest)
316
317
318    def send_file(self, source, dest, delete_dest=False,
319                  preserve_symlinks=False):
320        """
321        Copy files from a local path to the remote host.
322
323        Directories will be copied recursively.
324        If a source component is a directory with a trailing slash,
325        the content of the directory will be copied, otherwise, the
326        directory itself and its content will be copied. This
327        behavior is similar to that of the program 'rsync'.
328
329        Args:
330                source: either
331                        1) a single file or directory, as a string
332                        2) a list of one or more (possibly mixed)
333                                files or directories
334                dest: a file or a directory (if source contains a
335                        directory or more than one element, you must
336                        supply a directory dest)
337                delete_dest: if this is true, the command will also clear
338                             out any old files at dest that are not in the
339                             source
340                preserve_symlinks: controls if symlinks on the source will be
341                    copied as such on the destination or transformed into the
342                    referenced file/directory
343
344        Raises:
345                AutoservRunError: the scp command failed
346        """
347
348        # Start a master SSH connection if necessary.
349        self.start_master_ssh()
350
351        if isinstance(source, basestring):
352            source = [source]
353        remote_dest = self._encode_remote_paths([dest])
354
355        # If rsync is disabled or fails, try scp.
356        try_scp = True
357        if self.use_rsync():
358            try:
359                local_sources = [utils.sh_escape(path) for path in source]
360                rsync = self._make_rsync_cmd(local_sources, remote_dest,
361                                             delete_dest, preserve_symlinks)
362                utils.run(rsync)
363                try_scp = False
364            except error.CmdError, e:
365                logging.warn("trying scp, rsync failed: %s" % e)
366
367        if try_scp:
368            # scp has no equivalent to --delete, just drop the entire dest dir
369            if delete_dest:
370                is_dir = self.run("ls -d %s/" % dest,
371                                  ignore_status=True).exit_status == 0
372                if is_dir:
373                    cmd = "rm -rf %s && mkdir %s"
374                    cmd %= (dest, dest)
375                    self.run(cmd)
376
377            local_sources = self._make_rsync_compatible_source(source, True)
378            if local_sources:
379                scp = self._make_scp_cmd(local_sources, remote_dest)
380                try:
381                    utils.run(scp)
382                except error.CmdError, e:
383                    raise error.AutoservRunError(e.args[0], e.args[1])
384
385
386    def ssh_ping(self, timeout=60):
387        """
388        Pings remote host via ssh.
389
390        @param timeout: Time in seconds before giving up.
391                        Defaults to 60 seconds.
392        @raise AutoservSSHTimeout: If the ssh ping times out.
393        @raise AutoservSshPermissionDeniedError: If ssh ping fails due to
394                                                 permissions.
395        @raise AutoservSshPingHostError: For other AutoservRunErrors.
396        """
397        try:
398            self.run("true", timeout=timeout, connect_timeout=timeout)
399        except error.AutoservSSHTimeout:
400            msg = "Host (ssh) verify timed out (timeout = %d)" % timeout
401            raise error.AutoservSSHTimeout(msg)
402        except error.AutoservSshPermissionDeniedError:
403            #let AutoservSshPermissionDeniedError be visible to the callers
404            raise
405        except error.AutoservRunError, e:
406            # convert the generic AutoservRunError into something more
407            # specific for this context
408            raise error.AutoservSshPingHostError(e.description + '\n' +
409                                                 repr(e.result_obj))
410
411
412    def is_up(self, timeout=60):
413        """
414        Check if the remote host is up.
415
416        @param timeout: timeout in seconds.
417        @returns True if the remote host is up before the timeout expires,
418                 False otherwise.
419        """
420        try:
421            self.ssh_ping(timeout=timeout)
422        except error.AutoservError:
423            return False
424        else:
425            return True
426
427
428    def wait_up(self, timeout=None):
429        """
430        Wait until the remote host is up or the timeout expires.
431
432        In fact, it will wait until an ssh connection to the remote
433        host can be established, and getty is running.
434
435        @param timeout time limit in seconds before returning even
436            if the host is not up.
437
438        @returns True if the host was found to be up before the timeout expires,
439                 False otherwise
440        """
441        if timeout:
442            end_time = time.time() + timeout
443            current_time = time.time()
444
445        while not timeout or current_time < end_time:
446            if self.is_up(timeout=end_time - current_time):
447                try:
448                    if self.are_wait_up_processes_up():
449                        logging.debug('Host %s is now up', self.hostname)
450                        return True
451                except error.AutoservError:
452                    pass
453            time.sleep(1)
454            current_time = time.time()
455
456        logging.debug('Host %s is still down after waiting %d seconds',
457                      self.hostname, int(timeout + time.time() - end_time))
458        return False
459
460
461    def wait_down(self, timeout=None, warning_timer=None, old_boot_id=None):
462        """
463        Wait until the remote host is down or the timeout expires.
464
465        If old_boot_id is provided, this will wait until either the machine
466        is unpingable or self.get_boot_id() returns a value different from
467        old_boot_id. If the boot_id value has changed then the function
468        returns true under the assumption that the machine has shut down
469        and has now already come back up.
470
471        If old_boot_id is None then until the machine becomes unreachable the
472        method assumes the machine has not yet shut down.
473
474        Based on this definition, the 4 possible permutations of timeout
475        and old_boot_id are:
476        1. timeout and old_boot_id: wait timeout seconds for either the
477                                    host to become unpingable, or the boot id
478                                    to change. In the latter case we've rebooted
479                                    and in the former case we've only shutdown,
480                                    but both cases return True.
481        2. only timeout: wait timeout seconds for the host to become unpingable.
482                         If the host remains pingable throughout timeout seconds
483                         we return False.
484        3. only old_boot_id: wait forever until either the host becomes
485                             unpingable or the boot_id changes. Return true
486                             when either of those conditions are met.
487        4. not timeout, not old_boot_id: wait forever till the host becomes
488                                         unpingable.
489
490        @param timeout Time limit in seconds before returning even
491            if the host is still up.
492        @param warning_timer Time limit in seconds that will generate
493            a warning if the host is not down yet.
494        @param old_boot_id A string containing the result of self.get_boot_id()
495            prior to the host being told to shut down. Can be None if this is
496            not available.
497
498        @returns True if the host was found to be down, False otherwise
499        """
500        #TODO: there is currently no way to distinguish between knowing
501        #TODO: boot_id was unsupported and not knowing the boot_id.
502        current_time = time.time()
503        if timeout:
504            end_time = current_time + timeout
505
506        if warning_timer:
507            warn_time = current_time + warning_timer
508
509        if old_boot_id is not None:
510            logging.debug('Host %s pre-shutdown boot_id is %s',
511                          self.hostname, old_boot_id)
512
513        # Impose semi real-time deadline constraints, since some clients
514        # (eg: watchdog timer tests) expect strict checking of time elapsed.
515        # Each iteration of this loop is treated as though it atomically
516        # completes within current_time, this is needed because if we used
517        # inline time.time() calls instead then the following could happen:
518        #
519        # while not timeout or time.time() < end_time:      [23 < 30]
520        #    some code.                                     [takes 10 secs]
521        #    try:
522        #        new_boot_id = self.get_boot_id(timeout=end_time - time.time())
523        #                                                   [30 - 33]
524        # The last step will lead to a return True, when in fact the machine
525        # went down at 32 seconds (>30). Hence we need to pass get_boot_id
526        # the same time that allowed us into that iteration of the loop.
527        while not timeout or current_time < end_time:
528            try:
529                new_boot_id = self.get_boot_id(timeout=end_time - current_time)
530            except error.AutoservError:
531                logging.debug('Host %s is now unreachable over ssh, is down',
532                              self.hostname)
533                return True
534            else:
535                # if the machine is up but the boot_id value has changed from
536                # old boot id, then we can assume the machine has gone down
537                # and then already come back up
538                if old_boot_id is not None and old_boot_id != new_boot_id:
539                    logging.debug('Host %s now has boot_id %s and so must '
540                                  'have rebooted', self.hostname, new_boot_id)
541                    return True
542
543            if warning_timer and current_time > warn_time:
544                self.record("WARN", None, "shutdown",
545                            "Shutdown took longer than %ds" % warning_timer)
546                # Print the warning only once.
547                warning_timer = None
548                # If a machine is stuck switching runlevels
549                # This may cause the machine to reboot.
550                self.run('kill -HUP 1', ignore_status=True)
551
552            time.sleep(1)
553            current_time = time.time()
554
555        return False
556
557
558    # tunable constants for the verify & repair code
559    AUTOTEST_GB_DISKSPACE_REQUIRED = get_value("SERVER",
560                                               "gb_diskspace_required",
561                                               type=float,
562                                               default=20.0)
563
564
565    def verify_connectivity(self):
566        super(AbstractSSHHost, self).verify_connectivity()
567
568        logging.info('Pinging host ' + self.hostname)
569        self.ssh_ping()
570        logging.info("Host (ssh) %s is alive", self.hostname)
571
572        if self.is_shutting_down():
573            raise error.AutoservHostIsShuttingDownError("Host is shutting down")
574
575
576    def verify_software(self):
577        super(AbstractSSHHost, self).verify_software()
578        try:
579            self.check_diskspace(autotest.Autotest.get_install_dir(self),
580                                 self.AUTOTEST_GB_DISKSPACE_REQUIRED)
581        except error.AutoservHostError:
582            raise           # only want to raise if it's a space issue
583        except autotest.AutodirNotFoundError:
584            # autotest dir may not exist, etc. ignore
585            logging.debug('autodir space check exception, this is probably '
586                          'safe to ignore\n' + traceback.format_exc())
587
588
589    def close(self):
590        super(AbstractSSHHost, self).close()
591        self._cleanup_master_ssh()
592        self.known_hosts_file.close()
593
594
595    def _cleanup_master_ssh(self):
596        """
597        Release all resources (process, temporary directory) used by an active
598        master SSH connection.
599        """
600        # If a master SSH connection is running, kill it.
601        if self.master_ssh_job is not None:
602            utils.nuke_subprocess(self.master_ssh_job.sp)
603            self.master_ssh_job = None
604
605        # Remove the temporary directory for the master SSH socket.
606        if self.master_ssh_tempdir is not None:
607            self.master_ssh_tempdir.clean()
608            self.master_ssh_tempdir = None
609            self.master_ssh_option = ''
610
611
612    def start_master_ssh(self):
613        """
614        Called whenever a slave SSH connection needs to be initiated (e.g., by
615        run, rsync, scp). If master SSH support is enabled and a master SSH
616        connection is not active already, start a new one in the background.
617        Also, cleanup any zombie master SSH connections (e.g., dead due to
618        reboot).
619        """
620        if not enable_master_ssh:
621            return
622
623        # If a previously started master SSH connection is not running
624        # anymore, it needs to be cleaned up and then restarted.
625        if self.master_ssh_job is not None:
626            if self.master_ssh_job.sp.poll() is not None:
627                logging.info("Master ssh connection to %s is down.",
628                             self.hostname)
629                self._cleanup_master_ssh()
630
631        # Start a new master SSH connection.
632        if self.master_ssh_job is None:
633            # Create a shared socket in a temp location.
634            self.master_ssh_tempdir = autotemp.tempdir(unique_id='ssh-master')
635            self.master_ssh_option = ("-o ControlPath=%s/socket" %
636                                      self.master_ssh_tempdir.name)
637
638            # Start the master SSH connection in the background.
639            master_cmd = self.ssh_command(options="-N -o ControlMaster=yes")
640            logging.info("Starting master ssh connection '%s'" % master_cmd)
641            self.master_ssh_job = utils.BgJob(master_cmd)
642
643
644    def clear_known_hosts(self):
645        """Clears out the temporary ssh known_hosts file.
646
647        This is useful if the test SSHes to the machine, then reinstalls it,
648        then SSHes to it again.  It can be called after the reinstall to
649        reduce the spam in the logs.
650        """
651        logging.info("Clearing known hosts for host '%s', file '%s'.",
652                     self.hostname, self.known_hosts_fd)
653        # Clear out the file by opening it for writing and then closing.
654        fh = open(self.known_hosts_fd, "w")
655        fh.close()
656