1import imghdr
2import io
3import os
4import pathlib
5import unittest
6import warnings
7from test.support import findfile, TESTFN, unlink
8
9TEST_FILES = (
10    ('python.png', 'png'),
11    ('python.gif', 'gif'),
12    ('python.bmp', 'bmp'),
13    ('python.ppm', 'ppm'),
14    ('python.pgm', 'pgm'),
15    ('python.pbm', 'pbm'),
16    ('python.jpg', 'jpeg'),
17    ('python.ras', 'rast'),
18    ('python.sgi', 'rgb'),
19    ('python.tiff', 'tiff'),
20    ('python.xbm', 'xbm'),
21    ('python.webp', 'webp'),
22    ('python.exr', 'exr'),
23)
24
25class UnseekableIO(io.FileIO):
26    def tell(self):
27        raise io.UnsupportedOperation
28
29    def seek(self, *args, **kwargs):
30        raise io.UnsupportedOperation
31
32class TestImghdr(unittest.TestCase):
33    @classmethod
34    def setUpClass(cls):
35        cls.testfile = findfile('python.png', subdir='imghdrdata')
36        with open(cls.testfile, 'rb') as stream:
37            cls.testdata = stream.read()
38
39    def tearDown(self):
40        unlink(TESTFN)
41
42    def test_data(self):
43        for filename, expected in TEST_FILES:
44            filename = findfile(filename, subdir='imghdrdata')
45            self.assertEqual(imghdr.what(filename), expected)
46            with open(filename, 'rb') as stream:
47                self.assertEqual(imghdr.what(stream), expected)
48            with open(filename, 'rb') as stream:
49                data = stream.read()
50            self.assertEqual(imghdr.what(None, data), expected)
51            self.assertEqual(imghdr.what(None, bytearray(data)), expected)
52
53    def test_pathlike_filename(self):
54        for filename, expected in TEST_FILES:
55            with self.subTest(filename=filename):
56                filename = findfile(filename, subdir='imghdrdata')
57                self.assertEqual(imghdr.what(pathlib.Path(filename)), expected)
58
59    def test_register_test(self):
60        def test_jumbo(h, file):
61            if h.startswith(b'eggs'):
62                return 'ham'
63        imghdr.tests.append(test_jumbo)
64        self.addCleanup(imghdr.tests.pop)
65        self.assertEqual(imghdr.what(None, b'eggs'), 'ham')
66
67    def test_file_pos(self):
68        with open(TESTFN, 'wb') as stream:
69            stream.write(b'ababagalamaga')
70            pos = stream.tell()
71            stream.write(self.testdata)
72        with open(TESTFN, 'rb') as stream:
73            stream.seek(pos)
74            self.assertEqual(imghdr.what(stream), 'png')
75            self.assertEqual(stream.tell(), pos)
76
77    def test_bad_args(self):
78        with self.assertRaises(TypeError):
79            imghdr.what()
80        with self.assertRaises(AttributeError):
81            imghdr.what(None)
82        with self.assertRaises(TypeError):
83            imghdr.what(self.testfile, 1)
84        with self.assertRaises(AttributeError):
85            imghdr.what(os.fsencode(self.testfile))
86        with open(self.testfile, 'rb') as f:
87            with self.assertRaises(AttributeError):
88                imghdr.what(f.fileno())
89
90    def test_invalid_headers(self):
91        for header in (b'\211PN\r\n',
92                       b'\001\331',
93                       b'\x59\xA6',
94                       b'cutecat',
95                       b'000000JFI',
96                       b'GIF80'):
97            self.assertIsNone(imghdr.what(None, header))
98
99    def test_string_data(self):
100        with warnings.catch_warnings():
101            warnings.simplefilter("ignore", BytesWarning)
102            for filename, _ in TEST_FILES:
103                filename = findfile(filename, subdir='imghdrdata')
104                with open(filename, 'rb') as stream:
105                    data = stream.read().decode('latin1')
106                with self.assertRaises(TypeError):
107                    imghdr.what(io.StringIO(data))
108                with self.assertRaises(TypeError):
109                    imghdr.what(None, data)
110
111    def test_missing_file(self):
112        with self.assertRaises(FileNotFoundError):
113            imghdr.what('missing')
114
115    def test_closed_file(self):
116        stream = open(self.testfile, 'rb')
117        stream.close()
118        with self.assertRaises(ValueError) as cm:
119            imghdr.what(stream)
120        stream = io.BytesIO(self.testdata)
121        stream.close()
122        with self.assertRaises(ValueError) as cm:
123            imghdr.what(stream)
124
125    def test_unseekable(self):
126        with open(TESTFN, 'wb') as stream:
127            stream.write(self.testdata)
128        with UnseekableIO(TESTFN, 'rb') as stream:
129            with self.assertRaises(io.UnsupportedOperation):
130                imghdr.what(stream)
131
132    def test_output_stream(self):
133        with open(TESTFN, 'wb') as stream:
134            stream.write(self.testdata)
135            stream.seek(0)
136            with self.assertRaises(OSError) as cm:
137                imghdr.what(stream)
138
139if __name__ == '__main__':
140    unittest.main()
141