1# Copyright (C) 2003 Python Software Foundation
2
3import unittest
4import shutil
5import tempfile
6import sys
7import stat
8import os
9import os.path
10import errno
11from os.path import splitdrive
12from distutils.spawn import find_executable, spawn
13from shutil import (_make_tarball, _make_zipfile, make_archive,
14                    register_archive_format, unregister_archive_format,
15                    get_archive_formats)
16import tarfile
17import warnings
18
19from test import test_support
20from test.test_support import TESTFN, check_warnings, captured_stdout
21
22TESTFN2 = TESTFN + "2"
23
24try:
25    import grp
26    import pwd
27    UID_GID_SUPPORT = True
28except ImportError:
29    UID_GID_SUPPORT = False
30
31try:
32    import zlib
33except ImportError:
34    zlib = None
35
36try:
37    import zipfile
38    ZIP_SUPPORT = True
39except ImportError:
40    ZIP_SUPPORT = find_executable('zip')
41
42class TestShutil(unittest.TestCase):
43
44    def setUp(self):
45        super(TestShutil, self).setUp()
46        self.tempdirs = []
47
48    def tearDown(self):
49        super(TestShutil, self).tearDown()
50        while self.tempdirs:
51            d = self.tempdirs.pop()
52            shutil.rmtree(d, os.name in ('nt', 'cygwin'))
53
54    def write_file(self, path, content='xxx'):
55        """Writes a file in the given path.
56
57
58        path can be a string or a sequence.
59        """
60        if isinstance(path, (list, tuple)):
61            path = os.path.join(*path)
62        f = open(path, 'w')
63        try:
64            f.write(content)
65        finally:
66            f.close()
67
68    def mkdtemp(self):
69        """Create a temporary directory that will be cleaned up.
70
71        Returns the path of the directory.
72        """
73        d = tempfile.mkdtemp()
74        self.tempdirs.append(d)
75        return d
76    def test_rmtree_errors(self):
77        # filename is guaranteed not to exist
78        filename = tempfile.mktemp()
79        self.assertRaises(OSError, shutil.rmtree, filename)
80
81    # See bug #1071513 for why we don't run this on cygwin
82    # and bug #1076467 for why we don't run this as root.
83    if (hasattr(os, 'chmod') and sys.platform[:6] != 'cygwin'
84        and not (hasattr(os, 'geteuid') and os.geteuid() == 0)):
85        def test_on_error(self):
86            self.errorState = 0
87            os.mkdir(TESTFN)
88            self.childpath = os.path.join(TESTFN, 'a')
89            f = open(self.childpath, 'w')
90            f.close()
91            old_dir_mode = os.stat(TESTFN).st_mode
92            old_child_mode = os.stat(self.childpath).st_mode
93            # Make unwritable.
94            os.chmod(self.childpath, stat.S_IREAD)
95            os.chmod(TESTFN, stat.S_IREAD)
96
97            shutil.rmtree(TESTFN, onerror=self.check_args_to_onerror)
98            # Test whether onerror has actually been called.
99            self.assertEqual(self.errorState, 2,
100                             "Expected call to onerror function did not happen.")
101
102            # Make writable again.
103            os.chmod(TESTFN, old_dir_mode)
104            os.chmod(self.childpath, old_child_mode)
105
106            # Clean up.
107            shutil.rmtree(TESTFN)
108
109    def check_args_to_onerror(self, func, arg, exc):
110        # test_rmtree_errors deliberately runs rmtree
111        # on a directory that is chmod 400, which will fail.
112        # This function is run when shutil.rmtree fails.
113        # 99.9% of the time it initially fails to remove
114        # a file in the directory, so the first time through
115        # func is os.remove.
116        # However, some Linux machines running ZFS on
117        # FUSE experienced a failure earlier in the process
118        # at os.listdir.  The first failure may legally
119        # be either.
120        if self.errorState == 0:
121            if func is os.remove:
122                self.assertEqual(arg, self.childpath)
123            else:
124                self.assertIs(func, os.listdir,
125                              "func must be either os.remove or os.listdir")
126                self.assertEqual(arg, TESTFN)
127            self.assertTrue(issubclass(exc[0], OSError))
128            self.errorState = 1
129        else:
130            self.assertEqual(func, os.rmdir)
131            self.assertEqual(arg, TESTFN)
132            self.assertTrue(issubclass(exc[0], OSError))
133            self.errorState = 2
134
135    def test_rmtree_dont_delete_file(self):
136        # When called on a file instead of a directory, don't delete it.
137        handle, path = tempfile.mkstemp()
138        os.fdopen(handle).close()
139        self.assertRaises(OSError, shutil.rmtree, path)
140        os.remove(path)
141
142    def test_copytree_simple(self):
143        def write_data(path, data):
144            f = open(path, "w")
145            f.write(data)
146            f.close()
147
148        def read_data(path):
149            f = open(path)
150            data = f.read()
151            f.close()
152            return data
153
154        src_dir = tempfile.mkdtemp()
155        dst_dir = os.path.join(tempfile.mkdtemp(), 'destination')
156
157        write_data(os.path.join(src_dir, 'test.txt'), '123')
158
159        os.mkdir(os.path.join(src_dir, 'test_dir'))
160        write_data(os.path.join(src_dir, 'test_dir', 'test.txt'), '456')
161
162        try:
163            shutil.copytree(src_dir, dst_dir)
164            self.assertTrue(os.path.isfile(os.path.join(dst_dir, 'test.txt')))
165            self.assertTrue(os.path.isdir(os.path.join(dst_dir, 'test_dir')))
166            self.assertTrue(os.path.isfile(os.path.join(dst_dir, 'test_dir',
167                                                        'test.txt')))
168            actual = read_data(os.path.join(dst_dir, 'test.txt'))
169            self.assertEqual(actual, '123')
170            actual = read_data(os.path.join(dst_dir, 'test_dir', 'test.txt'))
171            self.assertEqual(actual, '456')
172        finally:
173            for path in (
174                    os.path.join(src_dir, 'test.txt'),
175                    os.path.join(dst_dir, 'test.txt'),
176                    os.path.join(src_dir, 'test_dir', 'test.txt'),
177                    os.path.join(dst_dir, 'test_dir', 'test.txt'),
178                ):
179                if os.path.exists(path):
180                    os.remove(path)
181            for path in (src_dir,
182                    os.path.dirname(dst_dir)
183                ):
184                if os.path.exists(path):
185                    shutil.rmtree(path)
186
187    def test_copytree_with_exclude(self):
188
189        def write_data(path, data):
190            f = open(path, "w")
191            f.write(data)
192            f.close()
193
194        def read_data(path):
195            f = open(path)
196            data = f.read()
197            f.close()
198            return data
199
200        # creating data
201        join = os.path.join
202        exists = os.path.exists
203        src_dir = tempfile.mkdtemp()
204        try:
205            dst_dir = join(tempfile.mkdtemp(), 'destination')
206            write_data(join(src_dir, 'test.txt'), '123')
207            write_data(join(src_dir, 'test.tmp'), '123')
208            os.mkdir(join(src_dir, 'test_dir'))
209            write_data(join(src_dir, 'test_dir', 'test.txt'), '456')
210            os.mkdir(join(src_dir, 'test_dir2'))
211            write_data(join(src_dir, 'test_dir2', 'test.txt'), '456')
212            os.mkdir(join(src_dir, 'test_dir2', 'subdir'))
213            os.mkdir(join(src_dir, 'test_dir2', 'subdir2'))
214            write_data(join(src_dir, 'test_dir2', 'subdir', 'test.txt'), '456')
215            write_data(join(src_dir, 'test_dir2', 'subdir2', 'test.py'), '456')
216
217
218            # testing glob-like patterns
219            try:
220                patterns = shutil.ignore_patterns('*.tmp', 'test_dir2')
221                shutil.copytree(src_dir, dst_dir, ignore=patterns)
222                # checking the result: some elements should not be copied
223                self.assertTrue(exists(join(dst_dir, 'test.txt')))
224                self.assertTrue(not exists(join(dst_dir, 'test.tmp')))
225                self.assertTrue(not exists(join(dst_dir, 'test_dir2')))
226            finally:
227                if os.path.exists(dst_dir):
228                    shutil.rmtree(dst_dir)
229            try:
230                patterns = shutil.ignore_patterns('*.tmp', 'subdir*')
231                shutil.copytree(src_dir, dst_dir, ignore=patterns)
232                # checking the result: some elements should not be copied
233                self.assertTrue(not exists(join(dst_dir, 'test.tmp')))
234                self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir2')))
235                self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir')))
236            finally:
237                if os.path.exists(dst_dir):
238                    shutil.rmtree(dst_dir)
239
240            # testing callable-style
241            try:
242                def _filter(src, names):
243                    res = []
244                    for name in names:
245                        path = os.path.join(src, name)
246
247                        if (os.path.isdir(path) and
248                            path.split()[-1] == 'subdir'):
249                            res.append(name)
250                        elif os.path.splitext(path)[-1] in ('.py'):
251                            res.append(name)
252                    return res
253
254                shutil.copytree(src_dir, dst_dir, ignore=_filter)
255
256                # checking the result: some elements should not be copied
257                self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir2',
258                                        'test.py')))
259                self.assertTrue(not exists(join(dst_dir, 'test_dir2', 'subdir')))
260
261            finally:
262                if os.path.exists(dst_dir):
263                    shutil.rmtree(dst_dir)
264        finally:
265            shutil.rmtree(src_dir)
266            shutil.rmtree(os.path.dirname(dst_dir))
267
268    if hasattr(os, "symlink"):
269        def test_dont_copy_file_onto_link_to_itself(self):
270            # bug 851123.
271            os.mkdir(TESTFN)
272            src = os.path.join(TESTFN, 'cheese')
273            dst = os.path.join(TESTFN, 'shop')
274            try:
275                f = open(src, 'w')
276                f.write('cheddar')
277                f.close()
278
279                os.link(src, dst)
280                self.assertRaises(shutil.Error, shutil.copyfile, src, dst)
281                with open(src, 'r') as f:
282                    self.assertEqual(f.read(), 'cheddar')
283                os.remove(dst)
284
285                # Using `src` here would mean we end up with a symlink pointing
286                # to TESTFN/TESTFN/cheese, while it should point at
287                # TESTFN/cheese.
288                os.symlink('cheese', dst)
289                self.assertRaises(shutil.Error, shutil.copyfile, src, dst)
290                with open(src, 'r') as f:
291                    self.assertEqual(f.read(), 'cheddar')
292                os.remove(dst)
293            finally:
294                try:
295                    shutil.rmtree(TESTFN)
296                except OSError:
297                    pass
298
299        def test_rmtree_on_symlink(self):
300            # bug 1669.
301            os.mkdir(TESTFN)
302            try:
303                src = os.path.join(TESTFN, 'cheese')
304                dst = os.path.join(TESTFN, 'shop')
305                os.mkdir(src)
306                os.symlink(src, dst)
307                self.assertRaises(OSError, shutil.rmtree, dst)
308            finally:
309                shutil.rmtree(TESTFN, ignore_errors=True)
310
311    if hasattr(os, "mkfifo"):
312        # Issue #3002: copyfile and copytree block indefinitely on named pipes
313        def test_copyfile_named_pipe(self):
314            os.mkfifo(TESTFN)
315            try:
316                self.assertRaises(shutil.SpecialFileError,
317                                  shutil.copyfile, TESTFN, TESTFN2)
318                self.assertRaises(shutil.SpecialFileError,
319                                  shutil.copyfile, __file__, TESTFN)
320            finally:
321                os.remove(TESTFN)
322
323        def test_copytree_named_pipe(self):
324            os.mkdir(TESTFN)
325            try:
326                subdir = os.path.join(TESTFN, "subdir")
327                os.mkdir(subdir)
328                pipe = os.path.join(subdir, "mypipe")
329                os.mkfifo(pipe)
330                try:
331                    shutil.copytree(TESTFN, TESTFN2)
332                except shutil.Error as e:
333                    errors = e.args[0]
334                    self.assertEqual(len(errors), 1)
335                    src, dst, error_msg = errors[0]
336                    self.assertEqual("`%s` is a named pipe" % pipe, error_msg)
337                else:
338                    self.fail("shutil.Error should have been raised")
339            finally:
340                shutil.rmtree(TESTFN, ignore_errors=True)
341                shutil.rmtree(TESTFN2, ignore_errors=True)
342
343    @unittest.skipUnless(hasattr(os, 'chflags') and
344                         hasattr(errno, 'EOPNOTSUPP') and
345                         hasattr(errno, 'ENOTSUP'),
346                         "requires os.chflags, EOPNOTSUPP & ENOTSUP")
347    def test_copystat_handles_harmless_chflags_errors(self):
348        tmpdir = self.mkdtemp()
349        file1 = os.path.join(tmpdir, 'file1')
350        file2 = os.path.join(tmpdir, 'file2')
351        self.write_file(file1, 'xxx')
352        self.write_file(file2, 'xxx')
353
354        def make_chflags_raiser(err):
355            ex = OSError()
356
357            def _chflags_raiser(path, flags):
358                ex.errno = err
359                raise ex
360            return _chflags_raiser
361        old_chflags = os.chflags
362        try:
363            for err in errno.EOPNOTSUPP, errno.ENOTSUP:
364                os.chflags = make_chflags_raiser(err)
365                shutil.copystat(file1, file2)
366            # assert others errors break it
367            os.chflags = make_chflags_raiser(errno.EOPNOTSUPP + errno.ENOTSUP)
368            self.assertRaises(OSError, shutil.copystat, file1, file2)
369        finally:
370            os.chflags = old_chflags
371
372    @unittest.skipUnless(zlib, "requires zlib")
373    def test_make_tarball(self):
374        # creating something to tar
375        tmpdir = self.mkdtemp()
376        self.write_file([tmpdir, 'file1'], 'xxx')
377        self.write_file([tmpdir, 'file2'], 'xxx')
378        os.mkdir(os.path.join(tmpdir, 'sub'))
379        self.write_file([tmpdir, 'sub', 'file3'], 'xxx')
380
381        tmpdir2 = self.mkdtemp()
382        # force shutil to create the directory
383        os.rmdir(tmpdir2)
384        unittest.skipUnless(splitdrive(tmpdir)[0] == splitdrive(tmpdir2)[0],
385                            "source and target should be on same drive")
386
387        base_name = os.path.join(tmpdir2, 'archive')
388
389        # working with relative paths to avoid tar warnings
390        old_dir = os.getcwd()
391        os.chdir(tmpdir)
392        try:
393            _make_tarball(splitdrive(base_name)[1], '.')
394        finally:
395            os.chdir(old_dir)
396
397        # check if the compressed tarball was created
398        tarball = base_name + '.tar.gz'
399        self.assertTrue(os.path.exists(tarball))
400
401        # trying an uncompressed one
402        base_name = os.path.join(tmpdir2, 'archive')
403        old_dir = os.getcwd()
404        os.chdir(tmpdir)
405        try:
406            _make_tarball(splitdrive(base_name)[1], '.', compress=None)
407        finally:
408            os.chdir(old_dir)
409        tarball = base_name + '.tar'
410        self.assertTrue(os.path.exists(tarball))
411
412    def _tarinfo(self, path):
413        tar = tarfile.open(path)
414        try:
415            names = tar.getnames()
416            names.sort()
417            return tuple(names)
418        finally:
419            tar.close()
420
421    def _create_files(self):
422        # creating something to tar
423        tmpdir = self.mkdtemp()
424        dist = os.path.join(tmpdir, 'dist')
425        os.mkdir(dist)
426        self.write_file([dist, 'file1'], 'xxx')
427        self.write_file([dist, 'file2'], 'xxx')
428        os.mkdir(os.path.join(dist, 'sub'))
429        self.write_file([dist, 'sub', 'file3'], 'xxx')
430        os.mkdir(os.path.join(dist, 'sub2'))
431        tmpdir2 = self.mkdtemp()
432        base_name = os.path.join(tmpdir2, 'archive')
433        return tmpdir, tmpdir2, base_name
434
435    @unittest.skipUnless(zlib, "Requires zlib")
436    @unittest.skipUnless(find_executable('tar') and find_executable('gzip'),
437                         'Need the tar command to run')
438    def test_tarfile_vs_tar(self):
439        tmpdir, tmpdir2, base_name =  self._create_files()
440        old_dir = os.getcwd()
441        os.chdir(tmpdir)
442        try:
443            _make_tarball(base_name, 'dist')
444        finally:
445            os.chdir(old_dir)
446
447        # check if the compressed tarball was created
448        tarball = base_name + '.tar.gz'
449        self.assertTrue(os.path.exists(tarball))
450
451        # now create another tarball using `tar`
452        tarball2 = os.path.join(tmpdir, 'archive2.tar.gz')
453        tar_cmd = ['tar', '-cf', 'archive2.tar', 'dist']
454        gzip_cmd = ['gzip', '-f9', 'archive2.tar']
455        old_dir = os.getcwd()
456        os.chdir(tmpdir)
457        try:
458            with captured_stdout() as s:
459                spawn(tar_cmd)
460                spawn(gzip_cmd)
461        finally:
462            os.chdir(old_dir)
463
464        self.assertTrue(os.path.exists(tarball2))
465        # let's compare both tarballs
466        self.assertEqual(self._tarinfo(tarball), self._tarinfo(tarball2))
467
468        # trying an uncompressed one
469        base_name = os.path.join(tmpdir2, 'archive')
470        old_dir = os.getcwd()
471        os.chdir(tmpdir)
472        try:
473            _make_tarball(base_name, 'dist', compress=None)
474        finally:
475            os.chdir(old_dir)
476        tarball = base_name + '.tar'
477        self.assertTrue(os.path.exists(tarball))
478
479        # now for a dry_run
480        base_name = os.path.join(tmpdir2, 'archive')
481        old_dir = os.getcwd()
482        os.chdir(tmpdir)
483        try:
484            _make_tarball(base_name, 'dist', compress=None, dry_run=True)
485        finally:
486            os.chdir(old_dir)
487        tarball = base_name + '.tar'
488        self.assertTrue(os.path.exists(tarball))
489
490    @unittest.skipUnless(zlib, "Requires zlib")
491    @unittest.skipUnless(ZIP_SUPPORT, 'Need zip support to run')
492    def test_make_zipfile(self):
493        # creating something to tar
494        tmpdir = self.mkdtemp()
495        self.write_file([tmpdir, 'file1'], 'xxx')
496        self.write_file([tmpdir, 'file2'], 'xxx')
497
498        tmpdir2 = self.mkdtemp()
499        # force shutil to create the directory
500        os.rmdir(tmpdir2)
501        base_name = os.path.join(tmpdir2, 'archive')
502        _make_zipfile(base_name, tmpdir)
503
504        # check if the compressed tarball was created
505        tarball = base_name + '.zip'
506        self.assertTrue(os.path.exists(tarball))
507
508
509    def test_make_archive(self):
510        tmpdir = self.mkdtemp()
511        base_name = os.path.join(tmpdir, 'archive')
512        self.assertRaises(ValueError, make_archive, base_name, 'xxx')
513
514    @unittest.skipUnless(zlib, "Requires zlib")
515    def test_make_archive_owner_group(self):
516        # testing make_archive with owner and group, with various combinations
517        # this works even if there's not gid/uid support
518        if UID_GID_SUPPORT:
519            group = grp.getgrgid(0)[0]
520            owner = pwd.getpwuid(0)[0]
521        else:
522            group = owner = 'root'
523
524        base_dir, root_dir, base_name =  self._create_files()
525        base_name = os.path.join(self.mkdtemp() , 'archive')
526        res = make_archive(base_name, 'zip', root_dir, base_dir, owner=owner,
527                           group=group)
528        self.assertTrue(os.path.exists(res))
529
530        res = make_archive(base_name, 'zip', root_dir, base_dir)
531        self.assertTrue(os.path.exists(res))
532
533        res = make_archive(base_name, 'tar', root_dir, base_dir,
534                           owner=owner, group=group)
535        self.assertTrue(os.path.exists(res))
536
537        res = make_archive(base_name, 'tar', root_dir, base_dir,
538                           owner='kjhkjhkjg', group='oihohoh')
539        self.assertTrue(os.path.exists(res))
540
541    @unittest.skipUnless(zlib, "Requires zlib")
542    @unittest.skipUnless(UID_GID_SUPPORT, "Requires grp and pwd support")
543    def test_tarfile_root_owner(self):
544        tmpdir, tmpdir2, base_name =  self._create_files()
545        old_dir = os.getcwd()
546        os.chdir(tmpdir)
547        group = grp.getgrgid(0)[0]
548        owner = pwd.getpwuid(0)[0]
549        try:
550            archive_name = _make_tarball(base_name, 'dist', compress=None,
551                                         owner=owner, group=group)
552        finally:
553            os.chdir(old_dir)
554
555        # check if the compressed tarball was created
556        self.assertTrue(os.path.exists(archive_name))
557
558        # now checks the rights
559        archive = tarfile.open(archive_name)
560        try:
561            for member in archive.getmembers():
562                self.assertEqual(member.uid, 0)
563                self.assertEqual(member.gid, 0)
564        finally:
565            archive.close()
566
567    def test_make_archive_cwd(self):
568        current_dir = os.getcwd()
569        def _breaks(*args, **kw):
570            raise RuntimeError()
571
572        register_archive_format('xxx', _breaks, [], 'xxx file')
573        try:
574            try:
575                make_archive('xxx', 'xxx', root_dir=self.mkdtemp())
576            except Exception:
577                pass
578            self.assertEqual(os.getcwd(), current_dir)
579        finally:
580            unregister_archive_format('xxx')
581
582    def test_register_archive_format(self):
583
584        self.assertRaises(TypeError, register_archive_format, 'xxx', 1)
585        self.assertRaises(TypeError, register_archive_format, 'xxx', lambda: x,
586                          1)
587        self.assertRaises(TypeError, register_archive_format, 'xxx', lambda: x,
588                          [(1, 2), (1, 2, 3)])
589
590        register_archive_format('xxx', lambda: x, [(1, 2)], 'xxx file')
591        formats = [name for name, params in get_archive_formats()]
592        self.assertIn('xxx', formats)
593
594        unregister_archive_format('xxx')
595        formats = [name for name, params in get_archive_formats()]
596        self.assertNotIn('xxx', formats)
597
598
599class TestMove(unittest.TestCase):
600
601    def setUp(self):
602        filename = "foo"
603        self.src_dir = tempfile.mkdtemp()
604        self.dst_dir = tempfile.mkdtemp()
605        self.src_file = os.path.join(self.src_dir, filename)
606        self.dst_file = os.path.join(self.dst_dir, filename)
607        # Try to create a dir in the current directory, hoping that it is
608        # not located on the same filesystem as the system tmp dir.
609        try:
610            self.dir_other_fs = tempfile.mkdtemp(
611                dir=os.path.dirname(__file__))
612            self.file_other_fs = os.path.join(self.dir_other_fs,
613                filename)
614        except OSError:
615            self.dir_other_fs = None
616        with open(self.src_file, "wb") as f:
617            f.write("spam")
618
619    def tearDown(self):
620        for d in (self.src_dir, self.dst_dir, self.dir_other_fs):
621            try:
622                if d:
623                    shutil.rmtree(d)
624            except:
625                pass
626
627    def _check_move_file(self, src, dst, real_dst):
628        with open(src, "rb") as f:
629            contents = f.read()
630        shutil.move(src, dst)
631        with open(real_dst, "rb") as f:
632            self.assertEqual(contents, f.read())
633        self.assertFalse(os.path.exists(src))
634
635    def _check_move_dir(self, src, dst, real_dst):
636        contents = sorted(os.listdir(src))
637        shutil.move(src, dst)
638        self.assertEqual(contents, sorted(os.listdir(real_dst)))
639        self.assertFalse(os.path.exists(src))
640
641    def test_move_file(self):
642        # Move a file to another location on the same filesystem.
643        self._check_move_file(self.src_file, self.dst_file, self.dst_file)
644
645    def test_move_file_to_dir(self):
646        # Move a file inside an existing dir on the same filesystem.
647        self._check_move_file(self.src_file, self.dst_dir, self.dst_file)
648
649    def test_move_file_other_fs(self):
650        # Move a file to an existing dir on another filesystem.
651        if not self.dir_other_fs:
652            # skip
653            return
654        self._check_move_file(self.src_file, self.file_other_fs,
655            self.file_other_fs)
656
657    def test_move_file_to_dir_other_fs(self):
658        # Move a file to another location on another filesystem.
659        if not self.dir_other_fs:
660            # skip
661            return
662        self._check_move_file(self.src_file, self.dir_other_fs,
663            self.file_other_fs)
664
665    def test_move_dir(self):
666        # Move a dir to another location on the same filesystem.
667        dst_dir = tempfile.mktemp()
668        try:
669            self._check_move_dir(self.src_dir, dst_dir, dst_dir)
670        finally:
671            try:
672                shutil.rmtree(dst_dir)
673            except:
674                pass
675
676    def test_move_dir_other_fs(self):
677        # Move a dir to another location on another filesystem.
678        if not self.dir_other_fs:
679            # skip
680            return
681        dst_dir = tempfile.mktemp(dir=self.dir_other_fs)
682        try:
683            self._check_move_dir(self.src_dir, dst_dir, dst_dir)
684        finally:
685            try:
686                shutil.rmtree(dst_dir)
687            except:
688                pass
689
690    def test_move_dir_to_dir(self):
691        # Move a dir inside an existing dir on the same filesystem.
692        self._check_move_dir(self.src_dir, self.dst_dir,
693            os.path.join(self.dst_dir, os.path.basename(self.src_dir)))
694
695    def test_move_dir_to_dir_other_fs(self):
696        # Move a dir inside an existing dir on another filesystem.
697        if not self.dir_other_fs:
698            # skip
699            return
700        self._check_move_dir(self.src_dir, self.dir_other_fs,
701            os.path.join(self.dir_other_fs, os.path.basename(self.src_dir)))
702
703    def test_existing_file_inside_dest_dir(self):
704        # A file with the same name inside the destination dir already exists.
705        with open(self.dst_file, "wb"):
706            pass
707        self.assertRaises(shutil.Error, shutil.move, self.src_file, self.dst_dir)
708
709    def test_dont_move_dir_in_itself(self):
710        # Moving a dir inside itself raises an Error.
711        dst = os.path.join(self.src_dir, "bar")
712        self.assertRaises(shutil.Error, shutil.move, self.src_dir, dst)
713
714    def test_destinsrc_false_negative(self):
715        os.mkdir(TESTFN)
716        try:
717            for src, dst in [('srcdir', 'srcdir/dest')]:
718                src = os.path.join(TESTFN, src)
719                dst = os.path.join(TESTFN, dst)
720                self.assertTrue(shutil._destinsrc(src, dst),
721                             msg='_destinsrc() wrongly concluded that '
722                             'dst (%s) is not in src (%s)' % (dst, src))
723        finally:
724            shutil.rmtree(TESTFN, ignore_errors=True)
725
726    def test_destinsrc_false_positive(self):
727        os.mkdir(TESTFN)
728        try:
729            for src, dst in [('srcdir', 'src/dest'), ('srcdir', 'srcdir.new')]:
730                src = os.path.join(TESTFN, src)
731                dst = os.path.join(TESTFN, dst)
732                self.assertFalse(shutil._destinsrc(src, dst),
733                            msg='_destinsrc() wrongly concluded that '
734                            'dst (%s) is in src (%s)' % (dst, src))
735        finally:
736            shutil.rmtree(TESTFN, ignore_errors=True)
737
738
739class TestCopyFile(unittest.TestCase):
740
741    _delete = False
742
743    class Faux(object):
744        _entered = False
745        _exited_with = None
746        _raised = False
747        def __init__(self, raise_in_exit=False, suppress_at_exit=True):
748            self._raise_in_exit = raise_in_exit
749            self._suppress_at_exit = suppress_at_exit
750        def read(self, *args):
751            return ''
752        def __enter__(self):
753            self._entered = True
754        def __exit__(self, exc_type, exc_val, exc_tb):
755            self._exited_with = exc_type, exc_val, exc_tb
756            if self._raise_in_exit:
757                self._raised = True
758                raise IOError("Cannot close")
759            return self._suppress_at_exit
760
761    def tearDown(self):
762        if self._delete:
763            del shutil.open
764
765    def _set_shutil_open(self, func):
766        shutil.open = func
767        self._delete = True
768
769    def test_w_source_open_fails(self):
770        def _open(filename, mode='r'):
771            if filename == 'srcfile':
772                raise IOError('Cannot open "srcfile"')
773            assert 0  # shouldn't reach here.
774
775        self._set_shutil_open(_open)
776
777        self.assertRaises(IOError, shutil.copyfile, 'srcfile', 'destfile')
778
779    def test_w_dest_open_fails(self):
780
781        srcfile = self.Faux()
782
783        def _open(filename, mode='r'):
784            if filename == 'srcfile':
785                return srcfile
786            if filename == 'destfile':
787                raise IOError('Cannot open "destfile"')
788            assert 0  # shouldn't reach here.
789
790        self._set_shutil_open(_open)
791
792        shutil.copyfile('srcfile', 'destfile')
793        self.assertTrue(srcfile._entered)
794        self.assertTrue(srcfile._exited_with[0] is IOError)
795        self.assertEqual(srcfile._exited_with[1].args,
796                         ('Cannot open "destfile"',))
797
798    def test_w_dest_close_fails(self):
799
800        srcfile = self.Faux()
801        destfile = self.Faux(True)
802
803        def _open(filename, mode='r'):
804            if filename == 'srcfile':
805                return srcfile
806            if filename == 'destfile':
807                return destfile
808            assert 0  # shouldn't reach here.
809
810        self._set_shutil_open(_open)
811
812        shutil.copyfile('srcfile', 'destfile')
813        self.assertTrue(srcfile._entered)
814        self.assertTrue(destfile._entered)
815        self.assertTrue(destfile._raised)
816        self.assertTrue(srcfile._exited_with[0] is IOError)
817        self.assertEqual(srcfile._exited_with[1].args,
818                         ('Cannot close',))
819
820    def test_w_source_close_fails(self):
821
822        srcfile = self.Faux(True)
823        destfile = self.Faux()
824
825        def _open(filename, mode='r'):
826            if filename == 'srcfile':
827                return srcfile
828            if filename == 'destfile':
829                return destfile
830            assert 0  # shouldn't reach here.
831
832        self._set_shutil_open(_open)
833
834        self.assertRaises(IOError,
835                          shutil.copyfile, 'srcfile', 'destfile')
836        self.assertTrue(srcfile._entered)
837        self.assertTrue(destfile._entered)
838        self.assertFalse(destfile._raised)
839        self.assertTrue(srcfile._exited_with[0] is None)
840        self.assertTrue(srcfile._raised)
841
842    def test_move_dir_caseinsensitive(self):
843        # Renames a folder to the same name
844        # but a different case.
845
846        self.src_dir = tempfile.mkdtemp()
847        dst_dir = os.path.join(
848                os.path.dirname(self.src_dir),
849                os.path.basename(self.src_dir).upper())
850        self.assertNotEqual(self.src_dir, dst_dir)
851
852        try:
853            shutil.move(self.src_dir, dst_dir)
854            self.assertTrue(os.path.isdir(dst_dir))
855        finally:
856            if os.path.exists(dst_dir):
857                os.rmdir(dst_dir)
858
859
860
861def test_main():
862    test_support.run_unittest(TestShutil, TestMove, TestCopyFile)
863
864if __name__ == '__main__':
865    test_main()
866