abstract_ssh.py revision 861b2d54aec24228cdb3895dbc40062cb40cb2ad
1import os, time, types, socket, shutil, glob, logging, traceback
2from autotest_lib.client.common_lib import autotemp, error, logging_manager
3from autotest_lib.server import utils, autotest
4from autotest_lib.server.hosts import remote
5from autotest_lib.client.common_lib.global_config import global_config
6
7
8get_value = global_config.get_config_value
9enable_master_ssh = get_value('AUTOSERV', 'enable_master_ssh', type=bool,
10                              default=False)
11
12
13def make_ssh_command(user="root", port=22, opts='', hosts_file='/dev/null',
14                     connect_timeout=30, alive_interval=300):
15    base_command = ("/usr/bin/ssh -a -x %s -o StrictHostKeyChecking=no "
16                    "-o UserKnownHostsFile=%s -o BatchMode=yes "
17                    "-o ConnectTimeout=%d -o ServerAliveInterval=%d "
18                    "-l %s -p %d")
19    assert isinstance(connect_timeout, (int, long))
20    assert connect_timeout > 0 # can't disable the timeout
21    return base_command % (opts, hosts_file, connect_timeout,
22                           alive_interval, user, port)
23
24
25# import site specific Host class
26SiteHost = utils.import_site_class(
27    __file__, "autotest_lib.server.hosts.site_host", "SiteHost",
28    remote.RemoteHost)
29
30
31class AbstractSSHHost(SiteHost):
32    """
33    This class represents a generic implementation of most of the
34    framework necessary for controlling a host via ssh. It implements
35    almost all of the abstract Host methods, except for the core
36    Host.run method.
37    """
38
39    def _initialize(self, hostname, user="root", port=22, password="",
40                    *args, **dargs):
41        super(AbstractSSHHost, self)._initialize(hostname=hostname,
42                                                 *args, **dargs)
43        self.ip = socket.getaddrinfo(self.hostname, None)[0][4][0]
44        self.user = user
45        self.port = port
46        self.password = password
47        self._use_rsync = None
48        self.known_hosts_file = os.tmpfile()
49        known_hosts_fd = self.known_hosts_file.fileno()
50        self.known_hosts_fd = '/dev/fd/%s' % known_hosts_fd
51
52        """
53        Master SSH connection background job, socket temp directory and socket
54        control path option. If master-SSH is enabled, these fields will be
55        initialized by start_master_ssh when a new SSH connection is initiated.
56        """
57        self.master_ssh_job = None
58        self.master_ssh_tempdir = None
59        self.master_ssh_option = ''
60
61
62    def use_rsync(self):
63        if self._use_rsync is not None:
64            return self._use_rsync
65
66        # Check if rsync is available on the remote host. If it's not,
67        # don't try to use it for any future file transfers.
68        self._use_rsync = self._check_rsync()
69        if not self._use_rsync:
70            logging.warn("rsync not available on remote host %s -- disabled",
71                         self.hostname)
72        return self._use_rsync
73
74
75    def _check_rsync(self):
76        """
77        Check if rsync is available on the remote host.
78        """
79        try:
80            self.run("rsync --version", stdout_tee=None, stderr_tee=None)
81        except error.AutoservRunError:
82            return False
83        return True
84
85
86    def _encode_remote_paths(self, paths, escape=True):
87        """
88        Given a list of file paths, encodes it as a single remote path, in
89        the style used by rsync and scp.
90        """
91        if escape:
92            paths = [utils.scp_remote_escape(path) for path in paths]
93        return '%s@%s:"%s"' % (self.user, self.hostname, " ".join(paths))
94
95
96    def _make_rsync_cmd(self, sources, dest, delete_dest, preserve_symlinks):
97        """
98        Given a list of source paths and a destination path, produces the
99        appropriate rsync command for copying them. Remote paths must be
100        pre-encoded.
101        """
102        ssh_cmd = make_ssh_command(user=self.user, port=self.port,
103                                   opts=self.master_ssh_option,
104                                   hosts_file=self.known_hosts_fd)
105        if delete_dest:
106            delete_flag = "--delete"
107        else:
108            delete_flag = ""
109        if preserve_symlinks:
110            symlink_flag = ""
111        else:
112            symlink_flag = "-L"
113        command = "rsync %s %s --timeout=1800 --rsh='%s' -az %s %s"
114        return command % (symlink_flag, delete_flag, ssh_cmd,
115                          " ".join(sources), dest)
116
117
118    def _make_ssh_cmd(self, cmd):
119        """
120        Create a base ssh command string for the host which can be used
121        to run commands directly on the machine
122        """
123        base_cmd = make_ssh_command(user=self.user, port=self.port,
124                                    opts=self.master_ssh_option,
125                                    hosts_file=self.known_hosts_fd)
126
127        return '%s %s "%s"' % (base_cmd, self.hostname, utils.sh_escape(cmd))
128
129    def _make_scp_cmd(self, sources, dest):
130        """
131        Given a list of source paths and a destination path, produces the
132        appropriate scp command for encoding it. Remote paths must be
133        pre-encoded.
134        """
135        command = ("scp -rq %s -o StrictHostKeyChecking=no "
136                   "-o UserKnownHostsFile=%s -P %d %s '%s'")
137        return command % (self.master_ssh_option, self.known_hosts_fd,
138                          self.port, " ".join(sources), dest)
139
140
141    def _make_rsync_compatible_globs(self, path, is_local):
142        """
143        Given an rsync-style path, returns a list of globbed paths
144        that will hopefully provide equivalent behaviour for scp. Does not
145        support the full range of rsync pattern matching behaviour, only that
146        exposed in the get/send_file interface (trailing slashes).
147
148        The is_local param is flag indicating if the paths should be
149        interpreted as local or remote paths.
150        """
151
152        # non-trailing slash paths should just work
153        if len(path) == 0 or path[-1] != "/":
154            return [path]
155
156        # make a function to test if a pattern matches any files
157        if is_local:
158            def glob_matches_files(path, pattern):
159                return len(glob.glob(path + pattern)) > 0
160        else:
161            def glob_matches_files(path, pattern):
162                result = self.run("ls \"%s\"%s" % (utils.sh_escape(path),
163                                                   pattern),
164                                  stdout_tee=None, ignore_status=True)
165                return result.exit_status == 0
166
167        # take a set of globs that cover all files, and see which are needed
168        patterns = ["*", ".[!.]*"]
169        patterns = [p for p in patterns if glob_matches_files(path, p)]
170
171        # convert them into a set of paths suitable for the commandline
172        if is_local:
173            return ["\"%s\"%s" % (utils.sh_escape(path), pattern)
174                    for pattern in patterns]
175        else:
176            return [utils.scp_remote_escape(path) + pattern
177                    for pattern in patterns]
178
179
180    def _make_rsync_compatible_source(self, source, is_local):
181        """
182        Applies the same logic as _make_rsync_compatible_globs, but
183        applies it to an entire list of sources, producing a new list of
184        sources, properly quoted.
185        """
186        return sum((self._make_rsync_compatible_globs(path, is_local)
187                    for path in source), [])
188
189
190    def _set_umask_perms(self, dest):
191        """
192        Given a destination file/dir (recursively) set the permissions on
193        all the files and directories to the max allowed by running umask.
194        """
195
196        # now this looks strange but I haven't found a way in Python to _just_
197        # get the umask, apparently the only option is to try to set it
198        umask = os.umask(0)
199        os.umask(umask)
200
201        max_privs = 0777 & ~umask
202
203        def set_file_privs(filename):
204            file_stat = os.stat(filename)
205
206            file_privs = max_privs
207            # if the original file permissions do not have at least one
208            # executable bit then do not set it anywhere
209            if not file_stat.st_mode & 0111:
210                file_privs &= ~0111
211
212            os.chmod(filename, file_privs)
213
214        # try a bottom-up walk so changes on directory permissions won't cut
215        # our access to the files/directories inside it
216        for root, dirs, files in os.walk(dest, topdown=False):
217            # when setting the privileges we emulate the chmod "X" behaviour
218            # that sets to execute only if it is a directory or any of the
219            # owner/group/other already has execute right
220            for dirname in dirs:
221                os.chmod(os.path.join(root, dirname), max_privs)
222
223            for filename in files:
224                set_file_privs(os.path.join(root, filename))
225
226
227        # now set privs for the dest itself
228        if os.path.isdir(dest):
229            os.chmod(dest, max_privs)
230        else:
231            set_file_privs(dest)
232
233
234    def get_file(self, source, dest, delete_dest=False, preserve_perm=True,
235                 preserve_symlinks=False):
236        """
237        Copy files from the remote host to a local path.
238
239        Directories will be copied recursively.
240        If a source component is a directory with a trailing slash,
241        the content of the directory will be copied, otherwise, the
242        directory itself and its content will be copied. This
243        behavior is similar to that of the program 'rsync'.
244
245        Args:
246                source: either
247                        1) a single file or directory, as a string
248                        2) a list of one or more (possibly mixed)
249                                files or directories
250                dest: a file or a directory (if source contains a
251                        directory or more than one element, you must
252                        supply a directory dest)
253                delete_dest: if this is true, the command will also clear
254                             out any old files at dest that are not in the
255                             source
256                preserve_perm: tells get_file() to try to preserve the sources
257                               permissions on files and dirs
258                preserve_symlinks: try to preserve symlinks instead of
259                                   transforming them into files/dirs on copy
260
261        Raises:
262                AutoservRunError: the scp command failed
263        """
264
265        # Start a master SSH connection if necessary.
266        self.start_master_ssh()
267
268        if isinstance(source, basestring):
269            source = [source]
270        dest = os.path.abspath(dest)
271
272        # If rsync is disabled or fails, try scp.
273        try_scp = True
274        if self.use_rsync():
275            try:
276                remote_source = self._encode_remote_paths(source)
277                local_dest = utils.sh_escape(dest)
278                rsync = self._make_rsync_cmd([remote_source], local_dest,
279                                             delete_dest, preserve_symlinks)
280                utils.run(rsync)
281                try_scp = False
282            except error.CmdError, e:
283                logging.warn("trying scp, rsync failed: %s" % e)
284
285        if try_scp:
286            # scp has no equivalent to --delete, just drop the entire dest dir
287            if delete_dest and os.path.isdir(dest):
288                shutil.rmtree(dest)
289                os.mkdir(dest)
290
291            remote_source = self._make_rsync_compatible_source(source, False)
292            if remote_source:
293                # _make_rsync_compatible_source() already did the escaping
294                remote_source = self._encode_remote_paths(remote_source,
295                                                          escape=False)
296                local_dest = utils.sh_escape(dest)
297                scp = self._make_scp_cmd([remote_source], local_dest)
298                try:
299                    utils.run(scp)
300                except error.CmdError, e:
301                    raise error.AutoservRunError(e.args[0], e.args[1])
302
303        if not preserve_perm:
304            # we have no way to tell scp to not try to preserve the
305            # permissions so set them after copy instead.
306            # for rsync we could use "--no-p --chmod=ugo=rwX" but those
307            # options are only in very recent rsync versions
308            self._set_umask_perms(dest)
309
310
311    def send_file(self, source, dest, delete_dest=False,
312                  preserve_symlinks=False):
313        """
314        Copy files from a local path to the remote host.
315
316        Directories will be copied recursively.
317        If a source component is a directory with a trailing slash,
318        the content of the directory will be copied, otherwise, the
319        directory itself and its content will be copied. This
320        behavior is similar to that of the program 'rsync'.
321
322        Args:
323                source: either
324                        1) a single file or directory, as a string
325                        2) a list of one or more (possibly mixed)
326                                files or directories
327                dest: a file or a directory (if source contains a
328                        directory or more than one element, you must
329                        supply a directory dest)
330                delete_dest: if this is true, the command will also clear
331                             out any old files at dest that are not in the
332                             source
333                preserve_symlinks: controls if symlinks on the source will be
334                    copied as such on the destination or transformed into the
335                    referenced file/directory
336
337        Raises:
338                AutoservRunError: the scp command failed
339        """
340
341        # Start a master SSH connection if necessary.
342        self.start_master_ssh()
343
344        if isinstance(source, basestring):
345            source = [source]
346        remote_dest = self._encode_remote_paths([dest])
347
348        # If rsync is disabled or fails, try scp.
349        try_scp = True
350        if self.use_rsync():
351            try:
352                local_sources = [utils.sh_escape(path) for path in source]
353                rsync = self._make_rsync_cmd(local_sources, remote_dest,
354                                             delete_dest, preserve_symlinks)
355                utils.run(rsync)
356                try_scp = False
357            except error.CmdError, e:
358                logging.warn("trying scp, rsync failed: %s" % e)
359
360        if try_scp:
361            # scp has no equivalent to --delete, just drop the entire dest dir
362            if delete_dest:
363                is_dir = self.run("ls -d %s/" % dest,
364                                  ignore_status=True).exit_status == 0
365                if is_dir:
366                    cmd = "rm -rf %s && mkdir %s"
367                    cmd %= (dest, dest)
368                    self.run(cmd)
369
370            local_sources = self._make_rsync_compatible_source(source, True)
371            if local_sources:
372                scp = self._make_scp_cmd(local_sources, remote_dest)
373                try:
374                    utils.run(scp)
375                except error.CmdError, e:
376                    raise error.AutoservRunError(e.args[0], e.args[1])
377
378
379    def ssh_ping(self, timeout=60):
380        try:
381            self.run("true", timeout=timeout, connect_timeout=timeout)
382        except error.AutoservSSHTimeout:
383            msg = "Host (ssh) verify timed out (timeout = %d)" % timeout
384            raise error.AutoservSSHTimeout(msg)
385        except error.AutoservSshPermissionDeniedError:
386            #let AutoservSshPermissionDeniedError be visible to the callers
387            raise
388        except error.AutoservRunError, e:
389            # convert the generic AutoservRunError into something more
390            # specific for this context
391            raise error.AutoservSshPingHostError(e.description + '\n' +
392                                                 repr(e.result_obj))
393
394
395    def is_up(self):
396        """
397        Check if the remote host is up.
398
399        @returns True if the remote host is up, False otherwise
400        """
401        try:
402            self.ssh_ping()
403        except error.AutoservError:
404            return False
405        else:
406            return True
407
408
409    def wait_up(self, timeout=None):
410        """
411        Wait until the remote host is up or the timeout expires.
412
413        In fact, it will wait until an ssh connection to the remote
414        host can be established, and getty is running.
415
416        @param timeout time limit in seconds before returning even
417            if the host is not up.
418
419        @returns True if the host was found to be up, False otherwise
420        """
421        if timeout:
422            end_time = time.time() + timeout
423
424        while not timeout or time.time() < end_time:
425            if self.is_up():
426                try:
427                    if self.are_wait_up_processes_up():
428                        logging.debug('Host %s is now up', self.hostname)
429                        return True
430                except error.AutoservError:
431                    pass
432            time.sleep(1)
433
434        logging.debug('Host %s is still down after waiting %d seconds',
435                      self.hostname, int(timeout + time.time() - end_time))
436        return False
437
438
439    def wait_down(self, timeout=None, warning_timer=None, old_boot_id=None):
440        """
441        Wait until the remote host is down or the timeout expires.
442
443        If old_boot_id is provided, this will wait until either the machine
444        is unpingable or self.get_boot_id() returns a value different from
445        old_boot_id. If the boot_id value has changed then the function
446        returns true under the assumption that the machine has shut down
447        and has now already come back up.
448
449        If old_boot_id is None then until the machine becomes unreachable the
450        method assumes the machine has not yet shut down.
451
452        @param timeout Time limit in seconds before returning even
453            if the host is still up.
454        @param warning_timer Time limit in seconds that will generate
455            a warning if the host is not down yet.
456        @param old_boot_id A string containing the result of self.get_boot_id()
457            prior to the host being told to shut down. Can be None if this is
458            not available.
459
460        @returns True if the host was found to be down, False otherwise
461        """
462        #TODO: there is currently no way to distinguish between knowing
463        #TODO: boot_id was unsupported and not knowing the boot_id.
464        current_time = time.time()
465        if timeout:
466            end_time = current_time + timeout
467
468        if warning_timer:
469            warn_time = current_time + warning_timer
470
471        if old_boot_id is not None:
472            logging.debug('Host %s pre-shutdown boot_id is %s',
473                          self.hostname, old_boot_id)
474
475        while not timeout or current_time < end_time:
476            try:
477                new_boot_id = self.get_boot_id()
478            except error.AutoservError:
479                logging.debug('Host %s is now unreachable over ssh, is down',
480                              self.hostname)
481                return True
482            else:
483                # if the machine is up but the boot_id value has changed from
484                # old boot id, then we can assume the machine has gone down
485                # and then already come back up
486                if old_boot_id is not None and old_boot_id != new_boot_id:
487                    logging.debug('Host %s now has boot_id %s and so must '
488                                  'have rebooted', self.hostname, new_boot_id)
489                    return True
490
491            if warning_timer and current_time > warn_time:
492                self.record("WARN", None, "shutdown",
493                            "Shutdown took longer than %ds" % warning_timer)
494                # Print the warning only once.
495                warning_timer = None
496                # If a machine is stuck switching runlevels
497                # This may cause the machine to reboot.
498                self.run('kill -HUP 1', ignore_status=True)
499
500            time.sleep(1)
501            current_time = time.time()
502
503        return False
504
505
506    # tunable constants for the verify & repair code
507    AUTOTEST_GB_DISKSPACE_REQUIRED = get_value("SERVER",
508                                               "gb_diskspace_required",
509                                               type=int,
510                                               default=20)
511
512
513    def verify_connectivity(self):
514        super(AbstractSSHHost, self).verify_connectivity()
515
516        logging.info('Pinging host ' + self.hostname)
517        self.ssh_ping()
518        logging.info("Host (ssh) %s is alive", self.hostname)
519
520        if self.is_shutting_down():
521            raise error.AutoservHostIsShuttingDownError("Host is shutting down")
522
523
524    def verify_software(self):
525        super(AbstractSSHHost, self).verify_software()
526        try:
527            self.check_diskspace(autotest.Autotest.get_install_dir(self),
528                                 self.AUTOTEST_GB_DISKSPACE_REQUIRED)
529        except error.AutoservHostError:
530            raise           # only want to raise if it's a space issue
531        except autotest.AutodirNotFoundError:
532            # autotest dir may not exist, etc. ignore
533            logging.debug('autodir space check exception, this is probably '
534                          'safe to ignore\n' + traceback.format_exc())
535
536
537    def close(self):
538        super(AbstractSSHHost, self).close()
539        self._cleanup_master_ssh()
540        self.known_hosts_file.close()
541
542
543    def _cleanup_master_ssh(self):
544        """
545        Release all resources (process, temporary directory) used by an active
546        master SSH connection.
547        """
548        # If a master SSH connection is running, kill it.
549        if self.master_ssh_job is not None:
550            utils.nuke_subprocess(self.master_ssh_job.sp)
551            self.master_ssh_job = None
552
553        # Remove the temporary directory for the master SSH socket.
554        if self.master_ssh_tempdir is not None:
555            self.master_ssh_tempdir.clean()
556            self.master_ssh_tempdir = None
557            self.master_ssh_option = ''
558
559
560    def start_master_ssh(self):
561        """
562        Called whenever a slave SSH connection needs to be initiated (e.g., by
563        run, rsync, scp). If master SSH support is enabled and a master SSH
564        connection is not active already, start a new one in the background.
565        Also, cleanup any zombie master SSH connections (e.g., dead due to
566        reboot).
567        """
568        if not enable_master_ssh:
569            return
570
571        # If a previously started master SSH connection is not running
572        # anymore, it needs to be cleaned up and then restarted.
573        if self.master_ssh_job is not None:
574            if self.master_ssh_job.sp.poll() is not None:
575                logging.info("Master ssh connection to %s is down.",
576                             self.hostname)
577                self._cleanup_master_ssh()
578
579        # Start a new master SSH connection.
580        if self.master_ssh_job is None:
581            # Create a shared socket in a temp location.
582            self.master_ssh_tempdir = autotemp.tempdir(unique_id='ssh-master')
583            self.master_ssh_option = ("-o ControlPath=%s/socket" %
584                                      self.master_ssh_tempdir.name)
585
586            # Start the master SSH connection in the background.
587            master_cmd = self.ssh_command(options="-N -o ControlMaster=yes")
588            logging.info("Starting master ssh connection '%s'" % master_cmd)
589            self.master_ssh_job = utils.BgJob(master_cmd)
590
591
592    def clear_known_hosts(self):
593        """Clears out the temporary ssh known_hosts file.
594
595        This is useful if the test SSHes to the machine, then reinstalls it,
596        then SSHes to it again.  It can be called after the reinstall to
597        reduce the spam in the logs.
598        """
599        logging.info("Clearing known hosts for host '%s', file '%s'.",
600                     self.hostname, self.known_hosts_fd)
601        # Clear out the file by opening it for writing and then closing.
602        fh = open(self.known_hosts_fd, "w")
603        fh.close()
604